xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 9ad453579b4f8e35becf925ceedf519a25efe885)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-ref.h"
18 
19 static int CeedOperatorDestroy_Ref(CeedOperator op) {
20   int ierr;
21   CeedOperator_Ref *impl;
22   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
23 
24   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
25     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
26   }
27   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
28   ierr = CeedFree(&impl->edata); CeedChk(ierr);
29   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numein; i++) {
32     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
33     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
34   }
35   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
36   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
37 
38   for (CeedInt i=0; i<impl->numeout; i++) {
39     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
40     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
41   }
42   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
43   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
44 
45   ierr = CeedFree(&impl); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
53                                        bool inOrOut,
54                                        CeedVector *fullevecs, CeedVector *evecs,
55                                        CeedVector *qvecs, CeedInt starte,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, ncomp, P;
58   Ceed ceed;
59   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
60   CeedBasis basis;
61   CeedElemRestriction Erestrict;
62   CeedOperatorField *opfields;
63   CeedQFunctionField *qffields;
64   if (inOrOut) {
65     ierr = CeedOperatorGetFields(op, NULL, &opfields);
66     CeedChk(ierr);
67     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
68     CeedChk(ierr);
69   } else {
70     ierr = CeedOperatorGetFields(op, &opfields, NULL);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
73     CeedChk(ierr);
74   }
75 
76   // Loop over fields
77   for (CeedInt i=0; i<numfields; i++) {
78     CeedEvalMode emode;
79     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
80 
81     if (emode != CEED_EVAL_WEIGHT) {
82       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
83       CeedChk(ierr);
84       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
85                                              &fullevecs[i+starte]);
86       CeedChk(ierr);
87     }
88 
89     switch(emode) {
90     case CEED_EVAL_NONE:
91       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
92       CeedChk(ierr);
93       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
94       break;
95     case CEED_EVAL_INTERP:
96       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
97       CeedChk(ierr);
98       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
99       CeedChk(ierr);
100       ierr = CeedVectorCreate(ceed, P*ncomp, &evecs[i]); CeedChk(ierr);
101       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
102       break;
103     case CEED_EVAL_GRAD:
104       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
105       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
106       CeedChk(ierr);
107       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
108       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
109       CeedChk(ierr);
110       ierr = CeedVectorCreate(ceed, P*ncomp, &evecs[i]); CeedChk(ierr);
111       ierr = CeedVectorCreate(ceed, Q*ncomp*dim, &qvecs[i]); CeedChk(ierr);
112       break;
113     case CEED_EVAL_WEIGHT: // Only on input fields
114       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
115       ierr = CeedVectorCreate(ceed, Q, &qvecs[i]); CeedChk(ierr);
116       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
117                             NULL, qvecs[i]); CeedChk(ierr);
118       break;
119     case CEED_EVAL_DIV:
120       break; // Not implimented
121     case CEED_EVAL_CURL:
122       break; // Not implimented
123     }
124   }
125   return 0;
126 }
127 
128 /*
129   CeedOperator needs to connect all the named fields (be they active or passive)
130   to the named inputs and outputs of its CeedQFunction.
131  */
132 static int CeedOperatorSetup_Ref(CeedOperator op) {
133   int ierr;
134   bool setupdone;
135   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
136   if (setupdone) return 0;
137   Ceed ceed;
138   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
139   CeedOperator_Ref *impl;
140   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
141   CeedQFunction qf;
142   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
143   CeedInt Q, numinputfields, numoutputfields;
144   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
145   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
146   CeedChk(ierr);
147   CeedOperatorField *opinputfields, *opoutputfields;
148   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
149   CeedChk(ierr);
150   CeedQFunctionField *qfinputfields, *qfoutputfields;
151   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
152   CeedChk(ierr);
153 
154   // Allocate
155   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
156   CeedChk(ierr);
157   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
158   CeedChk(ierr);
159 
160   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
161   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
162   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
163   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
164   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
165 
166   impl->numein = numinputfields; impl->numeout = numoutputfields;
167 
168   // Set up infield and outfield evecs and qvecs
169   // Infields
170   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->evecs,
171                                      impl->evecsin, impl->qvecsin, 0,
172                                      numinputfields, Q);
173   CeedChk(ierr);
174 
175   // Outfields
176   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->evecs,
177                                      impl->evecsout, impl->qvecsout,
178                                      numinputfields, numoutputfields, Q);
179   CeedChk(ierr);
180 
181   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
182 
183   return 0;
184 }
185 
186 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
187                                  CeedVector outvec, CeedRequest *request) {
188   int ierr;
189   CeedOperator_Ref *impl;
190   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
191   CeedQFunction qf;
192   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
193   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
194   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
195   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
196   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
197   CeedChk(ierr);
198   CeedTransposeMode lmode;
199   CeedOperatorField *opinputfields, *opoutputfields;
200   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
201   CeedChk(ierr);
202   CeedQFunctionField *qfinputfields, *qfoutputfields;
203   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
204   CeedChk(ierr);
205   CeedEvalMode emode;
206   CeedVector vec;
207   CeedBasis basis;
208   CeedElemRestriction Erestrict;
209   uint64_t state;
210 
211   // Setup
212   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
213 
214   // Input Evecs and Restriction
215   for (CeedInt i=0; i<numinputfields; i++) {
216     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
217     CeedChk(ierr);
218     if (emode == CEED_EVAL_WEIGHT) { // Skip
219     } else {
220       // Get input vector
221       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
222       if (vec == CEED_VECTOR_ACTIVE)
223         vec = invec;
224       // Restrict
225       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
226       // Skip restriction if input is unchanged
227       if (state != impl->inputstate[i] || vec == invec) {
228         ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
229         CeedChk(ierr);
230         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
231         ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
232                                         lmode, vec, impl->evecs[i],
233                                         request); CeedChk(ierr);
234         impl->inputstate[i] = state;
235       }
236       // Get evec
237       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
238                                     (const CeedScalar **) &impl->edata[i]);
239       CeedChk(ierr);
240     }
241   }
242 
243   // Output Evecs
244   for (CeedInt i=0; i<numoutputfields; i++) {
245     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
246                               &impl->edata[i + numinputfields]); CeedChk(ierr);
247   }
248 
249   // Loop through elements
250   for (CeedInt e=0; e<numelements; e++) {
251     // Input basis apply if needed
252     for (CeedInt i=0; i<numinputfields; i++) {
253       // Get elemsize, emode, ncomp
254       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
255       CeedChk(ierr);
256       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
257       CeedChk(ierr);
258       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
259       CeedChk(ierr);
260       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
261       CeedChk(ierr);
262       // Basis action
263       switch(emode) {
264       case CEED_EVAL_NONE:
265         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
266                                   CEED_USE_POINTER,
267                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
268         break;
269       case CEED_EVAL_INTERP:
270         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
271         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
272                                   CEED_USE_POINTER,
273                                   &impl->edata[i][e*elemsize*ncomp]);
274         CeedChk(ierr);
275         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
276                               CEED_EVAL_INTERP, impl->evecsin[i],
277                               impl->qvecsin[i]); CeedChk(ierr);
278         break;
279       case CEED_EVAL_GRAD:
280         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
281         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
282                                   CEED_USE_POINTER,
283                                   &impl->edata[i][e*elemsize*ncomp]);
284         CeedChk(ierr);
285         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
286                               CEED_EVAL_GRAD, impl->evecsin[i],
287                               impl->qvecsin[i]); CeedChk(ierr);
288         break;
289       case CEED_EVAL_WEIGHT:
290         break;  // No action
291       case CEED_EVAL_DIV:
292         break; // Not implimented
293       case CEED_EVAL_CURL:
294         break; // Not implimented
295       }
296     }
297     // Output pointers
298     for (CeedInt i=0; i<numoutputfields; i++) {
299       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
300       CeedChk(ierr);
301       if (emode == CEED_EVAL_NONE) {
302         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
303         CeedChk(ierr);
304         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
305                                   CEED_USE_POINTER,
306                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
307         CeedChk(ierr);
308       }
309     }
310     // Q function
311     ierr = CeedQFunctionApply(qf, Q, impl->qvecsin, impl->qvecsout); CeedChk(ierr);
312 
313     // Output basis apply if needed
314     for (CeedInt i=0; i<numoutputfields; i++) {
315       // Get elemsize, emode, ncomp
316       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
317       CeedChk(ierr);
318       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
319       CeedChk(ierr);
320       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
321       CeedChk(ierr);
322       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
323       CeedChk(ierr);
324       // Basis action
325       switch(emode) {
326       case CEED_EVAL_NONE:
327         break; // No action
328       case CEED_EVAL_INTERP:
329         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
330         CeedChk(ierr);
331         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
332                                   CEED_USE_POINTER,
333                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
334         CeedChk(ierr);
335         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
336                               CEED_EVAL_INTERP, impl->qvecsout[i],
337                               impl->evecsout[i]); CeedChk(ierr);
338         break;
339       case CEED_EVAL_GRAD:
340         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
341         CeedChk(ierr);
342         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
343                                   CEED_USE_POINTER,
344                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
345         CeedChk(ierr);
346         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
347                               CEED_EVAL_GRAD, impl->qvecsout[i],
348                               impl->evecsout[i]); CeedChk(ierr);
349         break;
350       case CEED_EVAL_WEIGHT: {
351         Ceed ceed;
352         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
353         return CeedError(ceed, 1,
354                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
355         break; // Should not occur
356       }
357       case CEED_EVAL_DIV:
358         break; // Not implimented
359       case CEED_EVAL_CURL:
360         break; // Not implimented
361       }
362     }
363   }
364 
365   // Zero lvecs
366   for (CeedInt i=0; i<numoutputfields; i++) {
367     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
368     if (vec == CEED_VECTOR_ACTIVE) {
369       if (!impl->add) {
370         vec = outvec;
371         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
372       }
373     } else {
374       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
375     }
376   }
377   impl->add = false;
378 
379   // Output restriction
380   for (CeedInt i=0; i<numoutputfields; i++) {
381     // Restore evec
382     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
383                                   &impl->edata[i + numinputfields]);
384     CeedChk(ierr);
385     // Get output vector
386     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
387     // Active
388     if (vec == CEED_VECTOR_ACTIVE)
389       vec = outvec;
390     // Restrict
391     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
392     CeedChk(ierr);
393     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
394     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
395                                     lmode, impl->evecs[i+impl->numein], vec,
396                                     request); CeedChk(ierr);
397   }
398 
399   // Restore input arrays
400   for (CeedInt i=0; i<numinputfields; i++) {
401     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
402     CeedChk(ierr);
403     if (emode == CEED_EVAL_WEIGHT) { // Skip
404     } else {
405       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
406                                         (const CeedScalar **) &impl->edata[i]);
407       CeedChk(ierr);
408     }
409   }
410 
411   return 0;
412 }
413 
414 static int CeedCompositeOperatorApply_Ref(CeedOperator op, CeedVector invec,
415     CeedVector outvec,
416     CeedRequest *request) {
417   int ierr;
418   CeedInt numsub;
419   CeedOperator_Ref *impl;
420   CeedOperator *suboperators;
421   ierr = CeedOperatorGetNumSub(op, &numsub); CeedChk(ierr);
422   ierr = CeedOperatorGetSubList(op, &suboperators); CeedChk(ierr);
423 
424   // Overwrite outvec with first output
425   ierr = CeedOperatorApply(suboperators[0], invec, outvec, request);
426   CeedChk(ierr);
427   // Add to outvec with subsequent outputs
428   for (CeedInt i=1; i<numsub; i++) {
429     ierr = CeedOperatorGetData(suboperators[i], (void *)&impl); CeedChk(ierr);
430     impl->add = true;
431     ierr = CeedOperatorApply(suboperators[i], invec, outvec, request);
432     CeedChk(ierr);
433   }
434 
435   return 0;
436 }
437 
438 int CeedOperatorCreate_Ref(CeedOperator op) {
439   int ierr;
440   Ceed ceed;
441   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
442   CeedOperator_Ref *impl;
443 
444   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
445   impl->add = false;
446   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
447 
448   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
449                                 CeedOperatorApply_Ref); CeedChk(ierr);
450   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
451                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
452   return 0;
453 }
454 
455 int CeedCompositeOperatorCreate_Ref(CeedOperator op) {
456   int ierr;
457   Ceed ceed;
458   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
459 
460   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
461                                 CeedCompositeOperatorApply_Ref); CeedChk(ierr);
462   return 0;
463 }
464