1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed-impl.h> 11 12 /// @file 13 /// Implementation of CeedTensorContract interfaces 14 15 /// ---------------------------------------------------------------------------- 16 /// CeedTensorContract Backend API 17 /// ---------------------------------------------------------------------------- 18 /// @addtogroup CeedBasisBackend 19 /// @{ 20 21 /** 22 @brief Create a CeedTensorContract object for a CeedBasis 23 24 @param ceed A Ceed object where the CeedTensorContract will be created 25 @param basis CeedBasis for which the tensor contraction will be used 26 @param[out] contract Address of the variable where the newly created 27 CeedTensorContract will be stored. 28 29 @return An error code: 0 - success, otherwise - failure 30 31 @ref Backend 32 **/ 33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, 34 CeedTensorContract *contract) { 35 int ierr; 36 37 if (!ceed->TensorContractCreate) { 38 Ceed delegate; 39 ierr = CeedGetObjectDelegate(ceed, &delegate, "TensorContract"); 40 CeedChk(ierr); 41 42 if (!delegate) 43 // LCOV_EXCL_START 44 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, 45 "Backend does not support TensorContractCreate"); 46 // LCOV_EXCL_STOP 47 48 ierr = CeedTensorContractCreate(delegate, basis, contract); 49 CeedChk(ierr); 50 return CEED_ERROR_SUCCESS; 51 } 52 53 ierr = CeedCalloc(1, contract); CeedChk(ierr); 54 55 (*contract)->ceed = ceed; 56 ierr = CeedReference(ceed); CeedChk(ierr); 57 ierr = ceed->TensorContractCreate(basis, *contract); 58 CeedChk(ierr); 59 return CEED_ERROR_SUCCESS; 60 } 61 62 /** 63 @brief Apply tensor contraction 64 65 Contracts on the middle index 66 NOTRANSPOSE: v_ajc = t_jb u_abc 67 TRANSPOSE: v_ajc = t_bj u_abc 68 If add != 0, "=" is replaced by "+=" 69 70 @param contract CeedTensorContract to use 71 @param A First index of u, v 72 @param B Middle index of u, one index of t 73 @param C Last index of u, v 74 @param J Middle index of v, one index of t 75 @param[in] t Tensor array to contract against 76 @param t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb 77 \ref CEED_TRANSPOSE for t_bj 78 @param add Add mode 79 @param[in] u Input array 80 @param[out] v Output array 81 82 @return An error code: 0 - success, otherwise - failure 83 84 @ref Backend 85 **/ 86 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, 87 CeedInt C, CeedInt J, const CeedScalar *restrict t, 88 CeedTransposeMode t_mode, const CeedInt add, 89 const CeedScalar *restrict u, 90 CeedScalar *restrict v) { 91 int ierr; 92 93 ierr = contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v); 94 CeedChk(ierr); 95 return CEED_ERROR_SUCCESS; 96 } 97 98 /** 99 @brief Get Ceed associated with a CeedTensorContract 100 101 @param contract CeedTensorContract 102 @param[out] ceed Variable to store Ceed 103 104 @return An error code: 0 - success, otherwise - failure 105 106 @ref Backend 107 **/ 108 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 109 *ceed = contract->ceed; 110 return CEED_ERROR_SUCCESS; 111 } 112 113 /** 114 @brief Get backend data of a CeedTensorContract 115 116 @param contract CeedTensorContract 117 @param[out] data Variable to store data 118 119 @return An error code: 0 - success, otherwise - failure 120 121 @ref Backend 122 **/ 123 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 124 *(void **)data = contract->data; 125 return CEED_ERROR_SUCCESS; 126 } 127 128 /** 129 @brief Set backend data of a CeedTensorContract 130 131 @param[out] contract CeedTensorContract 132 @param data Data to set 133 134 @return An error code: 0 - success, otherwise - failure 135 136 @ref Backend 137 **/ 138 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 139 contract->data = data; 140 return CEED_ERROR_SUCCESS; 141 } 142 143 /** 144 @brief Increment the reference counter for a CeedTensorContract 145 146 @param contract CeedTensorContract to increment the reference counter 147 148 @return An error code: 0 - success, otherwise - failure 149 150 @ref Backend 151 **/ 152 int CeedTensorContractReference(CeedTensorContract contract) { 153 contract->ref_count++; 154 return CEED_ERROR_SUCCESS; 155 } 156 157 /** 158 @brief Destroy a CeedTensorContract 159 160 @param contract CeedTensorContract to destroy 161 162 @return An error code: 0 - success, otherwise - failure 163 164 @ref Backend 165 **/ 166 int CeedTensorContractDestroy(CeedTensorContract *contract) { 167 int ierr; 168 169 if (!*contract || --(*contract)->ref_count > 0) return CEED_ERROR_SUCCESS; 170 if ((*contract)->Destroy) { 171 ierr = (*contract)->Destroy(*contract); CeedChk(ierr); 172 } 173 ierr = CeedDestroy(&(*contract)->ceed); CeedChk(ierr); 174 ierr = CeedFree(contract); CeedChk(ierr); 175 return CEED_ERROR_SUCCESS; 176 } 177 178 /// @} 179