xref: /libCEED/interface/ceed-tensor.c (revision f2989f2b3b8649d855bed22b6730ee0ecfa6b31b)
19ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
32f86a920SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
52f86a920SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
72f86a920SJeremy L Thompson 
83d576824SJeremy L Thompson #include <ceed-impl.h>
949aac155SJeremy L Thompson #include <ceed.h>
102b730f8bSJeremy L Thompson #include <ceed/backend.h>
1149aac155SJeremy L Thompson #include <stddef.h>
122f86a920SJeremy L Thompson 
132f86a920SJeremy L Thompson /// @file
147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces
157a982d89SJeremy L. Thompson 
167a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
17*6c328a79SJeremy L Thompson /// CeedTensorContract Library Internal Functions
18*6c328a79SJeremy L Thompson /// ----------------------------------------------------------------------------
19*6c328a79SJeremy L Thompson /// @addtogroup CeedTensorContractDeveloper
20*6c328a79SJeremy L Thompson /// @{
21*6c328a79SJeremy L Thompson 
22*6c328a79SJeremy L Thompson /**
23*6c328a79SJeremy L Thompson   @brief Destroy a `CeedTensorContract` passed as a `CeedObject`
24*6c328a79SJeremy L Thompson 
25*6c328a79SJeremy L Thompson   @param[in,out] contract Address of `CeedTensorContract` to destroy
26*6c328a79SJeremy L Thompson 
27*6c328a79SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
28*6c328a79SJeremy L Thompson 
29*6c328a79SJeremy L Thompson   @ref Developer
30*6c328a79SJeremy L Thompson **/
CeedTensorContractDestroy_Object(CeedObject * contract)31*6c328a79SJeremy L Thompson static int CeedTensorContractDestroy_Object(CeedObject *contract) {
32*6c328a79SJeremy L Thompson   CeedCall(CeedTensorContractDestroy((CeedTensorContract *)contract));
33*6c328a79SJeremy L Thompson   return CEED_ERROR_SUCCESS;
34*6c328a79SJeremy L Thompson }
35*6c328a79SJeremy L Thompson 
36*6c328a79SJeremy L Thompson /// @}
37*6c328a79SJeremy L Thompson 
38*6c328a79SJeremy L Thompson /// ----------------------------------------------------------------------------
397a982d89SJeremy L. Thompson /// CeedTensorContract Backend API
407a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
417a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend
422f86a920SJeremy L Thompson /// @{
432f86a920SJeremy L Thompson 
442f86a920SJeremy L Thompson /**
45ca94c3ddSJeremy L Thompson   @brief Create a `CeedTensorContract` object for a `CeedBasis`
462f86a920SJeremy L Thompson 
47ca94c3ddSJeremy L Thompson   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
48ca94c3ddSJeremy L Thompson   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
492f86a920SJeremy L Thompson 
502f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
512f86a920SJeremy L Thompson 
527a982d89SJeremy L. Thompson   @ref Backend
532f86a920SJeremy L Thompson **/
CeedTensorContractCreate(Ceed ceed,CeedTensorContract * contract)54a71faab1SSebastian Grimberg int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
552f86a920SJeremy L Thompson   if (!ceed->TensorContractCreate) {
562f86a920SJeremy L Thompson     Ceed delegate;
576574a04fSJeremy L Thompson 
582b730f8bSJeremy L Thompson     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
591ef3a2a9SJeremy L Thompson     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
60a71faab1SSebastian Grimberg     CeedCall(CeedTensorContractCreate(delegate, contract));
619bc66399SJeremy L Thompson     CeedCall(CeedDestroy(&delegate));
62e15f9bd0SJeremy L Thompson     return CEED_ERROR_SUCCESS;
632f86a920SJeremy L Thompson   }
642f86a920SJeremy L Thompson 
652b730f8bSJeremy L Thompson   CeedCall(CeedCalloc(1, contract));
66*6c328a79SJeremy L Thompson   CeedCall(CeedObjectCreate(ceed, NULL, CeedTensorContractDestroy_Object, &(*contract)->obj));
67a71faab1SSebastian Grimberg   CeedCall(ceed->TensorContractCreate(*contract));
68e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
697a982d89SJeremy L. Thompson }
702f86a920SJeremy L Thompson 
712f86a920SJeremy L Thompson /**
722f86a920SJeremy L Thompson   @brief Apply tensor contraction
732f86a920SJeremy L Thompson 
742f86a920SJeremy L Thompson   Contracts on the middle index
75ca94c3ddSJeremy L Thompson   NOTRANSPOSE: `v_ajc = t_jb u_abc`
76ca94c3ddSJeremy L Thompson   TRANSPOSE:   `v_ajc = t_bj u_abc`
77ca94c3ddSJeremy L Thompson   If `add != 0`, `=` is replaced by `+=`
782f86a920SJeremy L Thompson 
79ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract` to use
80ca94c3ddSJeremy L Thompson   @param[in]  A        First index of `u`, `v`
81ca94c3ddSJeremy L Thompson   @param[in]  B        Middle index of `u`, one index of `t`
82ca94c3ddSJeremy L Thompson   @param[in]  C        Last index of `u`, `v`
83ca94c3ddSJeremy L Thompson   @param[in]  J        Middle index of `v`, one index of `t`
842f86a920SJeremy L Thompson   @param[in]  t        Tensor array to contract against
85ca94c3ddSJeremy L Thompson   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
86ea61e9acSJeremy L Thompson   @param[in]  add      Add mode
872f86a920SJeremy L Thompson   @param[in]  u        Input array
882f86a920SJeremy L Thompson   @param[out] v        Output array
892f86a920SJeremy L Thompson 
902f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
912f86a920SJeremy L Thompson 
927a982d89SJeremy L. Thompson   @ref Backend
932f86a920SJeremy L Thompson **/
CeedTensorContractApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)942b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
952b730f8bSJeremy L Thompson                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
962b730f8bSJeremy L Thompson   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
97e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
987a982d89SJeremy L. Thompson }
992f86a920SJeremy L Thompson 
1002f86a920SJeremy L Thompson /**
101c4e3f59bSSebastian Grimberg   @brief Apply tensor contraction
102c4e3f59bSSebastian Grimberg 
103c4e3f59bSSebastian Grimberg   Contracts on the middle index
104ca94c3ddSJeremy L Thompson   NOTRANSPOSE: `v_dajc = t_djb u_abc`
105ca94c3ddSJeremy L Thompson   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
106ca94c3ddSJeremy L Thompson   If `add != 0`, `=` is replaced by `+=`
107c4e3f59bSSebastian Grimberg 
108ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract` to use
109ca94c3ddSJeremy L Thompson   @param[in]  A        First index of `u`, second index of `v`
110ca94c3ddSJeremy L Thompson   @param[in]  B        Middle index of `u`, one of last two indices of `t`
111ca94c3ddSJeremy L Thompson   @param[in]  C        Last index of `u`, `v`
112ca94c3ddSJeremy L Thompson   @param[in]  D        First index of `v`, first index of `t`
113ca94c3ddSJeremy L Thompson   @param[in]  J        Third index of `v`, one of last two indices of `t`
114c4e3f59bSSebastian Grimberg   @param[in]  t        Tensor array to contract against
115ca94c3ddSJeremy L Thompson   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
116c4e3f59bSSebastian Grimberg   @param[in]  add      Add mode
117c4e3f59bSSebastian Grimberg   @param[in]  u        Input array
118c4e3f59bSSebastian Grimberg   @param[out] v        Output array
119c4e3f59bSSebastian Grimberg 
120c4e3f59bSSebastian Grimberg   @return An error code: 0 - success, otherwise - failure
121c4e3f59bSSebastian Grimberg 
122c4e3f59bSSebastian Grimberg   @ref Backend
123c4e3f59bSSebastian Grimberg **/
CeedTensorContractStridedApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt D,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)124c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
125c4e3f59bSSebastian Grimberg                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
126c4e3f59bSSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
127c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
128c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
129c4e3f59bSSebastian Grimberg     }
130c4e3f59bSSebastian Grimberg   } else {
131c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
132c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
133c4e3f59bSSebastian Grimberg     }
134c4e3f59bSSebastian Grimberg   }
135c4e3f59bSSebastian Grimberg   return CEED_ERROR_SUCCESS;
136c4e3f59bSSebastian Grimberg }
137c4e3f59bSSebastian Grimberg 
138c4e3f59bSSebastian Grimberg /**
1396e536b99SJeremy L Thompson   @brief Get the `Ceed` associated with a `CeedTensorContract`
1402f86a920SJeremy L Thompson 
141ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract`
142ca94c3ddSJeremy L Thompson   @param[out] ceed     Variable to store `Ceed`
1432f86a920SJeremy L Thompson 
1442f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1452f86a920SJeremy L Thompson 
1467a982d89SJeremy L. Thompson   @ref Backend
1472f86a920SJeremy L Thompson **/
CeedTensorContractGetCeed(CeedTensorContract contract,Ceed * ceed)1482f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
149b0f67a9cSJeremy L Thompson   CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed));
150e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1517a982d89SJeremy L. Thompson }
1522f86a920SJeremy L Thompson 
1532f86a920SJeremy L Thompson /**
1546e536b99SJeremy L Thompson   @brief Return the `Ceed` associated with a `CeedTensorContract`
1556e536b99SJeremy L Thompson 
1566e536b99SJeremy L Thompson   @param[in]  contract `CeedTensorContract`
1576e536b99SJeremy L Thompson 
1586e536b99SJeremy L Thompson   @return `Ceed` associated with `contract`
1596e536b99SJeremy L Thompson 
1606e536b99SJeremy L Thompson   @ref Backend
1616e536b99SJeremy L Thompson **/
CeedTensorContractReturnCeed(CeedTensorContract contract)162b0f67a9cSJeremy L Thompson Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); }
1636e536b99SJeremy L Thompson 
1646e536b99SJeremy L Thompson /**
165ca94c3ddSJeremy L Thompson   @brief Get backend data of a `CeedTensorContract`
1662f86a920SJeremy L Thompson 
167ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract`
1682f86a920SJeremy L Thompson   @param[out] data     Variable to store data
1692f86a920SJeremy L Thompson 
1702f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1712f86a920SJeremy L Thompson 
1727a982d89SJeremy L. Thompson   @ref Backend
1732f86a920SJeremy L Thompson **/
CeedTensorContractGetData(CeedTensorContract contract,void * data)174777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
175777ff853SJeremy L Thompson   *(void **)data = contract->data;
176e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1772f86a920SJeremy L Thompson }
1782f86a920SJeremy L Thompson 
1792f86a920SJeremy L Thompson /**
180ca94c3ddSJeremy L Thompson   @brief Set backend data of a `CeedTensorContract`
1812f86a920SJeremy L Thompson 
182ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract`
183ea61e9acSJeremy L Thompson   @param[in]     data     Data to set
1842f86a920SJeremy L Thompson 
1852f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1862f86a920SJeremy L Thompson 
1877a982d89SJeremy L. Thompson   @ref Backend
1882f86a920SJeremy L Thompson **/
CeedTensorContractSetData(CeedTensorContract contract,void * data)189777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
190777ff853SJeremy L Thompson   contract->data = data;
191e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1922f86a920SJeremy L Thompson }
1932f86a920SJeremy L Thompson 
1942f86a920SJeremy L Thompson /**
195ca94c3ddSJeremy L Thompson   @brief Increment the reference counter for a `CeedTensorContract`
19634359f16Sjeremylt 
197ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract` to increment the reference counter
19834359f16Sjeremylt 
19934359f16Sjeremylt   @return An error code: 0 - success, otherwise - failure
20034359f16Sjeremylt 
20134359f16Sjeremylt   @ref Backend
20234359f16Sjeremylt **/
CeedTensorContractReference(CeedTensorContract contract)2039560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) {
204b0f67a9cSJeremy L Thompson   CeedCall(CeedObjectReference((CeedObject)contract));
20534359f16Sjeremylt   return CEED_ERROR_SUCCESS;
20634359f16Sjeremylt }
20734359f16Sjeremylt 
20834359f16Sjeremylt /**
209ca94c3ddSJeremy L Thompson   @brief Copy the pointer to a `CeedTensorContract`.
210585a562dSJeremy L Thompson 
211ca94c3ddSJeremy L Thompson   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
212585a562dSJeremy L Thompson 
213ca94c3ddSJeremy L Thompson   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
214ca94c3ddSJeremy L Thompson         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
215585a562dSJeremy L Thompson 
216ca94c3ddSJeremy L Thompson   @param[in]     tensor      `CeedTensorContract` to copy reference to
217585a562dSJeremy L Thompson   @param[in,out] tensor_copy Variable to store copied reference
218585a562dSJeremy L Thompson 
219585a562dSJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
220585a562dSJeremy L Thompson 
221585a562dSJeremy L Thompson   @ref User
222585a562dSJeremy L Thompson **/
CeedTensorContractReferenceCopy(CeedTensorContract tensor,CeedTensorContract * tensor_copy)223585a562dSJeremy L Thompson int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
224585a562dSJeremy L Thompson   CeedCall(CeedTensorContractReference(tensor));
225585a562dSJeremy L Thompson   CeedCall(CeedTensorContractDestroy(tensor_copy));
226585a562dSJeremy L Thompson   *tensor_copy = tensor;
227585a562dSJeremy L Thompson   return CEED_ERROR_SUCCESS;
228585a562dSJeremy L Thompson }
229585a562dSJeremy L Thompson 
230585a562dSJeremy L Thompson /**
231ca94c3ddSJeremy L Thompson   @brief Destroy a `CeedTensorContract`
2322f86a920SJeremy L Thompson 
233ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract` to destroy
2342f86a920SJeremy L Thompson 
2352f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
2362f86a920SJeremy L Thompson 
2377a982d89SJeremy L. Thompson   @ref Backend
2382f86a920SJeremy L Thompson **/
CeedTensorContractDestroy(CeedTensorContract * contract)2392f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) {
240b0f67a9cSJeremy L Thompson   if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) {
241ad6481ceSJeremy L Thompson     *contract = NULL;
242ad6481ceSJeremy L Thompson     return CEED_ERROR_SUCCESS;
243ad6481ceSJeremy L Thompson   }
2442f86a920SJeremy L Thompson   if ((*contract)->Destroy) {
2452b730f8bSJeremy L Thompson     CeedCall((*contract)->Destroy(*contract));
2462f86a920SJeremy L Thompson   }
247*6c328a79SJeremy L Thompson   CeedCall(CeedObjectDestroy_Private(&(*contract)->obj));
2482b730f8bSJeremy L Thompson   CeedCall(CeedFree(contract));
249e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2507a982d89SJeremy L. Thompson }
2512f86a920SJeremy L Thompson 
2522f86a920SJeremy L Thompson /// @}
253