1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include "ceed-opt.h" 11 12 //------------------------------------------------------------------------------ 13 // Tensor Contract Core loop 14 //------------------------------------------------------------------------------ 15 static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract, 16 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 17 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, 18 CeedScalar *restrict v) { 19 CeedInt t_stride_0 = B, t_stride_1 = 1; 20 if (t_mode == CEED_TRANSPOSE) { 21 t_stride_0 = 1; t_stride_1 = J; 22 } 23 24 for (CeedInt a=0; a<A; a++) 25 for (CeedInt b=0; b<B; b++) 26 for (CeedInt j=0; j<J; j++) { 27 CeedScalar tq = t[j*t_stride_0 + b*t_stride_1]; 28 for (CeedInt c=0; c<C; c++) 29 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 30 } 31 32 return CEED_ERROR_SUCCESS; 33 } 34 35 //------------------------------------------------------------------------------ 36 // Tensor Contract Apply 37 //------------------------------------------------------------------------------ 38 static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A, 39 CeedInt B, CeedInt C, CeedInt J, 40 const CeedScalar *restrict t, 41 CeedTransposeMode t_mode, const CeedInt add, 42 const CeedScalar *restrict u, 43 CeedScalar *restrict v) { 44 if (!add) 45 for (CeedInt q=0; q<A*J*C; q++) 46 v[q] = (CeedScalar) 0.0; 47 48 if (C == 1) 49 return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode, 50 add, u, v); 51 else 52 return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode, 53 add, u, v); 54 55 return CEED_ERROR_SUCCESS; 56 } 57 58 //------------------------------------------------------------------------------ 59 // Tensor Contract Create 60 //------------------------------------------------------------------------------ 61 int CeedTensorContractCreate_Opt(CeedBasis basis, CeedTensorContract contract) { 62 int ierr; 63 Ceed ceed; 64 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr); 65 66 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 67 CeedTensorContractApply_Opt); CeedChkBackend(ierr); 68 69 return CEED_ERROR_SUCCESS; 70 } 71 //------------------------------------------------------------------------------ 72