1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed.h> 10 #include <ceed/backend.h> 11 #include <stddef.h> 12 13 /// @file 14 /// Implementation of CeedTensorContract interfaces 15 16 /// ---------------------------------------------------------------------------- 17 /// CeedTensorContract Library Internal Functions 18 /// ---------------------------------------------------------------------------- 19 /// @addtogroup CeedTensorContractDeveloper 20 /// @{ 21 22 /** 23 @brief Destroy a `CeedTensorContract` passed as a `CeedObject` 24 25 @param[in,out] contract Address of `CeedTensorContract` to destroy 26 27 @return An error code: 0 - success, otherwise - failure 28 29 @ref Developer 30 **/ 31 static int CeedTensorContractDestroy_Object(CeedObject *contract) { 32 CeedCall(CeedTensorContractDestroy((CeedTensorContract *)contract)); 33 return CEED_ERROR_SUCCESS; 34 } 35 36 /// @} 37 38 /// ---------------------------------------------------------------------------- 39 /// CeedTensorContract Backend API 40 /// ---------------------------------------------------------------------------- 41 /// @addtogroup CeedBasisBackend 42 /// @{ 43 44 /** 45 @brief Create a `CeedTensorContract` object for a `CeedBasis` 46 47 @param[in] ceed `Ceed` object used to create the `CeedTensorContract` 48 @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored. 49 50 @return An error code: 0 - success, otherwise - failure 51 52 @ref Backend 53 **/ 54 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) { 55 if (!ceed->TensorContractCreate) { 56 Ceed delegate; 57 58 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 59 CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate"); 60 CeedCall(CeedTensorContractCreate(delegate, contract)); 61 CeedCall(CeedDestroy(&delegate)); 62 return CEED_ERROR_SUCCESS; 63 } 64 65 CeedCall(CeedCalloc(1, contract)); 66 CeedCall(CeedObjectCreate(ceed, NULL, CeedTensorContractDestroy_Object, &(*contract)->obj)); 67 CeedCall(ceed->TensorContractCreate(*contract)); 68 return CEED_ERROR_SUCCESS; 69 } 70 71 /** 72 @brief Apply tensor contraction 73 74 Contracts on the middle index 75 NOTRANSPOSE: `v_ajc = t_jb u_abc` 76 TRANSPOSE: `v_ajc = t_bj u_abc` 77 If `add != 0`, `=` is replaced by `+=` 78 79 @param[in] contract `CeedTensorContract` to use 80 @param[in] A First index of `u`, `v` 81 @param[in] B Middle index of `u`, one index of `t` 82 @param[in] C Last index of `u`, `v` 83 @param[in] J Middle index of `v`, one index of `t` 84 @param[in] t Tensor array to contract against 85 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj` 86 @param[in] add Add mode 87 @param[in] u Input array 88 @param[out] v Output array 89 90 @return An error code: 0 - success, otherwise - failure 91 92 @ref Backend 93 **/ 94 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 95 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 96 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 97 return CEED_ERROR_SUCCESS; 98 } 99 100 /** 101 @brief Apply tensor contraction 102 103 Contracts on the middle index 104 NOTRANSPOSE: `v_dajc = t_djb u_abc` 105 TRANSPOSE: `v_ajc = t_dbj u_dabc` 106 If `add != 0`, `=` is replaced by `+=` 107 108 @param[in] contract `CeedTensorContract` to use 109 @param[in] A First index of `u`, second index of `v` 110 @param[in] B Middle index of `u`, one of last two indices of `t` 111 @param[in] C Last index of `u`, `v` 112 @param[in] D First index of `v`, first index of `t` 113 @param[in] J Third index of `v`, one of last two indices of `t` 114 @param[in] t Tensor array to contract against 115 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj` 116 @param[in] add Add mode 117 @param[in] u Input array 118 @param[out] v Output array 119 120 @return An error code: 0 - success, otherwise - failure 121 122 @ref Backend 123 **/ 124 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 125 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 126 if (t_mode == CEED_TRANSPOSE) { 127 for (CeedInt d = 0; d < D; d++) { 128 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 129 } 130 } else { 131 for (CeedInt d = 0; d < D; d++) { 132 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 133 } 134 } 135 return CEED_ERROR_SUCCESS; 136 } 137 138 /** 139 @brief Get the `Ceed` associated with a `CeedTensorContract` 140 141 @param[in] contract `CeedTensorContract` 142 @param[out] ceed Variable to store `Ceed` 143 144 @return An error code: 0 - success, otherwise - failure 145 146 @ref Backend 147 **/ 148 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 149 CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed)); 150 return CEED_ERROR_SUCCESS; 151 } 152 153 /** 154 @brief Return the `Ceed` associated with a `CeedTensorContract` 155 156 @param[in] contract `CeedTensorContract` 157 158 @return `Ceed` associated with `contract` 159 160 @ref Backend 161 **/ 162 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); } 163 164 /** 165 @brief Get backend data of a `CeedTensorContract` 166 167 @param[in] contract `CeedTensorContract` 168 @param[out] data Variable to store data 169 170 @return An error code: 0 - success, otherwise - failure 171 172 @ref Backend 173 **/ 174 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 175 *(void **)data = contract->data; 176 return CEED_ERROR_SUCCESS; 177 } 178 179 /** 180 @brief Set backend data of a `CeedTensorContract` 181 182 @param[in,out] contract `CeedTensorContract` 183 @param[in] data Data to set 184 185 @return An error code: 0 - success, otherwise - failure 186 187 @ref Backend 188 **/ 189 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 190 contract->data = data; 191 return CEED_ERROR_SUCCESS; 192 } 193 194 /** 195 @brief Increment the reference counter for a `CeedTensorContract` 196 197 @param[in,out] contract `CeedTensorContract` to increment the reference counter 198 199 @return An error code: 0 - success, otherwise - failure 200 201 @ref Backend 202 **/ 203 int CeedTensorContractReference(CeedTensorContract contract) { 204 CeedCall(CeedObjectReference((CeedObject)contract)); 205 return CEED_ERROR_SUCCESS; 206 } 207 208 /** 209 @brief Copy the pointer to a `CeedTensorContract`. 210 211 Both pointers should be destroyed with @ref CeedTensorContractDestroy(). 212 213 Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`. 214 This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`. 215 216 @param[in] tensor `CeedTensorContract` to copy reference to 217 @param[in,out] tensor_copy Variable to store copied reference 218 219 @return An error code: 0 - success, otherwise - failure 220 221 @ref User 222 **/ 223 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) { 224 CeedCall(CeedTensorContractReference(tensor)); 225 CeedCall(CeedTensorContractDestroy(tensor_copy)); 226 *tensor_copy = tensor; 227 return CEED_ERROR_SUCCESS; 228 } 229 230 /** 231 @brief Destroy a `CeedTensorContract` 232 233 @param[in,out] contract `CeedTensorContract` to destroy 234 235 @return An error code: 0 - success, otherwise - failure 236 237 @ref Backend 238 **/ 239 int CeedTensorContractDestroy(CeedTensorContract *contract) { 240 if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) { 241 *contract = NULL; 242 return CEED_ERROR_SUCCESS; 243 } 244 if ((*contract)->Destroy) { 245 CeedCall((*contract)->Destroy(*contract)); 246 } 247 CeedCall(CeedObjectDestroy_Private(&(*contract)->obj)); 248 CeedCall(CeedFree(contract)); 249 return CEED_ERROR_SUCCESS; 250 } 251 252 /// @} 253