xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 5c32accb25e185c12b735e364209d30808de63d6) !
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
32     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
33   }
34   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
35   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
36 
37   ierr = CeedFree(&impl->indata); CeedChk(ierr);
38   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&impl); CeedChk(ierr);
41   return 0;
42 }
43 
44 /*
45   Setup infields or outfields
46  */
47 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
48                                        bool inOrOut,
49                                        CeedVector *evecs, CeedScalar **qdata,
50                                        CeedScalar **qdata_alloc, CeedScalar **indata,
51                                        CeedInt starti, CeedInt startq,
52                                        CeedInt numfields, CeedInt Q) {
53   CeedInt dim, ierr, iq=startq, ncomp;
54   CeedBasis basis;
55   CeedElemRestriction Erestrict;
56   CeedOperatorField *ofields;
57   CeedQFunctionField *qfields;
58   if (inOrOut) {
59     ierr = CeedOperatorGetFields(op, NULL, &ofields);
60     CeedChk(ierr);
61     ierr = CeedQFunctionGetFields(qf, NULL, &qfields);
62     CeedChk(ierr);
63   } else {
64     ierr = CeedOperatorGetFields(op, &ofields, NULL);
65     CeedChk(ierr);
66     ierr = CeedQFunctionGetFields(qf, &qfields, NULL);
67     CeedChk(ierr);
68   }
69 
70   // Loop over fields
71   for (CeedInt i=0; i<numfields; i++) {
72     CeedEvalMode emode;
73     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
74 
75     if (emode != CEED_EVAL_WEIGHT) {
76       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &Erestrict);
77       CeedChk(ierr);
78       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
79                                              &evecs[i+starti]);
80       CeedChk(ierr);
81     }
82 
83     switch(emode) {
84     case CEED_EVAL_NONE:
85       break; // No action
86     case CEED_EVAL_INTERP:
87       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
88       CeedChk(ierr);
89       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
90       qdata[i + starti] = qdata_alloc[iq];
91       iq++;
92       break;
93     case CEED_EVAL_GRAD:
94       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
95       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
96       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
97       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
98       qdata[i + starti] = qdata_alloc[iq];
99       iq++;
100       break;
101     case CEED_EVAL_WEIGHT: // Only on input fields
102       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
103       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
104       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
105                             NULL, qdata_alloc[iq]); CeedChk(ierr);
106       qdata[i] = qdata_alloc[iq];
107       indata[i] = qdata[i];
108       iq++;
109       break;
110     case CEED_EVAL_DIV:
111       break; // Not implimented
112     case CEED_EVAL_CURL:
113       break; // Not implimented
114     }
115   }
116   return 0;
117 }
118 
119 /*
120   CeedOperator needs to connect all the named fields (be they active or passive)
121   to the named inputs and outputs of its CeedQFunction.
122  */
123 static int CeedOperatorSetup_Ref(CeedOperator op) {
124   int ierr;
125   bool setupdone;
126   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
127   if (setupdone) return 0;
128   CeedOperator_Ref *impl;
129   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
130   CeedQFunction qf;
131   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
132   CeedInt Q, numinputfields, numoutputfields;
133   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
134   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
135   CeedChk(ierr);
136   CeedOperatorField *opinputfields, *opoutputfields;
137   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
138   CeedChk(ierr);
139   CeedQFunctionField *qfinputfields, *qfoutputfields;
140   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
141   CeedChk(ierr);
142   CeedEvalMode emode;
143 
144   // Count infield and outfield array sizes and evectors
145   impl->numein = numinputfields;
146   for (CeedInt i=0; i<numinputfields; i++) {
147     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
148     CeedChk(ierr);
149     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
150                     !!(emode & CEED_EVAL_WEIGHT);
151   }
152   impl->numeout = numoutputfields;
153   for (CeedInt i=0; i<numoutputfields; i++) {
154     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
155     CeedChk(ierr);
156     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
157   }
158 
159   // Allocate
160   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs); CeedChk(ierr);
161   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
162   CeedChk(ierr);
163 
164   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
165   CeedChk(ierr);
166   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
167   CeedChk(ierr);
168 
169   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
170   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
171 
172   // Set up infield and outfield pointer arrays
173   // Infields
174   ierr = CeedOperatorSetupFields_Ref(qf, op, 0,
175                                      impl->evecs, impl->qdata, impl->qdata_alloc,
176                                      impl->indata, 0, 0,
177                                      numinputfields, Q); CeedChk(ierr);
178 
179   // Outfields
180   ierr = CeedOperatorSetupFields_Ref(qf, op, 1,
181                                      impl->evecs, impl->qdata, impl->qdata_alloc,
182                                      impl->indata, numinputfields,
183                                      impl->numqin, numoutputfields, Q);
184   CeedChk(ierr);
185 
186   // Input Qvecs
187   for (CeedInt i=0; i<numinputfields; i++) {
188     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
189     CeedChk(ierr);
190     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
191       impl->indata[i] =  impl->qdata[i];
192   }
193   // Output Qvecs
194   for (CeedInt i=0; i<numoutputfields; i++) {
195     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
196     CeedChk(ierr);
197     if (emode != CEED_EVAL_NONE)
198       impl->outdata[i] =  impl->qdata[i + numinputfields];
199   }
200 
201   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
202 
203   return 0;
204 }
205 
206 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
207                                  CeedVector outvec, CeedRequest *request) {
208   int ierr;
209   CeedOperator_Ref *impl;
210   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
211   CeedQFunction qf;
212   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
213   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
214   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
215   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
216   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
217   CeedChk(ierr);
218   CeedTransposeMode lmode;
219   CeedOperatorField *opinputfields, *opoutputfields;
220   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
221   CeedChk(ierr);
222   CeedQFunctionField *qfinputfields, *qfoutputfields;
223   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
224   CeedChk(ierr);
225   CeedEvalMode emode;
226   CeedVector vec;
227   CeedBasis basis;
228   CeedElemRestriction Erestrict;
229 
230   // Setup
231   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
232 
233   // Input Evecs and Restriction
234   for (CeedInt i=0; i<numinputfields; i++) {
235     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
236     CeedChk(ierr);
237     if (emode == CEED_EVAL_WEIGHT) { // Skip
238     } else {
239       // Get input vector
240       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
241       if (vec == CEED_VECTOR_ACTIVE)
242         vec = invec;
243       // Restrict
244       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
245       CeedChk(ierr);
246       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
247       ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
248                                       lmode, vec, impl->evecs[i],
249                                       request); CeedChk(ierr);
250       // Get evec
251       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
252                                     (const CeedScalar **) &impl->edata[i]);
253       CeedChk(ierr);
254     }
255   }
256 
257   // Output Evecs
258   for (CeedInt i=0; i<numoutputfields; i++) {
259     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
260                               &impl->edata[i + numinputfields]); CeedChk(ierr);
261   }
262 
263   // Loop through elements
264   for (CeedInt e=0; e<numelements; e++) {
265     // Input basis apply if needed
266     for (CeedInt i=0; i<numinputfields; i++) {
267       // Get elemsize, emode, ncomp
268       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
269       CeedChk(ierr);
270       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
271       CeedChk(ierr);
272       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
273       CeedChk(ierr);
274       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
275       CeedChk(ierr);
276       // Basis action
277       switch(emode) {
278       case CEED_EVAL_NONE:
279         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
280         break;
281       case CEED_EVAL_INTERP:
282         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
283         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
284                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
285                               impl->qdata[i]); CeedChk(ierr);
286         break;
287       case CEED_EVAL_GRAD:
288         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
289         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
290                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
291                               impl->qdata[i]); CeedChk(ierr);
292         break;
293       case CEED_EVAL_WEIGHT:
294         break;  // No action
295       case CEED_EVAL_DIV:
296         break; // Not implimented
297       case CEED_EVAL_CURL:
298         break; // Not implimented
299       }
300     }
301     // Output pointers
302     for (CeedInt i=0; i<numoutputfields; i++) {
303       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
304       CeedChk(ierr);
305       if (emode == CEED_EVAL_NONE) {
306         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
307         CeedChk(ierr);
308         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
309       }
310     }
311     // Q function
312     ierr = CeedQFunctionApply(qf, Q, (const CeedScalar * const*) impl->indata,
313                               impl->outdata); CeedChk(ierr);
314 
315     // Output basis apply if needed
316     for (CeedInt i=0; i<numoutputfields; i++) {
317       // Get elemsize, emode, ncomp
318       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
319       CeedChk(ierr);
320       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
321       CeedChk(ierr);
322       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
323       CeedChk(ierr);
324       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
325       CeedChk(ierr);
326       // Basis action
327       switch(emode) {
328       case CEED_EVAL_NONE:
329         break; // No action
330       case CEED_EVAL_INTERP:
331         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
332         CeedChk(ierr);
333         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
334                               CEED_EVAL_INTERP, impl->outdata[i],
335                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
336         CeedChk(ierr);
337         break;
338       case CEED_EVAL_GRAD:
339         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
340         CeedChk(ierr);
341         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
342                               CEED_EVAL_GRAD, impl->outdata[i],
343                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
344         CeedChk(ierr);
345         break;
346       case CEED_EVAL_WEIGHT: {
347         Ceed ceed;
348         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
349         return CeedError(ceed, 1,
350                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
351         break; // Should not occur
352       }
353       case CEED_EVAL_DIV:
354         break; // Not implimented
355       case CEED_EVAL_CURL:
356         break; // Not implimented
357       }
358     }
359   }
360 
361   // Zero lvecs
362   for (CeedInt i=0; i<numoutputfields; i++) {
363     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
364     if (vec == CEED_VECTOR_ACTIVE)
365       vec = outvec;
366     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
367   }
368 
369   // Output restriction
370   for (CeedInt i=0; i<numoutputfields; i++) {
371     // Restore evec
372     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
373                                   &impl->edata[i + numinputfields]);
374     CeedChk(ierr);
375     // Get output vector
376     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
377     // Active
378     if (vec == CEED_VECTOR_ACTIVE)
379       vec = outvec;
380     // Restrict
381     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
382     CeedChk(ierr);
383     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
384     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
385                                     lmode, impl->evecs[i+impl->numein], vec,
386                                     request); CeedChk(ierr);
387   }
388 
389   // Restore input arrays
390   for (CeedInt i=0; i<numinputfields; i++) {
391     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
392     CeedChk(ierr);
393     if (emode == CEED_EVAL_WEIGHT) { // Skip
394     } else {
395       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
396                                         (const CeedScalar **) &impl->edata[i]);
397       CeedChk(ierr);
398     }
399   }
400 
401   return 0;
402 }
403 
404 int CeedOperatorCreate_Ref(CeedOperator op) {
405   int ierr;
406   Ceed ceed;
407   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
408   CeedOperator_Ref *impl;
409 
410   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
411   ierr = CeedOperatorSetData(op, (void*)&impl);
412 
413   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
414                                 CeedOperatorApply_Ref); CeedChk(ierr);
415   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
416                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
417   return 0;
418 }
419