1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed.h> 10 #include <ceed/backend.h> 11 #include <stddef.h> 12 13 /// @file 14 /// Implementation of CeedTensorContract interfaces 15 16 /// ---------------------------------------------------------------------------- 17 /// CeedTensorContract Backend API 18 /// ---------------------------------------------------------------------------- 19 /// @addtogroup CeedBasisBackend 20 /// @{ 21 22 /** 23 @brief Create a `CeedTensorContract` object for a `CeedBasis` 24 25 @param[in] ceed `Ceed` object used to create the `CeedTensorContract` 26 @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored. 27 28 @return An error code: 0 - success, otherwise - failure 29 30 @ref Backend 31 **/ 32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) { 33 if (!ceed->TensorContractCreate) { 34 Ceed delegate; 35 36 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37 CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support CeedTensorContractCreate"); 38 CeedCall(CeedTensorContractCreate(delegate, contract)); 39 return CEED_ERROR_SUCCESS; 40 } 41 42 CeedCall(CeedCalloc(1, contract)); 43 CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed)); 44 CeedCall(ceed->TensorContractCreate(*contract)); 45 return CEED_ERROR_SUCCESS; 46 } 47 48 /** 49 @brief Apply tensor contraction 50 51 Contracts on the middle index 52 NOTRANSPOSE: `v_ajc = t_jb u_abc` 53 TRANSPOSE: `v_ajc = t_bj u_abc` 54 If `add != 0`, `=` is replaced by `+=` 55 56 @param[in] contract `CeedTensorContract` to use 57 @param[in] A First index of `u`, `v` 58 @param[in] B Middle index of `u`, one index of `t` 59 @param[in] C Last index of `u`, `v` 60 @param[in] J Middle index of `v`, one index of `t` 61 @param[in] t Tensor array to contract against 62 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj` 63 @param[in] add Add mode 64 @param[in] u Input array 65 @param[out] v Output array 66 67 @return An error code: 0 - success, otherwise - failure 68 69 @ref Backend 70 **/ 71 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 72 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 73 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 74 return CEED_ERROR_SUCCESS; 75 } 76 77 /** 78 @brief Apply tensor contraction 79 80 Contracts on the middle index 81 NOTRANSPOSE: `v_dajc = t_djb u_abc` 82 TRANSPOSE: `v_ajc = t_dbj u_dabc` 83 If `add != 0`, `=` is replaced by `+=` 84 85 @param[in] contract `CeedTensorContract` to use 86 @param[in] A First index of `u`, second index of `v` 87 @param[in] B Middle index of `u`, one of last two indices of `t` 88 @param[in] C Last index of `u`, `v` 89 @param[in] D First index of `v`, first index of `t` 90 @param[in] J Third index of `v`, one of last two indices of `t` 91 @param[in] t Tensor array to contract against 92 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj` 93 @param[in] add Add mode 94 @param[in] u Input array 95 @param[out] v Output array 96 97 @return An error code: 0 - success, otherwise - failure 98 99 @ref Backend 100 **/ 101 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 102 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 103 if (t_mode == CEED_TRANSPOSE) { 104 for (CeedInt d = 0; d < D; d++) { 105 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 106 } 107 } else { 108 for (CeedInt d = 0; d < D; d++) { 109 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 110 } 111 } 112 return CEED_ERROR_SUCCESS; 113 } 114 115 /** 116 @brief Get `Ceed` associated with a `CeedTensorContract` 117 118 @param[in] contract `CeedTensorContract` 119 @param[out] ceed Variable to store `Ceed` 120 121 @return An error code: 0 - success, otherwise - failure 122 123 @ref Backend 124 **/ 125 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 126 *ceed = contract->ceed; 127 return CEED_ERROR_SUCCESS; 128 } 129 130 /** 131 @brief Get backend data of a `CeedTensorContract` 132 133 @param[in] contract `CeedTensorContract` 134 @param[out] data Variable to store data 135 136 @return An error code: 0 - success, otherwise - failure 137 138 @ref Backend 139 **/ 140 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 141 *(void **)data = contract->data; 142 return CEED_ERROR_SUCCESS; 143 } 144 145 /** 146 @brief Set backend data of a `CeedTensorContract` 147 148 @param[in,out] contract `CeedTensorContract` 149 @param[in] data Data to set 150 151 @return An error code: 0 - success, otherwise - failure 152 153 @ref Backend 154 **/ 155 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 156 contract->data = data; 157 return CEED_ERROR_SUCCESS; 158 } 159 160 /** 161 @brief Increment the reference counter for a `CeedTensorContract` 162 163 @param[in,out] contract `CeedTensorContract` to increment the reference counter 164 165 @return An error code: 0 - success, otherwise - failure 166 167 @ref Backend 168 **/ 169 int CeedTensorContractReference(CeedTensorContract contract) { 170 contract->ref_count++; 171 return CEED_ERROR_SUCCESS; 172 } 173 174 /** 175 @brief Copy the pointer to a `CeedTensorContract`. 176 177 Both pointers should be destroyed with @ref CeedTensorContractDestroy(). 178 179 Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`. 180 This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`. 181 182 @param[in] tensor `CeedTensorContract` to copy reference to 183 @param[in,out] tensor_copy Variable to store copied reference 184 185 @return An error code: 0 - success, otherwise - failure 186 187 @ref User 188 **/ 189 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) { 190 CeedCall(CeedTensorContractReference(tensor)); 191 CeedCall(CeedTensorContractDestroy(tensor_copy)); 192 *tensor_copy = tensor; 193 return CEED_ERROR_SUCCESS; 194 } 195 196 /** 197 @brief Destroy a `CeedTensorContract` 198 199 @param[in,out] contract `CeedTensorContract` to destroy 200 201 @return An error code: 0 - success, otherwise - failure 202 203 @ref Backend 204 **/ 205 int CeedTensorContractDestroy(CeedTensorContract *contract) { 206 if (!*contract || --(*contract)->ref_count > 0) { 207 *contract = NULL; 208 return CEED_ERROR_SUCCESS; 209 } 210 if ((*contract)->Destroy) { 211 CeedCall((*contract)->Destroy(*contract)); 212 } 213 CeedCall(CeedDestroy(&(*contract)->ceed)); 214 CeedCall(CeedFree(contract)); 215 return CEED_ERROR_SUCCESS; 216 } 217 218 /// @} 219