xref: /libCEED/interface/ceed-tensor.c (revision 2c2ea1dbee80fceecd2c97f30b09f8c87820a53e)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a CeedTensorContract object for a CeedBasis
24 
25   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26   @param[in]  basis    CeedBasis for which the tensor contraction will be used
27   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
34   if (!ceed->TensorContractCreate) {
35     Ceed delegate;
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37 
38     if (!delegate) {
39       // LCOV_EXCL_START
40       return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
41       // LCOV_EXCL_STOP
42     }
43 
44     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
45     return CEED_ERROR_SUCCESS;
46   }
47 
48   CeedCall(CeedCalloc(1, contract));
49 
50   (*contract)->ceed = ceed;
51   CeedCall(CeedReference(ceed));
52   CeedCall(ceed->TensorContractCreate(basis, *contract));
53   return CEED_ERROR_SUCCESS;
54 }
55 
56 /**
57   @brief Apply tensor contraction
58 
59     Contracts on the middle index
60     NOTRANSPOSE: v_ajc = t_jb u_abc
61     TRANSPOSE:   v_ajc = t_bj u_abc
62     If add != 0, "=" is replaced by "+="
63 
64   @param[in]  contract CeedTensorContract to use
65   @param[in]  A        First index of u, v
66   @param[in]  B        Middle index of u, one index of t
67   @param[in]  C        Last index of u, v
68   @param[in]  J        Middle index of v, one index of t
69   @param[in]  t        Tensor array to contract against
70   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
71   @param[in]  add      Add mode
72   @param[in]  u        Input array
73   @param[out] v        Output array
74 
75   @return An error code: 0 - success, otherwise - failure
76 
77   @ref Backend
78 **/
79 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
80                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
81   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
82   return CEED_ERROR_SUCCESS;
83 }
84 
85 /**
86   @brief Apply tensor contraction
87 
88     Contracts on the middle index
89     NOTRANSPOSE: v_dajc = t_djb u_abc
90     TRANSPOSE:   v_ajc  = t_dbj u_dabc
91     If add != 0, "=" is replaced by "+="
92 
93   @param[in]  contract CeedTensorContract to use
94   @param[in]  A        First index of u, second index of v
95   @param[in]  B        Middle index of u, one of last two indices of t
96   @param[in]  C        Last index of u, v
97   @param[in]  D        First index of v, first index of t
98   @param[in]  J        Third index of v, one of last two indices of t
99   @param[in]  t        Tensor array to contract against
100   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
101   @param[in]  add      Add mode
102   @param[in]  u        Input array
103   @param[out] v        Output array
104 
105   @return An error code: 0 - success, otherwise - failure
106 
107   @ref Backend
108 **/
109 int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
110                                    CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
111   if (t_mode == CEED_TRANSPOSE) {
112     for (CeedInt d = 0; d < D; d++) {
113       CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
114     }
115   } else {
116     for (CeedInt d = 0; d < D; d++) {
117       CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
118     }
119   }
120   return CEED_ERROR_SUCCESS;
121 }
122 
123 /**
124   @brief Get Ceed associated with a CeedTensorContract
125 
126   @param[in]  contract CeedTensorContract
127   @param[out] ceed     Variable to store Ceed
128 
129   @return An error code: 0 - success, otherwise - failure
130 
131   @ref Backend
132 **/
133 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
134   *ceed = contract->ceed;
135   return CEED_ERROR_SUCCESS;
136 }
137 
138 /**
139   @brief Get backend data of a CeedTensorContract
140 
141   @param[in]  contract CeedTensorContract
142   @param[out] data     Variable to store data
143 
144   @return An error code: 0 - success, otherwise - failure
145 
146   @ref Backend
147 **/
148 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
149   *(void **)data = contract->data;
150   return CEED_ERROR_SUCCESS;
151 }
152 
153 /**
154   @brief Set backend data of a CeedTensorContract
155 
156   @param[in,out] contract CeedTensorContract
157   @param[in]     data     Data to set
158 
159   @return An error code: 0 - success, otherwise - failure
160 
161   @ref Backend
162 **/
163 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
164   contract->data = data;
165   return CEED_ERROR_SUCCESS;
166 }
167 
168 /**
169   @brief Increment the reference counter for a CeedTensorContract
170 
171   @param[in,out] contract CeedTensorContract to increment the reference counter
172 
173   @return An error code: 0 - success, otherwise - failure
174 
175   @ref Backend
176 **/
177 int CeedTensorContractReference(CeedTensorContract contract) {
178   contract->ref_count++;
179   return CEED_ERROR_SUCCESS;
180 }
181 
182 /**
183   @brief Destroy a CeedTensorContract
184 
185   @param[in,out] contract CeedTensorContract to destroy
186 
187   @return An error code: 0 - success, otherwise - failure
188 
189   @ref Backend
190 **/
191 int CeedTensorContractDestroy(CeedTensorContract *contract) {
192   if (!*contract || --(*contract)->ref_count > 0) {
193     *contract = NULL;
194     return CEED_ERROR_SUCCESS;
195   }
196   if ((*contract)->Destroy) {
197     CeedCall((*contract)->Destroy(*contract));
198   }
199   CeedCall(CeedDestroy(&(*contract)->ceed));
200   CeedCall(CeedFree(contract));
201   return CEED_ERROR_SUCCESS;
202 }
203 
204 /// @}
205