xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 2774d5cba221882c6b4100a2009c80f26d4c6f5c)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
31 
32   for (CeedInt i=0; i<impl->numein; i++) {
33     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
34     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
35   }
36   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
37   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
38 
39   for (CeedInt i=0; i<impl->numeout; i++) {
40     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
41     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
42   }
43   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
44   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
45 
46   ierr = CeedFree(&impl); CeedChk(ierr);
47   return 0;
48 }
49 
50 /*
51   Setup infields or outfields
52  */
53 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
54                                        bool inOrOut,
55                                        CeedVector *fullevecs, CeedVector *evecs,
56                                        CeedVector *qvecs, CeedInt starte,
57                                        CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, ncomp, P;
59   Ceed ceed;
60   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
61   CeedBasis basis;
62   CeedElemRestriction Erestrict;
63   CeedOperatorField *opfields;
64   CeedQFunctionField *qffields;
65   if (inOrOut) {
66     ierr = CeedOperatorGetFields(op, NULL, &opfields);
67     CeedChk(ierr);
68     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
69     CeedChk(ierr);
70   } else {
71     ierr = CeedOperatorGetFields(op, &opfields, NULL);
72     CeedChk(ierr);
73     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
74     CeedChk(ierr);
75   }
76 
77   // Loop over fields
78   for (CeedInt i=0; i<numfields; i++) {
79     CeedEvalMode emode;
80     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
81 
82     if (emode != CEED_EVAL_WEIGHT) {
83       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
84       CeedChk(ierr);
85       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
86                                              &fullevecs[i+starte]);
87       CeedChk(ierr);
88     }
89 
90     switch(emode) {
91     case CEED_EVAL_NONE:
92       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
93       CeedChk(ierr);
94       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
95       break;
96     case CEED_EVAL_INTERP:
97       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
98       CeedChk(ierr);
99       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
100       CeedChk(ierr);
101       ierr = CeedVectorCreate(ceed, P*ncomp, &evecs[i]); CeedChk(ierr);
102       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
103       break;
104     case CEED_EVAL_GRAD:
105       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
106       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
107       CeedChk(ierr);
108       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
109       ierr = CeedElemRestrictionGetElementSize(Erestrict, &P);
110       CeedChk(ierr);
111       ierr = CeedVectorCreate(ceed, P*ncomp, &evecs[i]); CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*ncomp*dim, &qvecs[i]); CeedChk(ierr);
113       break;
114     case CEED_EVAL_WEIGHT: // Only on input fields
115       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
116       ierr = CeedVectorCreate(ceed, Q, &qvecs[i]); CeedChk(ierr);
117       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
118                             NULL, qvecs[i]); CeedChk(ierr);
119       break;
120     case CEED_EVAL_DIV:
121       break; // Not implimented
122     case CEED_EVAL_CURL:
123       break; // Not implimented
124     }
125   }
126   return 0;
127 }
128 
129 /*
130   CeedOperator needs to connect all the named fields (be they active or passive)
131   to the named inputs and outputs of its CeedQFunction.
132  */
133 static int CeedOperatorSetup_Ref(CeedOperator op) {
134   int ierr;
135   bool setupdone;
136   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
137   if (setupdone) return 0;
138   Ceed ceed;
139   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
140   CeedOperator_Ref *impl;
141   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
142   CeedQFunction qf;
143   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
144   CeedInt Q, numinputfields, numoutputfields;
145   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
146   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
147   CeedChk(ierr);
148   CeedOperatorField *opinputfields, *opoutputfields;
149   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
150   CeedChk(ierr);
151   CeedQFunctionField *qfinputfields, *qfoutputfields;
152   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
153   CeedChk(ierr);
154 
155   // Allocate
156   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
157   CeedChk(ierr);
158   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
159   CeedChk(ierr);
160 
161   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
162   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
163   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
164   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
165   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
166 
167   impl->numein = numinputfields; impl->numeout = numoutputfields;
168 
169   // Set up infield and outfield evecs and qvecs
170   // Infields
171   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->evecs,
172                                      impl->evecsin, impl->qvecsin, 0,
173                                      numinputfields, Q);
174   CeedChk(ierr);
175 
176   // Outfields
177   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->evecs,
178                                      impl->evecsout, impl->qvecsout,
179                                      numinputfields, numoutputfields, Q);
180   CeedChk(ierr);
181 
182   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
183 
184   return 0;
185 }
186 
187 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
188                                  CeedVector outvec, CeedRequest *request) {
189   int ierr;
190   CeedOperator_Ref *impl;
191   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
192   CeedQFunction qf;
193   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
194   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
195   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
196   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
197   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
198   CeedChk(ierr);
199   CeedTransposeMode lmode;
200   CeedOperatorField *opinputfields, *opoutputfields;
201   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
202   CeedChk(ierr);
203   CeedQFunctionField *qfinputfields, *qfoutputfields;
204   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
205   CeedChk(ierr);
206   CeedEvalMode emode;
207   CeedVector vec;
208   CeedBasis basis;
209   CeedElemRestriction Erestrict;
210   uint64_t state;
211 
212   // Setup
213   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
214 
215   // Input Evecs and Restriction
216   for (CeedInt i=0; i<numinputfields; i++) {
217     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
218     CeedChk(ierr);
219     if (emode == CEED_EVAL_WEIGHT) { // Skip
220     } else {
221       // Get input vector
222       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
223       if (vec == CEED_VECTOR_ACTIVE)
224         vec = invec;
225       // Restrict
226       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
227       // Skip restriction if input is unchanged
228       if (state != impl->inputstate[i] || vec == invec) {
229         ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
230         CeedChk(ierr);
231         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
232         ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
233                                         lmode, vec, impl->evecs[i],
234                                         request); CeedChk(ierr);
235         impl->inputstate[i] = state;
236       }
237       // Get evec
238       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
239                                     (const CeedScalar **) &impl->edata[i]);
240       CeedChk(ierr);
241     }
242   }
243 
244   // Output Evecs
245   for (CeedInt i=0; i<numoutputfields; i++) {
246     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
247                               &impl->edata[i + numinputfields]); CeedChk(ierr);
248   }
249 
250   // Loop through elements
251   for (CeedInt e=0; e<numelements; e++) {
252     // Input basis apply if needed
253     for (CeedInt i=0; i<numinputfields; i++) {
254       // Get elemsize, emode, ncomp
255       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
256       CeedChk(ierr);
257       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
258       CeedChk(ierr);
259       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
260       CeedChk(ierr);
261       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
262       CeedChk(ierr);
263       // Basis action
264       switch(emode) {
265       case CEED_EVAL_NONE:
266         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
267                                   CEED_USE_POINTER,
268                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
269         break;
270       case CEED_EVAL_INTERP:
271         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
272         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
273                                   CEED_USE_POINTER,
274                                   &impl->edata[i][e*elemsize*ncomp]);
275         CeedChk(ierr);
276         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
277                               CEED_EVAL_INTERP, impl->evecsin[i],
278                               impl->qvecsin[i]); CeedChk(ierr);
279         break;
280       case CEED_EVAL_GRAD:
281         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
282         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
283                                   CEED_USE_POINTER,
284                                   &impl->edata[i][e*elemsize*ncomp]);
285         CeedChk(ierr);
286         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
287                               CEED_EVAL_GRAD, impl->evecsin[i],
288                               impl->qvecsin[i]); CeedChk(ierr);
289         break;
290       case CEED_EVAL_WEIGHT:
291         break;  // No action
292       case CEED_EVAL_DIV:
293         break; // Not implimented
294       case CEED_EVAL_CURL:
295         break; // Not implimented
296       }
297     }
298     // Output pointers
299     for (CeedInt i=0; i<numoutputfields; i++) {
300       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
301       CeedChk(ierr);
302       if (emode == CEED_EVAL_NONE) {
303         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
304         CeedChk(ierr);
305         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
306                                   CEED_USE_POINTER,
307                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
308         CeedChk(ierr);
309       }
310     }
311     // Q function
312     ierr = CeedQFunctionApply(qf, Q, impl->qvecsin, impl->qvecsout); CeedChk(ierr);
313 
314     // Output basis apply if needed
315     for (CeedInt i=0; i<numoutputfields; i++) {
316       // Get elemsize, emode, ncomp
317       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
318       CeedChk(ierr);
319       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
320       CeedChk(ierr);
321       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
322       CeedChk(ierr);
323       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
324       CeedChk(ierr);
325       // Basis action
326       switch(emode) {
327       case CEED_EVAL_NONE:
328         break; // No action
329       case CEED_EVAL_INTERP:
330         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
331         CeedChk(ierr);
332         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
333                                   CEED_USE_POINTER,
334                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
335         CeedChk(ierr);
336         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
337                               CEED_EVAL_INTERP, impl->qvecsout[i],
338                               impl->evecsout[i]); CeedChk(ierr);
339         break;
340       case CEED_EVAL_GRAD:
341         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
342         CeedChk(ierr);
343         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
344                                   CEED_USE_POINTER,
345                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
346         CeedChk(ierr);
347         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
348                               CEED_EVAL_GRAD, impl->qvecsout[i],
349                               impl->evecsout[i]); CeedChk(ierr);
350         break;
351       case CEED_EVAL_WEIGHT: {
352         Ceed ceed;
353         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
354         return CeedError(ceed, 1,
355                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
356         break; // Should not occur
357       }
358       case CEED_EVAL_DIV:
359         break; // Not implimented
360       case CEED_EVAL_CURL:
361         break; // Not implimented
362       }
363     }
364   }
365 
366   // Zero lvecs
367   for (CeedInt i=0; i<numoutputfields; i++) {
368     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
369     if (vec == CEED_VECTOR_ACTIVE) {
370       if (!impl->add) {
371         vec = outvec;
372         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
373       }
374     } else {
375       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
376     }
377   }
378   impl->add = false;
379 
380   // Output restriction
381   for (CeedInt i=0; i<numoutputfields; i++) {
382     // Restore evec
383     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
384                                   &impl->edata[i + numinputfields]);
385     CeedChk(ierr);
386     // Get output vector
387     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
388     // Active
389     if (vec == CEED_VECTOR_ACTIVE)
390       vec = outvec;
391     // Restrict
392     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
393     CeedChk(ierr);
394     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
395     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
396                                     lmode, impl->evecs[i+impl->numein], vec,
397                                     request); CeedChk(ierr);
398   }
399 
400   // Restore input arrays
401   for (CeedInt i=0; i<numinputfields; i++) {
402     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
403     CeedChk(ierr);
404     if (emode == CEED_EVAL_WEIGHT) { // Skip
405     } else {
406       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
407                                         (const CeedScalar **) &impl->edata[i]);
408       CeedChk(ierr);
409     }
410   }
411 
412   return 0;
413 }
414 
415 static int CeedCompositeOperatorApply_Ref(CeedOperator op, CeedVector invec,
416     CeedVector outvec,
417     CeedRequest *request) {
418   int ierr;
419   CeedInt numsub;
420   CeedOperator_Ref *impl;
421   CeedOperator *suboperators;
422   ierr = CeedOperatorGetNumSub(op, &numsub); CeedChk(ierr);
423   ierr = CeedOperatorGetSubList(op, &suboperators); CeedChk(ierr);
424 
425   // Overwrite outvec with first output
426   ierr = CeedOperatorApply(suboperators[0], invec, outvec, request);
427   CeedChk(ierr);
428   // Add to outvec with subsequent outputs
429   for (CeedInt i=1; i<numsub; i++) {
430     ierr = CeedOperatorGetData(suboperators[i], (void *)&impl); CeedChk(ierr);
431     impl->add = true;
432     ierr = CeedOperatorApply(suboperators[i], invec, outvec, request);
433     CeedChk(ierr);
434   }
435 
436   return 0;
437 }
438 
439 int CeedOperatorCreate_Ref(CeedOperator op) {
440   int ierr;
441   Ceed ceed;
442   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
443   CeedOperator_Ref *impl;
444 
445   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
446   impl->add = false;
447   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
448 
449   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
450                                 CeedOperatorApply_Ref); CeedChk(ierr);
451   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
452                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
453   return 0;
454 }
455 
456 int CeedCompositeOperatorCreate_Ref(CeedOperator op) {
457   int ierr;
458   Ceed ceed;
459   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
460 
461   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
462                                 CeedCompositeOperatorApply_Ref); CeedChk(ierr);
463   return 0;
464 }
465