13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 32f86a920SJeremy L Thompson // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 52f86a920SJeremy L Thompson // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 72f86a920SJeremy L Thompson 83d576824SJeremy L Thompson #include <ceed-impl.h> 949aac155SJeremy L Thompson #include <ceed.h> 102b730f8bSJeremy L Thompson #include <ceed/backend.h> 1149aac155SJeremy L Thompson #include <stddef.h> 122f86a920SJeremy L Thompson 132f86a920SJeremy L Thompson /// @file 147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces 157a982d89SJeremy L. Thompson 167a982d89SJeremy L. Thompson /// ---------------------------------------------------------------------------- 177a982d89SJeremy L. Thompson /// CeedTensorContract Backend API 187a982d89SJeremy L. Thompson /// ---------------------------------------------------------------------------- 197a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend 202f86a920SJeremy L Thompson /// @{ 212f86a920SJeremy L Thompson 222f86a920SJeremy L Thompson /** 23ca94c3ddSJeremy L Thompson @brief Create a `CeedTensorContract` object for a `CeedBasis` 242f86a920SJeremy L Thompson 25ca94c3ddSJeremy L Thompson @param[in] ceed `Ceed` object used to create the `CeedTensorContract` 26ca94c3ddSJeremy L Thompson @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored. 272f86a920SJeremy L Thompson 282f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 292f86a920SJeremy L Thompson 307a982d89SJeremy L. Thompson @ref Backend 312f86a920SJeremy L Thompson **/ 32a71faab1SSebastian Grimberg int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) { 332f86a920SJeremy L Thompson if (!ceed->TensorContractCreate) { 342f86a920SJeremy L Thompson Ceed delegate; 356574a04fSJeremy L Thompson 362b730f8bSJeremy L Thompson CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37ca94c3ddSJeremy L Thompson CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support CeedTensorContractCreate"); 38a71faab1SSebastian Grimberg CeedCall(CeedTensorContractCreate(delegate, contract)); 39e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 402f86a920SJeremy L Thompson } 412f86a920SJeremy L Thompson 422b730f8bSJeremy L Thompson CeedCall(CeedCalloc(1, contract)); 43db002c03SJeremy L Thompson CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed)); 44a71faab1SSebastian Grimberg CeedCall(ceed->TensorContractCreate(*contract)); 45e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 467a982d89SJeremy L. Thompson } 472f86a920SJeremy L Thompson 482f86a920SJeremy L Thompson /** 492f86a920SJeremy L Thompson @brief Apply tensor contraction 502f86a920SJeremy L Thompson 512f86a920SJeremy L Thompson Contracts on the middle index 52ca94c3ddSJeremy L Thompson NOTRANSPOSE: `v_ajc = t_jb u_abc` 53ca94c3ddSJeremy L Thompson TRANSPOSE: `v_ajc = t_bj u_abc` 54ca94c3ddSJeremy L Thompson If `add != 0`, `=` is replaced by `+=` 552f86a920SJeremy L Thompson 56ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` to use 57ca94c3ddSJeremy L Thompson @param[in] A First index of `u`, `v` 58ca94c3ddSJeremy L Thompson @param[in] B Middle index of `u`, one index of `t` 59ca94c3ddSJeremy L Thompson @param[in] C Last index of `u`, `v` 60ca94c3ddSJeremy L Thompson @param[in] J Middle index of `v`, one index of `t` 612f86a920SJeremy L Thompson @param[in] t Tensor array to contract against 62ca94c3ddSJeremy L Thompson @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj` 63ea61e9acSJeremy L Thompson @param[in] add Add mode 642f86a920SJeremy L Thompson @param[in] u Input array 652f86a920SJeremy L Thompson @param[out] v Output array 662f86a920SJeremy L Thompson 672f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 682f86a920SJeremy L Thompson 697a982d89SJeremy L. Thompson @ref Backend 702f86a920SJeremy L Thompson **/ 712b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 722b730f8bSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 732b730f8bSJeremy L Thompson CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 74e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 757a982d89SJeremy L. Thompson } 762f86a920SJeremy L Thompson 772f86a920SJeremy L Thompson /** 78c4e3f59bSSebastian Grimberg @brief Apply tensor contraction 79c4e3f59bSSebastian Grimberg 80c4e3f59bSSebastian Grimberg Contracts on the middle index 81ca94c3ddSJeremy L Thompson NOTRANSPOSE: `v_dajc = t_djb u_abc` 82ca94c3ddSJeremy L Thompson TRANSPOSE: `v_ajc = t_dbj u_dabc` 83ca94c3ddSJeremy L Thompson If `add != 0`, `=` is replaced by `+=` 84c4e3f59bSSebastian Grimberg 85ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` to use 86ca94c3ddSJeremy L Thompson @param[in] A First index of `u`, second index of `v` 87ca94c3ddSJeremy L Thompson @param[in] B Middle index of `u`, one of last two indices of `t` 88ca94c3ddSJeremy L Thompson @param[in] C Last index of `u`, `v` 89ca94c3ddSJeremy L Thompson @param[in] D First index of `v`, first index of `t` 90ca94c3ddSJeremy L Thompson @param[in] J Third index of `v`, one of last two indices of `t` 91c4e3f59bSSebastian Grimberg @param[in] t Tensor array to contract against 92ca94c3ddSJeremy L Thompson @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj` 93c4e3f59bSSebastian Grimberg @param[in] add Add mode 94c4e3f59bSSebastian Grimberg @param[in] u Input array 95c4e3f59bSSebastian Grimberg @param[out] v Output array 96c4e3f59bSSebastian Grimberg 97c4e3f59bSSebastian Grimberg @return An error code: 0 - success, otherwise - failure 98c4e3f59bSSebastian Grimberg 99c4e3f59bSSebastian Grimberg @ref Backend 100c4e3f59bSSebastian Grimberg **/ 101c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 102c4e3f59bSSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 103c4e3f59bSSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 104c4e3f59bSSebastian Grimberg for (CeedInt d = 0; d < D; d++) { 105c4e3f59bSSebastian Grimberg CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 106c4e3f59bSSebastian Grimberg } 107c4e3f59bSSebastian Grimberg } else { 108c4e3f59bSSebastian Grimberg for (CeedInt d = 0; d < D; d++) { 109c4e3f59bSSebastian Grimberg CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 110c4e3f59bSSebastian Grimberg } 111c4e3f59bSSebastian Grimberg } 112c4e3f59bSSebastian Grimberg return CEED_ERROR_SUCCESS; 113c4e3f59bSSebastian Grimberg } 114c4e3f59bSSebastian Grimberg 115c4e3f59bSSebastian Grimberg /** 116*6e536b99SJeremy L Thompson @brief Get the `Ceed` associated with a `CeedTensorContract` 1172f86a920SJeremy L Thompson 118ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` 119ca94c3ddSJeremy L Thompson @param[out] ceed Variable to store `Ceed` 1202f86a920SJeremy L Thompson 1212f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1222f86a920SJeremy L Thompson 1237a982d89SJeremy L. Thompson @ref Backend 1242f86a920SJeremy L Thompson **/ 1252f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 126*6e536b99SJeremy L Thompson *ceed = CeedTensorContractReturnCeed(contract); 127e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1287a982d89SJeremy L. Thompson } 1292f86a920SJeremy L Thompson 1302f86a920SJeremy L Thompson /** 131*6e536b99SJeremy L Thompson @brief Return the `Ceed` associated with a `CeedTensorContract` 132*6e536b99SJeremy L Thompson 133*6e536b99SJeremy L Thompson @param[in] contract `CeedTensorContract` 134*6e536b99SJeremy L Thompson 135*6e536b99SJeremy L Thompson @return `Ceed` associated with `contract` 136*6e536b99SJeremy L Thompson 137*6e536b99SJeremy L Thompson @ref Backend 138*6e536b99SJeremy L Thompson **/ 139*6e536b99SJeremy L Thompson Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return contract->ceed; } 140*6e536b99SJeremy L Thompson 141*6e536b99SJeremy L Thompson /** 142ca94c3ddSJeremy L Thompson @brief Get backend data of a `CeedTensorContract` 1432f86a920SJeremy L Thompson 144ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` 1452f86a920SJeremy L Thompson @param[out] data Variable to store data 1462f86a920SJeremy L Thompson 1472f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1482f86a920SJeremy L Thompson 1497a982d89SJeremy L. Thompson @ref Backend 1502f86a920SJeremy L Thompson **/ 151777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 152777ff853SJeremy L Thompson *(void **)data = contract->data; 153e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1542f86a920SJeremy L Thompson } 1552f86a920SJeremy L Thompson 1562f86a920SJeremy L Thompson /** 157ca94c3ddSJeremy L Thompson @brief Set backend data of a `CeedTensorContract` 1582f86a920SJeremy L Thompson 159ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` 160ea61e9acSJeremy L Thompson @param[in] data Data to set 1612f86a920SJeremy L Thompson 1622f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1632f86a920SJeremy L Thompson 1647a982d89SJeremy L. Thompson @ref Backend 1652f86a920SJeremy L Thompson **/ 166777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 167777ff853SJeremy L Thompson contract->data = data; 168e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1692f86a920SJeremy L Thompson } 1702f86a920SJeremy L Thompson 1712f86a920SJeremy L Thompson /** 172ca94c3ddSJeremy L Thompson @brief Increment the reference counter for a `CeedTensorContract` 17334359f16Sjeremylt 174ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` to increment the reference counter 17534359f16Sjeremylt 17634359f16Sjeremylt @return An error code: 0 - success, otherwise - failure 17734359f16Sjeremylt 17834359f16Sjeremylt @ref Backend 17934359f16Sjeremylt **/ 1809560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) { 18134359f16Sjeremylt contract->ref_count++; 18234359f16Sjeremylt return CEED_ERROR_SUCCESS; 18334359f16Sjeremylt } 18434359f16Sjeremylt 18534359f16Sjeremylt /** 186ca94c3ddSJeremy L Thompson @brief Copy the pointer to a `CeedTensorContract`. 187585a562dSJeremy L Thompson 188ca94c3ddSJeremy L Thompson Both pointers should be destroyed with @ref CeedTensorContractDestroy(). 189585a562dSJeremy L Thompson 190ca94c3ddSJeremy L Thompson Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`. 191ca94c3ddSJeremy L Thompson This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`. 192585a562dSJeremy L Thompson 193ca94c3ddSJeremy L Thompson @param[in] tensor `CeedTensorContract` to copy reference to 194585a562dSJeremy L Thompson @param[in,out] tensor_copy Variable to store copied reference 195585a562dSJeremy L Thompson 196585a562dSJeremy L Thompson @return An error code: 0 - success, otherwise - failure 197585a562dSJeremy L Thompson 198585a562dSJeremy L Thompson @ref User 199585a562dSJeremy L Thompson **/ 200585a562dSJeremy L Thompson int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) { 201585a562dSJeremy L Thompson CeedCall(CeedTensorContractReference(tensor)); 202585a562dSJeremy L Thompson CeedCall(CeedTensorContractDestroy(tensor_copy)); 203585a562dSJeremy L Thompson *tensor_copy = tensor; 204585a562dSJeremy L Thompson return CEED_ERROR_SUCCESS; 205585a562dSJeremy L Thompson } 206585a562dSJeremy L Thompson 207585a562dSJeremy L Thompson /** 208ca94c3ddSJeremy L Thompson @brief Destroy a `CeedTensorContract` 2092f86a920SJeremy L Thompson 210ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` to destroy 2112f86a920SJeremy L Thompson 2122f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 2132f86a920SJeremy L Thompson 2147a982d89SJeremy L. Thompson @ref Backend 2152f86a920SJeremy L Thompson **/ 2162f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) { 217ad6481ceSJeremy L Thompson if (!*contract || --(*contract)->ref_count > 0) { 218ad6481ceSJeremy L Thompson *contract = NULL; 219ad6481ceSJeremy L Thompson return CEED_ERROR_SUCCESS; 220ad6481ceSJeremy L Thompson } 2212f86a920SJeremy L Thompson if ((*contract)->Destroy) { 2222b730f8bSJeremy L Thompson CeedCall((*contract)->Destroy(*contract)); 2232f86a920SJeremy L Thompson } 2242b730f8bSJeremy L Thompson CeedCall(CeedDestroy(&(*contract)->ceed)); 2252b730f8bSJeremy L Thompson CeedCall(CeedFree(contract)); 226e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2277a982d89SJeremy L. Thompson } 2282f86a920SJeremy L Thompson 2292f86a920SJeremy L Thompson /// @} 230