1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 11 #include "ceed-opt.h" 12 13 //------------------------------------------------------------------------------ 14 // Tensor Contract Core loop 15 //------------------------------------------------------------------------------ 16 static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 17 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 18 const CeedScalar *restrict u, CeedScalar *restrict v) { 19 CeedInt t_stride_0 = B, t_stride_1 = 1; 20 if (t_mode == CEED_TRANSPOSE) { 21 t_stride_0 = 1; 22 t_stride_1 = J; 23 } 24 25 for (CeedInt a = 0; a < A; a++) { 26 for (CeedInt b = 0; b < B; b++) { 27 for (CeedInt j = 0; j < J; j++) { 28 CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 29 for (CeedInt c = 0; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c]; 30 } 31 } 32 } 33 34 return CEED_ERROR_SUCCESS; 35 } 36 37 //------------------------------------------------------------------------------ 38 // Tensor Contract Apply 39 //------------------------------------------------------------------------------ 40 static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 41 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 42 if (!add) { 43 for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0; 44 } 45 46 if (C == 1) return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode, add, u, v); 47 else return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode, add, u, v); 48 49 return CEED_ERROR_SUCCESS; 50 } 51 52 //------------------------------------------------------------------------------ 53 // Tensor Contract Create 54 //------------------------------------------------------------------------------ 55 int CeedTensorContractCreate_Opt(CeedBasis basis, CeedTensorContract contract) { 56 Ceed ceed; 57 CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 58 59 CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Opt)); 60 61 return CEED_ERROR_SUCCESS; 62 } 63 //------------------------------------------------------------------------------ 64