xref: /libCEED/rust/libceed-sys/c-src/backends/ref/ceed-ref.c (revision 5e0799b17881e9614ad208aeee2ae82ae7f08ba3)
1ae3cba82Scamierjs // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2ae3cba82Scamierjs // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3ae3cba82Scamierjs // reserved. See files LICENSE and NOTICE for details.
4ae3cba82Scamierjs //
5ae3cba82Scamierjs // This file is part of CEED, a collection of benchmarks, miniapps, software
6ae3cba82Scamierjs // libraries and APIs for efficient high-order finite element and spectral
7ae3cba82Scamierjs // element discretizations for exascale applications. For more information and
8ae3cba82Scamierjs // source code availability see http://github.com/ceed.
9ae3cba82Scamierjs //
10ae3cba82Scamierjs // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11ae3cba82Scamierjs // a collaborative effort of two U.S. Department of Energy organizations (Office
12ae3cba82Scamierjs // of Science and the National Nuclear Security Administration) responsible for
13ae3cba82Scamierjs // the planning and preparation of a capable exascale ecosystem, including
14ae3cba82Scamierjs // software, applications, hardware, advanced system engineering and early
15ae3cba82Scamierjs // testbed platforms, in support of the nation's exascale computing imperative.
16ae3cba82Scamierjs 
17ae3cba82Scamierjs #include <ceed-impl.h>
18ae3cba82Scamierjs #include <string.h>
19ae3cba82Scamierjs 
20ae3cba82Scamierjs typedef struct {
21ae3cba82Scamierjs   CeedScalar *array;
22ae3cba82Scamierjs   CeedScalar *array_allocated;
23ae3cba82Scamierjs } CeedVector_Ref;
24ae3cba82Scamierjs 
25ae3cba82Scamierjs typedef struct {
26ae3cba82Scamierjs   const CeedInt *indices;
27ae3cba82Scamierjs   CeedInt *indices_allocated;
28ae3cba82Scamierjs } CeedElemRestriction_Ref;
29ae3cba82Scamierjs 
30ae3cba82Scamierjs typedef struct {
31ae3cba82Scamierjs   CeedVector etmp;
32ae3cba82Scamierjs   CeedVector qdata;
33ae3cba82Scamierjs } CeedOperator_Ref;
34ae3cba82Scamierjs 
35ae3cba82Scamierjs static int CeedVectorSetArray_Ref(CeedVector vec, CeedMemType mtype,
36ae3cba82Scamierjs                                   CeedCopyMode cmode, CeedScalar *array) {
37ae3cba82Scamierjs   CeedVector_Ref *impl = vec->data;
38ae3cba82Scamierjs   int ierr;
39ae3cba82Scamierjs 
40ae3cba82Scamierjs   if (mtype != CEED_MEM_HOST)
41ae3cba82Scamierjs     return CeedError(vec->ceed, 1, "Only MemType = HOST supported");
42ae3cba82Scamierjs   ierr = CeedFree(&impl->array_allocated); CeedChk(ierr);
43ae3cba82Scamierjs   switch (cmode) {
44ae3cba82Scamierjs   case CEED_COPY_VALUES:
45ae3cba82Scamierjs     ierr = CeedMalloc(vec->length, &impl->array_allocated); CeedChk(ierr);
46ae3cba82Scamierjs     impl->array = impl->array_allocated;
47ae3cba82Scamierjs     if (array) memcpy(impl->array, array, vec->length * sizeof(array[0]));
48ae3cba82Scamierjs     break;
49ae3cba82Scamierjs   case CEED_OWN_POINTER:
50ae3cba82Scamierjs     impl->array_allocated = array;
51ae3cba82Scamierjs     impl->array = array;
52ae3cba82Scamierjs     break;
53ae3cba82Scamierjs   case CEED_USE_POINTER:
54ae3cba82Scamierjs     impl->array = array;
55ae3cba82Scamierjs   }
56ae3cba82Scamierjs   return 0;
57ae3cba82Scamierjs }
58ae3cba82Scamierjs 
59ae3cba82Scamierjs static int CeedVectorGetArray_Ref(CeedVector vec, CeedMemType mtype,
60ae3cba82Scamierjs                                   CeedScalar **array) {
61ae3cba82Scamierjs   CeedVector_Ref *impl = vec->data;
62ae3cba82Scamierjs   int ierr;
63ae3cba82Scamierjs 
64ae3cba82Scamierjs   if (mtype != CEED_MEM_HOST)
65ae3cba82Scamierjs     return CeedError(vec->ceed, 1, "Can only provide to HOST memory");
66ae3cba82Scamierjs   if (!impl->array) { // Allocate if array is not yet allocated
67ae3cba82Scamierjs     ierr = CeedVectorSetArray(vec, CEED_MEM_HOST, CEED_COPY_VALUES, NULL);
68ae3cba82Scamierjs     CeedChk(ierr);
69ae3cba82Scamierjs   }
70ae3cba82Scamierjs   *array = impl->array;
71ae3cba82Scamierjs   return 0;
72ae3cba82Scamierjs }
73ae3cba82Scamierjs 
74ae3cba82Scamierjs static int CeedVectorGetArrayRead_Ref(CeedVector vec, CeedMemType mtype,
75ae3cba82Scamierjs                                       const CeedScalar **array) {
76ae3cba82Scamierjs   CeedVector_Ref *impl = vec->data;
77ae3cba82Scamierjs   int ierr;
78ae3cba82Scamierjs 
79ae3cba82Scamierjs   if (mtype != CEED_MEM_HOST)
80ae3cba82Scamierjs     return CeedError(vec->ceed, 1, "Can only provide to HOST memory");
81ae3cba82Scamierjs   if (!impl->array) { // Allocate if array is not yet allocated
82ae3cba82Scamierjs     ierr = CeedVectorSetArray(vec, CEED_MEM_HOST, CEED_COPY_VALUES, NULL);
83ae3cba82Scamierjs     CeedChk(ierr);
84ae3cba82Scamierjs   }
85ae3cba82Scamierjs   *array = impl->array;
86ae3cba82Scamierjs   return 0;
87ae3cba82Scamierjs }
88ae3cba82Scamierjs 
89ae3cba82Scamierjs static int CeedVectorRestoreArray_Ref(CeedVector vec, CeedScalar **array) {
90ae3cba82Scamierjs   *array = NULL;
91ae3cba82Scamierjs   return 0;
92ae3cba82Scamierjs }
93ae3cba82Scamierjs 
94ae3cba82Scamierjs static int CeedVectorRestoreArrayRead_Ref(CeedVector vec,
95ae3cba82Scamierjs     const CeedScalar **array) {
96ae3cba82Scamierjs   *array = NULL;
97ae3cba82Scamierjs   return 0;
98ae3cba82Scamierjs }
99ae3cba82Scamierjs 
100ae3cba82Scamierjs static int CeedVectorDestroy_Ref(CeedVector vec) {
101ae3cba82Scamierjs   CeedVector_Ref *impl = vec->data;
102ae3cba82Scamierjs   int ierr;
103ae3cba82Scamierjs 
104ae3cba82Scamierjs   ierr = CeedFree(&impl->array_allocated); CeedChk(ierr);
105ae3cba82Scamierjs   ierr = CeedFree(&vec->data); CeedChk(ierr);
106ae3cba82Scamierjs   return 0;
107ae3cba82Scamierjs }
108ae3cba82Scamierjs 
109ae3cba82Scamierjs static int CeedVectorCreate_Ref(Ceed ceed, CeedInt n, CeedVector vec) {
110ae3cba82Scamierjs   CeedVector_Ref *impl;
111ae3cba82Scamierjs   int ierr;
112ae3cba82Scamierjs 
113ae3cba82Scamierjs   vec->SetArray = CeedVectorSetArray_Ref;
114ae3cba82Scamierjs   vec->GetArray = CeedVectorGetArray_Ref;
115ae3cba82Scamierjs   vec->GetArrayRead = CeedVectorGetArrayRead_Ref;
116ae3cba82Scamierjs   vec->RestoreArray = CeedVectorRestoreArray_Ref;
117ae3cba82Scamierjs   vec->RestoreArrayRead = CeedVectorRestoreArrayRead_Ref;
118ae3cba82Scamierjs   vec->Destroy = CeedVectorDestroy_Ref;
119ae3cba82Scamierjs   ierr = CeedCalloc(1,&impl); CeedChk(ierr);
120ae3cba82Scamierjs   vec->data = impl;
121ae3cba82Scamierjs   return 0;
122ae3cba82Scamierjs }
123ae3cba82Scamierjs 
124ae3cba82Scamierjs static int CeedElemRestrictionApply_Ref(CeedElemRestriction r,
12585558d50Scamierjs                                         CeedTransposeMode tmode, CeedInt ncomp,
12685558d50Scamierjs                                         CeedTransposeMode lmode, CeedVector u,
127ae3cba82Scamierjs                                         CeedVector v, CeedRequest *request) {
128ae3cba82Scamierjs   CeedElemRestriction_Ref *impl = r->data;
129ae3cba82Scamierjs   int ierr;
130ae3cba82Scamierjs   const CeedScalar *uu;
131ae3cba82Scamierjs   CeedScalar *vv;
13285558d50Scamierjs   CeedInt esize = r->nelem*r->elemsize;
133ae3cba82Scamierjs 
134ae3cba82Scamierjs   ierr = CeedVectorGetArrayRead(u, CEED_MEM_HOST, &uu); CeedChk(ierr);
135ae3cba82Scamierjs   ierr = CeedVectorGetArray(v, CEED_MEM_HOST, &vv); CeedChk(ierr);
136ae3cba82Scamierjs   if (tmode == CEED_NOTRANSPOSE) {
13785558d50Scamierjs     // Perform: v = r * u
13885558d50Scamierjs     if (ncomp == 1) {
13985558d50Scamierjs       for (CeedInt i=0; i<esize; i++) vv[i] = uu[impl->indices[i]];
140ae3cba82Scamierjs     } else {
14185558d50Scamierjs       // vv is (elemsize x ncomp x nelem), column-major
14285558d50Scamierjs       if (lmode == CEED_NOTRANSPOSE) { // u is (ndof x ncomp), column-major
14385558d50Scamierjs         for (CeedInt e = 0; e < r->nelem; e++)
14485558d50Scamierjs           for (CeedInt d = 0; d < ncomp; d++)
14585558d50Scamierjs             for (CeedInt i=0; i<r->elemsize; i++) {
14685558d50Scamierjs               vv[i+r->elemsize*(d+ncomp*e)] =
14785558d50Scamierjs                 uu[impl->indices[i+r->elemsize*e]+r->ndof*d];
14885558d50Scamierjs             }
14985558d50Scamierjs       } else { // u is (ncomp x ndof), column-major
15085558d50Scamierjs         for (CeedInt e = 0; e < r->nelem; e++)
15185558d50Scamierjs           for (CeedInt d = 0; d < ncomp; d++)
15285558d50Scamierjs             for (CeedInt i=0; i<r->elemsize; i++) {
15385558d50Scamierjs               vv[i+r->elemsize*(d+ncomp*e)] =
15485558d50Scamierjs                 uu[d+ncomp*impl->indices[i+r->elemsize*e]];
15585558d50Scamierjs             }
15685558d50Scamierjs       }
15785558d50Scamierjs     }
15885558d50Scamierjs   } else {
15985558d50Scamierjs     // Note: in transpose mode, we perform: v += r^t * u
16085558d50Scamierjs     if (ncomp == 1) {
16185558d50Scamierjs       for (CeedInt i=0; i<esize; i++) vv[impl->indices[i]] += uu[i];
16285558d50Scamierjs     } else {
16385558d50Scamierjs       // u is (elemsize x ncomp x nelem)
16485558d50Scamierjs       if (lmode == CEED_NOTRANSPOSE) { // vv is (ndof x ncomp), column-major
16585558d50Scamierjs         for (CeedInt e = 0; e < r->nelem; e++)
16685558d50Scamierjs           for (CeedInt d = 0; d < ncomp; d++)
16785558d50Scamierjs             for (CeedInt i=0; i<r->elemsize; i++) {
16885558d50Scamierjs               vv[impl->indices[i+r->elemsize*e]+r->ndof*d] +=
16985558d50Scamierjs                 uu[i+r->elemsize*(d+e*ncomp)];
17085558d50Scamierjs             }
17185558d50Scamierjs       } else { // vv is (ncomp x ndof), column-major
17285558d50Scamierjs         for (CeedInt e = 0; e < r->nelem; e++)
17385558d50Scamierjs           for (CeedInt d = 0; d < ncomp; d++)
17485558d50Scamierjs             for (CeedInt i=0; i<r->elemsize; i++) {
17585558d50Scamierjs               vv[d+ncomp*impl->indices[i+r->elemsize*e]] +=
17685558d50Scamierjs                 uu[i+r->elemsize*(d+e*ncomp)];
17785558d50Scamierjs             }
17885558d50Scamierjs       }
17985558d50Scamierjs     }
180ae3cba82Scamierjs   }
181ae3cba82Scamierjs   ierr = CeedVectorRestoreArrayRead(u, &uu); CeedChk(ierr);
182ae3cba82Scamierjs   ierr = CeedVectorRestoreArray(v, &vv); CeedChk(ierr);
183*5e0799b1SVeselin Dobrev   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_NULL)
184*5e0799b1SVeselin Dobrev     *request = NULL;
185ae3cba82Scamierjs   return 0;
186ae3cba82Scamierjs }
187ae3cba82Scamierjs 
188ae3cba82Scamierjs static int CeedElemRestrictionDestroy_Ref(CeedElemRestriction r) {
189ae3cba82Scamierjs   CeedElemRestriction_Ref *impl = r->data;
190ae3cba82Scamierjs   int ierr;
191ae3cba82Scamierjs 
192ae3cba82Scamierjs   ierr = CeedFree(&impl->indices_allocated); CeedChk(ierr);
193ae3cba82Scamierjs   ierr = CeedFree(&r->data); CeedChk(ierr);
194ae3cba82Scamierjs   return 0;
195ae3cba82Scamierjs }
196ae3cba82Scamierjs 
197ae3cba82Scamierjs static int CeedElemRestrictionCreate_Ref(CeedElemRestriction r,
198ae3cba82Scamierjs     CeedMemType mtype,
199ae3cba82Scamierjs     CeedCopyMode cmode, const CeedInt *indices) {
200ae3cba82Scamierjs   int ierr;
201ae3cba82Scamierjs   CeedElemRestriction_Ref *impl;
202ae3cba82Scamierjs 
203ae3cba82Scamierjs   if (mtype != CEED_MEM_HOST)
204ae3cba82Scamierjs     return CeedError(r->ceed, 1, "Only MemType = HOST supported");
205ae3cba82Scamierjs   ierr = CeedCalloc(1,&impl); CeedChk(ierr);
206ae3cba82Scamierjs   switch (cmode) {
207ae3cba82Scamierjs   case CEED_COPY_VALUES:
208ae3cba82Scamierjs     ierr = CeedMalloc(r->nelem*r->elemsize, &impl->indices_allocated);
209ae3cba82Scamierjs     CeedChk(ierr);
210ae3cba82Scamierjs     memcpy(impl->indices_allocated, indices,
211ae3cba82Scamierjs            r->nelem * r->elemsize * sizeof(indices[0]));
212ae3cba82Scamierjs     impl->indices = impl->indices_allocated;
213ae3cba82Scamierjs     break;
214ae3cba82Scamierjs   case CEED_OWN_POINTER:
215ae3cba82Scamierjs     impl->indices_allocated = (CeedInt *)indices;
216ae3cba82Scamierjs     impl->indices = impl->indices_allocated;
217ae3cba82Scamierjs     break;
218ae3cba82Scamierjs   case CEED_USE_POINTER:
219ae3cba82Scamierjs     impl->indices = indices;
220ae3cba82Scamierjs   }
221ae3cba82Scamierjs   r->data = impl;
222ae3cba82Scamierjs   r->Apply = CeedElemRestrictionApply_Ref;
223ae3cba82Scamierjs   r->Destroy = CeedElemRestrictionDestroy_Ref;
224ae3cba82Scamierjs   return 0;
225ae3cba82Scamierjs }
226ae3cba82Scamierjs 
227ae3cba82Scamierjs // Contracts on the middle index
228ae3cba82Scamierjs // NOTRANSPOSE: V_ajc = T_jb U_abc
229ae3cba82Scamierjs // TRANSPOSE:   V_ajc = T_bj U_abc
230*5e0799b1SVeselin Dobrev // If Add != 0, "=" is replaced by "+="
231ae3cba82Scamierjs static int CeedTensorContract_Ref(Ceed ceed,
232ae3cba82Scamierjs                                   CeedInt A, CeedInt B, CeedInt C, CeedInt J,
233ae3cba82Scamierjs                                   const CeedScalar *t, CeedTransposeMode tmode,
234*5e0799b1SVeselin Dobrev                                   const CeedInt Add,
235ae3cba82Scamierjs                                   const CeedScalar *u, CeedScalar *v) {
236ae3cba82Scamierjs   CeedInt tstride0 = B, tstride1 = 1;
237ae3cba82Scamierjs   if (tmode == CEED_TRANSPOSE) {
238ae3cba82Scamierjs     tstride0 = 1; tstride1 = J;
239ae3cba82Scamierjs   }
240ae3cba82Scamierjs 
241ae3cba82Scamierjs   for (CeedInt a=0; a<A; a++) {
242ae3cba82Scamierjs     for (CeedInt j=0; j<J; j++) {
243*5e0799b1SVeselin Dobrev       if (!Add) {
244ae3cba82Scamierjs         for (CeedInt c=0; c<C; c++)
245ae3cba82Scamierjs           v[(a*J+j)*C+c] = 0;
246*5e0799b1SVeselin Dobrev       }
247ae3cba82Scamierjs       for (CeedInt b=0; b<B; b++) {
248ae3cba82Scamierjs         for (CeedInt c=0; c<C; c++) {
249ae3cba82Scamierjs           v[(a*J+j)*C+c] += t[j*tstride0 + b*tstride1] * u[(a*B+b)*C+c];
250ae3cba82Scamierjs         }
251ae3cba82Scamierjs       }
252ae3cba82Scamierjs     }
253ae3cba82Scamierjs   }
254ae3cba82Scamierjs   return 0;
255ae3cba82Scamierjs }
256ae3cba82Scamierjs 
257ae3cba82Scamierjs static int CeedBasisApply_Ref(CeedBasis basis, CeedTransposeMode tmode,
258ae3cba82Scamierjs                               CeedEvalMode emode,
259ae3cba82Scamierjs                               const CeedScalar *u, CeedScalar *v) {
260ae3cba82Scamierjs   int ierr;
261ae3cba82Scamierjs   const CeedInt dim = basis->dim;
262ae3cba82Scamierjs   const CeedInt ndof = basis->ndof;
263*5e0799b1SVeselin Dobrev   const CeedInt nqpt = ndof*CeedPowInt(basis->Q1d, dim);
264*5e0799b1SVeselin Dobrev   const CeedInt add = (tmode == CEED_TRANSPOSE);
265ae3cba82Scamierjs 
266*5e0799b1SVeselin Dobrev   if (tmode == CEED_TRANSPOSE) {
267*5e0799b1SVeselin Dobrev     const CeedInt vsize = ndof*CeedPowInt(basis->P1d, dim);
268*5e0799b1SVeselin Dobrev     for (CeedInt i = 0; i < vsize; i++)
269*5e0799b1SVeselin Dobrev       v[i] = (CeedScalar) 0;
270*5e0799b1SVeselin Dobrev   }
27185558d50Scamierjs   if (emode & CEED_EVAL_INTERP) {
272ae3cba82Scamierjs     CeedInt P = basis->P1d, Q = basis->Q1d;
273ae3cba82Scamierjs     if (tmode == CEED_TRANSPOSE) {
274ae3cba82Scamierjs       P = basis->Q1d; Q = basis->P1d;
275ae3cba82Scamierjs     }
276ae3cba82Scamierjs     CeedInt pre = ndof*CeedPowInt(P, dim-1), post = 1;
27706320e62SVeselin Dobrev     CeedScalar tmp[2][ndof*Q*CeedPowInt(P>Q?P:Q, dim-1)];
278ae3cba82Scamierjs     for (CeedInt d=0; d<dim; d++) {
279ae3cba82Scamierjs       ierr = CeedTensorContract_Ref(basis->ceed, pre, P, post, Q, basis->interp1d,
280*5e0799b1SVeselin Dobrev                                     tmode, add&&(d==dim-1),
281*5e0799b1SVeselin Dobrev                                     d==0?u:tmp[d%2], d==dim-1?v:tmp[(d+1)%2]);
28285558d50Scamierjs       CeedChk(ierr);
283ae3cba82Scamierjs       pre /= P;
284ae3cba82Scamierjs       post *= Q;
285ae3cba82Scamierjs     }
28685558d50Scamierjs     if (tmode == CEED_NOTRANSPOSE) {
287*5e0799b1SVeselin Dobrev       v += nqpt;
28885558d50Scamierjs     } else {
289*5e0799b1SVeselin Dobrev       u += nqpt;
29085558d50Scamierjs     }
29185558d50Scamierjs   }
29285558d50Scamierjs   if (emode & CEED_EVAL_GRAD) {
29385558d50Scamierjs     CeedInt P = basis->P1d, Q = basis->Q1d;
294*5e0799b1SVeselin Dobrev     // In CEED_NOTRANSPOSE mode:
29585558d50Scamierjs     // u is (P^dim x nc), column-major layout (nc = ndof)
29685558d50Scamierjs     // v is (Q^dim x nc x dim), column-major layout (nc = ndof)
297*5e0799b1SVeselin Dobrev     // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
298*5e0799b1SVeselin Dobrev     if (tmode == CEED_TRANSPOSE) {
299*5e0799b1SVeselin Dobrev       P = basis->Q1d, Q = basis->P1d;
300*5e0799b1SVeselin Dobrev     }
30106320e62SVeselin Dobrev     CeedScalar tmp[2][ndof*Q*CeedPowInt(P>Q?P:Q, dim-1)];
30285558d50Scamierjs     for (CeedInt p = 0; p < dim; p++) {
30385558d50Scamierjs       CeedInt pre = ndof*CeedPowInt(P, dim-1), post = 1;
30485558d50Scamierjs       for (CeedInt d=0; d<dim; d++) {
30585558d50Scamierjs         ierr = CeedTensorContract_Ref(basis->ceed, pre, P, post, Q,
30685558d50Scamierjs                                       (p==d)?basis->grad1d:basis->interp1d,
307*5e0799b1SVeselin Dobrev                                       tmode, add&&(d==dim-1),
308*5e0799b1SVeselin Dobrev                                       d==0?u:tmp[d%2], d==dim-1?v:tmp[(d+1)%2]);
309*5e0799b1SVeselin Dobrev         CeedChk(ierr);
31085558d50Scamierjs         pre /= P;
31185558d50Scamierjs         post *= Q;
31285558d50Scamierjs       }
313*5e0799b1SVeselin Dobrev       if (tmode == CEED_NOTRANSPOSE) {
314*5e0799b1SVeselin Dobrev         v += nqpt;
31585558d50Scamierjs       } else {
316*5e0799b1SVeselin Dobrev         u += nqpt;
317*5e0799b1SVeselin Dobrev       }
31885558d50Scamierjs     }
31985558d50Scamierjs   }
32085558d50Scamierjs   if (emode & CEED_EVAL_WEIGHT) {
321ae3cba82Scamierjs     if (tmode == CEED_TRANSPOSE)
322e3df9412Scamierjs       return CeedError(basis->ceed, 1,
323e3df9412Scamierjs                        "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
324ae3cba82Scamierjs     CeedInt Q = basis->Q1d;
325ae3cba82Scamierjs     for (CeedInt d=0; d<dim; d++) {
326ae3cba82Scamierjs       CeedInt pre = CeedPowInt(Q, dim-d-1), post = CeedPowInt(Q, d);
327ae3cba82Scamierjs       for (CeedInt i=0; i<pre; i++) {
328ae3cba82Scamierjs         for (CeedInt j=0; j<Q; j++) {
329ae3cba82Scamierjs           for (CeedInt k=0; k<post; k++) {
330ae3cba82Scamierjs             v[(i*Q + j)*post + k] = basis->qweight1d[j]
331ae3cba82Scamierjs                                     * (d == 0 ? 1 : v[(i*Q + j)*post + k]);
332ae3cba82Scamierjs           }
333ae3cba82Scamierjs         }
334ae3cba82Scamierjs       }
335ae3cba82Scamierjs     }
336ae3cba82Scamierjs   }
337ae3cba82Scamierjs   return 0;
338ae3cba82Scamierjs }
339ae3cba82Scamierjs 
340ae3cba82Scamierjs static int CeedBasisDestroy_Ref(CeedBasis basis) {
341ae3cba82Scamierjs   return 0;
342ae3cba82Scamierjs }
343ae3cba82Scamierjs 
344ae3cba82Scamierjs static int CeedBasisCreateTensorH1_Ref(Ceed ceed, CeedInt dim, CeedInt P1d,
345ae3cba82Scamierjs                                        CeedInt Q1d, const CeedScalar *interp1d,
346ae3cba82Scamierjs                                        const CeedScalar *grad1d,
347ae3cba82Scamierjs                                        const CeedScalar *qref1d,
348ae3cba82Scamierjs                                        const CeedScalar *qweight1d,
349ae3cba82Scamierjs                                        CeedBasis basis) {
350ae3cba82Scamierjs   basis->Apply = CeedBasisApply_Ref;
351ae3cba82Scamierjs   basis->Destroy = CeedBasisDestroy_Ref;
352ae3cba82Scamierjs   return 0;
353ae3cba82Scamierjs }
354ae3cba82Scamierjs 
355ae3cba82Scamierjs static int CeedQFunctionApply_Ref(CeedQFunction qf, void *qdata, CeedInt Q,
356ae3cba82Scamierjs                                   const CeedScalar *const *u,
357ae3cba82Scamierjs                                   CeedScalar *const *v) {
358ae3cba82Scamierjs   int ierr;
359ae3cba82Scamierjs   ierr = qf->function(qf->ctx, qdata, Q, u, v); CeedChk(ierr);
360ae3cba82Scamierjs   return 0;
361ae3cba82Scamierjs }
362ae3cba82Scamierjs 
363ae3cba82Scamierjs static int CeedQFunctionDestroy_Ref(CeedQFunction qf) {
364ae3cba82Scamierjs   return 0;
365ae3cba82Scamierjs }
366ae3cba82Scamierjs 
367ae3cba82Scamierjs static int CeedQFunctionCreate_Ref(CeedQFunction qf) {
368ae3cba82Scamierjs   qf->Apply = CeedQFunctionApply_Ref;
369ae3cba82Scamierjs   qf->Destroy = CeedQFunctionDestroy_Ref;
370ae3cba82Scamierjs   return 0;
371ae3cba82Scamierjs }
372ae3cba82Scamierjs 
373ae3cba82Scamierjs static int CeedOperatorDestroy_Ref(CeedOperator op) {
374ae3cba82Scamierjs   CeedOperator_Ref *impl = op->data;
375ae3cba82Scamierjs   int ierr;
376ae3cba82Scamierjs 
377ae3cba82Scamierjs   ierr = CeedVectorDestroy(&impl->etmp); CeedChk(ierr);
378ae3cba82Scamierjs   ierr = CeedVectorDestroy(&impl->qdata); CeedChk(ierr);
379ae3cba82Scamierjs   ierr = CeedFree(&op->data); CeedChk(ierr);
380ae3cba82Scamierjs   return 0;
381ae3cba82Scamierjs }
382ae3cba82Scamierjs 
383ae3cba82Scamierjs static int CeedOperatorApply_Ref(CeedOperator op, CeedVector qdata,
384ae3cba82Scamierjs                                  CeedVector ustate,
385ae3cba82Scamierjs                                  CeedVector residual, CeedRequest *request) {
386ae3cba82Scamierjs   CeedOperator_Ref *impl = op->data;
387ae3cba82Scamierjs   CeedVector etmp;
388ae3cba82Scamierjs   CeedInt Q;
38985558d50Scamierjs   const CeedInt nc = op->basis->ndof, dim = op->basis->dim;
390ae3cba82Scamierjs   CeedScalar *Eu;
391ae3cba82Scamierjs   char *qd;
392ae3cba82Scamierjs   int ierr;
39385558d50Scamierjs   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
394ae3cba82Scamierjs 
395ae3cba82Scamierjs   if (!impl->etmp) {
396ae3cba82Scamierjs     ierr = CeedVectorCreate(op->ceed,
39785558d50Scamierjs                             nc * op->Erestrict->nelem * op->Erestrict->elemsize,
398ae3cba82Scamierjs                             &impl->etmp); CeedChk(ierr);
39985558d50Scamierjs     // etmp is allocated when CeedVectorGetArray is called below
400ae3cba82Scamierjs   }
401ae3cba82Scamierjs   etmp = impl->etmp;
40285558d50Scamierjs   if (op->qf->inmode & ~CEED_EVAL_WEIGHT) {
40385558d50Scamierjs     ierr = CeedElemRestrictionApply(op->Erestrict, CEED_NOTRANSPOSE,
40485558d50Scamierjs                                     nc, lmode, ustate, etmp,
405ae3cba82Scamierjs                                     CEED_REQUEST_IMMEDIATE); CeedChk(ierr);
406ae3cba82Scamierjs   }
407ae3cba82Scamierjs   ierr = CeedBasisGetNumQuadraturePoints(op->basis, &Q); CeedChk(ierr);
408ae3cba82Scamierjs   ierr = CeedVectorGetArray(etmp, CEED_MEM_HOST, &Eu); CeedChk(ierr);
409e3df9412Scamierjs   ierr = CeedVectorGetArray(qdata, CEED_MEM_HOST, (CeedScalar**)&qd);
410e3df9412Scamierjs   CeedChk(ierr);
411ae3cba82Scamierjs   for (CeedInt e=0; e<op->Erestrict->nelem; e++) {
41285558d50Scamierjs     CeedScalar BEu[Q*nc*(dim+2)], BEv[Q*nc*(dim+2)], *out[5] = {0,0,0,0,0};
41385558d50Scamierjs     const CeedScalar *in[5] = {0,0,0,0,0};
41485558d50Scamierjs     // TODO: quadrature weights can be computed just once
415ae3cba82Scamierjs     ierr = CeedBasisApply(op->basis, CEED_NOTRANSPOSE, op->qf->inmode,
41685558d50Scamierjs                           &Eu[e*op->Erestrict->elemsize*nc], BEu);
41785558d50Scamierjs     CeedChk(ierr);
41885558d50Scamierjs     CeedScalar *u_ptr = BEu, *v_ptr = BEv;
41985558d50Scamierjs     if (op->qf->inmode & CEED_EVAL_INTERP) { in[0] = u_ptr; u_ptr += Q*nc; }
42085558d50Scamierjs     if (op->qf->inmode & CEED_EVAL_GRAD) { in[1] = u_ptr; u_ptr += Q*nc*dim; }
42185558d50Scamierjs     if (op->qf->inmode & CEED_EVAL_WEIGHT) { in[4] = u_ptr; u_ptr += Q; }
42285558d50Scamierjs     if (op->qf->outmode & CEED_EVAL_INTERP) { out[0] = v_ptr; v_ptr += Q*nc; }
42385558d50Scamierjs     if (op->qf->outmode & CEED_EVAL_GRAD) { out[1] = v_ptr; v_ptr += Q*nc*dim; }
424ae3cba82Scamierjs     ierr = CeedQFunctionApply(op->qf, &qd[e*Q*op->qf->qdatasize], Q, in, out);
425ae3cba82Scamierjs     CeedChk(ierr);
426ae3cba82Scamierjs     ierr = CeedBasisApply(op->basis, CEED_TRANSPOSE, op->qf->outmode, BEv,
42785558d50Scamierjs                           &Eu[e*op->Erestrict->elemsize*nc]);
42885558d50Scamierjs     CeedChk(ierr);
429ae3cba82Scamierjs   }
430ae3cba82Scamierjs   ierr = CeedVectorRestoreArray(etmp, &Eu); CeedChk(ierr);
431ae3cba82Scamierjs   if (residual) {
43285558d50Scamierjs     CeedScalar *res;
43385558d50Scamierjs     CeedVectorGetArray(residual, CEED_MEM_HOST, &res);
43485558d50Scamierjs     for (int i = 0; i < residual->length; i++)
43585558d50Scamierjs       res[i] = (CeedScalar)0;
43685558d50Scamierjs     ierr = CeedElemRestrictionApply(op->Erestrict, CEED_TRANSPOSE,
43785558d50Scamierjs                                     nc, lmode, etmp, residual,
438ae3cba82Scamierjs                                     CEED_REQUEST_IMMEDIATE); CeedChk(ierr);
439ae3cba82Scamierjs   }
440*5e0799b1SVeselin Dobrev   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_NULL)
441*5e0799b1SVeselin Dobrev     *request = NULL;
442ae3cba82Scamierjs   return 0;
443ae3cba82Scamierjs }
444ae3cba82Scamierjs 
445ae3cba82Scamierjs static int CeedOperatorGetQData_Ref(CeedOperator op, CeedVector *qdata) {
446ae3cba82Scamierjs   CeedOperator_Ref *impl = op->data;
447ae3cba82Scamierjs   int ierr;
448ae3cba82Scamierjs 
449ae3cba82Scamierjs   if (!impl->qdata) {
450ae3cba82Scamierjs     CeedInt Q;
451ae3cba82Scamierjs     ierr = CeedBasisGetNumQuadraturePoints(op->basis, &Q); CeedChk(ierr);
452ae3cba82Scamierjs     ierr = CeedVectorCreate(op->ceed,
4530f253a1aSTzanio                             op->Erestrict->nelem * Q
4540f253a1aSTzanio                             * op->qf->qdatasize / sizeof(CeedScalar),
455ae3cba82Scamierjs                             &impl->qdata); CeedChk(ierr);
456ae3cba82Scamierjs   }
457ae3cba82Scamierjs   *qdata = impl->qdata;
458ae3cba82Scamierjs   return 0;
459ae3cba82Scamierjs }
460ae3cba82Scamierjs 
461ae3cba82Scamierjs static int CeedOperatorCreate_Ref(CeedOperator op) {
462ae3cba82Scamierjs   CeedOperator_Ref *impl;
463ae3cba82Scamierjs   int ierr;
464ae3cba82Scamierjs 
465ae3cba82Scamierjs   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
466ae3cba82Scamierjs   op->data = impl;
467ae3cba82Scamierjs   op->Destroy = CeedOperatorDestroy_Ref;
468ae3cba82Scamierjs   op->Apply = CeedOperatorApply_Ref;
469ae3cba82Scamierjs   op->GetQData = CeedOperatorGetQData_Ref;
470ae3cba82Scamierjs   return 0;
471ae3cba82Scamierjs }
472ae3cba82Scamierjs 
473ae3cba82Scamierjs static int CeedInit_Ref(const char *resource, Ceed ceed) {
474ae3cba82Scamierjs   if (strcmp(resource, "/cpu/self")
475ae3cba82Scamierjs       && strcmp(resource, "/cpu/self/ref"))
476ae3cba82Scamierjs     return CeedError(ceed, 1, "Ref backend cannot use resource: %s", resource);
477ae3cba82Scamierjs   ceed->VecCreate = CeedVectorCreate_Ref;
478ae3cba82Scamierjs   ceed->BasisCreateTensorH1 = CeedBasisCreateTensorH1_Ref;
479ae3cba82Scamierjs   ceed->ElemRestrictionCreate = CeedElemRestrictionCreate_Ref;
480ae3cba82Scamierjs   ceed->QFunctionCreate = CeedQFunctionCreate_Ref;
481ae3cba82Scamierjs   ceed->OperatorCreate = CeedOperatorCreate_Ref;
482ae3cba82Scamierjs   return 0;
483ae3cba82Scamierjs }
484ae3cba82Scamierjs 
485ae3cba82Scamierjs __attribute__((constructor))
486ae3cba82Scamierjs static void Register(void) {
487ae3cba82Scamierjs   CeedRegister("/cpu/self/ref", CeedInit_Ref);
488ae3cba82Scamierjs }
489