xref: /libCEED/backends/ref/ceed-ref-basis.c (revision 9ba83ac0e4b1fca39d6fa6737a318a9f0cbc172d)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "ceed-ref.h"
15 
16 //------------------------------------------------------------------------------
17 // Basis Apply
18 //------------------------------------------------------------------------------
19 static int CeedBasisApplyCore_Ref(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector U,
20                                   CeedVector V) {
21   bool               is_tensor_basis, add = apply_add || (t_mode == CEED_TRANSPOSE);
22   CeedInt            dim, num_comp, q_comp, num_nodes, num_qpts;
23   const CeedScalar  *u;
24   CeedScalar        *v;
25   CeedTensorContract contract;
26   CeedBasis_Ref     *impl;
27 
28   CeedCallBackend(CeedBasisGetData(basis, &impl));
29   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
30   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
31   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
32   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
33   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
34   CeedCallBackend(CeedBasisGetTensorContract(basis, &contract));
35   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_HOST, &u));
36   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
37   // Clear v if operating in transpose
38   if (apply_add) CeedCallBackend(CeedVectorGetArray(V, CEED_MEM_HOST, &v));
39   else CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_HOST, &v));
40 
41   if (t_mode == CEED_TRANSPOSE && !apply_add) {
42     CeedSize len;
43 
44     CeedCallBackend(CeedVectorGetLength(V, &len));
45     for (CeedInt i = 0; i < len; i++) v[i] = 0.0;
46   }
47 
48   CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor_basis));
49   if (is_tensor_basis) {
50     // Tensor basis
51     CeedInt P_1d, Q_1d;
52 
53     CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
54     CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
55     switch (eval_mode) {
56       // Interpolate to/from quadrature points
57       case CEED_EVAL_INTERP: {
58         if (impl->is_collocated) {
59           memcpy(v, u, num_elem * num_comp * num_nodes * sizeof(u[0]));
60         } else {
61           CeedInt P = P_1d, Q = Q_1d;
62 
63           if (t_mode == CEED_TRANSPOSE) {
64             P = Q_1d;
65             Q = P_1d;
66           }
67           CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
68           CeedScalar        tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
69           const CeedScalar *interp_1d;
70 
71           CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
72           for (CeedInt d = 0; d < dim; d++) {
73             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, interp_1d, t_mode, add && (d == dim - 1), d == 0 ? u : tmp[d % 2],
74                                                     d == dim - 1 ? v : tmp[(d + 1) % 2]));
75             pre /= P;
76             post *= Q;
77           }
78         }
79       } break;
80       // Evaluate the gradient to/from quadrature points
81       case CEED_EVAL_GRAD: {
82         // In CEED_NOTRANSPOSE mode:
83         // u has shape [dim, num_comp, P^dim, num_elem], row-major layout
84         // v has shape [dim, num_comp, Q^dim, num_elem], row-major layout
85         // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
86         CeedInt P = P_1d, Q = Q_1d;
87 
88         if (t_mode == CEED_TRANSPOSE) {
89           P = Q_1d;
90           Q = Q_1d;
91         }
92         CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
93         const CeedScalar *interp_1d;
94 
95         CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
96         if (impl->collo_grad_1d) {
97           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
98           CeedScalar interp[num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
99 
100           // Interpolate to quadrature points (NoTranspose)
101           //  or Grad to quadrature points (Transpose)
102           for (CeedInt d = 0; d < dim; d++) {
103             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? interp_1d : impl->collo_grad_1d), t_mode,
104                                                     (t_mode == CEED_TRANSPOSE) && (d > 0),
105                                                     (t_mode == CEED_NOTRANSPOSE ? (d == 0 ? u : tmp[d % 2]) : &u[d * num_qpts * num_comp * num_elem]),
106                                                     (t_mode == CEED_NOTRANSPOSE ? (d == dim - 1 ? interp : tmp[(d + 1) % 2]) : interp)));
107             pre /= P;
108             post *= Q;
109           }
110           // Grad to quadrature points (NoTranspose)
111           //  or Interpolate to nodes (Transpose)
112           P = Q_1d, Q = Q_1d;
113           if (t_mode == CEED_TRANSPOSE) {
114             P = Q_1d;
115             Q = P_1d;
116           }
117           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
118           for (CeedInt d = 0; d < dim; d++) {
119             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? impl->collo_grad_1d : interp_1d), t_mode,
120                                                     (t_mode == CEED_NOTRANSPOSE && apply_add) || (t_mode == CEED_TRANSPOSE && (d == dim - 1)),
121                                                     (t_mode == CEED_NOTRANSPOSE ? interp : (d == 0 ? interp : tmp[d % 2])),
122                                                     (t_mode == CEED_NOTRANSPOSE ? &v[d * num_qpts * num_comp * num_elem]
123                                                                                 : (d == dim - 1 ? v : tmp[(d + 1) % 2]))));
124             pre /= P;
125             post *= Q;
126           }
127         } else if (impl->is_collocated) {  // Qpts collocated with nodes
128           const CeedScalar *grad_1d;
129 
130           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
131 
132           // Dim contractions, identity in other directions
133           CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
134 
135           for (CeedInt d = 0; d < dim; d++) {
136             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, grad_1d, t_mode, add && (d > 0),
137                                                     t_mode == CEED_NOTRANSPOSE ? u : &u[d * num_comp * num_qpts * num_elem],
138                                                     t_mode == CEED_TRANSPOSE ? v : &v[d * num_comp * num_qpts * num_elem]));
139             pre /= P;
140             post *= Q;
141           }
142         } else {  // Underintegration, P > Q
143           const CeedScalar *grad_1d;
144 
145           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
146 
147           if (t_mode == CEED_TRANSPOSE) {
148             P = Q_1d;
149             Q = P_1d;
150           }
151           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
152 
153           // Dim**2 contractions, apply grad when pass == dim
154           for (CeedInt p = 0; p < dim; p++) {
155             CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
156 
157             for (CeedInt d = 0; d < dim; d++) {
158               CeedCallBackend(CeedTensorContractApply(
159                   contract, pre, P, post, Q, (p == d) ? grad_1d : interp_1d, t_mode, add && (d == dim - 1),
160                   (d == 0 ? (t_mode == CEED_NOTRANSPOSE ? u : &u[p * num_comp * num_qpts * num_elem]) : tmp[d % 2]),
161                   (d == dim - 1 ? (t_mode == CEED_TRANSPOSE ? v : &v[p * num_comp * num_qpts * num_elem]) : tmp[(d + 1) % 2])));
162               pre /= P;
163               post *= Q;
164             }
165           }
166         }
167       } break;
168       // Retrieve interpolation weights
169       case CEED_EVAL_WEIGHT: {
170         CeedInt           Q = Q_1d;
171         const CeedScalar *q_weight_1d;
172 
173         CeedCheck(t_mode == CEED_NOTRANSPOSE, CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
174         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight_1d));
175         for (CeedInt d = 0; d < dim; d++) {
176           CeedInt pre = CeedIntPow(Q, dim - d - 1), post = CeedIntPow(Q, d);
177 
178           for (CeedInt i = 0; i < pre; i++) {
179             for (CeedInt j = 0; j < Q; j++) {
180               for (CeedInt k = 0; k < post; k++) {
181                 const CeedScalar w = q_weight_1d[j] * (d == 0 ? 1 : v[((i * Q + j) * post + k) * num_elem]);
182 
183                 for (CeedInt e = 0; e < num_elem; e++) v[((i * Q + j) * post + k) * num_elem + e] = w;
184               }
185             }
186           }
187         }
188       } break;
189       // LCOV_EXCL_START
190       case CEED_EVAL_DIV:
191       case CEED_EVAL_CURL:
192         return CeedError(CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
193       case CEED_EVAL_NONE:
194         return CeedError(CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
195         // LCOV_EXCL_STOP
196     }
197   } else {
198     // Non-tensor basis
199     CeedInt P = num_nodes, Q = num_qpts;
200 
201     switch (eval_mode) {
202       // Interpolate to/from quadrature points
203       case CEED_EVAL_INTERP: {
204         const CeedScalar *interp;
205 
206         CeedCallBackend(CeedBasisGetInterp(basis, &interp));
207         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, interp, t_mode, add, u, v));
208       } break;
209       // Evaluate the gradient to/from quadrature points
210       case CEED_EVAL_GRAD: {
211         const CeedScalar *grad;
212 
213         CeedCallBackend(CeedBasisGetGrad(basis, &grad));
214         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, grad, t_mode, add, u, v));
215       } break;
216       // Evaluate the divergence to/from the quadrature points
217       case CEED_EVAL_DIV: {
218         const CeedScalar *div;
219 
220         CeedCallBackend(CeedBasisGetDiv(basis, &div));
221         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, div, t_mode, add, u, v));
222       } break;
223       // Evaluate the curl to/from the quadrature points
224       case CEED_EVAL_CURL: {
225         const CeedScalar *curl;
226 
227         CeedCallBackend(CeedBasisGetCurl(basis, &curl));
228         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, curl, t_mode, add, u, v));
229       } break;
230       // Retrieve interpolation weights
231       case CEED_EVAL_WEIGHT: {
232         const CeedScalar *q_weight;
233 
234         CeedCheck(t_mode == CEED_NOTRANSPOSE, CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
235         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight));
236         for (CeedInt i = 0; i < num_qpts; i++) {
237           for (CeedInt e = 0; e < num_elem; e++) v[i * num_elem + e] = q_weight[i];
238         }
239       } break;
240       // LCOV_EXCL_START
241       case CEED_EVAL_NONE:
242         return CeedError(CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
243         // LCOV_EXCL_STOP
244     }
245   }
246   if (U != CEED_VECTOR_NONE) {
247     CeedCallBackend(CeedVectorRestoreArrayRead(U, &u));
248   }
249   CeedCallBackend(CeedVectorRestoreArray(V, &v));
250   return CEED_ERROR_SUCCESS;
251 }
252 
253 static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector U, CeedVector V) {
254   CeedCallBackend(CeedBasisApplyCore_Ref(basis, false, num_elem, t_mode, eval_mode, U, V));
255   return CEED_ERROR_SUCCESS;
256 }
257 
258 static int CeedBasisApplyAdd_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector U, CeedVector V) {
259   CeedCallBackend(CeedBasisApplyCore_Ref(basis, true, num_elem, t_mode, eval_mode, U, V));
260   return CEED_ERROR_SUCCESS;
261 }
262 
263 //------------------------------------------------------------------------------
264 // Basis Destroy Tensor
265 //------------------------------------------------------------------------------
266 static int CeedBasisDestroyTensor_Ref(CeedBasis basis) {
267   CeedBasis_Ref *impl;
268 
269   CeedCallBackend(CeedBasisGetData(basis, &impl));
270   CeedCallBackend(CeedFree(&impl->collo_grad_1d));
271   CeedCallBackend(CeedFree(&impl));
272   return CEED_ERROR_SUCCESS;
273 }
274 
275 //------------------------------------------------------------------------------
276 // Basis Create Tensor
277 //------------------------------------------------------------------------------
278 int CeedBasisCreateTensorH1_Ref(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
279                                 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
280   Ceed               ceed, ceed_parent;
281   CeedBasis_Ref     *impl;
282   CeedTensorContract contract;
283 
284   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
285   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
286 
287   CeedCallBackend(CeedCalloc(1, &impl));
288   // Calculate collocated grad
289   CeedCallBackend(CeedBasisIsCollocated(basis, &impl->is_collocated));
290   if (Q_1d >= P_1d && !impl->is_collocated) {
291     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &impl->collo_grad_1d));
292     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, impl->collo_grad_1d));
293   }
294   CeedCallBackend(CeedBasisSetData(basis, impl));
295 
296   CeedCallBackend(CeedTensorContractCreate(ceed_parent, &contract));
297   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
298 
299   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
300   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Ref));
301   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyTensor_Ref));
302   CeedCallBackend(CeedDestroy(&ceed));
303   CeedCallBackend(CeedDestroy(&ceed_parent));
304   return CEED_ERROR_SUCCESS;
305 }
306 
307 //------------------------------------------------------------------------------
308 // Basis Create Non-Tensor H^1
309 //------------------------------------------------------------------------------
310 int CeedBasisCreateH1_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
311                           const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
312   Ceed               ceed, ceed_parent;
313   CeedTensorContract contract;
314 
315   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
316   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
317 
318   CeedCallBackend(CeedTensorContractCreate(ceed_parent, &contract));
319   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
320 
321   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
322   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Ref));
323   CeedCallBackend(CeedDestroy(&ceed));
324   CeedCallBackend(CeedDestroy(&ceed_parent));
325   return CEED_ERROR_SUCCESS;
326 }
327 
328 //------------------------------------------------------------------------------
329 // Basis Create Non-Tensor H(div)
330 //------------------------------------------------------------------------------
331 int CeedBasisCreateHdiv_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *div,
332                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
333   Ceed               ceed, ceed_parent;
334   CeedTensorContract contract;
335 
336   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
337   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
338 
339   CeedCallBackend(CeedTensorContractCreate(ceed_parent, &contract));
340   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
341 
342   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
343   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Ref));
344   CeedCallBackend(CeedDestroy(&ceed));
345   CeedCallBackend(CeedDestroy(&ceed_parent));
346   return CEED_ERROR_SUCCESS;
347 }
348 
349 //------------------------------------------------------------------------------
350 // Basis Create Non-Tensor H(curl)
351 //------------------------------------------------------------------------------
352 int CeedBasisCreateHcurl_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
353                              const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
354   Ceed               ceed, ceed_parent;
355   CeedTensorContract contract;
356 
357   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
358   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
359 
360   CeedCallBackend(CeedTensorContractCreate(ceed_parent, &contract));
361   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
362 
363   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
364   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Ref));
365   CeedCallBackend(CeedDestroy(&ceed));
366   CeedCallBackend(CeedDestroy(&ceed_parent));
367   return CEED_ERROR_SUCCESS;
368 }
369 
370 //------------------------------------------------------------------------------
371