1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed.h> 10 #include <ceed/backend.h> 11 #include <stddef.h> 12 13 /// @file 14 /// Implementation of CeedTensorContract interfaces 15 16 /// ---------------------------------------------------------------------------- 17 /// CeedTensorContract Backend API 18 /// ---------------------------------------------------------------------------- 19 /// @addtogroup CeedBasisBackend 20 /// @{ 21 22 /** 23 @brief Create a `CeedTensorContract` object for a `CeedBasis` 24 25 @param[in] ceed `Ceed` object used to create the `CeedTensorContract` 26 @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored. 27 28 @return An error code: 0 - success, otherwise - failure 29 30 @ref Backend 31 **/ 32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) { 33 if (!ceed->TensorContractCreate) { 34 Ceed delegate; 35 36 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37 CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate"); 38 CeedCall(CeedTensorContractCreate(delegate, contract)); 39 CeedCall(CeedDestroy(&delegate)); 40 return CEED_ERROR_SUCCESS; 41 } 42 43 CeedCall(CeedCalloc(1, contract)); 44 CeedCall(CeedObjectCreate(ceed, NULL, &(*contract)->obj)); 45 CeedCall(ceed->TensorContractCreate(*contract)); 46 return CEED_ERROR_SUCCESS; 47 } 48 49 /** 50 @brief Apply tensor contraction 51 52 Contracts on the middle index 53 NOTRANSPOSE: `v_ajc = t_jb u_abc` 54 TRANSPOSE: `v_ajc = t_bj u_abc` 55 If `add != 0`, `=` is replaced by `+=` 56 57 @param[in] contract `CeedTensorContract` to use 58 @param[in] A First index of `u`, `v` 59 @param[in] B Middle index of `u`, one index of `t` 60 @param[in] C Last index of `u`, `v` 61 @param[in] J Middle index of `v`, one index of `t` 62 @param[in] t Tensor array to contract against 63 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj` 64 @param[in] add Add mode 65 @param[in] u Input array 66 @param[out] v Output array 67 68 @return An error code: 0 - success, otherwise - failure 69 70 @ref Backend 71 **/ 72 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 73 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 74 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 75 return CEED_ERROR_SUCCESS; 76 } 77 78 /** 79 @brief Apply tensor contraction 80 81 Contracts on the middle index 82 NOTRANSPOSE: `v_dajc = t_djb u_abc` 83 TRANSPOSE: `v_ajc = t_dbj u_dabc` 84 If `add != 0`, `=` is replaced by `+=` 85 86 @param[in] contract `CeedTensorContract` to use 87 @param[in] A First index of `u`, second index of `v` 88 @param[in] B Middle index of `u`, one of last two indices of `t` 89 @param[in] C Last index of `u`, `v` 90 @param[in] D First index of `v`, first index of `t` 91 @param[in] J Third index of `v`, one of last two indices of `t` 92 @param[in] t Tensor array to contract against 93 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj` 94 @param[in] add Add mode 95 @param[in] u Input array 96 @param[out] v Output array 97 98 @return An error code: 0 - success, otherwise - failure 99 100 @ref Backend 101 **/ 102 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 103 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 104 if (t_mode == CEED_TRANSPOSE) { 105 for (CeedInt d = 0; d < D; d++) { 106 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 107 } 108 } else { 109 for (CeedInt d = 0; d < D; d++) { 110 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 111 } 112 } 113 return CEED_ERROR_SUCCESS; 114 } 115 116 /** 117 @brief Get the `Ceed` associated with a `CeedTensorContract` 118 119 @param[in] contract `CeedTensorContract` 120 @param[out] ceed Variable to store `Ceed` 121 122 @return An error code: 0 - success, otherwise - failure 123 124 @ref Backend 125 **/ 126 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 127 CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed)); 128 return CEED_ERROR_SUCCESS; 129 } 130 131 /** 132 @brief Return the `Ceed` associated with a `CeedTensorContract` 133 134 @param[in] contract `CeedTensorContract` 135 136 @return `Ceed` associated with `contract` 137 138 @ref Backend 139 **/ 140 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); } 141 142 /** 143 @brief Get backend data of a `CeedTensorContract` 144 145 @param[in] contract `CeedTensorContract` 146 @param[out] data Variable to store data 147 148 @return An error code: 0 - success, otherwise - failure 149 150 @ref Backend 151 **/ 152 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 153 *(void **)data = contract->data; 154 return CEED_ERROR_SUCCESS; 155 } 156 157 /** 158 @brief Set backend data of a `CeedTensorContract` 159 160 @param[in,out] contract `CeedTensorContract` 161 @param[in] data Data to set 162 163 @return An error code: 0 - success, otherwise - failure 164 165 @ref Backend 166 **/ 167 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 168 contract->data = data; 169 return CEED_ERROR_SUCCESS; 170 } 171 172 /** 173 @brief Increment the reference counter for a `CeedTensorContract` 174 175 @param[in,out] contract `CeedTensorContract` to increment the reference counter 176 177 @return An error code: 0 - success, otherwise - failure 178 179 @ref Backend 180 **/ 181 int CeedTensorContractReference(CeedTensorContract contract) { 182 CeedCall(CeedObjectReference((CeedObject)contract)); 183 return CEED_ERROR_SUCCESS; 184 } 185 186 /** 187 @brief Copy the pointer to a `CeedTensorContract`. 188 189 Both pointers should be destroyed with @ref CeedTensorContractDestroy(). 190 191 Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`. 192 This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`. 193 194 @param[in] tensor `CeedTensorContract` to copy reference to 195 @param[in,out] tensor_copy Variable to store copied reference 196 197 @return An error code: 0 - success, otherwise - failure 198 199 @ref User 200 **/ 201 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) { 202 CeedCall(CeedTensorContractReference(tensor)); 203 CeedCall(CeedTensorContractDestroy(tensor_copy)); 204 *tensor_copy = tensor; 205 return CEED_ERROR_SUCCESS; 206 } 207 208 /** 209 @brief Destroy a `CeedTensorContract` 210 211 @param[in,out] contract `CeedTensorContract` to destroy 212 213 @return An error code: 0 - success, otherwise - failure 214 215 @ref Backend 216 **/ 217 int CeedTensorContractDestroy(CeedTensorContract *contract) { 218 if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) { 219 *contract = NULL; 220 return CEED_ERROR_SUCCESS; 221 } 222 if ((*contract)->Destroy) { 223 CeedCall((*contract)->Destroy(*contract)); 224 } 225 CeedCall(CeedObjectDestroy(&(*contract)->obj)); 226 CeedCall(CeedFree(contract)); 227 return CEED_ERROR_SUCCESS; 228 } 229 230 /// @} 231