xref: /libCEED/interface/ceed-tensor.c (revision c9d5affad74485f8d1e55e6be07e3d9f76bd4cae)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a `CeedTensorContract` object for a `CeedBasis`
24 
25   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
26   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
27 
28   @return An error code: 0 - success, otherwise - failure
29 
30   @ref Backend
31 **/
32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
33   if (!ceed->TensorContractCreate) {
34     Ceed delegate;
35 
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
38     CeedCall(CeedTensorContractCreate(delegate, contract));
39     CeedCall(CeedDestroy(&delegate));
40     return CEED_ERROR_SUCCESS;
41   }
42 
43   CeedCall(CeedCalloc(1, contract));
44   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
45   CeedCall(ceed->TensorContractCreate(*contract));
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 /**
50   @brief Apply tensor contraction
51 
52   Contracts on the middle index
53   NOTRANSPOSE: `v_ajc = t_jb u_abc`
54   TRANSPOSE:   `v_ajc = t_bj u_abc`
55   If `add != 0`, `=` is replaced by `+=`
56 
57   @param[in]  contract `CeedTensorContract` to use
58   @param[in]  A        First index of `u`, `v`
59   @param[in]  B        Middle index of `u`, one index of `t`
60   @param[in]  C        Last index of `u`, `v`
61   @param[in]  J        Middle index of `v`, one index of `t`
62   @param[in]  t        Tensor array to contract against
63   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
64   @param[in]  add      Add mode
65   @param[in]  u        Input array
66   @param[out] v        Output array
67 
68   @return An error code: 0 - success, otherwise - failure
69 
70   @ref Backend
71 **/
72 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
73                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
74   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
75   return CEED_ERROR_SUCCESS;
76 }
77 
78 /**
79   @brief Apply tensor contraction
80 
81   Contracts on the middle index
82   NOTRANSPOSE: `v_dajc = t_djb u_abc`
83   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
84   If `add != 0`, `=` is replaced by `+=`
85 
86   @param[in]  contract `CeedTensorContract` to use
87   @param[in]  A        First index of `u`, second index of `v`
88   @param[in]  B        Middle index of `u`, one of last two indices of `t`
89   @param[in]  C        Last index of `u`, `v`
90   @param[in]  D        First index of `v`, first index of `t`
91   @param[in]  J        Third index of `v`, one of last two indices of `t`
92   @param[in]  t        Tensor array to contract against
93   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
94   @param[in]  add      Add mode
95   @param[in]  u        Input array
96   @param[out] v        Output array
97 
98   @return An error code: 0 - success, otherwise - failure
99 
100   @ref Backend
101 **/
102 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
103                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
104   if (t_mode == CEED_TRANSPOSE) {
105     for (CeedInt d = 0; d < D; d++) {
106       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
107     }
108   } else {
109     for (CeedInt d = 0; d < D; d++) {
110       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
111     }
112   }
113   return CEED_ERROR_SUCCESS;
114 }
115 
116 /**
117   @brief Get the `Ceed` associated with a `CeedTensorContract`
118 
119   @param[in]  contract `CeedTensorContract`
120   @param[out] ceed     Variable to store `Ceed`
121 
122   @return An error code: 0 - success, otherwise - failure
123 
124   @ref Backend
125 **/
126 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
127   *ceed = NULL;
128   CeedCall(CeedReferenceCopy(CeedTensorContractReturnCeed(contract), ceed));
129   return CEED_ERROR_SUCCESS;
130 }
131 
132 /**
133   @brief Return the `Ceed` associated with a `CeedTensorContract`
134 
135   @param[in]  contract `CeedTensorContract`
136 
137   @return `Ceed` associated with `contract`
138 
139   @ref Backend
140 **/
141 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return contract->ceed; }
142 
143 /**
144   @brief Get backend data of a `CeedTensorContract`
145 
146   @param[in]  contract `CeedTensorContract`
147   @param[out] data     Variable to store data
148 
149   @return An error code: 0 - success, otherwise - failure
150 
151   @ref Backend
152 **/
153 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
154   *(void **)data = contract->data;
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 /**
159   @brief Set backend data of a `CeedTensorContract`
160 
161   @param[in,out] contract `CeedTensorContract`
162   @param[in]     data     Data to set
163 
164   @return An error code: 0 - success, otherwise - failure
165 
166   @ref Backend
167 **/
168 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
169   contract->data = data;
170   return CEED_ERROR_SUCCESS;
171 }
172 
173 /**
174   @brief Increment the reference counter for a `CeedTensorContract`
175 
176   @param[in,out] contract `CeedTensorContract` to increment the reference counter
177 
178   @return An error code: 0 - success, otherwise - failure
179 
180   @ref Backend
181 **/
182 int CeedTensorContractReference(CeedTensorContract contract) {
183   contract->ref_count++;
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 /**
188   @brief Copy the pointer to a `CeedTensorContract`.
189 
190   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
191 
192   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
193         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
194 
195   @param[in]     tensor      `CeedTensorContract` to copy reference to
196   @param[in,out] tensor_copy Variable to store copied reference
197 
198   @return An error code: 0 - success, otherwise - failure
199 
200   @ref User
201 **/
202 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
203   CeedCall(CeedTensorContractReference(tensor));
204   CeedCall(CeedTensorContractDestroy(tensor_copy));
205   *tensor_copy = tensor;
206   return CEED_ERROR_SUCCESS;
207 }
208 
209 /**
210   @brief Destroy a `CeedTensorContract`
211 
212   @param[in,out] contract `CeedTensorContract` to destroy
213 
214   @return An error code: 0 - success, otherwise - failure
215 
216   @ref Backend
217 **/
218 int CeedTensorContractDestroy(CeedTensorContract *contract) {
219   if (!*contract || --(*contract)->ref_count > 0) {
220     *contract = NULL;
221     return CEED_ERROR_SUCCESS;
222   }
223   if ((*contract)->Destroy) {
224     CeedCall((*contract)->Destroy(*contract));
225   }
226   CeedCall(CeedDestroy(&(*contract)->ceed));
227   CeedCall(CeedFree(contract));
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 /// @}
232