xref: /libCEED/backends/ref/ceed-ref-basis.c (revision be9261b744a4f5557a9c58123cb6ee5694c9a713)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-ref.h"
18 
19 static int CeedBasisApply_Ref(CeedBasis basis, CeedInt nelem,
20                               CeedTransposeMode tmode, CeedEvalMode emode,
21                               CeedVector U, CeedVector V) {
22   int ierr;
23   Ceed ceed;
24   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
25   CeedInt dim, ncomp, ndof, nqpt;
26   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
27   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
28   ierr = CeedBasisGetNumNodes(basis, &ndof); CeedChk(ierr);
29   ierr = CeedBasisGetNumQuadraturePoints(basis, &nqpt); CeedChk(ierr);
30   CeedTensorContract contract;
31   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
32   const CeedInt add = (tmode == CEED_TRANSPOSE);
33   const CeedScalar *u;
34   CeedScalar *v;
35   if (U) {
36     ierr = CeedVectorGetArrayRead(U, CEED_MEM_HOST, &u); CeedChk(ierr);
37   } else if (emode != CEED_EVAL_WEIGHT) {
38     return CeedError(ceed, 1,
39                      "An input vector is required for this CeedEvalMode");
40   }
41   ierr = CeedVectorGetArray(V, CEED_MEM_HOST, &v); CeedChk(ierr);
42 
43   // Clear v if operating in transpose
44   if (tmode == CEED_TRANSPOSE) {
45     const CeedInt vsize = nelem*ncomp*ndof;
46     for (CeedInt i = 0; i < vsize; i++)
47       v[i] = (CeedScalar) 0.0;
48   }
49   bool tensorbasis;
50   ierr = CeedBasisGetTensorStatus(basis, &tensorbasis); CeedChk(ierr);
51   // Tensor basis
52   if (tensorbasis) {
53     CeedInt P1d, Q1d;
54     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
55     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
56     switch (emode) {
57     // Interpolate to/from quadrature points
58     case CEED_EVAL_INTERP: {
59       CeedInt P = P1d, Q = Q1d;
60       if (tmode == CEED_TRANSPOSE) {
61         P = Q1d; Q = P1d;
62       }
63       CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
64       CeedScalar tmp[2][nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
65       CeedScalar *interp1d;
66       ierr = CeedBasisGetInterp(basis, &interp1d); CeedChk(ierr);
67       for (CeedInt d=0; d<dim; d++) {
68         ierr = CeedTensorContractApply(contract, pre, P, post, Q,
69                                        interp1d, tmode, add&&(d==dim-1),
70                                        d==0?u:tmp[d%2], d==dim-1?v:tmp[(d+1)%2]);
71         CeedChk(ierr);
72         pre /= P;
73         post *= Q;
74       }
75     } break;
76     // Evaluate the gradient to/from quadrature points
77     case CEED_EVAL_GRAD: {
78       // In CEED_NOTRANSPOSE mode:
79       // u has shape [dim, ncomp, P^dim, nelem], row-major layout
80       // v has shape [dim, ncomp, Q^dim, nelem], row-major layout
81       // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
82       CeedInt P = P1d, Q = Q1d;
83       if (tmode == CEED_TRANSPOSE) {
84         P = Q1d, Q = Q1d;
85       }
86       CeedBasis_Ref *impl;
87       ierr = CeedBasisGetData(basis, (void *)&impl); CeedChk(ierr);
88       CeedScalar interp[nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
89       CeedInt pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
90       CeedScalar tmp[2][nelem*ncomp*Q*CeedIntPow(P>Q?P:Q, dim-1)];
91       CeedScalar *interp1d;
92       ierr = CeedBasisGetInterp(basis, &interp1d); CeedChk(ierr);
93       // Interpolate to quadrature points (NoTranspose)
94       //  or Grad to quadrature points (Transpose)
95       for (CeedInt d=0; d<dim; d++) {
96         ierr = CeedTensorContractApply(contract, pre, P, post, Q,
97                                        (tmode == CEED_NOTRANSPOSE
98                                         ? interp1d
99                                         : impl->colograd1d),
100                                        tmode, add&&(d>0),
101                                        (tmode == CEED_NOTRANSPOSE
102                                         ? (d==0?u:tmp[d%2])
103                                         : u + d*nqpt*ncomp*nelem),
104                                        (tmode == CEED_NOTRANSPOSE
105                                         ? (d==dim-1?interp:tmp[(d+1)%2])
106                                         : interp));
107         CeedChk(ierr);
108         pre /= P;
109         post *= Q;
110       }
111       // Grad to quadrature points (NoTranspose)
112       //  or Interpolate to dofs (Transpose)
113       P = Q1d, Q = Q1d;
114       if (tmode == CEED_TRANSPOSE) {
115         P = Q1d, Q = P1d;
116       }
117       pre = ncomp*CeedIntPow(P, dim-1), post = nelem;
118       for (CeedInt d=0; d<dim; d++) {
119         ierr = CeedTensorContractApply(contract, pre, P, post, Q,
120                                        (tmode == CEED_NOTRANSPOSE
121                                         ? impl->colograd1d
122                                         : interp1d),
123                                        tmode, add&&(d==dim-1),
124                                        (tmode == CEED_NOTRANSPOSE
125                                         ? interp
126                                         : (d==0?interp:tmp[d%2])),
127                                        (tmode == CEED_NOTRANSPOSE
128                                         ? v + d*nqpt*ncomp*nelem
129                                         : (d==dim-1?v:tmp[(d+1)%2])));
130         CeedChk(ierr);
131         pre /= P;
132         post *= Q;
133       }
134     } break;
135     // Retrieve interpolation weights
136     case CEED_EVAL_WEIGHT: {
137       if (tmode == CEED_TRANSPOSE)
138         return CeedError(ceed, 1,
139                          "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
140       CeedInt Q = Q1d;
141       CeedScalar *qweight1d;
142       ierr = CeedBasisGetQWeights(basis, &qweight1d); CeedChk(ierr);
143       for (CeedInt d=0; d<dim; d++) {
144         CeedInt pre = CeedIntPow(Q, dim-d-1), post = CeedIntPow(Q, d);
145         for (CeedInt i=0; i<pre; i++)
146           for (CeedInt j=0; j<Q; j++)
147             for (CeedInt k=0; k<post; k++) {
148               CeedScalar w = qweight1d[j]
149                              * (d == 0 ? 1 : v[((i*Q + j)*post + k)*nelem]);
150               for (CeedInt e=0; e<nelem; e++)
151                 v[((i*Q + j)*post + k)*nelem + e] = w;
152             }
153       }
154     } break;
155     // Evaluate the divergence to/from the quadrature points
156     case CEED_EVAL_DIV:
157       return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
158     // Evaluate the curl to/from the quadrature points
159     case CEED_EVAL_CURL:
160       return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
161     // Take no action, BasisApply should not have been called
162     case CEED_EVAL_NONE:
163       return CeedError(ceed, 1,
164                        "CEED_EVAL_NONE does not make sense in this context");
165     }
166   } else {
167     // Non-tensor basis
168     switch (emode) {
169     // Interpolate to/from quadrature points
170     case CEED_EVAL_INTERP: {
171       CeedInt P = ndof, Q = nqpt;
172       CeedScalar *interp;
173       ierr = CeedBasisGetInterp(basis, &interp); CeedChk(ierr);
174       if (tmode == CEED_TRANSPOSE) {
175         P = nqpt; Q = ndof;
176       }
177       ierr = CeedTensorContractApply(contract, ncomp, P, nelem, Q,
178                                      interp, tmode, add, u, v);
179       CeedChk(ierr);
180     }
181     break;
182     // Evaluate the gradient to/from quadrature points
183     case CEED_EVAL_GRAD: {
184       CeedInt P = ndof, Q = dim*nqpt;
185       CeedScalar *grad;
186       ierr = CeedBasisGetGrad(basis, &grad); CeedChk(ierr);
187       if (tmode == CEED_TRANSPOSE) {
188         P = dim*nqpt; Q = ndof;
189       }
190       ierr = CeedTensorContractApply(contract, ncomp, P, nelem, Q,
191                                      grad, tmode, add, u, v);
192       CeedChk(ierr);
193     }
194     break;
195     // Retrieve interpolation weights
196     case CEED_EVAL_WEIGHT: {
197       if (tmode == CEED_TRANSPOSE)
198         return CeedError(ceed, 1,
199                          "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
200       CeedScalar *qweight;
201       ierr = CeedBasisGetQWeights(basis, &qweight); CeedChk(ierr);
202       for (CeedInt i=0; i<nqpt; i++)
203         for (CeedInt e=0; e<nelem; e++)
204           v[i*nelem + e] = qweight[i];
205     } break;
206     // Evaluate the divergence to/from the quadrature points
207     case CEED_EVAL_DIV:
208       return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
209     // Evaluate the curl to/from the quadrature points
210     case CEED_EVAL_CURL:
211       return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
212     // Take no action, BasisApply should not have been called
213     case CEED_EVAL_NONE:
214       return CeedError(ceed, 1,
215                        "CEED_EVAL_NONE does not make sense in this context");
216     }
217   }
218   if (U) {
219     ierr = CeedVectorRestoreArrayRead(U, &u); CeedChk(ierr);
220   }
221   ierr = CeedVectorRestoreArray(V, &v); CeedChk(ierr);
222   return 0;
223 }
224 
225 static int CeedBasisDestroyNonTensor_Ref(CeedBasis basis) {
226   int ierr;
227   CeedTensorContract contract;
228   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
229   ierr = CeedTensorContractDestroy(&contract); CeedChk(ierr);
230   return 0;
231 }
232 
233 static int CeedBasisDestroyTensor_Ref(CeedBasis basis) {
234   int ierr;
235   CeedTensorContract contract;
236   ierr = CeedBasisGetTensorContract(basis, &contract); CeedChk(ierr);
237   ierr = CeedTensorContractDestroy(&contract); CeedChk(ierr);
238 
239   CeedBasis_Ref *impl;
240   ierr = CeedBasisGetData(basis, (void *)&impl); CeedChk(ierr);
241   ierr = CeedFree(&impl->colograd1d); CeedChk(ierr);
242   ierr = CeedFree(&impl); CeedChk(ierr);
243 
244   return 0;
245 }
246 
247 int CeedBasisCreateTensorH1_Ref(CeedInt dim, CeedInt P1d,
248                                 CeedInt Q1d, const CeedScalar *interp1d,
249                                 const CeedScalar *grad1d,
250                                 const CeedScalar *qref1d,
251                                 const CeedScalar *qweight1d,
252                                 CeedBasis basis) {
253   int ierr;
254   Ceed ceed;
255   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
256   CeedBasis_Ref *impl;
257   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
258   ierr = CeedMalloc(Q1d*Q1d, &impl->colograd1d); CeedChk(ierr);
259   ierr = CeedBasisGetCollocatedGrad(basis, impl->colograd1d); CeedChk(ierr);
260   ierr = CeedBasisSetData(basis, (void *)&impl); CeedChk(ierr);
261 
262   Ceed parent;
263   ierr = CeedGetParent(ceed, &parent); CeedChk(ierr);
264   CeedTensorContract contract;
265   ierr = CeedTensorContractCreate(parent, basis, &contract); CeedChk(ierr);
266   ierr = CeedBasisSetTensorContract(basis, &contract); CeedChk(ierr);
267 
268   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
269                                 CeedBasisApply_Ref); CeedChk(ierr);
270   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
271                                 CeedBasisDestroyTensor_Ref); CeedChk(ierr);
272   return 0;
273 }
274 
275 
276 
277 int CeedBasisCreateH1_Ref(CeedElemTopology topo, CeedInt dim,
278                           CeedInt ndof, CeedInt nqpts,
279                           const CeedScalar *interp,
280                           const CeedScalar *grad,
281                           const CeedScalar *qref,
282                           const CeedScalar *qweight,
283                           CeedBasis basis) {
284   int ierr;
285   Ceed ceed;
286   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
287 
288   Ceed parent;
289   ierr = CeedGetParent(ceed, &parent); CeedChk(ierr);
290   CeedTensorContract contract;
291   ierr = CeedTensorContractCreate(parent, basis, &contract); CeedChk(ierr);
292   ierr = CeedBasisSetTensorContract(basis, &contract); CeedChk(ierr);
293 
294   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
295                                 CeedBasisApply_Ref); CeedChk(ierr);
296   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
297                                 CeedBasisDestroyNonTensor_Ref); CeedChk(ierr);
298 
299   return 0;
300 }
301