1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed.h> 10 #include <ceed/backend.h> 11 #include <stddef.h> 12 13 /// @file 14 /// Implementation of CeedTensorContract interfaces 15 16 /// ---------------------------------------------------------------------------- 17 /// CeedTensorContract Backend API 18 /// ---------------------------------------------------------------------------- 19 /// @addtogroup CeedBasisBackend 20 /// @{ 21 22 /** 23 @brief Create a CeedTensorContract object for a CeedBasis 24 25 @param[in] ceed Ceed object where the CeedTensorContract will be created 26 @param[in] basis CeedBasis for which the tensor contraction will be used 27 @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored. 28 29 @return An error code: 0 - success, otherwise - failure 30 31 @ref Backend 32 **/ 33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) { 34 if (!ceed->TensorContractCreate) { 35 Ceed delegate; 36 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37 38 if (!delegate) { 39 // LCOV_EXCL_START 40 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate"); 41 // LCOV_EXCL_STOP 42 } 43 44 CeedCall(CeedTensorContractCreate(delegate, basis, contract)); 45 return CEED_ERROR_SUCCESS; 46 } 47 48 CeedCall(CeedCalloc(1, contract)); 49 50 (*contract)->ceed = ceed; 51 CeedCall(CeedReference(ceed)); 52 CeedCall(ceed->TensorContractCreate(basis, *contract)); 53 return CEED_ERROR_SUCCESS; 54 } 55 56 /** 57 @brief Apply tensor contraction 58 59 Contracts on the middle index 60 NOTRANSPOSE: v_ajc = t_jb u_abc 61 TRANSPOSE: v_ajc = t_bj u_abc 62 If add != 0, "=" is replaced by "+=" 63 64 @param[in] contract CeedTensorContract to use 65 @param[in] A First index of u, v 66 @param[in] B Middle index of u, one index of t 67 @param[in] C Last index of u, v 68 @param[in] J Middle index of v, one index of t 69 @param[in] t Tensor array to contract against 70 @param[in] t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj 71 @param[in] add Add mode 72 @param[in] u Input array 73 @param[out] v Output array 74 75 @return An error code: 0 - success, otherwise - failure 76 77 @ref Backend 78 **/ 79 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 80 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 81 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 82 return CEED_ERROR_SUCCESS; 83 } 84 85 /** 86 @brief Apply tensor contraction 87 88 Contracts on the middle index 89 NOTRANSPOSE: v_dajc = t_djb u_abc 90 TRANSPOSE: v_ajc = t_dbj u_dabc 91 If add != 0, "=" is replaced by "+=" 92 93 @param[in] contract CeedTensorContract to use 94 @param[in] A First index of u, second index of v 95 @param[in] B Middle index of u, one of last two indices of t 96 @param[in] C Last index of u, v 97 @param[in] D First index of v, first index of t 98 @param[in] J Third index of v, one of last two indices of t 99 @param[in] t Tensor array to contract against 100 @param[in] t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj 101 @param[in] add Add mode 102 @param[in] u Input array 103 @param[out] v Output array 104 105 @return An error code: 0 - success, otherwise - failure 106 107 @ref Backend 108 **/ 109 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 110 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 111 if (t_mode == CEED_TRANSPOSE) { 112 for (CeedInt d = 0; d < D; d++) { 113 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 114 } 115 } else { 116 for (CeedInt d = 0; d < D; d++) { 117 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 118 } 119 } 120 return CEED_ERROR_SUCCESS; 121 } 122 123 /** 124 @brief Get Ceed associated with a CeedTensorContract 125 126 @param[in] contract CeedTensorContract 127 @param[out] ceed Variable to store Ceed 128 129 @return An error code: 0 - success, otherwise - failure 130 131 @ref Backend 132 **/ 133 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 134 *ceed = contract->ceed; 135 return CEED_ERROR_SUCCESS; 136 } 137 138 /** 139 @brief Get backend data of a CeedTensorContract 140 141 @param[in] contract CeedTensorContract 142 @param[out] data Variable to store data 143 144 @return An error code: 0 - success, otherwise - failure 145 146 @ref Backend 147 **/ 148 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 149 *(void **)data = contract->data; 150 return CEED_ERROR_SUCCESS; 151 } 152 153 /** 154 @brief Set backend data of a CeedTensorContract 155 156 @param[in,out] contract CeedTensorContract 157 @param[in] data Data to set 158 159 @return An error code: 0 - success, otherwise - failure 160 161 @ref Backend 162 **/ 163 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 164 contract->data = data; 165 return CEED_ERROR_SUCCESS; 166 } 167 168 /** 169 @brief Increment the reference counter for a CeedTensorContract 170 171 @param[in,out] contract CeedTensorContract to increment the reference counter 172 173 @return An error code: 0 - success, otherwise - failure 174 175 @ref Backend 176 **/ 177 int CeedTensorContractReference(CeedTensorContract contract) { 178 contract->ref_count++; 179 return CEED_ERROR_SUCCESS; 180 } 181 182 /** 183 @brief Destroy a CeedTensorContract 184 185 @param[in,out] contract CeedTensorContract to destroy 186 187 @return An error code: 0 - success, otherwise - failure 188 189 @ref Backend 190 **/ 191 int CeedTensorContractDestroy(CeedTensorContract *contract) { 192 if (!*contract || --(*contract)->ref_count > 0) { 193 *contract = NULL; 194 return CEED_ERROR_SUCCESS; 195 } 196 if ((*contract)->Destroy) { 197 CeedCall((*contract)->Destroy(*contract)); 198 } 199 CeedCall(CeedDestroy(&(*contract)->ceed)); 200 CeedCall(CeedFree(contract)); 201 return CEED_ERROR_SUCCESS; 202 } 203 204 /// @} 205