1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Library Internal Functions
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedTensorContractDeveloper
20 /// @{
21
22 /**
23 @brief Destroy a `CeedTensorContract` passed as a `CeedObject`
24
25 @param[in,out] contract Address of `CeedTensorContract` to destroy
26
27 @return An error code: 0 - success, otherwise - failure
28
29 @ref Developer
30 **/
CeedTensorContractDestroy_Object(CeedObject * contract)31 static int CeedTensorContractDestroy_Object(CeedObject *contract) {
32 CeedCall(CeedTensorContractDestroy((CeedTensorContract *)contract));
33 return CEED_ERROR_SUCCESS;
34 }
35
36 /// @}
37
38 /// ----------------------------------------------------------------------------
39 /// CeedTensorContract Backend API
40 /// ----------------------------------------------------------------------------
41 /// @addtogroup CeedBasisBackend
42 /// @{
43
44 /**
45 @brief Create a `CeedTensorContract` object for a `CeedBasis`
46
47 @param[in] ceed `Ceed` object used to create the `CeedTensorContract`
48 @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
49
50 @return An error code: 0 - success, otherwise - failure
51
52 @ref Backend
53 **/
CeedTensorContractCreate(Ceed ceed,CeedTensorContract * contract)54 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
55 if (!ceed->TensorContractCreate) {
56 Ceed delegate;
57
58 CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
59 CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
60 CeedCall(CeedTensorContractCreate(delegate, contract));
61 CeedCall(CeedDestroy(&delegate));
62 return CEED_ERROR_SUCCESS;
63 }
64
65 CeedCall(CeedCalloc(1, contract));
66 CeedCall(CeedObjectCreate(ceed, NULL, CeedTensorContractDestroy_Object, &(*contract)->obj));
67 CeedCall(ceed->TensorContractCreate(*contract));
68 return CEED_ERROR_SUCCESS;
69 }
70
71 /**
72 @brief Apply tensor contraction
73
74 Contracts on the middle index
75 NOTRANSPOSE: `v_ajc = t_jb u_abc`
76 TRANSPOSE: `v_ajc = t_bj u_abc`
77 If `add != 0`, `=` is replaced by `+=`
78
79 @param[in] contract `CeedTensorContract` to use
80 @param[in] A First index of `u`, `v`
81 @param[in] B Middle index of `u`, one index of `t`
82 @param[in] C Last index of `u`, `v`
83 @param[in] J Middle index of `v`, one index of `t`
84 @param[in] t Tensor array to contract against
85 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
86 @param[in] add Add mode
87 @param[in] u Input array
88 @param[out] v Output array
89
90 @return An error code: 0 - success, otherwise - failure
91
92 @ref Backend
93 **/
CeedTensorContractApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)94 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
95 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
96 CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
97 return CEED_ERROR_SUCCESS;
98 }
99
100 /**
101 @brief Apply tensor contraction
102
103 Contracts on the middle index
104 NOTRANSPOSE: `v_dajc = t_djb u_abc`
105 TRANSPOSE: `v_ajc = t_dbj u_dabc`
106 If `add != 0`, `=` is replaced by `+=`
107
108 @param[in] contract `CeedTensorContract` to use
109 @param[in] A First index of `u`, second index of `v`
110 @param[in] B Middle index of `u`, one of last two indices of `t`
111 @param[in] C Last index of `u`, `v`
112 @param[in] D First index of `v`, first index of `t`
113 @param[in] J Third index of `v`, one of last two indices of `t`
114 @param[in] t Tensor array to contract against
115 @param[in] t_mode Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
116 @param[in] add Add mode
117 @param[in] u Input array
118 @param[out] v Output array
119
120 @return An error code: 0 - success, otherwise - failure
121
122 @ref Backend
123 **/
CeedTensorContractStridedApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt D,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)124 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
125 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
126 if (t_mode == CEED_TRANSPOSE) {
127 for (CeedInt d = 0; d < D; d++) {
128 CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
129 }
130 } else {
131 for (CeedInt d = 0; d < D; d++) {
132 CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
133 }
134 }
135 return CEED_ERROR_SUCCESS;
136 }
137
138 /**
139 @brief Get the `Ceed` associated with a `CeedTensorContract`
140
141 @param[in] contract `CeedTensorContract`
142 @param[out] ceed Variable to store `Ceed`
143
144 @return An error code: 0 - success, otherwise - failure
145
146 @ref Backend
147 **/
CeedTensorContractGetCeed(CeedTensorContract contract,Ceed * ceed)148 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
149 CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed));
150 return CEED_ERROR_SUCCESS;
151 }
152
153 /**
154 @brief Return the `Ceed` associated with a `CeedTensorContract`
155
156 @param[in] contract `CeedTensorContract`
157
158 @return `Ceed` associated with `contract`
159
160 @ref Backend
161 **/
CeedTensorContractReturnCeed(CeedTensorContract contract)162 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); }
163
164 /**
165 @brief Get backend data of a `CeedTensorContract`
166
167 @param[in] contract `CeedTensorContract`
168 @param[out] data Variable to store data
169
170 @return An error code: 0 - success, otherwise - failure
171
172 @ref Backend
173 **/
CeedTensorContractGetData(CeedTensorContract contract,void * data)174 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
175 *(void **)data = contract->data;
176 return CEED_ERROR_SUCCESS;
177 }
178
179 /**
180 @brief Set backend data of a `CeedTensorContract`
181
182 @param[in,out] contract `CeedTensorContract`
183 @param[in] data Data to set
184
185 @return An error code: 0 - success, otherwise - failure
186
187 @ref Backend
188 **/
CeedTensorContractSetData(CeedTensorContract contract,void * data)189 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
190 contract->data = data;
191 return CEED_ERROR_SUCCESS;
192 }
193
194 /**
195 @brief Increment the reference counter for a `CeedTensorContract`
196
197 @param[in,out] contract `CeedTensorContract` to increment the reference counter
198
199 @return An error code: 0 - success, otherwise - failure
200
201 @ref Backend
202 **/
CeedTensorContractReference(CeedTensorContract contract)203 int CeedTensorContractReference(CeedTensorContract contract) {
204 CeedCall(CeedObjectReference((CeedObject)contract));
205 return CEED_ERROR_SUCCESS;
206 }
207
208 /**
209 @brief Copy the pointer to a `CeedTensorContract`.
210
211 Both pointers should be destroyed with @ref CeedTensorContractDestroy().
212
213 Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
214 This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
215
216 @param[in] tensor `CeedTensorContract` to copy reference to
217 @param[in,out] tensor_copy Variable to store copied reference
218
219 @return An error code: 0 - success, otherwise - failure
220
221 @ref User
222 **/
CeedTensorContractReferenceCopy(CeedTensorContract tensor,CeedTensorContract * tensor_copy)223 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
224 CeedCall(CeedTensorContractReference(tensor));
225 CeedCall(CeedTensorContractDestroy(tensor_copy));
226 *tensor_copy = tensor;
227 return CEED_ERROR_SUCCESS;
228 }
229
230 /**
231 @brief Destroy a `CeedTensorContract`
232
233 @param[in,out] contract `CeedTensorContract` to destroy
234
235 @return An error code: 0 - success, otherwise - failure
236
237 @ref Backend
238 **/
CeedTensorContractDestroy(CeedTensorContract * contract)239 int CeedTensorContractDestroy(CeedTensorContract *contract) {
240 if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) {
241 *contract = NULL;
242 return CEED_ERROR_SUCCESS;
243 }
244 if ((*contract)->Destroy) {
245 CeedCall((*contract)->Destroy(*contract));
246 }
247 CeedCall(CeedObjectDestroy_Private(&(*contract)->obj));
248 CeedCall(CeedFree(contract));
249 return CEED_ERROR_SUCCESS;
250 }
251
252 /// @}
253