xref: /libCEED/interface/ceed-tensor.c (revision a71faab149de599cd7784196409e851b67144f0a)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
32f86a920SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
52f86a920SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
72f86a920SJeremy L Thompson 
83d576824SJeremy L Thompson #include <ceed-impl.h>
949aac155SJeremy L Thompson #include <ceed.h>
102b730f8bSJeremy L Thompson #include <ceed/backend.h>
1149aac155SJeremy L Thompson #include <stddef.h>
122f86a920SJeremy L Thompson 
132f86a920SJeremy L Thompson /// @file
147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces
157a982d89SJeremy L. Thompson 
167a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
177a982d89SJeremy L. Thompson /// CeedTensorContract Backend API
187a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
197a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend
202f86a920SJeremy L Thompson /// @{
212f86a920SJeremy L Thompson 
222f86a920SJeremy L Thompson /**
232f86a920SJeremy L Thompson   @brief Create a CeedTensorContract object for a CeedBasis
242f86a920SJeremy L Thompson 
25ea61e9acSJeremy L Thompson   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26ea61e9acSJeremy L Thompson   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
272f86a920SJeremy L Thompson 
282f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
292f86a920SJeremy L Thompson 
307a982d89SJeremy L. Thompson   @ref Backend
312f86a920SJeremy L Thompson **/
32*a71faab1SSebastian Grimberg int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
332f86a920SJeremy L Thompson   if (!ceed->TensorContractCreate) {
342f86a920SJeremy L Thompson     Ceed delegate;
356574a04fSJeremy L Thompson 
362b730f8bSJeremy L Thompson     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
376574a04fSJeremy L Thompson     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
38*a71faab1SSebastian Grimberg     CeedCall(CeedTensorContractCreate(delegate, contract));
39e15f9bd0SJeremy L Thompson     return CEED_ERROR_SUCCESS;
402f86a920SJeremy L Thompson   }
412f86a920SJeremy L Thompson 
422b730f8bSJeremy L Thompson   CeedCall(CeedCalloc(1, contract));
43db002c03SJeremy L Thompson   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
44*a71faab1SSebastian Grimberg   CeedCall(ceed->TensorContractCreate(*contract));
45e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
467a982d89SJeremy L. Thompson }
472f86a920SJeremy L Thompson 
482f86a920SJeremy L Thompson /**
492f86a920SJeremy L Thompson   @brief Apply tensor contraction
502f86a920SJeremy L Thompson 
512f86a920SJeremy L Thompson   Contracts on the middle index
522f86a920SJeremy L Thompson   NOTRANSPOSE: v_ajc = t_jb u_abc
532f86a920SJeremy L Thompson   TRANSPOSE:   v_ajc = t_bj u_abc
542f86a920SJeremy L Thompson   If add != 0, "=" is replaced by "+="
552f86a920SJeremy L Thompson 
56ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract to use
57ea61e9acSJeremy L Thompson   @param[in]  A        First index of u, v
58ea61e9acSJeremy L Thompson   @param[in]  B        Middle index of u, one index of t
59ea61e9acSJeremy L Thompson   @param[in]  C        Last index of u, v
60ea61e9acSJeremy L Thompson   @param[in]  J        Middle index of v, one index of t
612f86a920SJeremy L Thompson   @param[in]  t        Tensor array to contract against
62ea61e9acSJeremy L Thompson   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
63ea61e9acSJeremy L Thompson   @param[in]  add      Add mode
642f86a920SJeremy L Thompson   @param[in]  u        Input array
652f86a920SJeremy L Thompson   @param[out] v        Output array
662f86a920SJeremy L Thompson 
672f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
682f86a920SJeremy L Thompson 
697a982d89SJeremy L. Thompson   @ref Backend
702f86a920SJeremy L Thompson **/
712b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
722b730f8bSJeremy L Thompson                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
732b730f8bSJeremy L Thompson   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
74e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
757a982d89SJeremy L. Thompson }
762f86a920SJeremy L Thompson 
772f86a920SJeremy L Thompson /**
78c4e3f59bSSebastian Grimberg   @brief Apply tensor contraction
79c4e3f59bSSebastian Grimberg 
80c4e3f59bSSebastian Grimberg   Contracts on the middle index
81c4e3f59bSSebastian Grimberg   NOTRANSPOSE: v_dajc = t_djb u_abc
82c4e3f59bSSebastian Grimberg   TRANSPOSE:   v_ajc  = t_dbj u_dabc
83c4e3f59bSSebastian Grimberg   If add != 0, "=" is replaced by "+="
84c4e3f59bSSebastian Grimberg 
85c4e3f59bSSebastian Grimberg   @param[in]  contract CeedTensorContract to use
86c4e3f59bSSebastian Grimberg   @param[in]  A        First index of u, second index of v
87c4e3f59bSSebastian Grimberg   @param[in]  B        Middle index of u, one of last two indices of t
88c4e3f59bSSebastian Grimberg   @param[in]  C        Last index of u, v
89c4e3f59bSSebastian Grimberg   @param[in]  D        First index of v, first index of t
90c4e3f59bSSebastian Grimberg   @param[in]  J        Third index of v, one of last two indices of t
91c4e3f59bSSebastian Grimberg   @param[in]  t        Tensor array to contract against
924548da4eSSebastian Grimberg   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_djb \ref CEED_TRANSPOSE for t_dbj
93c4e3f59bSSebastian Grimberg   @param[in]  add      Add mode
94c4e3f59bSSebastian Grimberg   @param[in]  u        Input array
95c4e3f59bSSebastian Grimberg   @param[out] v        Output array
96c4e3f59bSSebastian Grimberg 
97c4e3f59bSSebastian Grimberg   @return An error code: 0 - success, otherwise - failure
98c4e3f59bSSebastian Grimberg 
99c4e3f59bSSebastian Grimberg   @ref Backend
100c4e3f59bSSebastian Grimberg **/
101c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
102c4e3f59bSSebastian Grimberg                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
103c4e3f59bSSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
104c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
105c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
106c4e3f59bSSebastian Grimberg     }
107c4e3f59bSSebastian Grimberg   } else {
108c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
109c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
110c4e3f59bSSebastian Grimberg     }
111c4e3f59bSSebastian Grimberg   }
112c4e3f59bSSebastian Grimberg   return CEED_ERROR_SUCCESS;
113c4e3f59bSSebastian Grimberg }
114c4e3f59bSSebastian Grimberg 
115c4e3f59bSSebastian Grimberg /**
1162f86a920SJeremy L Thompson   @brief Get Ceed associated with a CeedTensorContract
1172f86a920SJeremy L Thompson 
118ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract
1192f86a920SJeremy L Thompson   @param[out] ceed     Variable to store Ceed
1202f86a920SJeremy L Thompson 
1212f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1222f86a920SJeremy L Thompson 
1237a982d89SJeremy L. Thompson   @ref Backend
1242f86a920SJeremy L Thompson **/
1252f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
1262f86a920SJeremy L Thompson   *ceed = contract->ceed;
127e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1287a982d89SJeremy L. Thompson }
1292f86a920SJeremy L Thompson 
1302f86a920SJeremy L Thompson /**
1312f86a920SJeremy L Thompson   @brief Get backend data of a CeedTensorContract
1322f86a920SJeremy L Thompson 
133ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract
1342f86a920SJeremy L Thompson   @param[out] data     Variable to store data
1352f86a920SJeremy L Thompson 
1362f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1372f86a920SJeremy L Thompson 
1387a982d89SJeremy L. Thompson   @ref Backend
1392f86a920SJeremy L Thompson **/
140777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
141777ff853SJeremy L Thompson   *(void **)data = contract->data;
142e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1432f86a920SJeremy L Thompson }
1442f86a920SJeremy L Thompson 
1452f86a920SJeremy L Thompson /**
1462f86a920SJeremy L Thompson   @brief Set backend data of a CeedTensorContract
1472f86a920SJeremy L Thompson 
148ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract
149ea61e9acSJeremy L Thompson   @param[in]     data     Data to set
1502f86a920SJeremy L Thompson 
1512f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1522f86a920SJeremy L Thompson 
1537a982d89SJeremy L. Thompson   @ref Backend
1542f86a920SJeremy L Thompson **/
155777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
156777ff853SJeremy L Thompson   contract->data = data;
157e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1582f86a920SJeremy L Thompson }
1592f86a920SJeremy L Thompson 
1602f86a920SJeremy L Thompson /**
16134359f16Sjeremylt   @brief Increment the reference counter for a CeedTensorContract
16234359f16Sjeremylt 
163ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract to increment the reference counter
16434359f16Sjeremylt 
16534359f16Sjeremylt   @return An error code: 0 - success, otherwise - failure
16634359f16Sjeremylt 
16734359f16Sjeremylt   @ref Backend
16834359f16Sjeremylt **/
1699560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) {
17034359f16Sjeremylt   contract->ref_count++;
17134359f16Sjeremylt   return CEED_ERROR_SUCCESS;
17234359f16Sjeremylt }
17334359f16Sjeremylt 
17434359f16Sjeremylt /**
175585a562dSJeremy L Thompson   @brief Copy the pointer to a CeedTensorContract.
176585a562dSJeremy L Thompson 
177585a562dSJeremy L Thompson   Both pointers should be destroyed with `CeedTensorContractDestroy()`.
178585a562dSJeremy L Thompson 
179585a562dSJeremy L Thompson   Note: If the value of `tensor_copy` passed to this function is non-NULL, then it is assumed that `tensor_copy` is a pointer to a CeedTensorContract.
180585a562dSJeremy L Thompson         This CeedTensorContract will be destroyed if `tensor_copy` is the only reference to this CeedVector.
181585a562dSJeremy L Thompson 
182585a562dSJeremy L Thompson   @param[in]     tensor      CeedTensorContract to copy reference to
183585a562dSJeremy L Thompson   @param[in,out] tensor_copy Variable to store copied reference
184585a562dSJeremy L Thompson 
185585a562dSJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
186585a562dSJeremy L Thompson 
187585a562dSJeremy L Thompson   @ref User
188585a562dSJeremy L Thompson **/
189585a562dSJeremy L Thompson int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
190585a562dSJeremy L Thompson   CeedCall(CeedTensorContractReference(tensor));
191585a562dSJeremy L Thompson   CeedCall(CeedTensorContractDestroy(tensor_copy));
192585a562dSJeremy L Thompson   *tensor_copy = tensor;
193585a562dSJeremy L Thompson   return CEED_ERROR_SUCCESS;
194585a562dSJeremy L Thompson }
195585a562dSJeremy L Thompson 
196585a562dSJeremy L Thompson /**
1972f86a920SJeremy L Thompson   @brief Destroy a CeedTensorContract
1982f86a920SJeremy L Thompson 
199ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract to destroy
2002f86a920SJeremy L Thompson 
2012f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
2022f86a920SJeremy L Thompson 
2037a982d89SJeremy L. Thompson   @ref Backend
2042f86a920SJeremy L Thompson **/
2052f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) {
206ad6481ceSJeremy L Thompson   if (!*contract || --(*contract)->ref_count > 0) {
207ad6481ceSJeremy L Thompson     *contract = NULL;
208ad6481ceSJeremy L Thompson     return CEED_ERROR_SUCCESS;
209ad6481ceSJeremy L Thompson   }
2102f86a920SJeremy L Thompson   if ((*contract)->Destroy) {
2112b730f8bSJeremy L Thompson     CeedCall((*contract)->Destroy(*contract));
2122f86a920SJeremy L Thompson   }
2132b730f8bSJeremy L Thompson   CeedCall(CeedDestroy(&(*contract)->ceed));
2142b730f8bSJeremy L Thompson   CeedCall(CeedFree(contract));
215e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2167a982d89SJeremy L. Thompson }
2172f86a920SJeremy L Thompson 
2182f86a920SJeremy L Thompson /// @}
219