xref: /libCEED/interface/ceed-tensor.c (revision d0593705e733b5bdd5e4c173fe0008b11db2ed29)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a `CeedTensorContract` object for a `CeedBasis`
24 
25   @param[in]  ceed     `Ceed` object used to create the `CeedTensorContract`
26   @param[out] contract Address of the variable where the newly created `CeedTensorContract` will be stored.
27 
28   @return An error code: 0 - success, otherwise - failure
29 
30   @ref Backend
31 **/
32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
33   if (!ceed->TensorContractCreate) {
34     Ceed delegate;
35 
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support CeedTensorContractCreate");
38     CeedCall(CeedTensorContractCreate(delegate, contract));
39     return CEED_ERROR_SUCCESS;
40   }
41 
42   CeedCall(CeedCalloc(1, contract));
43   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
44   CeedCall(ceed->TensorContractCreate(*contract));
45   return CEED_ERROR_SUCCESS;
46 }
47 
48 /**
49   @brief Apply tensor contraction
50 
51   Contracts on the middle index
52   NOTRANSPOSE: `v_ajc = t_jb u_abc`
53   TRANSPOSE:   `v_ajc = t_bj u_abc`
54   If `add != 0`, `=` is replaced by `+=`
55 
56   @param[in]  contract `CeedTensorContract` to use
57   @param[in]  A        First index of `u`, `v`
58   @param[in]  B        Middle index of `u`, one index of `t`
59   @param[in]  C        Last index of `u`, `v`
60   @param[in]  J        Middle index of `v`, one index of `t`
61   @param[in]  t        Tensor array to contract against
62   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_jb` @ref CEED_TRANSPOSE for `t_bj`
63   @param[in]  add      Add mode
64   @param[in]  u        Input array
65   @param[out] v        Output array
66 
67   @return An error code: 0 - success, otherwise - failure
68 
69   @ref Backend
70 **/
71 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
72                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
73   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
74   return CEED_ERROR_SUCCESS;
75 }
76 
77 /**
78   @brief Apply tensor contraction
79 
80   Contracts on the middle index
81   NOTRANSPOSE: `v_dajc = t_djb u_abc`
82   TRANSPOSE:   `v_ajc  = t_dbj u_dabc`
83   If `add != 0`, `=` is replaced by `+=`
84 
85   @param[in]  contract `CeedTensorContract` to use
86   @param[in]  A        First index of `u`, second index of `v`
87   @param[in]  B        Middle index of `u`, one of last two indices of `t`
88   @param[in]  C        Last index of `u`, `v`
89   @param[in]  D        First index of `v`, first index of `t`
90   @param[in]  J        Third index of `v`, one of last two indices of `t`
91   @param[in]  t        Tensor array to contract against
92   @param[in]  t_mode   Transpose mode for `t`, @ref CEED_NOTRANSPOSE for `t_djb` @ref CEED_TRANSPOSE for `t_dbj`
93   @param[in]  add      Add mode
94   @param[in]  u        Input array
95   @param[out] v        Output array
96 
97   @return An error code: 0 - success, otherwise - failure
98 
99   @ref Backend
100 **/
101 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
102                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
103   if (t_mode == CEED_TRANSPOSE) {
104     for (CeedInt d = 0; d < D; d++) {
105       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
106     }
107   } else {
108     for (CeedInt d = 0; d < D; d++) {
109       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
110     }
111   }
112   return CEED_ERROR_SUCCESS;
113 }
114 
115 /**
116   @brief Get the `Ceed` associated with a `CeedTensorContract`
117 
118   @param[in]  contract `CeedTensorContract`
119   @param[out] ceed     Variable to store `Ceed`
120 
121   @return An error code: 0 - success, otherwise - failure
122 
123   @ref Backend
124 **/
125 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
126   *ceed = CeedTensorContractReturnCeed(contract);
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 /**
131   @brief Return the `Ceed` associated with a `CeedTensorContract`
132 
133   @param[in]  contract `CeedTensorContract`
134 
135   @return `Ceed` associated with `contract`
136 
137   @ref Backend
138 **/
139 Ceed CeedTensorContractReturnCeed(CeedTensorContract contract) { return contract->ceed; }
140 
141 /**
142   @brief Get backend data of a `CeedTensorContract`
143 
144   @param[in]  contract `CeedTensorContract`
145   @param[out] data     Variable to store data
146 
147   @return An error code: 0 - success, otherwise - failure
148 
149   @ref Backend
150 **/
151 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
152   *(void **)data = contract->data;
153   return CEED_ERROR_SUCCESS;
154 }
155 
156 /**
157   @brief Set backend data of a `CeedTensorContract`
158 
159   @param[in,out] contract `CeedTensorContract`
160   @param[in]     data     Data to set
161 
162   @return An error code: 0 - success, otherwise - failure
163 
164   @ref Backend
165 **/
166 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
167   contract->data = data;
168   return CEED_ERROR_SUCCESS;
169 }
170 
171 /**
172   @brief Increment the reference counter for a `CeedTensorContract`
173 
174   @param[in,out] contract `CeedTensorContract` to increment the reference counter
175 
176   @return An error code: 0 - success, otherwise - failure
177 
178   @ref Backend
179 **/
180 int CeedTensorContractReference(CeedTensorContract contract) {
181   contract->ref_count++;
182   return CEED_ERROR_SUCCESS;
183 }
184 
185 /**
186   @brief Copy the pointer to a `CeedTensorContract`.
187 
188   Both pointers should be destroyed with @ref CeedTensorContractDestroy().
189 
190   Note: If the value of `*tensor_copy` passed to this function is non-`NULL`, then it is assumed that `*tensor_copy` is a pointer to a `CeedTensorContract`.
191         This `CeedTensorContract` will be destroyed if `*tensor_copy` is the only reference to this `CeedTensorContract`.
192 
193   @param[in]     tensor      `CeedTensorContract` to copy reference to
194   @param[in,out] tensor_copy Variable to store copied reference
195 
196   @return An error code: 0 - success, otherwise - failure
197 
198   @ref User
199 **/
200 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
201   CeedCall(CeedTensorContractReference(tensor));
202   CeedCall(CeedTensorContractDestroy(tensor_copy));
203   *tensor_copy = tensor;
204   return CEED_ERROR_SUCCESS;
205 }
206 
207 /**
208   @brief Destroy a `CeedTensorContract`
209 
210   @param[in,out] contract `CeedTensorContract` to destroy
211 
212   @return An error code: 0 - success, otherwise - failure
213 
214   @ref Backend
215 **/
216 int CeedTensorContractDestroy(CeedTensorContract *contract) {
217   if (!*contract || --(*contract)->ref_count > 0) {
218     *contract = NULL;
219     return CEED_ERROR_SUCCESS;
220   }
221   if ((*contract)->Destroy) {
222     CeedCall((*contract)->Destroy(*contract));
223   }
224   CeedCall(CeedDestroy(&(*contract)->ceed));
225   CeedCall(CeedFree(contract));
226   return CEED_ERROR_SUCCESS;
227 }
228 
229 /// @}
230