xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision 585a562d532a0ab5cf252f239fa5829eab5ad152)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a CeedTensorContract object for a CeedBasis
24 
25   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26   @param[in]  basis    CeedBasis for which the tensor contraction will be used
27   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
34   if (!ceed->TensorContractCreate) {
35     Ceed delegate;
36 
37     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
38     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
39     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
40     return CEED_ERROR_SUCCESS;
41   }
42 
43   CeedCall(CeedCalloc(1, contract));
44   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
45   CeedCall(ceed->TensorContractCreate(basis, *contract));
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 /**
50   @brief Apply tensor contraction
51 
52   Contracts on the middle index
53   NOTRANSPOSE: v_ajc = t_jb u_abc
54   TRANSPOSE:   v_ajc = t_bj u_abc
55   If add != 0, "=" is replaced by "+="
56 
57   @param[in]  contract CeedTensorContract to use
58   @param[in]  A        First index of u, v
59   @param[in]  B        Middle index of u, one index of t
60   @param[in]  C        Last index of u, v
61   @param[in]  J        Middle index of v, one index of t
62   @param[in]  t        Tensor array to contract against
63   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
64   @param[in]  add      Add mode
65   @param[in]  u        Input array
66   @param[out] v        Output array
67 
68   @return An error code: 0 - success, otherwise - failure
69 
70   @ref Backend
71 **/
72 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
73                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
74   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
75   return CEED_ERROR_SUCCESS;
76 }
77 
78 /**
79   @brief Apply tensor contraction
80 
81     Contracts on the middle index
82     NOTRANSPOSE: v_dajc = t_djb u_abc
83     TRANSPOSE:   v_ajc  = t_dbj u_dabc
84     If add != 0, "=" is replaced by "+="
85 
86   @param[in]  contract CeedTensorContract to use
87   @param[in]  A        First index of u, second index of v
88   @param[in]  B        Middle index of u, one of last two indices of t
89   @param[in]  C        Last index of u, v
90   @param[in]  D        First index of v, first index of t
91   @param[in]  J        Third index of v, one of last two indices of t
92   @param[in]  t        Tensor array to contract against
93   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_djb \ref CEED_TRANSPOSE for t_dbj
94   @param[in]  add      Add mode
95   @param[in]  u        Input array
96   @param[out] v        Output array
97 
98   @return An error code: 0 - success, otherwise - failure
99 
100   @ref Backend
101 **/
102 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
103                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
104   if (t_mode == CEED_TRANSPOSE) {
105     for (CeedInt d = 0; d < D; d++) {
106       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
107     }
108   } else {
109     for (CeedInt d = 0; d < D; d++) {
110       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
111     }
112   }
113   return CEED_ERROR_SUCCESS;
114 }
115 
116 /**
117   @brief Get Ceed associated with a CeedTensorContract
118 
119   @param[in]  contract CeedTensorContract
120   @param[out] ceed     Variable to store Ceed
121 
122   @return An error code: 0 - success, otherwise - failure
123 
124   @ref Backend
125 **/
126 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
127   *ceed = contract->ceed;
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 /**
132   @brief Get backend data of a CeedTensorContract
133 
134   @param[in]  contract CeedTensorContract
135   @param[out] data     Variable to store data
136 
137   @return An error code: 0 - success, otherwise - failure
138 
139   @ref Backend
140 **/
141 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
142   *(void **)data = contract->data;
143   return CEED_ERROR_SUCCESS;
144 }
145 
146 /**
147   @brief Set backend data of a CeedTensorContract
148 
149   @param[in,out] contract CeedTensorContract
150   @param[in]     data     Data to set
151 
152   @return An error code: 0 - success, otherwise - failure
153 
154   @ref Backend
155 **/
156 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
157   contract->data = data;
158   return CEED_ERROR_SUCCESS;
159 }
160 
161 /**
162   @brief Increment the reference counter for a CeedTensorContract
163 
164   @param[in,out] contract CeedTensorContract to increment the reference counter
165 
166   @return An error code: 0 - success, otherwise - failure
167 
168   @ref Backend
169 **/
170 int CeedTensorContractReference(CeedTensorContract contract) {
171   contract->ref_count++;
172   return CEED_ERROR_SUCCESS;
173 }
174 
175 /**
176   @brief Copy the pointer to a CeedTensorContract.
177 
178   Both pointers should be destroyed with `CeedTensorContractDestroy()`.
179 
180   Note: If the value of `tensor_copy` passed to this function is non-NULL, then it is assumed that `tensor_copy` is a pointer to a CeedTensorContract.
181         This CeedTensorContract will be destroyed if `tensor_copy` is the only reference to this CeedVector.
182 
183   @param[in]     tensor      CeedTensorContract to copy reference to
184   @param[in,out] tensor_copy Variable to store copied reference
185 
186   @return An error code: 0 - success, otherwise - failure
187 
188   @ref User
189 **/
190 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
191   CeedCall(CeedTensorContractReference(tensor));
192   CeedCall(CeedTensorContractDestroy(tensor_copy));
193   *tensor_copy = tensor;
194   return CEED_ERROR_SUCCESS;
195 }
196 
197 /**
198   @brief Destroy a CeedTensorContract
199 
200   @param[in,out] contract CeedTensorContract to destroy
201 
202   @return An error code: 0 - success, otherwise - failure
203 
204   @ref Backend
205 **/
206 int CeedTensorContractDestroy(CeedTensorContract *contract) {
207   if (!*contract || --(*contract)->ref_count > 0) {
208     *contract = NULL;
209     return CEED_ERROR_SUCCESS;
210   }
211   if ((*contract)->Destroy) {
212     CeedCall((*contract)->Destroy(*contract));
213   }
214   CeedCall(CeedDestroy(&(*contract)->ceed));
215   CeedCall(CeedFree(contract));
216   return CEED_ERROR_SUCCESS;
217 }
218 
219 /// @}
220