1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include "ceed-ref.h" 11 12 //------------------------------------------------------------------------------ 13 // Tensor Contract Apply 14 //------------------------------------------------------------------------------ 15 static int CeedTensorContractApply_Ref(CeedTensorContract contract, CeedInt A, 16 CeedInt B, CeedInt C, CeedInt J, 17 const CeedScalar *restrict t, 18 CeedTransposeMode t_mode, const CeedInt add, 19 const CeedScalar *restrict u, 20 CeedScalar *restrict v) { 21 CeedInt t_stride_0 = B, t_stride_1 = 1; 22 if (t_mode == CEED_TRANSPOSE) { 23 t_stride_0 = 1; t_stride_1 = J; 24 } 25 26 if (!add) 27 for (CeedInt q=0; q<A*J*C; q++) 28 v[q] = (CeedScalar) 0.0; 29 30 for (CeedInt a=0; a<A; a++) 31 for (CeedInt b=0; b<B; b++) 32 for (CeedInt j=0; j<J; j++) { 33 CeedScalar tq = t[j*t_stride_0 + b*t_stride_1]; 34 for (CeedInt c=0; c<C; c++) 35 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 36 } 37 return CEED_ERROR_SUCCESS; 38 } 39 40 //------------------------------------------------------------------------------ 41 // Tensor Contract Destroy 42 //------------------------------------------------------------------------------ 43 static int CeedTensorContractDestroy_Ref(CeedTensorContract contract) { 44 return CEED_ERROR_SUCCESS; 45 } 46 47 //------------------------------------------------------------------------------ 48 // Tensor Contract Create 49 //------------------------------------------------------------------------------ 50 int CeedTensorContractCreate_Ref(CeedBasis basis, CeedTensorContract contract) { 51 int ierr; 52 Ceed ceed; 53 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr); 54 55 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 56 CeedTensorContractApply_Ref); CeedChkBackend(ierr); 57 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 58 CeedTensorContractDestroy_Ref); CeedChkBackend(ierr); 59 60 return CEED_ERROR_SUCCESS; 61 } 62 //------------------------------------------------------------------------------ 63