xref: /libCEED/interface/ceed-tensor.c (revision c4e3f59b2ea5a0c95cc0118aa5026c447cce3092)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
32f86a920SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
52f86a920SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
72f86a920SJeremy L Thompson 
83d576824SJeremy L Thompson #include <ceed-impl.h>
949aac155SJeremy L Thompson #include <ceed.h>
102b730f8bSJeremy L Thompson #include <ceed/backend.h>
1149aac155SJeremy L Thompson #include <stddef.h>
122f86a920SJeremy L Thompson 
132f86a920SJeremy L Thompson /// @file
147a982d89SJeremy L. Thompson /// Implementation of CeedTensorContract interfaces
157a982d89SJeremy L. Thompson 
167a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
177a982d89SJeremy L. Thompson /// CeedTensorContract Backend API
187a982d89SJeremy L. Thompson /// ----------------------------------------------------------------------------
197a982d89SJeremy L. Thompson /// @addtogroup CeedBasisBackend
202f86a920SJeremy L Thompson /// @{
212f86a920SJeremy L Thompson 
222f86a920SJeremy L Thompson /**
232f86a920SJeremy L Thompson   @brief Create a CeedTensorContract object for a CeedBasis
242f86a920SJeremy L Thompson 
25ea61e9acSJeremy L Thompson   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26ea61e9acSJeremy L Thompson   @param[in]  basis    CeedBasis for which the tensor contraction will be used
27ea61e9acSJeremy L Thompson   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
282f86a920SJeremy L Thompson 
292f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
302f86a920SJeremy L Thompson 
317a982d89SJeremy L. Thompson   @ref Backend
322f86a920SJeremy L Thompson **/
332b730f8bSJeremy L Thompson int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
342f86a920SJeremy L Thompson   if (!ceed->TensorContractCreate) {
352f86a920SJeremy L Thompson     Ceed delegate;
362b730f8bSJeremy L Thompson     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
372f86a920SJeremy L Thompson 
382b730f8bSJeremy L Thompson     if (!delegate) {
39c042f62fSJeremy L Thompson       // LCOV_EXCL_START
402b730f8bSJeremy L Thompson       return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
41c042f62fSJeremy L Thompson       // LCOV_EXCL_STOP
422b730f8bSJeremy L Thompson     }
432f86a920SJeremy L Thompson 
442b730f8bSJeremy L Thompson     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
45e15f9bd0SJeremy L Thompson     return CEED_ERROR_SUCCESS;
462f86a920SJeremy L Thompson   }
472f86a920SJeremy L Thompson 
482b730f8bSJeremy L Thompson   CeedCall(CeedCalloc(1, contract));
492f86a920SJeremy L Thompson 
502f86a920SJeremy L Thompson   (*contract)->ceed = ceed;
512b730f8bSJeremy L Thompson   CeedCall(CeedReference(ceed));
522b730f8bSJeremy L Thompson   CeedCall(ceed->TensorContractCreate(basis, *contract));
53e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
547a982d89SJeremy L. Thompson }
552f86a920SJeremy L Thompson 
562f86a920SJeremy L Thompson /**
572f86a920SJeremy L Thompson   @brief Apply tensor contraction
582f86a920SJeremy L Thompson 
592f86a920SJeremy L Thompson     Contracts on the middle index
602f86a920SJeremy L Thompson     NOTRANSPOSE: v_ajc = t_jb u_abc
612f86a920SJeremy L Thompson     TRANSPOSE:   v_ajc = t_bj u_abc
622f86a920SJeremy L Thompson     If add != 0, "=" is replaced by "+="
632f86a920SJeremy L Thompson 
64ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract to use
65ea61e9acSJeremy L Thompson   @param[in]  A        First index of u, v
66ea61e9acSJeremy L Thompson   @param[in]  B        Middle index of u, one index of t
67ea61e9acSJeremy L Thompson   @param[in]  C        Last index of u, v
68ea61e9acSJeremy L Thompson   @param[in]  J        Middle index of v, one index of t
692f86a920SJeremy L Thompson   @param[in]  t        Tensor array to contract against
70ea61e9acSJeremy L Thompson   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
71ea61e9acSJeremy L Thompson   @param[in]  add      Add mode
722f86a920SJeremy L Thompson   @param[in]  u        Input array
732f86a920SJeremy L Thompson   @param[out] v        Output array
742f86a920SJeremy L Thompson 
752f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
762f86a920SJeremy L Thompson 
777a982d89SJeremy L. Thompson   @ref Backend
782f86a920SJeremy L Thompson **/
792b730f8bSJeremy L Thompson int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
802b730f8bSJeremy L Thompson                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
812b730f8bSJeremy L Thompson   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
82e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
837a982d89SJeremy L. Thompson }
842f86a920SJeremy L Thompson 
852f86a920SJeremy L Thompson /**
86*c4e3f59bSSebastian Grimberg   @brief Apply tensor contraction
87*c4e3f59bSSebastian Grimberg 
88*c4e3f59bSSebastian Grimberg     Contracts on the middle index
89*c4e3f59bSSebastian Grimberg     NOTRANSPOSE: v_dajc = t_djb u_abc
90*c4e3f59bSSebastian Grimberg     TRANSPOSE:   v_ajc  = t_dbj u_dabc
91*c4e3f59bSSebastian Grimberg     If add != 0, "=" is replaced by "+="
92*c4e3f59bSSebastian Grimberg 
93*c4e3f59bSSebastian Grimberg   @param[in]  contract CeedTensorContract to use
94*c4e3f59bSSebastian Grimberg   @param[in]  A        First index of u, second index of v
95*c4e3f59bSSebastian Grimberg   @param[in]  B        Middle index of u, one of last two indices of t
96*c4e3f59bSSebastian Grimberg   @param[in]  C        Last index of u, v
97*c4e3f59bSSebastian Grimberg   @param[in]  D        First index of v, first index of t
98*c4e3f59bSSebastian Grimberg   @param[in]  J        Third index of v, one of last two indices of t
99*c4e3f59bSSebastian Grimberg   @param[in]  t        Tensor array to contract against
100*c4e3f59bSSebastian Grimberg   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
101*c4e3f59bSSebastian Grimberg   @param[in]  add      Add mode
102*c4e3f59bSSebastian Grimberg   @param[in]  u        Input array
103*c4e3f59bSSebastian Grimberg   @param[out] v        Output array
104*c4e3f59bSSebastian Grimberg 
105*c4e3f59bSSebastian Grimberg   @return An error code: 0 - success, otherwise - failure
106*c4e3f59bSSebastian Grimberg 
107*c4e3f59bSSebastian Grimberg   @ref Backend
108*c4e3f59bSSebastian Grimberg **/
109*c4e3f59bSSebastian Grimberg int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
110*c4e3f59bSSebastian Grimberg                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
111*c4e3f59bSSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
112*c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
113*c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
114*c4e3f59bSSebastian Grimberg     }
115*c4e3f59bSSebastian Grimberg   } else {
116*c4e3f59bSSebastian Grimberg     for (CeedInt d = 0; d < D; d++) {
117*c4e3f59bSSebastian Grimberg       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
118*c4e3f59bSSebastian Grimberg     }
119*c4e3f59bSSebastian Grimberg   }
120*c4e3f59bSSebastian Grimberg   return CEED_ERROR_SUCCESS;
121*c4e3f59bSSebastian Grimberg }
122*c4e3f59bSSebastian Grimberg 
123*c4e3f59bSSebastian Grimberg /**
1242f86a920SJeremy L Thompson   @brief Get Ceed associated with a CeedTensorContract
1252f86a920SJeremy L Thompson 
126ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract
1272f86a920SJeremy L Thompson   @param[out] ceed     Variable to store Ceed
1282f86a920SJeremy L Thompson 
1292f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1302f86a920SJeremy L Thompson 
1317a982d89SJeremy L. Thompson   @ref Backend
1322f86a920SJeremy L Thompson **/
1332f86a920SJeremy L Thompson int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
1342f86a920SJeremy L Thompson   *ceed = contract->ceed;
135e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1367a982d89SJeremy L. Thompson }
1372f86a920SJeremy L Thompson 
1382f86a920SJeremy L Thompson /**
1392f86a920SJeremy L Thompson   @brief Get backend data of a CeedTensorContract
1402f86a920SJeremy L Thompson 
141ea61e9acSJeremy L Thompson   @param[in]  contract CeedTensorContract
1422f86a920SJeremy L Thompson   @param[out] data     Variable to store data
1432f86a920SJeremy L Thompson 
1442f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1452f86a920SJeremy L Thompson 
1467a982d89SJeremy L. Thompson   @ref Backend
1472f86a920SJeremy L Thompson **/
148777ff853SJeremy L Thompson int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
149777ff853SJeremy L Thompson   *(void **)data = contract->data;
150e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1512f86a920SJeremy L Thompson }
1522f86a920SJeremy L Thompson 
1532f86a920SJeremy L Thompson /**
1542f86a920SJeremy L Thompson   @brief Set backend data of a CeedTensorContract
1552f86a920SJeremy L Thompson 
156ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract
157ea61e9acSJeremy L Thompson   @param[in]     data     Data to set
1582f86a920SJeremy L Thompson 
1592f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1602f86a920SJeremy L Thompson 
1617a982d89SJeremy L. Thompson   @ref Backend
1622f86a920SJeremy L Thompson **/
163777ff853SJeremy L Thompson int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
164777ff853SJeremy L Thompson   contract->data = data;
165e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1662f86a920SJeremy L Thompson }
1672f86a920SJeremy L Thompson 
1682f86a920SJeremy L Thompson /**
16934359f16Sjeremylt   @brief Increment the reference counter for a CeedTensorContract
17034359f16Sjeremylt 
171ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract to increment the reference counter
17234359f16Sjeremylt 
17334359f16Sjeremylt   @return An error code: 0 - success, otherwise - failure
17434359f16Sjeremylt 
17534359f16Sjeremylt   @ref Backend
17634359f16Sjeremylt **/
1779560d06aSjeremylt int CeedTensorContractReference(CeedTensorContract contract) {
17834359f16Sjeremylt   contract->ref_count++;
17934359f16Sjeremylt   return CEED_ERROR_SUCCESS;
18034359f16Sjeremylt }
18134359f16Sjeremylt 
18234359f16Sjeremylt /**
1832f86a920SJeremy L Thompson   @brief Destroy a CeedTensorContract
1842f86a920SJeremy L Thompson 
185ea61e9acSJeremy L Thompson   @param[in,out] contract CeedTensorContract to destroy
1862f86a920SJeremy L Thompson 
1872f86a920SJeremy L Thompson   @return An error code: 0 - success, otherwise - failure
1882f86a920SJeremy L Thompson 
1897a982d89SJeremy L. Thompson   @ref Backend
1902f86a920SJeremy L Thompson **/
1912f86a920SJeremy L Thompson int CeedTensorContractDestroy(CeedTensorContract *contract) {
192ad6481ceSJeremy L Thompson   if (!*contract || --(*contract)->ref_count > 0) {
193ad6481ceSJeremy L Thompson     *contract = NULL;
194ad6481ceSJeremy L Thompson     return CEED_ERROR_SUCCESS;
195ad6481ceSJeremy L Thompson   }
1962f86a920SJeremy L Thompson   if ((*contract)->Destroy) {
1972b730f8bSJeremy L Thompson     CeedCall((*contract)->Destroy(*contract));
1982f86a920SJeremy L Thompson   }
1992b730f8bSJeremy L Thompson   CeedCall(CeedDestroy(&(*contract)->ceed));
2002b730f8bSJeremy L Thompson   CeedCall(CeedFree(contract));
201e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2027a982d89SJeremy L. Thompson }
2032f86a920SJeremy L Thompson 
2042f86a920SJeremy L Thompson /// @}
205