xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 2f4d9adb86426a8397f9d81f8b0a0e6244dbfdbd)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
31 
32   for (CeedInt i=0; i<impl->numein; i++) {
33     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
34     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
35   }
36   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
37   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
38 
39   for (CeedInt i=0; i<impl->numeout; i++) {
40     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
41     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
42   }
43   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
44   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
45 
46   ierr = CeedFree(&impl); CeedChk(ierr);
47   return 0;
48 }
49 
50 /*
51   Setup infields or outfields
52  */
53 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
54                                        bool inOrOut,
55                                        CeedVector *fullevecs, CeedVector *evecs,
56                                        CeedVector *qvecs, CeedInt starte,
57                                        CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, ncomp;
59   Ceed ceed;
60   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
61   CeedBasis basis;
62   CeedElemRestriction Erestrict;
63   CeedOperatorField *opfields;
64   CeedQFunctionField *qffields;
65   if (inOrOut) {
66     ierr = CeedOperatorGetFields(op, NULL, &opfields);
67     CeedChk(ierr);
68     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
69     CeedChk(ierr);
70   } else {
71     ierr = CeedOperatorGetFields(op, &opfields, NULL);
72     CeedChk(ierr);
73     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
74     CeedChk(ierr);
75   }
76 
77   // Loop over fields
78   for (CeedInt i=0; i<numfields; i++) {
79     CeedEvalMode emode;
80     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
81 
82     if (emode != CEED_EVAL_WEIGHT) {
83       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
84       CeedChk(ierr);
85       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
86                                              &fullevecs[i+starte]);
87       CeedChk(ierr);
88     }
89 
90     switch(emode) {
91     case CEED_EVAL_NONE:
92       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
93       CeedChk(ierr);
94       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
95       break;
96     case CEED_EVAL_INTERP:
97       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
98       CeedChk(ierr);
99       ierr = CeedVectorCreate(ceed, Q*ncomp, &evecs[i]); CeedChk(ierr);
100       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
101       break;
102     case CEED_EVAL_GRAD:
103       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
104       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
105       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
106       ierr = CeedVectorCreate(ceed, Q*ncomp, &evecs[i]); CeedChk(ierr);
107       ierr = CeedVectorCreate(ceed, Q*ncomp*dim, &qvecs[i]); CeedChk(ierr);
108       break;
109     case CEED_EVAL_WEIGHT: // Only on input fields
110       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
111       ierr = CeedVectorCreate(ceed, Q, &qvecs[i]); CeedChk(ierr);
112       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
113                             NULL, qvecs[i]); CeedChk(ierr);
114       break;
115     case CEED_EVAL_DIV:
116       break; // Not implimented
117     case CEED_EVAL_CURL:
118       break; // Not implimented
119     }
120   }
121   return 0;
122 }
123 
124 /*
125   CeedOperator needs to connect all the named fields (be they active or passive)
126   to the named inputs and outputs of its CeedQFunction.
127  */
128 static int CeedOperatorSetup_Ref(CeedOperator op) {
129   int ierr;
130   bool setupdone;
131   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
132   if (setupdone) return 0;
133   Ceed ceed;
134   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
135   CeedOperator_Ref *impl;
136   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
137   CeedQFunction qf;
138   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
139   CeedInt Q, numinputfields, numoutputfields;
140   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
141   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
142   CeedChk(ierr);
143   CeedOperatorField *opinputfields, *opoutputfields;
144   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
145   CeedChk(ierr);
146   CeedQFunctionField *qfinputfields, *qfoutputfields;
147   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
148   CeedChk(ierr);
149 
150   // Allocate
151   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
152   CeedChk(ierr);
153   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
154   CeedChk(ierr);
155 
156   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
157   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
158   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
159   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
160   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
161 
162   impl->numein = numinputfields; impl->numeout = numoutputfields;
163 
164   // Set up infield and outfield evecs and qvecs
165   // Infields
166   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->evecs,
167                                      impl->evecsin, impl->qvecsin, 0,
168                                      numinputfields, Q);
169   CeedChk(ierr);
170 
171   // Outfields
172   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->evecs,
173                                      impl->evecsout, impl->qvecsout,
174                                      numinputfields, numoutputfields, Q);
175   CeedChk(ierr);
176 
177   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
178 
179   return 0;
180 }
181 
182 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
183                                  CeedVector outvec, CeedRequest *request) {
184   int ierr;
185   CeedOperator_Ref *impl;
186   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
187   CeedQFunction qf;
188   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
189   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
190   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
191   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
192   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
193   CeedChk(ierr);
194   CeedTransposeMode lmode;
195   CeedOperatorField *opinputfields, *opoutputfields;
196   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
197   CeedChk(ierr);
198   CeedQFunctionField *qfinputfields, *qfoutputfields;
199   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
200   CeedChk(ierr);
201   CeedEvalMode emode;
202   CeedVector vec;
203   CeedBasis basis;
204   CeedElemRestriction Erestrict;
205   uint64_t state;
206 
207   // Setup
208   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
209 
210   // Input Evecs and Restriction
211   for (CeedInt i=0; i<numinputfields; i++) {
212     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
213     CeedChk(ierr);
214     if (emode == CEED_EVAL_WEIGHT) { // Skip
215     } else {
216       // Get input vector
217       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
218       if (vec == CEED_VECTOR_ACTIVE)
219         vec = invec;
220       // Restrict
221       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
222       // Skip restriction if input is unchanged
223       if (state != impl->inputstate[i] || vec == invec) {
224         ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
225         CeedChk(ierr);
226         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
227         ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
228                                         lmode, vec, impl->evecs[i],
229                                         request); CeedChk(ierr);
230         impl->inputstate[i] = state;
231       }
232       // Get evec
233       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
234                                     (const CeedScalar **) &impl->edata[i]);
235       CeedChk(ierr);
236     }
237   }
238 
239   // Output Evecs
240   for (CeedInt i=0; i<numoutputfields; i++) {
241     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
242                               &impl->edata[i + numinputfields]); CeedChk(ierr);
243   }
244 
245   // Loop through elements
246   for (CeedInt e=0; e<numelements; e++) {
247     // Input basis apply if needed
248     for (CeedInt i=0; i<numinputfields; i++) {
249       // Get elemsize, emode, ncomp
250       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
251       CeedChk(ierr);
252       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
253       CeedChk(ierr);
254       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
255       CeedChk(ierr);
256       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
257       CeedChk(ierr);
258       // Basis action
259       switch(emode) {
260       case CEED_EVAL_NONE:
261         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
262                                   CEED_USE_POINTER,
263                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
264         break;
265       case CEED_EVAL_INTERP:
266         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
267         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
268                                   CEED_USE_POINTER,
269                                   &impl->edata[i][e*elemsize*ncomp]);
270         CeedChk(ierr);
271         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
272                               CEED_EVAL_INTERP, impl->evecsin[i],
273                               impl->qvecsin[i]); CeedChk(ierr);
274         break;
275       case CEED_EVAL_GRAD:
276         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
277         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
278                                   CEED_USE_POINTER,
279                                   &impl->edata[i][e*elemsize*ncomp]);
280         CeedChk(ierr);
281         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
282                               CEED_EVAL_GRAD, impl->evecsin[i],
283                               impl->qvecsin[i]); CeedChk(ierr);
284         break;
285       case CEED_EVAL_WEIGHT:
286         break;  // No action
287       case CEED_EVAL_DIV:
288         break; // Not implimented
289       case CEED_EVAL_CURL:
290         break; // Not implimented
291       }
292     }
293     // Output pointers
294     for (CeedInt i=0; i<numoutputfields; i++) {
295       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
296       CeedChk(ierr);
297       if (emode == CEED_EVAL_NONE) {
298         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
299         CeedChk(ierr);
300         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
301                                   CEED_USE_POINTER,
302                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
303         CeedChk(ierr);
304       }
305     }
306     // Q function
307     ierr = CeedQFunctionApply(qf, Q, impl->qvecsin, impl->qvecsout); CeedChk(ierr);
308 
309     // Output basis apply if needed
310     for (CeedInt i=0; i<numoutputfields; i++) {
311       // Get elemsize, emode, ncomp
312       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
313       CeedChk(ierr);
314       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
315       CeedChk(ierr);
316       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
317       CeedChk(ierr);
318       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
319       CeedChk(ierr);
320       // Basis action
321       switch(emode) {
322       case CEED_EVAL_NONE:
323         break; // No action
324       case CEED_EVAL_INTERP:
325         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
326         CeedChk(ierr);
327         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
328                                   CEED_USE_POINTER,
329                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
330         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
331                               CEED_EVAL_INTERP, impl->qvecsout[i],
332                               impl->evecsout[i]); CeedChk(ierr);
333         break;
334       case CEED_EVAL_GRAD:
335         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
336         CeedChk(ierr);
337         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
338                                   CEED_USE_POINTER,
339                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
340         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
341                               CEED_EVAL_GRAD, impl->qvecsout[i],
342                               impl->evecsout[i]); CeedChk(ierr);
343         break;
344       case CEED_EVAL_WEIGHT: {
345         Ceed ceed;
346         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
347         return CeedError(ceed, 1,
348                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
349         break; // Should not occur
350       }
351       case CEED_EVAL_DIV:
352         break; // Not implimented
353       case CEED_EVAL_CURL:
354         break; // Not implimented
355       }
356     }
357   }
358 
359   // Zero lvecs
360   for (CeedInt i=0; i<numoutputfields; i++) {
361     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
362     if (vec == CEED_VECTOR_ACTIVE)
363       vec = outvec;
364     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
365   }
366 
367   // Output restriction
368   for (CeedInt i=0; i<numoutputfields; i++) {
369     // Restore evec
370     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
371                                   &impl->edata[i + numinputfields]);
372     CeedChk(ierr);
373     // Get output vector
374     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
375     // Active
376     if (vec == CEED_VECTOR_ACTIVE)
377       vec = outvec;
378     // Restrict
379     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
380     CeedChk(ierr);
381     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
382     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
383                                     lmode, impl->evecs[i+impl->numein], vec,
384                                     request); CeedChk(ierr);
385   }
386 
387   // Restore input arrays
388   for (CeedInt i=0; i<numinputfields; i++) {
389     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
390     CeedChk(ierr);
391     if (emode == CEED_EVAL_WEIGHT) { // Skip
392     } else {
393       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
394                                         (const CeedScalar **) &impl->edata[i]);
395       CeedChk(ierr);
396     }
397   }
398 
399   return 0;
400 }
401 
402 int CeedOperatorCreate_Ref(CeedOperator op) {
403   int ierr;
404   Ceed ceed;
405   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
406   CeedOperator_Ref *impl;
407 
408   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
409   ierr = CeedOperatorSetData(op, (void*)&impl);
410 
411   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
412                                 CeedOperatorApply_Ref); CeedChk(ierr);
413   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
414                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
415   return 0;
416 }
417