xref: /libCEED/backends/ref/ceed-ref-operator.c (revision c042f62f62e10e1321eb699b116e67a6568d5716)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-ref.h"
18 
19 static int CeedOperatorDestroy_Ref(CeedOperator op) {
20   int ierr;
21   CeedOperator_Ref *impl;
22   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
23 
24   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
25     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
26   }
27   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
28   ierr = CeedFree(&impl->edata); CeedChk(ierr);
29   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numein; i++) {
32     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
33     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
34   }
35   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
36   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
37 
38   for (CeedInt i=0; i<impl->numeout; i++) {
39     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
40     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
41   }
42   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
43   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
44 
45   ierr = CeedFree(&impl); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
53                                        bool inOrOut,
54                                        CeedVector *fullevecs, CeedVector *evecs,
55                                        CeedVector *qvecs, CeedInt starte,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, size, P;
58   Ceed ceed;
59   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
60   CeedBasis basis;
61   CeedElemRestriction Erestrict;
62   CeedOperatorField *opfields;
63   CeedQFunctionField *qffields;
64   if (inOrOut) {
65     ierr = CeedOperatorGetFields(op, NULL, &opfields);
66     CeedChk(ierr);
67     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
68     CeedChk(ierr);
69   } else {
70     ierr = CeedOperatorGetFields(op, &opfields, NULL);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
73     CeedChk(ierr);
74   }
75 
76   // Loop over fields
77   for (CeedInt i=0; i<numfields; i++) {
78     CeedEvalMode emode;
79     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
80 
81     if (emode != CEED_EVAL_WEIGHT) {
82       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
83       CeedChk(ierr);
84       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
85                                              &fullevecs[i+starte]);
86       CeedChk(ierr);
87     }
88 
89     switch(emode) {
90     case CEED_EVAL_NONE:
91       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
92       ierr = CeedVectorCreate(ceed, Q*size, &qvecs[i]); CeedChk(ierr);
93       break;
94     case CEED_EVAL_INTERP:
95       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
97       CeedChk(ierr);
98       ierr = CeedVectorCreate(ceed, P*size, &evecs[i]); CeedChk(ierr);
99       ierr = CeedVectorCreate(ceed, Q*size, &qvecs[i]); CeedChk(ierr);
100       break;
101     case CEED_EVAL_GRAD:
102       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
103       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
104       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
105       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
106       CeedChk(ierr);
107       ierr = CeedVectorCreate(ceed, P*size/dim, &evecs[i]); CeedChk(ierr);
108       ierr = CeedVectorCreate(ceed, Q*size, &qvecs[i]); CeedChk(ierr);
109       break;
110     case CEED_EVAL_WEIGHT: // Only on input fields
111       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q, &qvecs[i]); CeedChk(ierr);
113       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
114                             NULL, qvecs[i]); CeedChk(ierr);
115       break;
116     case CEED_EVAL_DIV:
117       break; // Not implemented
118     case CEED_EVAL_CURL:
119       break; // Not implemented
120     }
121   }
122   return 0;
123 }
124 
125 /*
126   CeedOperator needs to connect all the named fields (be they active or passive)
127   to the named inputs and outputs of its CeedQFunction.
128  */
129 static int CeedOperatorSetup_Ref(CeedOperator op) {
130   int ierr;
131   bool setupdone;
132   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
133   if (setupdone) return 0;
134   Ceed ceed;
135   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
136   CeedOperator_Ref *impl;
137   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
138   CeedQFunction qf;
139   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
140   CeedInt Q, numinputfields, numoutputfields;
141   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
142   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
143   CeedChk(ierr);
144   CeedOperatorField *opinputfields, *opoutputfields;
145   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
146   CeedChk(ierr);
147   CeedQFunctionField *qfinputfields, *qfoutputfields;
148   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
149   CeedChk(ierr);
150 
151   // Allocate
152   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
153   CeedChk(ierr);
154   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
155   CeedChk(ierr);
156 
157   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
158   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
159   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
160   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
161   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
162 
163   impl->numein = numinputfields; impl->numeout = numoutputfields;
164 
165   // Set up infield and outfield evecs and qvecs
166   // Infields
167   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->evecs,
168                                      impl->evecsin, impl->qvecsin, 0,
169                                      numinputfields, Q);
170   CeedChk(ierr);
171 
172   // Outfields
173   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->evecs,
174                                      impl->evecsout, impl->qvecsout,
175                                      numinputfields, numoutputfields, Q);
176   CeedChk(ierr);
177 
178   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
179 
180   return 0;
181 }
182 
183 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
184                                  CeedVector outvec, CeedRequest *request) {
185   int ierr;
186   CeedOperator_Ref *impl;
187   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
188   CeedQFunction qf;
189   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
190   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, size, dim;
191   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
192   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
193   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
194   CeedChk(ierr);
195   CeedTransposeMode lmode;
196   CeedOperatorField *opinputfields, *opoutputfields;
197   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
198   CeedChk(ierr);
199   CeedQFunctionField *qfinputfields, *qfoutputfields;
200   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
201   CeedChk(ierr);
202   CeedEvalMode emode;
203   CeedVector vec;
204   CeedBasis basis;
205   CeedElemRestriction Erestrict;
206   uint64_t state;
207 
208   // Setup
209   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
210 
211   // Input Evecs and Restriction
212   for (CeedInt i=0; i<numinputfields; i++) {
213     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
214     CeedChk(ierr);
215     if (emode == CEED_EVAL_WEIGHT) { // Skip
216     } else {
217       // Get input vector
218       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
219       if (vec == CEED_VECTOR_ACTIVE)
220         vec = invec;
221       // Restrict
222       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
223       // Skip restriction if input is unchanged
224       if (state != impl->inputstate[i] || vec == invec) {
225         ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
226         CeedChk(ierr);
227         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
228         ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
229                                         lmode, vec, impl->evecs[i],
230                                         request); CeedChk(ierr);
231         impl->inputstate[i] = state;
232       }
233       // Get evec
234       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
235                                     (const CeedScalar **) &impl->edata[i]);
236       CeedChk(ierr);
237     }
238   }
239 
240   // Output Evecs
241   for (CeedInt i=0; i<numoutputfields; i++) {
242     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
243                               &impl->edata[i + numinputfields]); CeedChk(ierr);
244   }
245 
246   // Loop through elements
247   for (CeedInt e=0; e<numelements; e++) {
248     // Input basis apply if needed
249     for (CeedInt i=0; i<numinputfields; i++) {
250       // Get elemsize, emode, size
251       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
252       CeedChk(ierr);
253       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
254       CeedChk(ierr);
255       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
256       CeedChk(ierr);
257       ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChk(ierr);
258       // Basis action
259       switch(emode) {
260       case CEED_EVAL_NONE:
261         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
262                                   CEED_USE_POINTER,
263                                   &impl->edata[i][e*Q*size]); CeedChk(ierr);
264         break;
265       case CEED_EVAL_INTERP:
266         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
267         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
268                                   CEED_USE_POINTER,
269                                   &impl->edata[i][e*elemsize*size]);
270         CeedChk(ierr);
271         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
272                               CEED_EVAL_INTERP, impl->evecsin[i],
273                               impl->qvecsin[i]); CeedChk(ierr);
274         break;
275       case CEED_EVAL_GRAD:
276         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
277         ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
278         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
279                                   CEED_USE_POINTER,
280                                   &impl->edata[i][e*elemsize*size/dim]);
281         CeedChk(ierr);
282         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
283                               CEED_EVAL_GRAD, impl->evecsin[i],
284                               impl->qvecsin[i]); CeedChk(ierr);
285         break;
286       case CEED_EVAL_WEIGHT:
287         break;  // No action
288       case CEED_EVAL_DIV:
289         break; // Not implemented
290       case CEED_EVAL_CURL:
291         break; // Not implemented
292       }
293     }
294     // Output pointers
295     for (CeedInt i=0; i<numoutputfields; i++) {
296       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
297       CeedChk(ierr);
298       if (emode == CEED_EVAL_NONE) {
299         ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size);
300         CeedChk(ierr);
301         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
302                                   CEED_USE_POINTER,
303                                   &impl->edata[i + numinputfields][e*Q*size]);
304         CeedChk(ierr);
305       }
306     }
307     // Q function
308     ierr = CeedQFunctionApply(qf, Q, impl->qvecsin, impl->qvecsout); CeedChk(ierr);
309 
310     // Output basis apply if needed
311     for (CeedInt i=0; i<numoutputfields; i++) {
312       // Get elemsize, emode, size
313       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
314       CeedChk(ierr);
315       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
316       CeedChk(ierr);
317       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
318       CeedChk(ierr);
319       ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size); CeedChk(ierr);
320       // Basis action
321       switch(emode) {
322       case CEED_EVAL_NONE:
323         break; // No action
324       case CEED_EVAL_INTERP:
325         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
326         CeedChk(ierr);
327         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
328                                   CEED_USE_POINTER,
329                                   &impl->edata[i + numinputfields][e*elemsize*size]);
330         CeedChk(ierr);
331         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
332                               CEED_EVAL_INTERP, impl->qvecsout[i],
333                               impl->evecsout[i]); CeedChk(ierr);
334         break;
335       case CEED_EVAL_GRAD:
336         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
337         CeedChk(ierr);
338         ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
339         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
340                                   CEED_USE_POINTER,
341                                   &impl->edata[i + numinputfields][e*elemsize*size/dim]);
342         CeedChk(ierr);
343         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
344                               CEED_EVAL_GRAD, impl->qvecsout[i],
345                               impl->evecsout[i]); CeedChk(ierr);
346         break;
347       case CEED_EVAL_WEIGHT: {
348         // LCOV_EXCL_START
349         Ceed ceed;
350         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
351         return CeedError(ceed, 1,
352                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
353         // LCOV_EXCL_STOP
354         break; // Should not occur
355       }
356       case CEED_EVAL_DIV:
357         break; // Not implemented
358       case CEED_EVAL_CURL:
359         break; // Not implemented
360       }
361     }
362   }
363 
364   // Zero lvecs
365   for (CeedInt i=0; i<numoutputfields; i++) {
366     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
367     if (vec == CEED_VECTOR_ACTIVE) {
368       if (!impl->add) {
369         vec = outvec;
370         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
371       }
372     } else {
373       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
374     }
375   }
376   impl->add = false;
377 
378   // Output restriction
379   for (CeedInt i=0; i<numoutputfields; i++) {
380     // Restore evec
381     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
382                                   &impl->edata[i + numinputfields]);
383     CeedChk(ierr);
384     // Get output vector
385     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
386     // Active
387     if (vec == CEED_VECTOR_ACTIVE)
388       vec = outvec;
389     // Restrict
390     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
391     CeedChk(ierr);
392     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
393     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
394                                     lmode, impl->evecs[i+impl->numein], vec,
395                                     request); CeedChk(ierr);
396   }
397 
398   // Restore input arrays
399   for (CeedInt i=0; i<numinputfields; i++) {
400     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
401     CeedChk(ierr);
402     if (emode == CEED_EVAL_WEIGHT) { // Skip
403     } else {
404       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
405                                         (const CeedScalar **) &impl->edata[i]);
406       CeedChk(ierr);
407     }
408   }
409 
410   return 0;
411 }
412 
413 static int CeedCompositeOperatorApply_Ref(CeedOperator op, CeedVector invec,
414     CeedVector outvec,
415     CeedRequest *request) {
416   int ierr;
417   CeedInt numsub;
418   CeedOperator_Ref *impl;
419   CeedOperator *suboperators;
420   ierr = CeedOperatorGetNumSub(op, &numsub); CeedChk(ierr);
421   ierr = CeedOperatorGetSubList(op, &suboperators); CeedChk(ierr);
422 
423   // Overwrite outvec with first output
424   ierr = CeedOperatorApply(suboperators[0], invec, outvec, request);
425   CeedChk(ierr);
426   // Add to outvec with subsequent outputs
427   for (CeedInt i=1; i<numsub; i++) {
428     ierr = CeedOperatorGetData(suboperators[i], (void *)&impl); CeedChk(ierr);
429     impl->add = true;
430     ierr = CeedOperatorApply(suboperators[i], invec, outvec, request);
431     CeedChk(ierr);
432   }
433 
434   return 0;
435 }
436 
437 int CeedOperatorCreate_Ref(CeedOperator op) {
438   int ierr;
439   Ceed ceed;
440   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
441   CeedOperator_Ref *impl;
442 
443   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
444   impl->add = false;
445   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
446 
447   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
448                                 CeedOperatorApply_Ref); CeedChk(ierr);
449   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
450                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
451   return 0;
452 }
453 
454 int CeedCompositeOperatorCreate_Ref(CeedOperator op) {
455   int ierr;
456   Ceed ceed;
457   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
458 
459   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
460                                 CeedCompositeOperatorApply_Ref); CeedChk(ierr);
461   return 0;
462 }
463