xref: /libCEED/backends/ref/ceed-ref-operator.c (revision d1bcdac99090efedac6b43fd437855ade149b9cc)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
32     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
33   }
34   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
35   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
36 
37   ierr = CeedFree(&impl->indata); CeedChk(ierr);
38   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&op->data); CeedChk(ierr);
41   return 0;
42 }
43 
44 /*
45   Setup infields or outfields
46  */
47 static int CeedOperatorSetupFields_Ref(CeedQFunctionField qfields[16],
48                                        CeedOperatorField ofields[16],
49                                        CeedVector *evecs, CeedScalar **qdata,
50                                        CeedScalar **qdata_alloc, CeedScalar **indata,
51                                        CeedInt starti, CeedInt startq,
52                                        CeedInt numfields, CeedInt Q) {
53   CeedInt dim, ierr, iq=startq, ncomp;
54   CeedBasis basis;
55   CeedElemRestriction Erestrict;
56 
57   // Loop over fields
58   for (CeedInt i=0; i<numfields; i++) {
59     CeedEvalMode emode;
60     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
61 
62     if (emode != CEED_EVAL_WEIGHT) {
63       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &Erestrict);
64       CeedChk(ierr);
65       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
66                                              &evecs[i+starti]);
67       CeedChk(ierr);
68     }
69 
70     switch(emode) {
71     case CEED_EVAL_NONE:
72       break; // No action
73     case CEED_EVAL_INTERP:
74       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
75       CeedChk(ierr);
76       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
77       qdata[i + starti] = qdata_alloc[iq];
78       iq++;
79       break;
80     case CEED_EVAL_GRAD:
81       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
82       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
83       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
84       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
85       qdata[i + starti] = qdata_alloc[iq];
86       iq++;
87       break;
88     case CEED_EVAL_WEIGHT: // Only on input fields
89       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
90       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
91       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
92                             NULL, qdata_alloc[iq]); CeedChk(ierr);
93       qdata[i] = qdata_alloc[iq];
94       indata[i] = qdata[i];
95       iq++;
96       break;
97     case CEED_EVAL_DIV:
98       break; // Not implimented
99     case CEED_EVAL_CURL:
100       break; // Not implimented
101     }
102   }
103   return 0;
104 }
105 
106 /*
107   CeedOperator needs to connect all the named fields (be they active or passive)
108   to the named inputs and outputs of its CeedQFunction.
109  */
110 static int CeedOperatorSetup_Ref(CeedOperator op) {
111   int ierr;
112   bool setupdone;
113   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
114   if (setupdone) return 0;
115   CeedOperator_Ref *impl;
116   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
117   CeedQFunction qf;
118   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
119   CeedInt Q, numinputfields, numoutputfields;
120   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
121   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
122   CeedChk(ierr);
123   CeedOperatorField *opinputfields, *opoutputfields;
124   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
125   CeedChk(ierr);
126   CeedQFunctionField *qfinputfields, *qfoutputfields;
127   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
128   CeedChk(ierr);
129   CeedEvalMode emode;
130 
131   // Count infield and outfield array sizes and evectors
132   impl->numein = numinputfields;
133   for (CeedInt i=0; i<numinputfields; i++) {
134     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
135     CeedChk(ierr);
136     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
137                     !!(emode & CEED_EVAL_WEIGHT);
138   }
139   impl->numeout = numoutputfields;
140   for (CeedInt i=0; i<numoutputfields; i++) {
141     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
142     CeedChk(ierr);
143     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
144   }
145 
146   // Allocate
147   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs); CeedChk(ierr);
148   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
149   CeedChk(ierr);
150 
151   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
152   CeedChk(ierr);
153   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
154   CeedChk(ierr);
155 
156   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
157   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
158 
159   // Set up infield and outfield pointer arrays
160   // Infields
161   ierr = CeedOperatorSetupFields_Ref(qfinputfields, opinputfields,
162                                      impl->evecs, impl->qdata, impl->qdata_alloc,
163                                      impl->indata, 0, 0,
164                                      numinputfields, Q); CeedChk(ierr);
165 
166   // Outfields
167   ierr = CeedOperatorSetupFields_Ref(qfoutputfields, opoutputfields,
168                                      impl->evecs, impl->qdata, impl->qdata_alloc,
169                                      impl->indata, numinputfields,
170                                      impl->numqin, numoutputfields, Q);
171   CeedChk(ierr);
172 
173   // Input Qvecs
174   for (CeedInt i=0; i<numinputfields; i++) {
175     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
176     CeedChk(ierr);
177     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
178       impl->indata[i] =  impl->qdata[i];
179   }
180   // Output Qvecs
181   for (CeedInt i=0; i<numoutputfields; i++) {
182     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
183     CeedChk(ierr);
184     if (emode != CEED_EVAL_NONE)
185       impl->outdata[i] =  impl->qdata[i + numinputfields];
186   }
187 
188   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
189 
190   return 0;
191 }
192 
193 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
194                                  CeedVector outvec, CeedRequest *request) {
195   int ierr;
196   CeedOperator_Ref *impl;
197   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
198   CeedQFunction qf;
199   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
200   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
201   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
202   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
203   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
204   CeedChk(ierr);
205   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
206   CeedOperatorField *opinputfields, *opoutputfields;
207   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
208   CeedChk(ierr);
209   CeedQFunctionField *qfinputfields, *qfoutputfields;
210   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
211   CeedChk(ierr);
212   CeedEvalMode emode;
213   CeedVector vec;
214   CeedBasis basis;
215   CeedElemRestriction Erestrict;
216 
217   // Setup
218   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
219 
220   // Input Evecs and Restriction
221   for (CeedInt i=0; i<numinputfields; i++) {
222     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
223     CeedChk(ierr);
224     if (emode == CEED_EVAL_WEIGHT) { // Skip
225     } else {
226       // Get input vector
227       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
228       if (vec == CEED_VECTOR_ACTIVE)
229         vec = invec;
230       // Restrict
231       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
232       CeedChk(ierr);
233       ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
234                                       lmode, vec, impl->evecs[i],
235                                       request); CeedChk(ierr);
236       // Get evec
237       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
238                                     (const CeedScalar **) &impl->edata[i]);
239       CeedChk(ierr);
240     }
241   }
242 
243   // Output Evecs
244   for (CeedInt i=0; i<numoutputfields; i++) {
245     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
246                               &impl->edata[i + numinputfields]); CeedChk(ierr);
247   }
248 
249   // Loop through elements
250   for (CeedInt e=0; e<numelements; e++) {
251     // Input basis apply if needed
252     for (CeedInt i=0; i<numinputfields; i++) {
253       // Get elemsize, emode, ncomp
254       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
255       CeedChk(ierr);
256       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
257       CeedChk(ierr);
258       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
259       CeedChk(ierr);
260       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
261       CeedChk(ierr);
262       // Basis action
263       switch(emode) {
264       case CEED_EVAL_NONE:
265         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
266         break;
267       case CEED_EVAL_INTERP:
268         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
269         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
270                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
271                               impl->qdata[i]); CeedChk(ierr);
272         break;
273       case CEED_EVAL_GRAD:
274         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
275         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
276                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
277                               impl->qdata[i]); CeedChk(ierr);
278         break;
279       case CEED_EVAL_WEIGHT:
280         break;  // No action
281       case CEED_EVAL_DIV:
282         break; // Not implimented
283       case CEED_EVAL_CURL:
284         break; // Not implimented
285       }
286     }
287     // Output pointers
288     for (CeedInt i=0; i<numoutputfields; i++) {
289       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
290       CeedChk(ierr);
291       if (emode == CEED_EVAL_NONE) {
292         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
293         CeedChk(ierr);
294         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
295       }
296     }
297     // Q function
298     ierr = CeedQFunctionApply(qf, Q, (const CeedScalar * const*) impl->indata,
299                               impl->outdata); CeedChk(ierr);
300 
301     // Output basis apply if needed
302     for (CeedInt i=0; i<numoutputfields; i++) {
303       // Get elemsize, emode, ncomp
304       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
305       CeedChk(ierr);
306       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
307       CeedChk(ierr);
308       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
309       CeedChk(ierr);
310       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
311       CeedChk(ierr);
312       // Basis action
313       switch(emode) {
314       case CEED_EVAL_NONE:
315         break; // No action
316       case CEED_EVAL_INTERP:
317         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
318         CeedChk(ierr);
319         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
320                               CEED_EVAL_INTERP, impl->outdata[i],
321                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
322         CeedChk(ierr);
323         break;
324       case CEED_EVAL_GRAD:
325         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
326         CeedChk(ierr);
327         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
328                               CEED_EVAL_GRAD, impl->outdata[i],
329                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
330         CeedChk(ierr);
331         break;
332       case CEED_EVAL_WEIGHT: {
333         Ceed ceed;
334         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
335         return CeedError(ceed, 1,
336                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
337         break; // Should not occur
338       }
339       case CEED_EVAL_DIV:
340         break; // Not implimented
341       case CEED_EVAL_CURL:
342         break; // Not implimented
343       }
344     }
345   }
346 
347   // Zero lvecs
348   for (CeedInt i=0; i<numoutputfields; i++) {
349     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
350     if (vec == CEED_VECTOR_ACTIVE)
351       vec = outvec;
352     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
353     }
354 
355   // Output restriction
356   for (CeedInt i=0; i<numoutputfields; i++) {
357     // Restore evec
358     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
359                                   &impl->edata[i + numinputfields]);
360     CeedChk(ierr);
361     // Get output vector
362     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
363     // Active
364     if (vec == CEED_VECTOR_ACTIVE)
365       vec = outvec;
366     // Restrict
367     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
368     CeedChk(ierr);
369     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
370                                     lmode, impl->evecs[i+impl->numein], vec,
371                                     request); CeedChk(ierr);
372   }
373 
374   // Restore input arrays
375   for (CeedInt i=0; i<numinputfields; i++) {
376     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
377     CeedChk(ierr);
378     if (emode == CEED_EVAL_WEIGHT) { // Skip
379     } else {
380       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
381                                         (const CeedScalar **) &impl->edata[i]);
382       CeedChk(ierr);
383     }
384   }
385 
386   return 0;
387 }
388 
389 int CeedOperatorCreate_Ref(CeedOperator op) {
390   int ierr;
391   CeedOperator_Ref *impl;
392 
393   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
394   op->data = impl;
395   op->Destroy = CeedOperatorDestroy_Ref;
396   op->Apply = CeedOperatorApply_Ref;
397   return 0;
398 }
399