1*3d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*3d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 39b2a10adSJeremy L Thompson // 4*3d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 59b2a10adSJeremy L Thompson // 6*3d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 79b2a10adSJeremy L Thompson 89b2a10adSJeremy L Thompson #include <ceed/ceed.h> 99b2a10adSJeremy L Thompson #include <ceed/backend.h> 109b2a10adSJeremy L Thompson #include "ceed-opt.h" 119b2a10adSJeremy L Thompson 129b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 139b2a10adSJeremy L Thompson // Tensor Contract Core loop 149b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 159b2a10adSJeremy L Thompson static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract, 169b2a10adSJeremy L Thompson CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 179b2a10adSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, 189b2a10adSJeremy L Thompson CeedScalar *restrict v) { 199b2a10adSJeremy L Thompson CeedInt t_stride_0 = B, t_stride_1 = 1; 209b2a10adSJeremy L Thompson if (t_mode == CEED_TRANSPOSE) { 219b2a10adSJeremy L Thompson t_stride_0 = 1; t_stride_1 = J; 229b2a10adSJeremy L Thompson } 239b2a10adSJeremy L Thompson 249b2a10adSJeremy L Thompson for (CeedInt a=0; a<A; a++) 259b2a10adSJeremy L Thompson for (CeedInt b=0; b<B; b++) 269b2a10adSJeremy L Thompson for (CeedInt j=0; j<J; j++) { 279b2a10adSJeremy L Thompson CeedScalar tq = t[j*t_stride_0 + b*t_stride_1]; 289b2a10adSJeremy L Thompson for (CeedInt c=0; c<C; c++) 299b2a10adSJeremy L Thompson v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 309b2a10adSJeremy L Thompson } 319b2a10adSJeremy L Thompson 329b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS; 339b2a10adSJeremy L Thompson } 349b2a10adSJeremy L Thompson 359b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 369b2a10adSJeremy L Thompson // Tensor Contract Apply 379b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 389b2a10adSJeremy L Thompson static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A, 399b2a10adSJeremy L Thompson CeedInt B, CeedInt C, CeedInt J, 409b2a10adSJeremy L Thompson const CeedScalar *restrict t, 419b2a10adSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, 429b2a10adSJeremy L Thompson const CeedScalar *restrict u, 439b2a10adSJeremy L Thompson CeedScalar *restrict v) { 449b2a10adSJeremy L Thompson if (!add) 459b2a10adSJeremy L Thompson for (CeedInt q=0; q<A*J*C; q++) 469b2a10adSJeremy L Thompson v[q] = (CeedScalar) 0.0; 479b2a10adSJeremy L Thompson 489b2a10adSJeremy L Thompson if (C == 1) 499b2a10adSJeremy L Thompson return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode, 509b2a10adSJeremy L Thompson add, u, v); 519b2a10adSJeremy L Thompson else 529b2a10adSJeremy L Thompson return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode, 539b2a10adSJeremy L Thompson add, u, v); 549b2a10adSJeremy L Thompson 559b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS; 569b2a10adSJeremy L Thompson } 579b2a10adSJeremy L Thompson 589b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 599b2a10adSJeremy L Thompson // Tensor Contract Create 609b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 619b2a10adSJeremy L Thompson int CeedTensorContractCreate_Opt(CeedBasis basis, CeedTensorContract contract) { 629b2a10adSJeremy L Thompson int ierr; 639b2a10adSJeremy L Thompson Ceed ceed; 649b2a10adSJeremy L Thompson ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr); 659b2a10adSJeremy L Thompson 669b2a10adSJeremy L Thompson ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 679b2a10adSJeremy L Thompson CeedTensorContractApply_Opt); CeedChkBackend(ierr); 689b2a10adSJeremy L Thompson 699b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS; 709b2a10adSJeremy L Thompson } 719b2a10adSJeremy L Thompson //------------------------------------------------------------------------------ 72