1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed.h> 10 #include <ceed/backend.h> 11 #include <stddef.h> 12 13 /// @file 14 /// Implementation of CeedTensorContract interfaces 15 16 /// ---------------------------------------------------------------------------- 17 /// CeedTensorContract Backend API 18 /// ---------------------------------------------------------------------------- 19 /// @addtogroup CeedBasisBackend 20 /// @{ 21 22 /** 23 @brief Create a CeedTensorContract object for a CeedBasis 24 25 @param[in] ceed Ceed object where the CeedTensorContract will be created 26 @param[in] basis CeedBasis for which the tensor contraction will be used 27 @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored. 28 29 @return An error code: 0 - success, otherwise - failure 30 31 @ref Backend 32 **/ 33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) { 34 if (!ceed->TensorContractCreate) { 35 Ceed delegate; 36 37 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 38 CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate"); 39 CeedCall(CeedTensorContractCreate(delegate, basis, contract)); 40 return CEED_ERROR_SUCCESS; 41 } 42 43 CeedCall(CeedCalloc(1, contract)); 44 CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed)); 45 CeedCall(ceed->TensorContractCreate(basis, *contract)); 46 return CEED_ERROR_SUCCESS; 47 } 48 49 /** 50 @brief Apply tensor contraction 51 52 Contracts on the middle index 53 NOTRANSPOSE: v_ajc = t_jb u_abc 54 TRANSPOSE: v_ajc = t_bj u_abc 55 If add != 0, "=" is replaced by "+=" 56 57 @param[in] contract CeedTensorContract to use 58 @param[in] A First index of u, v 59 @param[in] B Middle index of u, one index of t 60 @param[in] C Last index of u, v 61 @param[in] J Middle index of v, one index of t 62 @param[in] t Tensor array to contract against 63 @param[in] t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj 64 @param[in] add Add mode 65 @param[in] u Input array 66 @param[out] v Output array 67 68 @return An error code: 0 - success, otherwise - failure 69 70 @ref Backend 71 **/ 72 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 73 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 74 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 75 return CEED_ERROR_SUCCESS; 76 } 77 78 /** 79 @brief Apply tensor contraction 80 81 Contracts on the middle index 82 NOTRANSPOSE: v_dajc = t_djb u_abc 83 TRANSPOSE: v_ajc = t_dbj u_dabc 84 If add != 0, "=" is replaced by "+=" 85 86 @param[in] contract CeedTensorContract to use 87 @param[in] A First index of u, second index of v 88 @param[in] B Middle index of u, one of last two indices of t 89 @param[in] C Last index of u, v 90 @param[in] D First index of v, first index of t 91 @param[in] J Third index of v, one of last two indices of t 92 @param[in] t Tensor array to contract against 93 @param[in] t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj 94 @param[in] add Add mode 95 @param[in] u Input array 96 @param[out] v Output array 97 98 @return An error code: 0 - success, otherwise - failure 99 100 @ref Backend 101 **/ 102 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 103 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 104 if (t_mode == CEED_TRANSPOSE) { 105 for (CeedInt d = 0; d < D; d++) { 106 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 107 } 108 } else { 109 for (CeedInt d = 0; d < D; d++) { 110 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 111 } 112 } 113 return CEED_ERROR_SUCCESS; 114 } 115 116 /** 117 @brief Get Ceed associated with a CeedTensorContract 118 119 @param[in] contract CeedTensorContract 120 @param[out] ceed Variable to store Ceed 121 122 @return An error code: 0 - success, otherwise - failure 123 124 @ref Backend 125 **/ 126 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 127 *ceed = contract->ceed; 128 return CEED_ERROR_SUCCESS; 129 } 130 131 /** 132 @brief Get backend data of a CeedTensorContract 133 134 @param[in] contract CeedTensorContract 135 @param[out] data Variable to store data 136 137 @return An error code: 0 - success, otherwise - failure 138 139 @ref Backend 140 **/ 141 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 142 *(void **)data = contract->data; 143 return CEED_ERROR_SUCCESS; 144 } 145 146 /** 147 @brief Set backend data of a CeedTensorContract 148 149 @param[in,out] contract CeedTensorContract 150 @param[in] data Data to set 151 152 @return An error code: 0 - success, otherwise - failure 153 154 @ref Backend 155 **/ 156 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 157 contract->data = data; 158 return CEED_ERROR_SUCCESS; 159 } 160 161 /** 162 @brief Increment the reference counter for a CeedTensorContract 163 164 @param[in,out] contract CeedTensorContract to increment the reference counter 165 166 @return An error code: 0 - success, otherwise - failure 167 168 @ref Backend 169 **/ 170 int CeedTensorContractReference(CeedTensorContract contract) { 171 contract->ref_count++; 172 return CEED_ERROR_SUCCESS; 173 } 174 175 /** 176 @brief Destroy a CeedTensorContract 177 178 @param[in,out] contract CeedTensorContract to destroy 179 180 @return An error code: 0 - success, otherwise - failure 181 182 @ref Backend 183 **/ 184 int CeedTensorContractDestroy(CeedTensorContract *contract) { 185 if (!*contract || --(*contract)->ref_count > 0) { 186 *contract = NULL; 187 return CEED_ERROR_SUCCESS; 188 } 189 if ((*contract)->Destroy) { 190 CeedCall((*contract)->Destroy(*contract)); 191 } 192 CeedCall(CeedDestroy(&(*contract)->ceed)); 193 CeedCall(CeedFree(contract)); 194 return CEED_ERROR_SUCCESS; 195 } 196 197 /// @} 198