xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision f2989f2b3b8649d855bed22b6730ee0ecfa6b31b)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Library Internal Functions
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedTensorContractDeveloper
20 /// @{
21 
22 /**
23   @brief Destroy a `CeedTensorContract` passed as a `CeedObject`
24 
25   @param[in,out] contract Address of `CeedTensorContract` to destroy
26 
27   @return An error code: 0 - success, otherwise - failure
28 
29   @ref Developer
30 **/
CeedTensorContractDestroy_Object(CeedObject * contract)31 static int CeedTensorContractDestroy_Object(CeedObject *contract) {
32   CeedCall(CeedTensorContractDestroy((CeedTensorContract *)contract));
33   return CEED_ERROR_SUCCESS;
34 }
35 
36 /// @}
37 
38 /// ----------------------------------------------------------------------------
39 /// CeedTensorContract Backend API
40 /// ----------------------------------------------------------------------------
41 /// @addtogroup CeedBasisBackend
42 /// @{
43 
44 /**
45   @brief Create a `CeedTensorContract` object for a `CeedBasis`
46 
47   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
48   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
49 
50   @return An error code: 0 - success, otherwise - failure
51 
52   @ref Backend
53 **/
CeedTensorContractCreate(Ceed ceed,CeedTensorContract * contract)54 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
55   if (!ceed->TensorContractCreate) {
56     Ceed delegate;
57 
58     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
59     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
60     CeedCall(CeedTensorContractCreate(delegate, contract));
61     CeedCall(CeedDestroy(&delegate));
62     return CEED_ERROR_SUCCESS;
63   }
64 
65   CeedCall(CeedCalloc(1, contract));
66   CeedCall(CeedObjectCreate(ceed, NULL, CeedTensorContractDestroy_Object, &(*contract)->obj));
67   CeedCall(ceed->TensorContractCreate(*contract));
68   return CEED_ERROR_SUCCESS;
69 }
70 
71 /**
72   @brief Apply tensor contraction
73 
74   Contracts on the middle index
75   NOTRANSPOSE: `v_ajc = t_jb u_abc`
76   TRANSPOSE:   `v_ajc = t_bj u_abc`
77   If `add != 0`, `=` is replaced by `+=`
78 
79   @param[in]  contract `CeedTensorContract` to use
80   @param[in]  A        First index of `u`, `v`
81   @param[in]  B        Middle index of `u`, one index of `t`
82   @param[in]  C        Last index of `u`, `v`
83   @param[in]  J        Middle index of `v`, one index of `t`
84   @param[in]  t        Tensor array to contract against
85   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
86   @param[in]  add      Add mode
87   @param[in]  u        Input array
88   @param[out] v        Output array
89 
90   @return An error code: 0 - success, otherwise - failure
91 
92   @ref Backend
93 **/
CeedTensorContractApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)94 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
95                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
96   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
97   return CEED_ERROR_SUCCESS;
98 }
99 
100 /**
101   @brief Apply tensor contraction
102 
103   Contracts on the middle index
104   NOTRANSPOSE: `v_dajc = t_djb u_abc`
105   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
106   If `add != 0`, `=` is replaced by `+=`
107 
108   @param[in]  contract `CeedTensorContract` to use
109   @param[in]  A        First index of `u`, second index of `v`
110   @param[in]  B        Middle index of `u`, one of last two indices of `t`
111   @param[in]  C        Last index of `u`, `v`
112   @param[in]  D        First index of `v`, first index of `t`
113   @param[in]  J        Third index of `v`, one of last two indices of `t`
114   @param[in]  t        Tensor array to contract against
115   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
116   @param[in]  add      Add mode
117   @param[in]  u        Input array
118   @param[out] v        Output array
119 
120   @return An error code: 0 - success, otherwise - failure
121 
122   @ref Backend
123 **/
CeedTensorContractStridedApply(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt D,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)124 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
125                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
126   if (t_mode == CEED_TRANSPOSE) {
127     for (CeedInt d = 0; d < D; d++) {
128       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
129     }
130   } else {
131     for (CeedInt d = 0; d < D; d++) {
132       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
133     }
134   }
135   return CEED_ERROR_SUCCESS;
136 }
137 
138 /**
139   @brief Get the `Ceed` associated with a `CeedTensorContract`
140 
141   @param[in]  contract `CeedTensorContract`
142   @param[out] ceed     Variable to store `Ceed`
143 
144   @return An error code: 0 - success, otherwise - failure
145 
146   @ref Backend
147 **/
CeedTensorContractGetCeed(CeedTensorContract contract,Ceed * ceed)148 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
149   CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed));
150   return CEED_ERROR_SUCCESS;
151 }
152 
153 /**
154   @brief Return the `Ceed` associated with a `CeedTensorContract`
155 
156   @param[in]  contract `CeedTensorContract`
157 
158   @return `Ceed` associated with `contract`
159 
160   @ref Backend
161 **/
CeedTensorContractReturnCeed(CeedTensorContract contract)162 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); }
163 
164 /**
165   @brief Get backend data of a `CeedTensorContract`
166 
167   @param[in]  contract `CeedTensorContract`
168   @param[out] data     Variable to store data
169 
170   @return An error code: 0 - success, otherwise - failure
171 
172   @ref Backend
173 **/
CeedTensorContractGetData(CeedTensorContract contract,void * data)174 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
175   *(void **)data = contract->data;
176   return CEED_ERROR_SUCCESS;
177 }
178 
179 /**
180   @brief Set backend data of a `CeedTensorContract`
181 
182   @param[in,out] contract `CeedTensorContract`
183   @param[in]     data     Data to set
184 
185   @return An error code: 0 - success, otherwise - failure
186 
187   @ref Backend
188 **/
CeedTensorContractSetData(CeedTensorContract contract,void * data)189 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
190   contract->data = data;
191   return CEED_ERROR_SUCCESS;
192 }
193 
194 /**
195   @brief Increment the reference counter for a `CeedTensorContract`
196 
197   @param[in,out] contract `CeedTensorContract` to increment the reference counter
198 
199   @return An error code: 0 - success, otherwise - failure
200 
201   @ref Backend
202 **/
CeedTensorContractReference(CeedTensorContract contract)203 int CeedTensorContractReference(CeedTensorContract contract) {
204   CeedCall(CeedObjectReference((CeedObject)contract));
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 /**
209   @brief Copy the pointer to a `CeedTensorContract`.
210 
211   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
212 
213   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
214         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
215 
216   @param[in]     tensor      `CeedTensorContract` to copy reference to
217   @param[in,out] tensor_copy Variable to store copied reference
218 
219   @return An error code: 0 - success, otherwise - failure
220 
221   @ref User
222 **/
CeedTensorContractReferenceCopy(CeedTensorContract tensor,CeedTensorContract * tensor_copy)223 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
224   CeedCall(CeedTensorContractReference(tensor));
225   CeedCall(CeedTensorContractDestroy(tensor_copy));
226   *tensor_copy = tensor;
227   return CEED_ERROR_SUCCESS;
228 }
229 
230 /**
231   @brief Destroy a `CeedTensorContract`
232 
233   @param[in,out] contract `CeedTensorContract` to destroy
234 
235   @return An error code: 0 - success, otherwise - failure
236 
237   @ref Backend
238 **/
CeedTensorContractDestroy(CeedTensorContract * contract)239 int CeedTensorContractDestroy(CeedTensorContract *contract) {
240   if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) {
241     *contract = NULL;
242     return CEED_ERROR_SUCCESS;
243   }
244   if ((*contract)->Destroy) {
245     CeedCall((*contract)->Destroy(*contract));
246   }
247   CeedCall(CeedObjectDestroy_Private(&(*contract)->obj));
248   CeedCall(CeedFree(contract));
249   return CEED_ERROR_SUCCESS;
250 }
251 
252 /// @}
253