xref: /libCEED/backends/ref/ceed-ref-basis.c (revision e64bb3f3ed2986a0c10dec3b47522d734c6e367d)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "ceed-ref.h"
15 
16 //------------------------------------------------------------------------------
17 // Basis Apply
18 //------------------------------------------------------------------------------
19 static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector U, CeedVector V) {
20   Ceed ceed;
21   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
22   CeedInt dim, num_comp, q_comp, num_nodes, num_qpts;
23   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
24   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
25   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
26   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
27   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
28   CeedTensorContract contract;
29   CeedCallBackend(CeedBasisGetTensorContract(basis, &contract));
30   const CeedInt     add = (t_mode == CEED_TRANSPOSE);
31   const CeedScalar *u;
32   CeedScalar       *v;
33   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_HOST, &u));
34   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
35   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_HOST, &v));
36 
37   // Clear v if operating in transpose
38   if (t_mode == CEED_TRANSPOSE) {
39     const CeedInt v_size = num_elem * num_comp * num_nodes;
40     for (CeedInt i = 0; i < v_size; i++) v[i] = (CeedScalar)0.0;
41   }
42   bool is_tensor_basis;
43   CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor_basis));
44   if (is_tensor_basis) {
45     // Tensor basis
46     CeedInt P_1d, Q_1d;
47     CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
48     CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
49     switch (eval_mode) {
50       // Interpolate to/from quadrature points
51       case CEED_EVAL_INTERP: {
52         CeedBasis_Ref *impl;
53         CeedCallBackend(CeedBasisGetData(basis, &impl));
54         if (impl->has_collo_interp) {
55           memcpy(v, u, num_elem * num_comp * num_nodes * sizeof(u[0]));
56         } else {
57           CeedInt P = P_1d, Q = Q_1d;
58           if (t_mode == CEED_TRANSPOSE) {
59             P = Q_1d;
60             Q = P_1d;
61           }
62           CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
63           CeedScalar        tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
64           const CeedScalar *interp_1d;
65           CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
66           for (CeedInt d = 0; d < dim; d++) {
67             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, interp_1d, t_mode, add && (d == dim - 1), d == 0 ? u : tmp[d % 2],
68                                                     d == dim - 1 ? v : tmp[(d + 1) % 2]));
69             pre /= P;
70             post *= Q;
71           }
72         }
73       } break;
74       // Evaluate the gradient to/from quadrature points
75       case CEED_EVAL_GRAD: {
76         // In CEED_NOTRANSPOSE mode:
77         // u has shape [dim, num_comp, P^dim, num_elem], row-major layout
78         // v has shape [dim, num_comp, Q^dim, num_elem], row-major layout
79         // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
80         CeedInt P = P_1d, Q = Q_1d;
81         if (t_mode == CEED_TRANSPOSE) {
82           P = Q_1d;
83           Q = Q_1d;
84         }
85         CeedBasis_Ref *impl;
86         CeedCallBackend(CeedBasisGetData(basis, &impl));
87         CeedInt           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
88         const CeedScalar *interp_1d;
89         CeedCallBackend(CeedBasisGetInterp1D(basis, &interp_1d));
90         if (impl->collo_grad_1d) {
91           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
92           CeedScalar interp[num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
93           // Interpolate to quadrature points (NoTranspose)
94           //  or Grad to quadrature points (Transpose)
95           for (CeedInt d = 0; d < dim; d++) {
96             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? interp_1d : impl->collo_grad_1d), t_mode,
97                                                     add && (d > 0),
98                                                     (t_mode == CEED_NOTRANSPOSE ? (d == 0 ? u : tmp[d % 2]) : u + d * num_qpts * num_comp * num_elem),
99                                                     (t_mode == CEED_NOTRANSPOSE ? (d == dim - 1 ? interp : tmp[(d + 1) % 2]) : interp)));
100             pre /= P;
101             post *= Q;
102           }
103           // Grad to quadrature points (NoTranspose)
104           //  or Interpolate to nodes (Transpose)
105           P = Q_1d, Q = Q_1d;
106           if (t_mode == CEED_TRANSPOSE) {
107             P = Q_1d;
108             Q = P_1d;
109           }
110           pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
111           for (CeedInt d = 0; d < dim; d++) {
112             CeedCallBackend(CeedTensorContractApply(
113                 contract, pre, P, post, Q, (t_mode == CEED_NOTRANSPOSE ? impl->collo_grad_1d : interp_1d), t_mode, add && (d == dim - 1),
114                 (t_mode == CEED_NOTRANSPOSE ? interp : (d == 0 ? interp : tmp[d % 2])),
115                 (t_mode == CEED_NOTRANSPOSE ? v + d * num_qpts * num_comp * num_elem : (d == dim - 1 ? v : tmp[(d + 1) % 2]))));
116             pre /= P;
117             post *= Q;
118           }
119         } else if (impl->has_collo_interp) {  // Qpts collocated with nodes
120           const CeedScalar *grad_1d;
121           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
122 
123           // Dim contractions, identity in other directions
124           CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
125           for (CeedInt d = 0; d < dim; d++) {
126             CeedCallBackend(CeedTensorContractApply(contract, pre, P, post, Q, grad_1d, t_mode, add && (d > 0),
127                                                     t_mode == CEED_NOTRANSPOSE ? u : u + d * num_comp * num_qpts * num_elem,
128                                                     t_mode == CEED_TRANSPOSE ? v : v + d * num_comp * num_qpts * num_elem));
129             pre /= P;
130             post *= Q;
131           }
132         } else {  // Underintegration, P > Q
133           const CeedScalar *grad_1d;
134           CeedCallBackend(CeedBasisGetGrad1D(basis, &grad_1d));
135 
136           if (t_mode == CEED_TRANSPOSE) {
137             P = Q_1d;
138             Q = P_1d;
139           }
140           CeedScalar tmp[2][num_elem * num_comp * Q * CeedIntPow(P > Q ? P : Q, dim - 1)];
141 
142           // Dim**2 contractions, apply grad when pass == dim
143           for (CeedInt p = 0; p < dim; p++) {
144             CeedInt pre = num_comp * CeedIntPow(P, dim - 1), post = num_elem;
145             for (CeedInt d = 0; d < dim; d++) {
146               CeedCallBackend(CeedTensorContractApply(
147                   contract, pre, P, post, Q, (p == d) ? grad_1d : interp_1d, t_mode, add && (d == dim - 1),
148                   (d == 0 ? (t_mode == CEED_NOTRANSPOSE ? u : u + p * num_comp * num_qpts * num_elem) : tmp[d % 2]),
149                   (d == dim - 1 ? (t_mode == CEED_TRANSPOSE ? v : v + p * num_comp * num_qpts * num_elem) : tmp[(d + 1) % 2])));
150               pre /= P;
151               post *= Q;
152             }
153           }
154         }
155       } break;
156       // Retrieve interpolation weights
157       case CEED_EVAL_WEIGHT: {
158         CeedCheck(t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
159         CeedInt           Q = Q_1d;
160         const CeedScalar *q_weight_1d;
161         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight_1d));
162         for (CeedInt d = 0; d < dim; d++) {
163           CeedInt pre = CeedIntPow(Q, dim - d - 1), post = CeedIntPow(Q, d);
164           for (CeedInt i = 0; i < pre; i++) {
165             for (CeedInt j = 0; j < Q; j++) {
166               for (CeedInt k = 0; k < post; k++) {
167                 CeedScalar w = q_weight_1d[j] * (d == 0 ? 1 : v[((i * Q + j) * post + k) * num_elem]);
168                 for (CeedInt e = 0; e < num_elem; e++) v[((i * Q + j) * post + k) * num_elem + e] = w;
169               }
170             }
171           }
172         }
173       } break;
174       // LCOV_EXCL_START
175       // Evaluate the divergence to/from the quadrature points
176       case CEED_EVAL_DIV:
177         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
178       // Evaluate the curl to/from the quadrature points
179       case CEED_EVAL_CURL:
180         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
181       // Take no action, BasisApply should not have been called
182       case CEED_EVAL_NONE:
183         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
184         // LCOV_EXCL_STOP
185     }
186   } else {
187     // Non-tensor basis
188     CeedInt P = num_nodes, Q = num_qpts;
189     switch (eval_mode) {
190       // Interpolate to/from quadrature points
191       case CEED_EVAL_INTERP: {
192         const CeedScalar *interp;
193         CeedCallBackend(CeedBasisGetInterp(basis, &interp));
194         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, interp, t_mode, add, u, v));
195       } break;
196       // Evaluate the gradient to/from quadrature points
197       case CEED_EVAL_GRAD: {
198         const CeedScalar *grad;
199         CeedCallBackend(CeedBasisGetGrad(basis, &grad));
200         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, grad, t_mode, add, u, v));
201       } break;
202       // Evaluate the divergence to/from the quadrature points
203       case CEED_EVAL_DIV: {
204         const CeedScalar *div;
205         CeedCallBackend(CeedBasisGetDiv(basis, &div));
206         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, div, t_mode, add, u, v));
207       } break;
208       // Evaluate the curl to/from the quadrature points
209       case CEED_EVAL_CURL: {
210         const CeedScalar *curl;
211         CeedCallBackend(CeedBasisGetCurl(basis, &curl));
212         CeedCallBackend(CeedTensorContractStridedApply(contract, num_comp, P, num_elem, q_comp, Q, curl, t_mode, add, u, v));
213       } break;
214       // Retrieve interpolation weights
215       case CEED_EVAL_WEIGHT: {
216         CeedCheck(t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
217         const CeedScalar *q_weight;
218         CeedCallBackend(CeedBasisGetQWeights(basis, &q_weight));
219         for (CeedInt i = 0; i < num_qpts; i++) {
220           for (CeedInt e = 0; e < num_elem; e++) v[i * num_elem + e] = q_weight[i];
221         }
222       } break;
223       // LCOV_EXCL_START
224       // Take no action, BasisApply should not have been called
225       case CEED_EVAL_NONE:
226         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
227         // LCOV_EXCL_STOP
228     }
229   }
230   if (U != CEED_VECTOR_NONE) {
231     CeedCallBackend(CeedVectorRestoreArrayRead(U, &u));
232   }
233   CeedCallBackend(CeedVectorRestoreArray(V, &v));
234 
235   return CEED_ERROR_SUCCESS;
236 }
237 
238 //------------------------------------------------------------------------------
239 // Basis Destroy Tensor
240 //------------------------------------------------------------------------------
241 static int CeedBasisDestroyTensor_Ref(CeedBasis basis) {
242   CeedBasis_Ref *impl;
243   CeedCallBackend(CeedBasisGetData(basis, &impl));
244   CeedCallBackend(CeedFree(&impl->collo_grad_1d));
245   CeedCallBackend(CeedFree(&impl));
246 
247   return CEED_ERROR_SUCCESS;
248 }
249 
250 //------------------------------------------------------------------------------
251 // Basis Create Tensor
252 //------------------------------------------------------------------------------
253 int CeedBasisCreateTensorH1_Ref(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
254                                 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
255   Ceed ceed;
256   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
257   CeedBasis_Ref *impl;
258   CeedCallBackend(CeedCalloc(1, &impl));
259   // Check for collocated interp
260   if (Q_1d == P_1d) {
261     bool collocated = 1;
262     for (CeedInt i = 0; i < P_1d; i++) {
263       collocated = collocated && (fabs(interp_1d[i + P_1d * i] - 1.0) < 1e-14);
264       for (CeedInt j = 0; j < P_1d; j++) {
265         if (j != i) collocated = collocated && (fabs(interp_1d[j + P_1d * i]) < 1e-14);
266       }
267     }
268     impl->has_collo_interp = collocated;
269   }
270   // Calculate collocated grad
271   if (Q_1d >= P_1d && !impl->has_collo_interp) {
272     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &impl->collo_grad_1d));
273     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, impl->collo_grad_1d));
274   }
275   CeedCallBackend(CeedBasisSetData(basis, impl));
276 
277   Ceed parent;
278   CeedCallBackend(CeedGetParent(ceed, &parent));
279   CeedTensorContract contract;
280   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
281   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
282 
283   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
284   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyTensor_Ref));
285 
286   return CEED_ERROR_SUCCESS;
287 }
288 
289 //------------------------------------------------------------------------------
290 // Basis Create Non-Tensor H^1
291 //------------------------------------------------------------------------------
292 int CeedBasisCreateH1_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
293                           const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
294   Ceed ceed;
295   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
296 
297   Ceed parent;
298   CeedCallBackend(CeedGetParent(ceed, &parent));
299   CeedTensorContract contract;
300   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
301   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
302 
303   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
304 
305   return CEED_ERROR_SUCCESS;
306 }
307 
308 //------------------------------------------------------------------------------
309 // Basis Create Non-Tensor H(div)
310 //------------------------------------------------------------------------------
311 int CeedBasisCreateHdiv_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *div,
312                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
313   Ceed ceed;
314   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
315 
316   Ceed parent;
317   CeedCallBackend(CeedGetParent(ceed, &parent));
318   CeedTensorContract contract;
319   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
320   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
321 
322   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Basis Create Non-Tensor H(curl)
329 //------------------------------------------------------------------------------
330 int CeedBasisCreateHcurl_Ref(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
331                              const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
332   Ceed ceed;
333   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
334 
335   Ceed parent;
336   CeedCallBackend(CeedGetParent(ceed, &parent));
337   CeedTensorContract contract;
338   CeedCallBackend(CeedTensorContractCreate(parent, basis, &contract));
339   CeedCallBackend(CeedBasisSetTensorContract(basis, contract));
340 
341   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Ref));
342 
343   return CEED_ERROR_SUCCESS;
344 }
345 
346 //------------------------------------------------------------------------------
347