xref: /libCEED/interface/ceed-tensor.c (revision b0f67a9c1aeeb4d82b4724afaae1227ff4e81f15)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a `CeedTensorContract` object for a `CeedBasis`
24 
25   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
26   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
27 
28   @return An error code: 0 - success, otherwise - failure
29 
30   @ref Backend
31 **/
32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
33   if (!ceed->TensorContractCreate) {
34     Ceed delegate;
35 
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement CeedTensorContractCreate");
38     CeedCall(CeedTensorContractCreate(delegate, contract));
39     CeedCall(CeedDestroy(&delegate));
40     return CEED_ERROR_SUCCESS;
41   }
42 
43   CeedCall(CeedCalloc(1, contract));
44   CeedCall(CeedObjectCreate(ceed, NULL, &(*contract)->obj));
45   CeedCall(ceed->TensorContractCreate(*contract));
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 /**
50   @brief Apply tensor contraction
51 
52   Contracts on the middle index
53   NOTRANSPOSE: `v_ajc = t_jb u_abc`
54   TRANSPOSE:   `v_ajc = t_bj u_abc`
55   If `add != 0`, `=` is replaced by `+=`
56 
57   @param[in]  contract `CeedTensorContract` to use
58   @param[in]  A        First index of `u`, `v`
59   @param[in]  B        Middle index of `u`, one index of `t`
60   @param[in]  C        Last index of `u`, `v`
61   @param[in]  J        Middle index of `v`, one index of `t`
62   @param[in]  t        Tensor array to contract against
63   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
64   @param[in]  add      Add mode
65   @param[in]  u        Input array
66   @param[out] v        Output array
67 
68   @return An error code: 0 - success, otherwise - failure
69 
70   @ref Backend
71 **/
72 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
73                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
74   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
75   return CEED_ERROR_SUCCESS;
76 }
77 
78 /**
79   @brief Apply tensor contraction
80 
81   Contracts on the middle index
82   NOTRANSPOSE: `v_dajc = t_djb u_abc`
83   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
84   If `add != 0`, `=` is replaced by `+=`
85 
86   @param[in]  contract `CeedTensorContract` to use
87   @param[in]  A        First index of `u`, second index of `v`
88   @param[in]  B        Middle index of `u`, one of last two indices of `t`
89   @param[in]  C        Last index of `u`, `v`
90   @param[in]  D        First index of `v`, first index of `t`
91   @param[in]  J        Third index of `v`, one of last two indices of `t`
92   @param[in]  t        Tensor array to contract against
93   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
94   @param[in]  add      Add mode
95   @param[in]  u        Input array
96   @param[out] v        Output array
97 
98   @return An error code: 0 - success, otherwise - failure
99 
100   @ref Backend
101 **/
102 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
103                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
104   if (t_mode == CEED_TRANSPOSE) {
105     for (CeedInt d = 0; d < D; d++) {
106       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
107     }
108   } else {
109     for (CeedInt d = 0; d < D; d++) {
110       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
111     }
112   }
113   return CEED_ERROR_SUCCESS;
114 }
115 
116 /**
117   @brief Get the `Ceed` associated with a `CeedTensorContract`
118 
119   @param[in]  contract `CeedTensorContract`
120   @param[out] ceed     Variable to store `Ceed`
121 
122   @return An error code: 0 - success, otherwise - failure
123 
124   @ref Backend
125 **/
126 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
127   CeedCall(CeedObjectGetCeed((CeedObject)contract, ceed));
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 /**
132   @brief Return the `Ceed` associated with a `CeedTensorContract`
133 
134   @param[in]  contract `CeedTensorContract`
135 
136   @return `Ceed` associated with `contract`
137 
138   @ref Backend
139 **/
140 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return CeedObjectReturnCeed((CeedObject)contract); }
141 
142 /**
143   @brief Get backend data of a `CeedTensorContract`
144 
145   @param[in]  contract `CeedTensorContract`
146   @param[out] data     Variable to store data
147 
148   @return An error code: 0 - success, otherwise - failure
149 
150   @ref Backend
151 **/
152 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
153   *(void **)data = contract->data;
154   return CEED_ERROR_SUCCESS;
155 }
156 
157 /**
158   @brief Set backend data of a `CeedTensorContract`
159 
160   @param[in,out] contract `CeedTensorContract`
161   @param[in]     data     Data to set
162 
163   @return An error code: 0 - success, otherwise - failure
164 
165   @ref Backend
166 **/
167 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
168   contract->data = data;
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 /**
173   @brief Increment the reference counter for a `CeedTensorContract`
174 
175   @param[in,out] contract `CeedTensorContract` to increment the reference counter
176 
177   @return An error code: 0 - success, otherwise - failure
178 
179   @ref Backend
180 **/
181 int CeedTensorContractReference(CeedTensorContract contract) {
182   CeedCall(CeedObjectReference((CeedObject)contract));
183   return CEED_ERROR_SUCCESS;
184 }
185 
186 /**
187   @brief Copy the pointer to a `CeedTensorContract`.
188 
189   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
190 
191   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
192         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
193 
194   @param[in]     tensor      `CeedTensorContract` to copy reference to
195   @param[in,out] tensor_copy Variable to store copied reference
196 
197   @return An error code: 0 - success, otherwise - failure
198 
199   @ref User
200 **/
201 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
202   CeedCall(CeedTensorContractReference(tensor));
203   CeedCall(CeedTensorContractDestroy(tensor_copy));
204   *tensor_copy = tensor;
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 /**
209   @brief Destroy a `CeedTensorContract`
210 
211   @param[in,out] contract `CeedTensorContract` to destroy
212 
213   @return An error code: 0 - success, otherwise - failure
214 
215   @ref Backend
216 **/
217 int CeedTensorContractDestroy(CeedTensorContract *contract) {
218   if (!*contract || CeedObjectDereference((CeedObject)*contract) > 0) {
219     *contract = NULL;
220     return CEED_ERROR_SUCCESS;
221   }
222   if ((*contract)->Destroy) {
223     CeedCall((*contract)->Destroy(*contract));
224   }
225   CeedCall(CeedObjectDestroy(&(*contract)->obj));
226   CeedCall(CeedFree(contract));
227   return CEED_ERROR_SUCCESS;
228 }
229 
230 /// @}
231