xref: /libCEED/backends/ref/ceed-ref-basis.c (revision a1dbd226388b1aa52ab1a7f0d0bca8c2c2ac2292)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-ref.h"
18 
19 static int CeedBasisApply_Ref(CeedBasis basis, CeedInt nelem,
20                               CeedTransposeMode tmode, CeedEvalMode emode,
21                               CeedVector U, CeedVector V) {
22   int ierr;
23   Ceed ceed;
24   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
25   CeedInt dim, ncomp, nnodes, nqpt;
26   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
27   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
28   ierr = CeedBasisGetNumNodes(basis, &nnodes); CeedChk(ierr);
29   ierr = CeedBasisGetNumQuadraturePoints(basis, &nqpt); CeedChk(ierr);
30   CeedTensorContract contract;
31   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
32   const CeedInt add = (tmode == CEED_TRANSPOSE);
33   const CeedScalar *u;
34   CeedScalar *v;
35   if (U) {
36     ierr = CeedVectorGetArrayRead(U, CEED_MEM_HOST, &u); CeedChk(ierr);
37   } else if (emode != CEED_EVAL_WEIGHT) {
38     return CeedError(ceed, 1,
39                      "An input vector is required for this CeedEvalMode");
40   }
41   ierr = CeedVectorGetArray(V, CEED_MEM_HOST, &v); CeedChk(ierr);
42 
43   // Clear v if operating in transpose
44   if (tmode == CEED_TRANSPOSE) {
45     const CeedInt vsize = nelem*ncomp*nnodes;
46     for (CeedInt i = 0; i < vsize; i++)
47       v[i] = (CeedScalar) 0.0;
48   }
49   bool tensorbasis;
50   ierr = CeedBasisGetTensorStatus(basis, &tensorbasis); CeedChk(ierr);
51   // Tensor basis
52   if (tensorbasis) {
53     CeedInt P1d, Q1d;
54     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
55     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
56     switch (emode) {
57     // Interpolate to/from quadrature points
58     case CEED_EVAL_INTERP: {
59       CeedBasis_Ref *impl;
60       ierr = CeedBasisGetData(basis, (void *)&impl); CeedChk(ierr);
61       if (impl->collointerp) {
62         memcpy(v, u, nelem*ncomp*nnodes*sizeof(u[0]));
63       } else {
64         CeedInt P = P1d, Q = Q1d;
65         if (tmode == CEED_TRANSPOSE) {
66           P = Q1d; Q = P1d;
67         }
68         CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
69         CeedScalar tmp[2][nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
70         CeedScalar *interp1d;
71         ierr = CeedBasisGetInterp(basis, &interp1d); CeedChk(ierr);
72         for (CeedInt d=0; d<dim; d++) {
73           ierr = CeedTensorContractApply(contract, pre, P, post, Q,
74                                          interp1d, tmode, add&&(d==dim-1),
75                                          d==0?u:tmp[d%2],
76                                          d==dim-1?v:tmp[(d+1)%2]);
77           CeedChk(ierr);
78           pre /= P;
79           post *= Q;
80         }
81       }
82     } break;
83     // Evaluate the gradient to/from quadrature points
84     case CEED_EVAL_GRAD: {
85       // In CEED_NOTRANSPOSE mode:
86       // u has shape [dim, ncomp, P^dim, nelem], row-major layout
87       // v has shape [dim, ncomp, Q^dim, nelem], row-major layout
88       // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
89       CeedInt P = P1d, Q = Q1d;
90       if (tmode == CEED_TRANSPOSE) {
91         P = Q1d, Q = Q1d;
92       }
93       CeedBasis_Ref *impl;
94       ierr = CeedBasisGetData(basis, (void *)&impl); CeedChk(ierr);
95       CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
96       CeedScalar *interp1d;
97       ierr = CeedBasisGetInterp(basis, &interp1d); CeedChk(ierr);
98       if (impl->collograd1d) {
99         CeedScalar tmp[2][nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
100         CeedScalar interp[nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
101         // Interpolate to quadrature points (NoTranspose)
102         //  or Grad to quadrature points (Transpose)
103         for (CeedInt d=0; d<dim; d++) {
104           ierr = CeedTensorContractApply(contract, pre, P, post, Q,
105                                          (tmode == CEED_NOTRANSPOSE
106                                           ? interp1d
107                                           : impl->collograd1d),
108                                          tmode, add&&(d>0),
109                                          (tmode == CEED_NOTRANSPOSE
110                                           ? (d==0?u:tmp[d%2])
111                                           : u + d*nqpt*ncomp*nelem),
112                                          (tmode == CEED_NOTRANSPOSE
113                                           ? (d==dim-1?interp:tmp[(d+1)%2])
114                                           : interp));
115           CeedChk(ierr);
116           pre /= P;
117           post *= Q;
118         }
119         // Grad to quadrature points (NoTranspose)
120         //  or Interpolate to nodes (Transpose)
121         P = Q1d, Q = Q1d;
122         if (tmode == CEED_TRANSPOSE) {
123           P = Q1d, Q = P1d;
124         }
125         pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
126         for (CeedInt d=0; d<dim; d++) {
127           ierr = CeedTensorContractApply(contract, pre, P, post, Q,
128                                          (tmode == CEED_NOTRANSPOSE
129                                           ? impl->collograd1d
130                                           : interp1d),
131                                          tmode, add&&(d==dim-1),
132                                          (tmode == CEED_NOTRANSPOSE
133                                           ? interp
134                                           : (d==0?interp:tmp[d%2])),
135                                          (tmode == CEED_NOTRANSPOSE
136                                           ? v + d*nqpt*ncomp*nelem
137                                           : (d==dim-1?v:tmp[(d+1)%2])));
138           CeedChk(ierr);
139           pre /= P;
140           post *= Q;
141         }
142       } else if (impl->collointerp) { // Qpts collocated with nodes
143         CeedScalar *grad1d;
144         ierr = CeedBasisGetGrad(basis, &grad1d); CeedChk(ierr);
145 
146         // Dim contractions, identity in other directions
147         for (CeedInt d=0; d<dim; d++) {
148           CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
149           ierr = CeedTensorContractApply(contract, pre, P, post, Q,
150                                          grad1d, tmode, add&&(d>0),
151                                          tmode == CEED_NOTRANSPOSE
152                                            ? u : u+d*ncomp*nqpt*nelem,
153                                          tmode == CEED_TRANSPOSE
154                                            ? v : v+d*ncomp*nqpt*nelem);
155           CeedChk(ierr);
156         }
157       } else { // Underintegration, P > Q
158         CeedScalar *grad1d;
159         ierr = CeedBasisGetGrad(basis, &grad1d); CeedChk(ierr);
160 
161         if (tmode == CEED_TRANSPOSE) {
162           P = Q1d, Q = P1d;
163         }
164         CeedScalar tmp[2][nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
165 
166         // Dim**2 contractions, apply grad when pass == dim
167         for (CeedInt p=0; p<dim; p++) {
168           CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
169           for (CeedInt d=0; d<dim; d++) {
170             ierr = CeedTensorContractApply(contract, pre, P, post, Q,
171                                            (p==d)? grad1d : interp1d,
172                                            tmode, add&&(d==dim-1),
173                                            (d == 0
174                                             ? (tmode == CEED_NOTRANSPOSE
175                                                ? u : u+p*ncomp*nqpt*nelem)
176                                             : tmp[d%2]),
177                                            (d == dim-1
178                                             ? (tmode == CEED_TRANSPOSE
179                                                ? v : v+p*ncomp*nqpt*nelem)
180                                             : tmp[(d+1)%2]));
181             CeedChk(ierr);
182             pre /= P;
183             post *= Q;
184           }
185         }
186       }
187     } break;
188     // Retrieve interpolation weights
189     case CEED_EVAL_WEIGHT: {
190       if (tmode == CEED_TRANSPOSE)
191         return CeedError(ceed, 1,
192                          "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
193       CeedInt Q = Q1d;
194       CeedScalar *qweight1d;
195       ierr = CeedBasisGetQWeights(basis, &qweight1d); CeedChk(ierr);
196       for (CeedInt d=0; d<dim; d++) {
197         CeedInt pre = CeedIntPow(Q, dim-d-1), post = CeedIntPow(Q, d);
198         for (CeedInt i=0; i<pre; i++)
199           for (CeedInt j=0; j<Q; j++)
200             for (CeedInt k=0; k<post; k++) {
201               CeedScalar w = qweight1d[j]
202                              * (d == 0 ? 1 : v[((i*Q + j)*post + k)*nelem]);
203               for (CeedInt e=0; e<nelem; e++)
204                 v[((i*Q + j)*post + k)*nelem + e] = w;
205             }
206       }
207     } break;
208     // Evaluate the divergence to/from the quadrature points
209     case CEED_EVAL_DIV:
210       return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
211     // Evaluate the curl to/from the quadrature points
212     case CEED_EVAL_CURL:
213       return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
214     // Take no action, BasisApply should not have been called
215     case CEED_EVAL_NONE:
216       return CeedError(ceed, 1,
217                        "CEED_EVAL_NONE does not make sense in this context");
218     }
219   } else {
220     // Non-tensor basis
221     switch (emode) {
222     // Interpolate to/from quadrature points
223     case CEED_EVAL_INTERP: {
224       CeedInt P = nnodes, Q = nqpt;
225       CeedScalar *interp;
226       ierr = CeedBasisGetInterp(basis, &interp); CeedChk(ierr);
227       if (tmode == CEED_TRANSPOSE) {
228         P = nqpt; Q = nnodes;
229       }
230       ierr = CeedTensorContractApply(contract, ncomp, P, nelem, Q,
231                                      interp, tmode, add, u, v);
232       CeedChk(ierr);
233     }
234     break;
235     // Evaluate the gradient to/from quadrature points
236     case CEED_EVAL_GRAD: {
237       CeedInt P = nnodes, Q = dim*nqpt;
238       CeedScalar *grad;
239       ierr = CeedBasisGetGrad(basis, &grad); CeedChk(ierr);
240       if (tmode == CEED_TRANSPOSE) {
241         P = dim*nqpt; Q = nnodes;
242       }
243       ierr = CeedTensorContractApply(contract, ncomp, P, nelem, Q,
244                                      grad, tmode, add, u, v);
245       CeedChk(ierr);
246     }
247     break;
248     // Retrieve interpolation weights
249     case CEED_EVAL_WEIGHT: {
250       if (tmode == CEED_TRANSPOSE)
251         return CeedError(ceed, 1,
252                          "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
253       CeedScalar *qweight;
254       ierr = CeedBasisGetQWeights(basis, &qweight); CeedChk(ierr);
255       for (CeedInt i=0; i<nqpt; i++)
256         for (CeedInt e=0; e<nelem; e++)
257           v[i*nelem + e] = qweight[i];
258     } break;
259     // Evaluate the divergence to/from the quadrature points
260     case CEED_EVAL_DIV:
261       return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
262     // Evaluate the curl to/from the quadrature points
263     case CEED_EVAL_CURL:
264       return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
265     // Take no action, BasisApply should not have been called
266     case CEED_EVAL_NONE:
267       return CeedError(ceed, 1,
268                        "CEED_EVAL_NONE does not make sense in this context");
269     }
270   }
271   if (U) {
272     ierr = CeedVectorRestoreArrayRead(U, &u); CeedChk(ierr);
273   }
274   ierr = CeedVectorRestoreArray(V, &v); CeedChk(ierr);
275   return 0;
276 }
277 
278 static int CeedBasisDestroyNonTensor_Ref(CeedBasis basis) {
279   int ierr;
280   CeedTensorContract contract;
281   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
282   ierr = CeedTensorContractDestroy(&contract); CeedChk(ierr);
283   return 0;
284 }
285 
286 static int CeedBasisDestroyTensor_Ref(CeedBasis basis) {
287   int ierr;
288   CeedTensorContract contract;
289   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
290   ierr = CeedTensorContractDestroy(&contract); CeedChk(ierr);
291 
292   CeedBasis_Ref *impl;
293   ierr = CeedBasisGetData(basis, (void *)&impl); CeedChk(ierr);
294   ierr = CeedFree(&impl->collograd1d); CeedChk(ierr);
295   ierr = CeedFree(&impl); CeedChk(ierr);
296 
297   return 0;
298 }
299 
300 int CeedBasisCreateTensorH1_Ref(CeedInt dim, CeedInt P1d,
301                                 CeedInt Q1d, const CeedScalar *interp1d,
302                                 const CeedScalar *grad1d,
303                                 const CeedScalar *qref1d,
304                                 const CeedScalar *qweight1d,
305                                 CeedBasis basis) {
306   int ierr;
307   Ceed ceed;
308   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
309   CeedBasis_Ref *impl;
310   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
311   // Check for collocated interp
312   if (Q1d == P1d) {
313     bool collocated = 1;
314     for (CeedInt i=0; i<P1d; i++) {
315       collocated = collocated && (fabs(interp1d[i+P1d*i] - 1.0) < 1e-14);
316       for (CeedInt j=0; j<P1d; j++)
317         if (j != i)
318           collocated = collocated && (fabs(interp1d[j+P1d*i]) < 1e-14);
319     }
320     impl->collointerp = collocated;
321   }
322   // Calculate collocated grad
323   if (Q1d >= P1d && !impl->collointerp) {
324     ierr = CeedMalloc(Q1d*Q1d, &impl->collograd1d); CeedChk(ierr);
325     ierr = CeedBasisGetCollocatedGrad(basis, impl->collograd1d); CeedChk(ierr);
326   }
327   ierr = CeedBasisSetData(basis, (void *)&impl); CeedChk(ierr);
328 
329   Ceed parent;
330   ierr = CeedGetParent(ceed, &parent); CeedChk(ierr);
331   CeedTensorContract contract;
332   ierr = CeedTensorContractCreate(parent, basis, &contract); CeedChk(ierr);
333   ierr = CeedBasisSetTensorContract(basis, &contract); CeedChk(ierr);
334 
335   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
336                                 CeedBasisApply_Ref); CeedChk(ierr);
337   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
338                                 CeedBasisDestroyTensor_Ref); CeedChk(ierr);
339   return 0;
340 }
341 
342 
343 
344 int CeedBasisCreateH1_Ref(CeedElemTopology topo, CeedInt dim,
345                           CeedInt nnodes, CeedInt nqpts,
346                           const CeedScalar *interp,
347                           const CeedScalar *grad,
348                           const CeedScalar *qref,
349                           const CeedScalar *qweight,
350                           CeedBasis basis) {
351   int ierr;
352   Ceed ceed;
353   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
354 
355   Ceed parent;
356   ierr = CeedGetParent(ceed, &parent); CeedChk(ierr);
357   CeedTensorContract contract;
358   ierr = CeedTensorContractCreate(parent, basis, &contract); CeedChk(ierr);
359   ierr = CeedBasisSetTensorContract(basis, &contract); CeedChk(ierr);
360 
361   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
362                                 CeedBasisApply_Ref); CeedChk(ierr);
363   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
364                                 CeedBasisDestroyNonTensor_Ref); CeedChk(ierr);
365 
366   return 0;
367 }
368