xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision c52157525785c51910c8fadbbfd077683f26126a)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a CeedTensorContract object for a CeedBasis
24 
25   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26   @param[in]  basis    CeedBasis for which the tensor contraction will be used
27   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
34   if (!ceed->TensorContractCreate) {
35     Ceed delegate;
36 
37     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
38     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
39     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
40     return CEED_ERROR_SUCCESS;
41   }
42 
43   CeedCall(CeedCalloc(1, contract));
44 
45   (*contract)->ceed = ceed;
46   CeedCall(CeedReference(ceed));
47   CeedCall(ceed->TensorContractCreate(basis, *contract));
48   return CEED_ERROR_SUCCESS;
49 }
50 
51 /**
52   @brief Apply tensor contraction
53 
54   Contracts on the middle index
55   NOTRANSPOSE: v_ajc = t_jb u_abc
56   TRANSPOSE:   v_ajc = t_bj u_abc
57   If add != 0, "=" is replaced by "+="
58 
59   @param[in]  contract CeedTensorContract to use
60   @param[in]  A        First index of u, v
61   @param[in]  B        Middle index of u, one index of t
62   @param[in]  C        Last index of u, v
63   @param[in]  J        Middle index of v, one index of t
64   @param[in]  t        Tensor array to contract against
65   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
66   @param[in]  add      Add mode
67   @param[in]  u        Input array
68   @param[out] v        Output array
69 
70   @return An error code: 0 - success, otherwise - failure
71 
72   @ref Backend
73 **/
74 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
75                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
76   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
77   return CEED_ERROR_SUCCESS;
78 }
79 
80 /**
81   @brief Apply tensor contraction
82 
83     Contracts on the middle index
84     NOTRANSPOSE: v_dajc = t_djb u_abc
85     TRANSPOSE:   v_ajc  = t_dbj u_dabc
86     If add != 0, "=" is replaced by "+="
87 
88   @param[in]  contract CeedTensorContract to use
89   @param[in]  A        First index of u, second index of v
90   @param[in]  B        Middle index of u, one of last two indices of t
91   @param[in]  C        Last index of u, v
92   @param[in]  D        First index of v, first index of t
93   @param[in]  J        Third index of v, one of last two indices of t
94   @param[in]  t        Tensor array to contract against
95   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
96   @param[in]  add      Add mode
97   @param[in]  u        Input array
98   @param[out] v        Output array
99 
100   @return An error code: 0 - success, otherwise - failure
101 
102   @ref Backend
103 **/
104 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
105                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
106   if (t_mode == CEED_TRANSPOSE) {
107     for (CeedInt d = 0; d < D; d++) {
108       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
109     }
110   } else {
111     for (CeedInt d = 0; d < D; d++) {
112       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
113     }
114   }
115   return CEED_ERROR_SUCCESS;
116 }
117 
118 /**
119   @brief Get Ceed associated with a CeedTensorContract
120 
121   @param[in]  contract CeedTensorContract
122   @param[out] ceed     Variable to store Ceed
123 
124   @return An error code: 0 - success, otherwise - failure
125 
126   @ref Backend
127 **/
128 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
129   *ceed = contract->ceed;
130   return CEED_ERROR_SUCCESS;
131 }
132 
133 /**
134   @brief Get backend data of a CeedTensorContract
135 
136   @param[in]  contract CeedTensorContract
137   @param[out] data     Variable to store data
138 
139   @return An error code: 0 - success, otherwise - failure
140 
141   @ref Backend
142 **/
143 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
144   *(void **)data = contract->data;
145   return CEED_ERROR_SUCCESS;
146 }
147 
148 /**
149   @brief Set backend data of a CeedTensorContract
150 
151   @param[in,out] contract CeedTensorContract
152   @param[in]     data     Data to set
153 
154   @return An error code: 0 - success, otherwise - failure
155 
156   @ref Backend
157 **/
158 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
159   contract->data = data;
160   return CEED_ERROR_SUCCESS;
161 }
162 
163 /**
164   @brief Increment the reference counter for a CeedTensorContract
165 
166   @param[in,out] contract CeedTensorContract to increment the reference counter
167 
168   @return An error code: 0 - success, otherwise - failure
169 
170   @ref Backend
171 **/
172 int CeedTensorContractReference(CeedTensorContract contract) {
173   contract->ref_count++;
174   return CEED_ERROR_SUCCESS;
175 }
176 
177 /**
178   @brief Destroy a CeedTensorContract
179 
180   @param[in,out] contract CeedTensorContract to destroy
181 
182   @return An error code: 0 - success, otherwise - failure
183 
184   @ref Backend
185 **/
186 int CeedTensorContractDestroy(CeedTensorContract *contract) {
187   if (!*contract || --(*contract)->ref_count > 0) {
188     *contract = NULL;
189     return CEED_ERROR_SUCCESS;
190   }
191   if ((*contract)->Destroy) {
192     CeedCall((*contract)->Destroy(*contract));
193   }
194   CeedCall(CeedDestroy(&(*contract)->ceed));
195   CeedCall(CeedFree(contract));
196   return CEED_ERROR_SUCCESS;
197 }
198 
199 /// @}
200