13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 32f86a920SJeremy L Thompson // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 52f86a920SJeremy L Thompson // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 72f86a920SJeremy L Thompson 83d576824SJeremy L Thompson #include <ceed-impl.h> 949aac155SJeremy L Thompson #include <ceed.h> 102b730f8bSJeremy L Thompson #include <ceed/backend.h> 1149aac155SJeremy L Thompson #include <stddef.h> 122f86a920SJeremy L Thompson 132f86a920SJeremy L Thompson /// @file 147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces 157a982d89SJeremy L. Thompson 167a982d89SJeremy L. Thompson /// ---------------------------------------------------------------------------- 177a982d89SJeremy L. Thompson /// CeedTensorContract Backend API 187a982d89SJeremy L. Thompson /// ---------------------------------------------------------------------------- 197a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend 202f86a920SJeremy L Thompson /// @{ 212f86a920SJeremy L Thompson 222f86a920SJeremy L Thompson /** 23*ca94c3ddSJeremy L Thompson @brief Create a `CeedTensorContract` object for a `CeedBasis` 242f86a920SJeremy L Thompson 25*ca94c3ddSJeremy L Thompson @param[in] ceed `Ceed` object used to create the `CeedTensorContract` 26*ca94c3ddSJeremy L Thompson @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored. 272f86a920SJeremy L Thompson 282f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 292f86a920SJeremy L Thompson 307a982d89SJeremy L. Thompson @ref Backend 312f86a920SJeremy L Thompson **/ 32a71faab1SSebastian Grimberg int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) { 332f86a920SJeremy L Thompson if (!ceed->TensorContractCreate) { 342f86a920SJeremy L Thompson Ceed delegate; 356574a04fSJeremy L Thompson 362b730f8bSJeremy L Thompson CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract")); 37*ca94c3ddSJeremy L Thompson CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support CeedTensorContractCreate"); 38a71faab1SSebastian Grimberg CeedCall(CeedTensorContractCreate(delegate, contract)); 39e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 402f86a920SJeremy L Thompson } 412f86a920SJeremy L Thompson 422b730f8bSJeremy L Thompson CeedCall(CeedCalloc(1, contract)); 43db002c03SJeremy L Thompson CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed)); 44a71faab1SSebastian Grimberg CeedCall(ceed->TensorContractCreate(*contract)); 45e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 467a982d89SJeremy L. Thompson } 472f86a920SJeremy L Thompson 482f86a920SJeremy L Thompson /** 492f86a920SJeremy L Thompson @brief Apply tensor contraction 502f86a920SJeremy L Thompson 512f86a920SJeremy L Thompson Contracts on the middle index 52*ca94c3ddSJeremy L Thompson NOTRANSPOSE: `v_ajc = t_jb u_abc` 53*ca94c3ddSJeremy L Thompson TRANSPOSE: `v_ajc = t_bj u_abc` 54*ca94c3ddSJeremy L Thompson If `add != 0`, `=` is replaced by `+=` 552f86a920SJeremy L Thompson 56*ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` to use 57*ca94c3ddSJeremy L Thompson @param[in] A First index of `u`, `v` 58*ca94c3ddSJeremy L Thompson @param[in] B Middle index of `u`, one index of `t` 59*ca94c3ddSJeremy L Thompson @param[in] C Last index of `u`, `v` 60*ca94c3ddSJeremy L Thompson @param[in] J Middle index of `v`, one index of `t` 612f86a920SJeremy L Thompson @param[in] t Tensor array to contract against 62*ca94c3ddSJeremy L Thompson @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj` 63ea61e9acSJeremy L Thompson @param[in] add Add mode 642f86a920SJeremy L Thompson @param[in] u Input array 652f86a920SJeremy L Thompson @param[out] v Output array 662f86a920SJeremy L Thompson 672f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 682f86a920SJeremy L Thompson 697a982d89SJeremy L. Thompson @ref Backend 702f86a920SJeremy L Thompson **/ 712b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 722b730f8bSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 732b730f8bSJeremy L Thompson CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v)); 74e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 757a982d89SJeremy L. Thompson } 762f86a920SJeremy L Thompson 772f86a920SJeremy L Thompson /** 78c4e3f59bSSebastian Grimberg @brief Apply tensor contraction 79c4e3f59bSSebastian Grimberg 80c4e3f59bSSebastian Grimberg Contracts on the middle index 81*ca94c3ddSJeremy L Thompson NOTRANSPOSE: `v_dajc = t_djb u_abc` 82*ca94c3ddSJeremy L Thompson TRANSPOSE: `v_ajc = t_dbj u_dabc` 83*ca94c3ddSJeremy L Thompson If `add != 0`, `=` is replaced by `+=` 84c4e3f59bSSebastian Grimberg 85*ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` to use 86*ca94c3ddSJeremy L Thompson @param[in] A First index of `u`, second index of `v` 87*ca94c3ddSJeremy L Thompson @param[in] B Middle index of `u`, one of last two indices of `t` 88*ca94c3ddSJeremy L Thompson @param[in] C Last index of `u`, `v` 89*ca94c3ddSJeremy L Thompson @param[in] D First index of `v`, first index of `t` 90*ca94c3ddSJeremy L Thompson @param[in] J Third index of `v`, one of last two indices of `t` 91c4e3f59bSSebastian Grimberg @param[in] t Tensor array to contract against 92*ca94c3ddSJeremy L Thompson @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj` 93c4e3f59bSSebastian Grimberg @param[in] add Add mode 94c4e3f59bSSebastian Grimberg @param[in] u Input array 95c4e3f59bSSebastian Grimberg @param[out] v Output array 96c4e3f59bSSebastian Grimberg 97c4e3f59bSSebastian Grimberg @return An error code: 0 - success, otherwise - failure 98c4e3f59bSSebastian Grimberg 99c4e3f59bSSebastian Grimberg @ref Backend 100c4e3f59bSSebastian Grimberg **/ 101c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t, 102c4e3f59bSSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 103c4e3f59bSSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 104c4e3f59bSSebastian Grimberg for (CeedInt d = 0; d < D; d++) { 105c4e3f59bSSebastian Grimberg CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v)); 106c4e3f59bSSebastian Grimberg } 107c4e3f59bSSebastian Grimberg } else { 108c4e3f59bSSebastian Grimberg for (CeedInt d = 0; d < D; d++) { 109c4e3f59bSSebastian Grimberg CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C)); 110c4e3f59bSSebastian Grimberg } 111c4e3f59bSSebastian Grimberg } 112c4e3f59bSSebastian Grimberg return CEED_ERROR_SUCCESS; 113c4e3f59bSSebastian Grimberg } 114c4e3f59bSSebastian Grimberg 115c4e3f59bSSebastian Grimberg /** 116*ca94c3ddSJeremy L Thompson @brief Get `Ceed` associated with a `CeedTensorContract` 1172f86a920SJeremy L Thompson 118*ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` 119*ca94c3ddSJeremy L Thompson @param[out] ceed Variable to store `Ceed` 1202f86a920SJeremy L Thompson 1212f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1222f86a920SJeremy L Thompson 1237a982d89SJeremy L. Thompson @ref Backend 1242f86a920SJeremy L Thompson **/ 1252f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) { 1262f86a920SJeremy L Thompson *ceed = contract->ceed; 127e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1287a982d89SJeremy L. Thompson } 1292f86a920SJeremy L Thompson 1302f86a920SJeremy L Thompson /** 131*ca94c3ddSJeremy L Thompson @brief Get backend data of a `CeedTensorContract` 1322f86a920SJeremy L Thompson 133*ca94c3ddSJeremy L Thompson @param[in] contract `CeedTensorContract` 1342f86a920SJeremy L Thompson @param[out] data Variable to store data 1352f86a920SJeremy L Thompson 1362f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1372f86a920SJeremy L Thompson 1387a982d89SJeremy L. Thompson @ref Backend 1392f86a920SJeremy L Thompson **/ 140777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) { 141777ff853SJeremy L Thompson *(void **)data = contract->data; 142e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1432f86a920SJeremy L Thompson } 1442f86a920SJeremy L Thompson 1452f86a920SJeremy L Thompson /** 146*ca94c3ddSJeremy L Thompson @brief Set backend data of a `CeedTensorContract` 1472f86a920SJeremy L Thompson 148*ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` 149ea61e9acSJeremy L Thompson @param[in] data Data to set 1502f86a920SJeremy L Thompson 1512f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 1522f86a920SJeremy L Thompson 1537a982d89SJeremy L. Thompson @ref Backend 1542f86a920SJeremy L Thompson **/ 155777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) { 156777ff853SJeremy L Thompson contract->data = data; 157e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1582f86a920SJeremy L Thompson } 1592f86a920SJeremy L Thompson 1602f86a920SJeremy L Thompson /** 161*ca94c3ddSJeremy L Thompson @brief Increment the reference counter for a `CeedTensorContract` 16234359f16Sjeremylt 163*ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` to increment the reference counter 16434359f16Sjeremylt 16534359f16Sjeremylt @return An error code: 0 - success, otherwise - failure 16634359f16Sjeremylt 16734359f16Sjeremylt @ref Backend 16834359f16Sjeremylt **/ 1699560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) { 17034359f16Sjeremylt contract->ref_count++; 17134359f16Sjeremylt return CEED_ERROR_SUCCESS; 17234359f16Sjeremylt } 17334359f16Sjeremylt 17434359f16Sjeremylt /** 175*ca94c3ddSJeremy L Thompson @brief Copy the pointer to a `CeedTensorContract`. 176585a562dSJeremy L Thompson 177*ca94c3ddSJeremy L Thompson Both pointers should be destroyed with @ref CeedTensorContractDestroy(). 178585a562dSJeremy L Thompson 179*ca94c3ddSJeremy L Thompson Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`. 180*ca94c3ddSJeremy L Thompson This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`. 181585a562dSJeremy L Thompson 182*ca94c3ddSJeremy L Thompson @param[in] tensor `CeedTensorContract` to copy reference to 183585a562dSJeremy L Thompson @param[in,out] tensor_copy Variable to store copied reference 184585a562dSJeremy L Thompson 185585a562dSJeremy L Thompson @return An error code: 0 - success, otherwise - failure 186585a562dSJeremy L Thompson 187585a562dSJeremy L Thompson @ref User 188585a562dSJeremy L Thompson **/ 189585a562dSJeremy L Thompson int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) { 190585a562dSJeremy L Thompson CeedCall(CeedTensorContractReference(tensor)); 191585a562dSJeremy L Thompson CeedCall(CeedTensorContractDestroy(tensor_copy)); 192585a562dSJeremy L Thompson *tensor_copy = tensor; 193585a562dSJeremy L Thompson return CEED_ERROR_SUCCESS; 194585a562dSJeremy L Thompson } 195585a562dSJeremy L Thompson 196585a562dSJeremy L Thompson /** 197*ca94c3ddSJeremy L Thompson @brief Destroy a `CeedTensorContract` 1982f86a920SJeremy L Thompson 199*ca94c3ddSJeremy L Thompson @param[in,out] contract `CeedTensorContract` to destroy 2002f86a920SJeremy L Thompson 2012f86a920SJeremy L Thompson @return An error code: 0 - success, otherwise - failure 2022f86a920SJeremy L Thompson 2037a982d89SJeremy L. Thompson @ref Backend 2042f86a920SJeremy L Thompson **/ 2052f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) { 206ad6481ceSJeremy L Thompson if (!*contract || --(*contract)->ref_count > 0) { 207ad6481ceSJeremy L Thompson *contract = NULL; 208ad6481ceSJeremy L Thompson return CEED_ERROR_SUCCESS; 209ad6481ceSJeremy L Thompson } 2102f86a920SJeremy L Thompson if ((*contract)->Destroy) { 2112b730f8bSJeremy L Thompson CeedCall((*contract)->Destroy(*contract)); 2122f86a920SJeremy L Thompson } 2132b730f8bSJeremy L Thompson CeedCall(CeedDestroy(&(*contract)->ceed)); 2142b730f8bSJeremy L Thompson CeedCall(CeedFree(contract)); 215e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2167a982d89SJeremy L. Thompson } 2172f86a920SJeremy L Thompson 2182f86a920SJeremy L Thompson /// @} 219