xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision 9bc663991d6482bcb1d60b1f116148f11db83fa1)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
32f86a920SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
52f86a920SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
72f86a920SJeremy L Thompson 
83d576824SJeremy L Thompson #include <ceed-impl.h>
949aac155SJeremy L Thompson #include <ceed.h>
102b730f8bSJeremy L Thompson #include <ceed/backend.h>
1149aac155SJeremy L Thompson #include <stddef.h>
122f86a920SJeremy L Thompson 
132f86a920SJeremy L Thompson /// @file
147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces
157a982d89SJeremy L. Thompson 
167a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
177a982d89SJeremy L. Thompson /// CeedTensorContract Backend API
187a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
197a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend
202f86a920SJeremy L Thompson /// @{
212f86a920SJeremy L Thompson 
222f86a920SJeremy L Thompson /**
23ca94c3ddSJeremy L Thompson   @brief Create a `CeedTensorContract` object for a `CeedBasis`
242f86a920SJeremy L Thompson 
25ca94c3ddSJeremy L Thompson   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
26ca94c3ddSJeremy L Thompson   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
272f86a920SJeremy L Thompson 
282f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
292f86a920SJeremy L Thompson 
307a982d89SJeremy L. Thompson   @ref Backend
312f86a920SJeremy L Thompson **/
32a71faab1SSebastian Grimberg int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
332f86a920SJeremy L Thompson   if (!ceed->TensorContractCreate) {
342f86a920SJeremy L Thompson     Ceed delegate;
356574a04fSJeremy L Thompson 
362b730f8bSJeremy L Thompson     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
371ef3a2a9SJeremy L Thompson     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
38a71faab1SSebastian Grimberg     CeedCall(CeedTensorContractCreate(delegate, contract));
39*9bc66399SJeremy L Thompson     CeedCall(CeedDestroy(&delegate));
40e15f9bd0SJeremy L Thompson     return CEED_ERROR_SUCCESS;
412f86a920SJeremy L Thompson   }
422f86a920SJeremy L Thompson 
432b730f8bSJeremy L Thompson   CeedCall(CeedCalloc(1, contract));
44db002c03SJeremy L Thompson   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
45a71faab1SSebastian Grimberg   CeedCall(ceed->TensorContractCreate(*contract));
46e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
477a982d89SJeremy L. Thompson }
482f86a920SJeremy L Thompson 
492f86a920SJeremy L Thompson /**
502f86a920SJeremy L Thompson   @brief Apply tensor contraction
512f86a920SJeremy L Thompson 
522f86a920SJeremy L Thompson   Contracts on the middle index
53ca94c3ddSJeremy L Thompson   NOTRANSPOSE: `v_ajc = t_jb u_abc`
54ca94c3ddSJeremy L Thompson   TRANSPOSE:   `v_ajc = t_bj u_abc`
55ca94c3ddSJeremy L Thompson   If `add != 0`, `=` is replaced by `+=`
562f86a920SJeremy L Thompson 
57ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract` to use
58ca94c3ddSJeremy L Thompson   @param[in]  A        First index of `u`, `v`
59ca94c3ddSJeremy L Thompson   @param[in]  B        Middle index of `u`, one index of `t`
60ca94c3ddSJeremy L Thompson   @param[in]  C        Last index of `u`, `v`
61ca94c3ddSJeremy L Thompson   @param[in]  J        Middle index of `v`, one index of `t`
622f86a920SJeremy L Thompson   @param[in]  t        Tensor array to contract against
63ca94c3ddSJeremy L Thompson   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
64ea61e9acSJeremy L Thompson   @param[in]  add      Add mode
652f86a920SJeremy L Thompson   @param[in]  u        Input array
662f86a920SJeremy L Thompson   @param[out] v        Output array
672f86a920SJeremy L Thompson 
682f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
692f86a920SJeremy L Thompson 
707a982d89SJeremy L. Thompson   @ref Backend
712f86a920SJeremy L Thompson **/
722b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
732b730f8bSJeremy L Thompson                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
742b730f8bSJeremy L Thompson   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
75e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
767a982d89SJeremy L. Thompson }
772f86a920SJeremy L Thompson 
782f86a920SJeremy L Thompson /**
79c4e3f59bSSebastian Grimberg   @brief Apply tensor contraction
80c4e3f59bSSebastian Grimberg 
81c4e3f59bSSebastian Grimberg   Contracts on the middle index
82ca94c3ddSJeremy L Thompson   NOTRANSPOSE: `v_dajc = t_djb u_abc`
83ca94c3ddSJeremy L Thompson   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
84ca94c3ddSJeremy L Thompson   If `add != 0`, `=` is replaced by `+=`
85c4e3f59bSSebastian Grimberg 
86ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract` to use
87ca94c3ddSJeremy L Thompson   @param[in]  A        First index of `u`, second index of `v`
88ca94c3ddSJeremy L Thompson   @param[in]  B        Middle index of `u`, one of last two indices of `t`
89ca94c3ddSJeremy L Thompson   @param[in]  C        Last index of `u`, `v`
90ca94c3ddSJeremy L Thompson   @param[in]  D        First index of `v`, first index of `t`
91ca94c3ddSJeremy L Thompson   @param[in]  J        Third index of `v`, one of last two indices of `t`
92c4e3f59bSSebastian Grimberg   @param[in]  t        Tensor array to contract against
93ca94c3ddSJeremy L Thompson   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
94c4e3f59bSSebastian Grimberg   @param[in]  add      Add mode
95c4e3f59bSSebastian Grimberg   @param[in]  u        Input array
96c4e3f59bSSebastian Grimberg   @param[out] v        Output array
97c4e3f59bSSebastian Grimberg 
98c4e3f59bSSebastian Grimberg   @return An error code: 0 - success, otherwise - failure
99c4e3f59bSSebastian Grimberg 
100c4e3f59bSSebastian Grimberg   @ref Backend
101c4e3f59bSSebastian Grimberg **/
102c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
103c4e3f59bSSebastian Grimberg                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
104c4e3f59bSSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
105c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
106c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
107c4e3f59bSSebastian Grimberg     }
108c4e3f59bSSebastian Grimberg   } else {
109c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
110c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
111c4e3f59bSSebastian Grimberg     }
112c4e3f59bSSebastian Grimberg   }
113c4e3f59bSSebastian Grimberg   return CEED_ERROR_SUCCESS;
114c4e3f59bSSebastian Grimberg }
115c4e3f59bSSebastian Grimberg 
116c4e3f59bSSebastian Grimberg /**
1176e536b99SJeremy L Thompson   @brief Get the `Ceed` associated with a `CeedTensorContract`
1182f86a920SJeremy L Thompson 
119ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract`
120ca94c3ddSJeremy L Thompson   @param[out] ceed     Variable to store `Ceed`
1212f86a920SJeremy L Thompson 
1222f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1232f86a920SJeremy L Thompson 
1247a982d89SJeremy L. Thompson   @ref Backend
1252f86a920SJeremy L Thompson **/
1262f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
127*9bc66399SJeremy L Thompson   *ceed = NULL;
128*9bc66399SJeremy L Thompson   CeedCall(CeedReferenceCopy(CeedTensorContractReturnCeed(contract), ceed));
129e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1307a982d89SJeremy L. Thompson }
1312f86a920SJeremy L Thompson 
1322f86a920SJeremy L Thompson /**
1336e536b99SJeremy L Thompson   @brief Return the `Ceed` associated with a `CeedTensorContract`
1346e536b99SJeremy L Thompson 
1356e536b99SJeremy L Thompson   @param[in]  contract `CeedTensorContract`
1366e536b99SJeremy L Thompson 
1376e536b99SJeremy L Thompson   @return `Ceed` associated with `contract`
1386e536b99SJeremy L Thompson 
1396e536b99SJeremy L Thompson   @ref Backend
1406e536b99SJeremy L Thompson **/
1416e536b99SJeremy L Thompson Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return contract->ceed; }
1426e536b99SJeremy L Thompson 
1436e536b99SJeremy L Thompson /**
144ca94c3ddSJeremy L Thompson   @brief Get backend data of a `CeedTensorContract`
1452f86a920SJeremy L Thompson 
146ca94c3ddSJeremy L Thompson   @param[in]  contract `CeedTensorContract`
1472f86a920SJeremy L Thompson   @param[out] data     Variable to store data
1482f86a920SJeremy L Thompson 
1492f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1502f86a920SJeremy L Thompson 
1517a982d89SJeremy L. Thompson   @ref Backend
1522f86a920SJeremy L Thompson **/
153777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
154777ff853SJeremy L Thompson   *(void **)data = contract->data;
155e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1562f86a920SJeremy L Thompson }
1572f86a920SJeremy L Thompson 
1582f86a920SJeremy L Thompson /**
159ca94c3ddSJeremy L Thompson   @brief Set backend data of a `CeedTensorContract`
1602f86a920SJeremy L Thompson 
161ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract`
162ea61e9acSJeremy L Thompson   @param[in]     data     Data to set
1632f86a920SJeremy L Thompson 
1642f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1652f86a920SJeremy L Thompson 
1667a982d89SJeremy L. Thompson   @ref Backend
1672f86a920SJeremy L Thompson **/
168777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
169777ff853SJeremy L Thompson   contract->data = data;
170e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1712f86a920SJeremy L Thompson }
1722f86a920SJeremy L Thompson 
1732f86a920SJeremy L Thompson /**
174ca94c3ddSJeremy L Thompson   @brief Increment the reference counter for a `CeedTensorContract`
17534359f16Sjeremylt 
176ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract` to increment the reference counter
17734359f16Sjeremylt 
17834359f16Sjeremylt   @return An error code: 0 - success, otherwise - failure
17934359f16Sjeremylt 
18034359f16Sjeremylt   @ref Backend
18134359f16Sjeremylt **/
1829560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) {
18334359f16Sjeremylt   contract->ref_count++;
18434359f16Sjeremylt   return CEED_ERROR_SUCCESS;
18534359f16Sjeremylt }
18634359f16Sjeremylt 
18734359f16Sjeremylt /**
188ca94c3ddSJeremy L Thompson   @brief Copy the pointer to a `CeedTensorContract`.
189585a562dSJeremy L Thompson 
190ca94c3ddSJeremy L Thompson   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
191585a562dSJeremy L Thompson 
192ca94c3ddSJeremy L Thompson   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
193ca94c3ddSJeremy L Thompson         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
194585a562dSJeremy L Thompson 
195ca94c3ddSJeremy L Thompson   @param[in]     tensor      `CeedTensorContract` to copy reference to
196585a562dSJeremy L Thompson   @param[in,out] tensor_copy Variable to store copied reference
197585a562dSJeremy L Thompson 
198585a562dSJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
199585a562dSJeremy L Thompson 
200585a562dSJeremy L Thompson   @ref User
201585a562dSJeremy L Thompson **/
202585a562dSJeremy L Thompson int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
203585a562dSJeremy L Thompson   CeedCall(CeedTensorContractReference(tensor));
204585a562dSJeremy L Thompson   CeedCall(CeedTensorContractDestroy(tensor_copy));
205585a562dSJeremy L Thompson   *tensor_copy = tensor;
206585a562dSJeremy L Thompson   return CEED_ERROR_SUCCESS;
207585a562dSJeremy L Thompson }
208585a562dSJeremy L Thompson 
209585a562dSJeremy L Thompson /**
210ca94c3ddSJeremy L Thompson   @brief Destroy a `CeedTensorContract`
2112f86a920SJeremy L Thompson 
212ca94c3ddSJeremy L Thompson   @param[in,out] contract `CeedTensorContract` to destroy
2132f86a920SJeremy L Thompson 
2142f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
2152f86a920SJeremy L Thompson 
2167a982d89SJeremy L. Thompson   @ref Backend
2172f86a920SJeremy L Thompson **/
2182f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) {
219ad6481ceSJeremy L Thompson   if (!*contract || --(*contract)->ref_count > 0) {
220ad6481ceSJeremy L Thompson     *contract = NULL;
221ad6481ceSJeremy L Thompson     return CEED_ERROR_SUCCESS;
222ad6481ceSJeremy L Thompson   }
2232f86a920SJeremy L Thompson   if ((*contract)->Destroy) {
2242b730f8bSJeremy L Thompson     CeedCall((*contract)->Destroy(*contract));
2252f86a920SJeremy L Thompson   }
2262b730f8bSJeremy L Thompson   CeedCall(CeedDestroy(&(*contract)->ceed));
2272b730f8bSJeremy L Thompson   CeedCall(CeedFree(contract));
228e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2297a982d89SJeremy L. Thompson }
2302f86a920SJeremy L Thompson 
2312f86a920SJeremy L Thompson /// @}
232