xref: /libCEED/interface/ceed-tensor.c (revision d075f50ba6d3b1e38c233860adb1de6c814f0afc)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a CeedTensorContract object for a CeedBasis
24 
25   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
27 
28   @return An error code: 0 - success, otherwise - failure
29 
30   @ref Backend
31 **/
32 int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
33   if (!ceed->TensorContractCreate) {
34     Ceed delegate;
35 
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37     CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
38     CeedCall(CeedTensorContractCreate(delegate, contract));
39     return CEED_ERROR_SUCCESS;
40   }
41 
42   CeedCall(CeedCalloc(1, contract));
43   CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
44   CeedCall(ceed->TensorContractCreate(*contract));
45   return CEED_ERROR_SUCCESS;
46 }
47 
48 /**
49   @brief Apply tensor contraction
50 
51   Contracts on the middle index
52   NOTRANSPOSE: v_ajc = t_jb u_abc
53   TRANSPOSE:   v_ajc = t_bj u_abc
54   If add != 0, "=" is replaced by "+="
55 
56   @param[in]  contract CeedTensorContract to use
57   @param[in]  A        First index of u, v
58   @param[in]  B        Middle index of u, one index of t
59   @param[in]  C        Last index of u, v
60   @param[in]  J        Middle index of v, one index of t
61   @param[in]  t        Tensor array to contract against
62   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
63   @param[in]  add      Add mode
64   @param[in]  u        Input array
65   @param[out] v        Output array
66 
67   @return An error code: 0 - success, otherwise - failure
68 
69   @ref Backend
70 **/
71 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
72                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
73   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
74   return CEED_ERROR_SUCCESS;
75 }
76 
77 /**
78   @brief Apply tensor contraction
79 
80   Contracts on the middle index
81   NOTRANSPOSE: v_dajc = t_djb u_abc
82   TRANSPOSE:   v_ajc  = t_dbj u_dabc
83   If add != 0, "=" is replaced by "+="
84 
85   @param[in]  contract CeedTensorContract to use
86   @param[in]  A        First index of u, second index of v
87   @param[in]  B        Middle index of u, one of last two indices of t
88   @param[in]  C        Last index of u, v
89   @param[in]  D        First index of v, first index of t
90   @param[in]  J        Third index of v, one of last two indices of t
91   @param[in]  t        Tensor array to contract against
92   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_djb \ref CEED_TRANSPOSE for t_dbj
93   @param[in]  add      Add mode
94   @param[in]  u        Input array
95   @param[out] v        Output array
96 
97   @return An error code: 0 - success, otherwise - failure
98 
99   @ref Backend
100 **/
101 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
102                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
103   if (t_mode == CEED_TRANSPOSE) {
104     for (CeedInt d = 0; d < D; d++) {
105       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
106     }
107   } else {
108     for (CeedInt d = 0; d < D; d++) {
109       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
110     }
111   }
112   return CEED_ERROR_SUCCESS;
113 }
114 
115 /**
116   @brief Get Ceed associated with a CeedTensorContract
117 
118   @param[in]  contract CeedTensorContract
119   @param[out] ceed     Variable to store Ceed
120 
121   @return An error code: 0 - success, otherwise - failure
122 
123   @ref Backend
124 **/
125 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
126   *ceed = contract->ceed;
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 /**
131   @brief Get backend data of a CeedTensorContract
132 
133   @param[in]  contract CeedTensorContract
134   @param[out] data     Variable to store data
135 
136   @return An error code: 0 - success, otherwise - failure
137 
138   @ref Backend
139 **/
140 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
141   *(void **)data = contract->data;
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 /**
146   @brief Set backend data of a CeedTensorContract
147 
148   @param[in,out] contract CeedTensorContract
149   @param[in]     data     Data to set
150 
151   @return An error code: 0 - success, otherwise - failure
152 
153   @ref Backend
154 **/
155 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
156   contract->data = data;
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 /**
161   @brief Increment the reference counter for a CeedTensorContract
162 
163   @param[in,out] contract CeedTensorContract to increment the reference counter
164 
165   @return An error code: 0 - success, otherwise - failure
166 
167   @ref Backend
168 **/
169 int CeedTensorContractReference(CeedTensorContract contract) {
170   contract->ref_count++;
171   return CEED_ERROR_SUCCESS;
172 }
173 
174 /**
175   @brief Copy the pointer to a CeedTensorContract.
176 
177   Both pointers should be destroyed with `CeedTensorContractDestroy()`.
178 
179   Note: If the value of `tensor_copy` passed to this function is non-NULL, then it is assumed that `tensor_copy` is a pointer to a CeedTensorContract.
180         This CeedTensorContract will be destroyed if `tensor_copy` is the only reference to this CeedVector.
181 
182   @param[in]     tensor      CeedTensorContract to copy reference to
183   @param[in,out] tensor_copy Variable to store copied reference
184 
185   @return An error code: 0 - success, otherwise - failure
186 
187   @ref User
188 **/
189 int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
190   CeedCall(CeedTensorContractReference(tensor));
191   CeedCall(CeedTensorContractDestroy(tensor_copy));
192   *tensor_copy = tensor;
193   return CEED_ERROR_SUCCESS;
194 }
195 
196 /**
197   @brief Destroy a CeedTensorContract
198 
199   @param[in,out] contract CeedTensorContract to destroy
200 
201   @return An error code: 0 - success, otherwise - failure
202 
203   @ref Backend
204 **/
205 int CeedTensorContractDestroy(CeedTensorContract *contract) {
206   if (!*contract || --(*contract)->ref_count > 0) {
207     *contract = NULL;
208     return CEED_ERROR_SUCCESS;
209   }
210   if ((*contract)->Destroy) {
211     CeedCall((*contract)->Destroy(*contract));
212   }
213   CeedCall(CeedDestroy(&(*contract)->ceed));
214   CeedCall(CeedFree(contract));
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 /// @}
219