1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed-impl.h> 9 #include <ceed/backend.h> 10 #include <ceed/ceed.h> 11 12 /// @file 13 /// Implementation of CeedTensorContract interfaces 14 15 /// ---------------------------------------------------------------------------- 16 /// CeedTensorContract Backend API 17 /// ---------------------------------------------------------------------------- 18 /// @addtogroup CeedBasisBackend 19 /// @{ 20 21 /** 22 @brief Create a CeedTensorContract object for a CeedBasis 23 24 @param ceed A Ceed object where the CeedTensorContract will be created 25 @param basis CeedBasis for which the tensor contraction will be used 26 @param[out] contract Address of the variable where the newly created 27 CeedTensorContract will be stored. 28 29 @return An error code: 0 - success, otherwise - failure 30 31 @ref Backend 32 **/ 33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) { 34 if (!ceed->TensorContractCreate) { 35 Ceed delegate; 36 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37 38 if (!delegate) { 39 // LCOV_EXCL_START 40 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate"); 41 // LCOV_EXCL_STOP 42 } 43 44 CeedCall(CeedTensorContractCreate(delegate, basis, contract)); 45 return CEED_ERROR_SUCCESS; 46 } 47 48 CeedCall(CeedCalloc(1, contract)); 49 50 (*contract)->ceed = ceed; 51 CeedCall(CeedReference(ceed)); 52 CeedCall(ceed->TensorContractCreate(basis, *contract)); 53 return CEED_ERROR_SUCCESS; 54 } 55 56 /** 57 @brief Apply tensor contraction 58 59 Contracts on the middle index 60 NOTRANSPOSE: v_ajc = t_jb u_abc 61 TRANSPOSE: v_ajc = t_bj u_abc 62 If add != 0, "=" is replaced by "+=" 63 64 @param contract CeedTensorContract to use 65 @param A First index of u, v 66 @param B Middle index of u, one index of t 67 @param C Last index of u, v 68 @param J Middle index of v, one index of t 69 @param[in] t Tensor array to contract against 70 @param t_mode Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb 71 \ref CEED_TRANSPOSE for t_bj 72 @param add Add mode 73 @param[in] u Input array 74 @param[out] v Output array 75 76 @return An error code: 0 - success, otherwise - failure 77 78 @ref Backend 79 **/ 80 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 81 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 82 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 83 return CEED_ERROR_SUCCESS; 84 } 85 86 /** 87 @brief Get Ceed associated with a CeedTensorContract 88 89 @param contract CeedTensorContract 90 @param[out] ceed Variable to store Ceed 91 92 @return An error code: 0 - success, otherwise - failure 93 94 @ref Backend 95 **/ 96 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 97 *ceed = contract->ceed; 98 return CEED_ERROR_SUCCESS; 99 } 100 101 /** 102 @brief Get backend data of a CeedTensorContract 103 104 @param contract CeedTensorContract 105 @param[out] data Variable to store data 106 107 @return An error code: 0 - success, otherwise - failure 108 109 @ref Backend 110 **/ 111 int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 112 *(void **)data = contract->data; 113 return CEED_ERROR_SUCCESS; 114 } 115 116 /** 117 @brief Set backend data of a CeedTensorContract 118 119 @param[out] contract CeedTensorContract 120 @param data Data to set 121 122 @return An error code: 0 - success, otherwise - failure 123 124 @ref Backend 125 **/ 126 int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 127 contract->data = data; 128 return CEED_ERROR_SUCCESS; 129 } 130 131 /** 132 @brief Increment the reference counter for a CeedTensorContract 133 134 @param contract CeedTensorContract to increment the reference counter 135 136 @return An error code: 0 - success, otherwise - failure 137 138 @ref Backend 139 **/ 140 int CeedTensorContractReference(CeedTensorContract contract) { 141 contract->ref_count++; 142 return CEED_ERROR_SUCCESS; 143 } 144 145 /** 146 @brief Destroy a CeedTensorContract 147 148 @param contract CeedTensorContract to destroy 149 150 @return An error code: 0 - success, otherwise - failure 151 152 @ref Backend 153 **/ 154 int CeedTensorContractDestroy(CeedTensorContract *contract) { 155 if (!*contract || --(*contract)->ref_count > 0) return CEED_ERROR_SUCCESS; 156 if ((*contract)->Destroy) { 157 CeedCall((*contract)->Destroy(*contract)); 158 } 159 CeedCall(CeedDestroy(&(*contract)->ceed)); 160 CeedCall(CeedFree(contract)); 161 return CEED_ERROR_SUCCESS; 162 } 163 164 /// @} 165