xref: /libCEED/rust/libceed-sys/c-src/backends/ref/ceed-ref-basis.c (revision 3c708d80bcce8388043253c45b9dea101d51a504)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "ceed-ref.h"
15 
16 //------------------------------------------------------------------------------
17 // Basis Apply
18 //------------------------------------------------------------------------------
19 static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector U, CeedVector V) {
20   Ceed ceed;
21   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
22   CeedInt dim, num_comp, q_comp, num_nodes, num_qpts;
23   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
24   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
25   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
26   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
27   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
28   CeedTensorContract contract;
29   CeedCallBackend(CeedBasisGetTensorContract(basis, &contract));
30   const CeedInt     add = (t_mode == CEED_TRANSPOSE);
31   const CeedScalar *u;
32   CeedScalar       *v;
33   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_HOST, &u));
34   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
35   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_HOST, &v));
36 
37   // Clear v if operating in transpose
38   if (t_mode == CEED_TRANSPOSE) {
39     const CeedInt v_size = num_elem * num_comp * num_nodes;
40     for (CeedInt i = 0; i < v_size; i++) v[i] = (CeedScalar)0.0;
41   }
42   bool is_tensor_basis;
43   CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor_basis));
44   if (is_tensor_basis) {
45     // Tensor basis
46     CeedInt P_1d, Q_1d;
47     CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
48     CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
49     switch (eval_mode) {
50       // Interpolate to/from quadrature points
51       case CEED_EVAL_INTERP: {
52         CeedBasis_Ref *impl;
53         CeedCallBackend(CeedBasisGetData(basis, &impl));
54         if (impl->has_collo_interp) {
55           memcpy(v, u, num_elem * num_comp * num_nodes * sizeof(u[0]));
56         } else {
57           CeedInt P = P_1d, Q = Q_1d;
58           if (t_mode == CEED_TRANSPOSE) {
59             P = Q_1d;
60             Q = P_1d;
61           }
62           CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
63           CeedScalar        tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
64           const CeedScalar *interp_1d;
65           CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
66           for (CeedInt d = 0; d < dim; d++) {
67             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, interp_1d, t_mode, add && (d == dim - 1), d == 0 ? u : tmp[d % 2],
68                                                     d == dim - 1 ? v : tmp[(d + 1) % 2]));
69             pre /= P;
70             post *= Q;
71           }
72         }
73       } break;
74       // Evaluate the gradient to/from quadrature points
75       case CEED_EVAL_GRAD: {
76         // In CEED_NOTRANSPOSE mode:
77         // u has shape [dim, num_comp, P^dim, num_elem], row-major layout
78         // v has shape [dim, num_comp, Q^dim, num_elem], row-major layout
79         // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
80         CeedInt P = P_1d, Q = Q_1d;
81         if (t_mode == CEED_TRANSPOSE) {
82           P = Q_1d, Q = Q_1d;
83         }
84         CeedBasis_Ref *impl;
85         CeedCallBackend(CeedBasisGetData(basis, &impl));
86         CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
87         const CeedScalar *interp_1d;
88         CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
89         if (impl->collo_grad_1d) {
90           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
91           CeedScalar interp[num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
92           // Interpolate to quadrature points (NoTranspose)
93           //  or Grad to quadrature points (Transpose)
94           for (CeedInt d = 0; d < dim; d++) {
95             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? interp_1d : impl->collo_grad_1d), t_mode,
96                                                     add && (d > 0),
97                                                     (t_mode == CEED_NOTRANSPOSE ? (d == 0 ? u : tmp[d % 2]) : u + d * num_qpts * num_comp * num_elem),
98                                                     (t_mode == CEED_NOTRANSPOSE ? (d == dim - 1 ? interp : tmp[(d + 1) % 2]) : interp)));
99             pre /= P;
100             post *= Q;
101           }
102           // Grad to quadrature points (NoTranspose)
103           //  or Interpolate to nodes (Transpose)
104           P = Q_1d, Q = Q_1d;
105           if (t_mode == CEED_TRANSPOSE) {
106             P = Q_1d, Q = P_1d;
107           }
108           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
109           for (CeedInt d = 0; d < dim; d++) {
110             CeedCallBackend(CeedTensorContractApply(
111                 contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? impl->collo_grad_1d : interp_1d), t_mode, add && (d == dim - 1),
112                 (t_mode == CEED_NOTRANSPOSE ? interp : (d == 0 ? interp : tmp[d % 2])),
113                 (t_mode == CEED_NOTRANSPOSE ? v + d * num_qpts * num_comp * num_elem : (d == dim - 1 ? v : tmp[(d + 1) % 2]))));
114             pre /= P;
115             post *= Q;
116           }
117         } else if (impl->has_collo_interp) {  // Qpts collocated with nodes
118           const CeedScalar *grad_1d;
119           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
120 
121           // Dim contractions, identity in other directions
122           CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
123           for (CeedInt d = 0; d < dim; d++) {
124             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, grad_1d, t_mode, add && (d > 0),
125                                                     t_mode == CEED_NOTRANSPOSE ? u : u + d * num_comp * num_qpts * num_elem,
126                                                     t_mode == CEED_TRANSPOSE ? v : v + d * num_comp * num_qpts * num_elem));
127             pre /= P;
128             post *= Q;
129           }
130         } else {  // Underintegration, P > Q
131           const CeedScalar *grad_1d;
132           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
133 
134           if (t_mode == CEED_TRANSPOSE) {
135             P = Q_1d, Q = P_1d;
136           }
137           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
138 
139           // Dim**2 contractions, apply grad when pass == dim
140           for (CeedInt p = 0; p < dim; p++) {
141             CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
142             for (CeedInt d = 0; d < dim; d++) {
143               CeedCallBackend(CeedTensorContractApply(
144                   contract, pre, P, post, Q, (p == d) ? grad_1d : interp_1d, t_mode, add && (d == dim - 1),
145                   (d == 0 ? (t_mode == CEED_NOTRANSPOSE ? u : u + p * num_comp * num_qpts * num_elem) : tmp[d % 2]),
146                   (d == dim - 1 ? (t_mode == CEED_TRANSPOSE ? v : v + p * num_comp * num_qpts * num_elem) : tmp[(d + 1) % 2])));
147               pre /= P;
148               post *= Q;
149             }
150           }
151         }
152       } break;
153       // Retrieve interpolation weights
154       case CEED_EVAL_WEIGHT: {
155         CeedCheck(t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
156         CeedInt           Q = Q_1d;
157         const CeedScalar *q_weight_1d;
158         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight_1d));
159         for (CeedInt d = 0; d < dim; d++) {
160           CeedInt pre = CeedIntPow(Q, dim - d - 1), post = CeedIntPow(Q, d);
161           for (CeedInt i = 0; i < pre; i++) {
162             for (CeedInt j = 0; j < Q; j++) {
163               for (CeedInt k = 0; k < post; k++) {
164                 CeedScalar w = q_weight_1d[j] * (d == 0 ? 1 : v[((i * Q + j) * post + k) * num_elem]);
165                 for (CeedInt e = 0; e < num_elem; e++) v[((i * Q + j) * post + k) * num_elem + e] = w;
166               }
167             }
168           }
169         }
170       } break;
171       // LCOV_EXCL_START
172       // Evaluate the divergence to/from the quadrature points
173       case CEED_EVAL_DIV:
174         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
175       // Evaluate the curl to/from the quadrature points
176       case CEED_EVAL_CURL:
177         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
178       // Take no action, BasisApply should not have been called
179       case CEED_EVAL_NONE:
180         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
181         // LCOV_EXCL_STOP
182     }
183   } else {
184     // Non-tensor basis
185     CeedInt P = num_nodes, Q = num_qpts;
186     switch (eval_mode) {
187       // Interpolate to/from quadrature points
188       case CEED_EVAL_INTERP: {
189         const CeedScalar *interp;
190         CeedCallBackend(CeedBasisGetInterp(basis, &interp));
191         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, interp, t_mode, add, u, v));
192       } break;
193       // Evaluate the gradient to/from quadrature points
194       case CEED_EVAL_GRAD: {
195         const CeedScalar *grad;
196         CeedCallBackend(CeedBasisGetGrad(basis, &grad));
197         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, grad, t_mode, add, u, v));
198       } break;
199       // Evaluate the divergence to/from the quadrature points
200       case CEED_EVAL_DIV: {
201         const CeedScalar *div;
202         CeedCallBackend(CeedBasisGetDiv(basis, &div));
203         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, div, t_mode, add, u, v));
204       } break;
205       // Evaluate the curl to/from the quadrature points
206       case CEED_EVAL_CURL: {
207         const CeedScalar *curl;
208         CeedCallBackend(CeedBasisGetCurl(basis, &curl));
209         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, curl, t_mode, add, u, v));
210       } break;
211       // Retrieve interpolation weights
212       case CEED_EVAL_WEIGHT: {
213         CeedCheck(t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
214         const CeedScalar *q_weight;
215         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight));
216         for (CeedInt i = 0; i < num_qpts; i++) {
217           for (CeedInt e = 0; e < num_elem; e++) v[i * num_elem + e] = q_weight[i];
218         }
219       } break;
220       // LCOV_EXCL_START
221       // Take no action, BasisApply should not have been called
222       case CEED_EVAL_NONE:
223         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
224         // LCOV_EXCL_STOP
225     }
226   }
227   if (U != CEED_VECTOR_NONE) {
228     CeedCallBackend(CeedVectorRestoreArrayRead(U, &u));
229   }
230   CeedCallBackend(CeedVectorRestoreArray(V, &v));
231 
232   return CEED_ERROR_SUCCESS;
233 }
234 
235 //------------------------------------------------------------------------------
236 // Basis Destroy Tensor
237 //------------------------------------------------------------------------------
238 static int CeedBasisDestroyTensor_Ref(CeedBasis basis) {
239   CeedBasis_Ref *impl;
240   CeedCallBackend(CeedBasisGetData(basis, &impl));
241   CeedCallBackend(CeedFree(&impl->collo_grad_1d));
242   CeedCallBackend(CeedFree(&impl));
243 
244   return CEED_ERROR_SUCCESS;
245 }
246 
247 //------------------------------------------------------------------------------
248 // Basis Create Tensor
249 //------------------------------------------------------------------------------
250 int CeedBasisCreateTensorH1_Ref(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
251                                 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
252   Ceed ceed;
253   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
254   CeedBasis_Ref *impl;
255   CeedCallBackend(CeedCalloc(1, &impl));
256   // Check for collocated interp
257   if (Q_1d == P_1d) {
258     bool collocated = 1;
259     for (CeedInt i = 0; i < P_1d; i++) {
260       collocated = collocated && (fabs(interp_1d[i + P_1d * i] - 1.0) < 1e-14);
261       for (CeedInt j = 0; j < P_1d; j++) {
262         if (j != i) collocated = collocated && (fabs(interp_1d[j + P_1d * i]) < 1e-14);
263       }
264     }
265     impl->has_collo_interp = collocated;
266   }
267   // Calculate collocated grad
268   if (Q_1d >= P_1d && !impl->has_collo_interp) {
269     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &impl->collo_grad_1d));
270     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, impl->collo_grad_1d));
271   }
272   CeedCallBackend(CeedBasisSetData(basis, impl));
273 
274   Ceed parent;
275   CeedCallBackend(CeedGetParent(ceed, &parent));
276   CeedTensorContract contract;
277   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
278   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
279 
280   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
281   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyTensor_Ref));
282 
283   return CEED_ERROR_SUCCESS;
284 }
285 
286 //------------------------------------------------------------------------------
287 // Basis Create Non-Tensor H^1
288 //------------------------------------------------------------------------------
289 int CeedBasisCreateH1_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
290                           const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
291   Ceed ceed;
292   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
293 
294   Ceed parent;
295   CeedCallBackend(CeedGetParent(ceed, &parent));
296   CeedTensorContract contract;
297   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
298   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
299 
300   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
301 
302   return CEED_ERROR_SUCCESS;
303 }
304 
305 //------------------------------------------------------------------------------
306 // Basis Create Non-Tensor H(div)
307 //------------------------------------------------------------------------------
308 int CeedBasisCreateHdiv_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *div,
309                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
310   Ceed ceed;
311   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
312 
313   Ceed parent;
314   CeedCallBackend(CeedGetParent(ceed, &parent));
315   CeedTensorContract contract;
316   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
317   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
318 
319   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
320 
321   return CEED_ERROR_SUCCESS;
322 }
323 
324 //------------------------------------------------------------------------------
325 // Basis Create Non-Tensor H(curl)
326 //------------------------------------------------------------------------------
327 int CeedBasisCreateHcurl_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
328                              const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
329   Ceed ceed;
330   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
331 
332   Ceed parent;
333   CeedCallBackend(CeedGetParent(ceed, &parent));
334   CeedTensorContract contract;
335   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
336   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
337 
338   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
339 
340   return CEED_ERROR_SUCCESS;
341 }
342 
343 //------------------------------------------------------------------------------
344