xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 9832694600dc0a18633050ae506a587ff6bc6365) !
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
32     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
33   }
34   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
35   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
36 
37   ierr = CeedFree(&impl->indata); CeedChk(ierr);
38   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&op->data); CeedChk(ierr);
41   return 0;
42 }
43 
44 /*
45   Setup infields or outfields
46  */
47 static int CeedOperatorSetupFields_Ref(CeedQFunctionField qfields[16],
48                                        CeedOperatorField ofields[16],
49                                        CeedVector *evecs, CeedScalar **qdata,
50                                        CeedScalar **qdata_alloc, CeedScalar **indata,
51                                        CeedInt starti, CeedInt startq,
52                                        CeedInt numfields, CeedInt Q) {
53   CeedInt dim, ierr, iq=startq, ncomp;
54   CeedBasis basis;
55   CeedElemRestriction Erestrict;
56 
57   // Loop over fields
58   for (CeedInt i=0; i<numfields; i++) {
59     CeedEvalMode emode;
60     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
61 
62     if (emode != CEED_EVAL_WEIGHT) {
63       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &Erestrict);
64       CeedChk(ierr);
65       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
66                                              &evecs[i+starti]);
67       CeedChk(ierr);
68     }
69 
70     switch(emode) {
71     case CEED_EVAL_NONE:
72       break; // No action
73     case CEED_EVAL_INTERP:
74       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
75       CeedChk(ierr);
76       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
77       qdata[i + starti] = qdata_alloc[iq];
78       iq++;
79       break;
80     case CEED_EVAL_GRAD:
81       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
82       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
83       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
84       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
85       qdata[i + starti] = qdata_alloc[iq];
86       iq++;
87       break;
88     case CEED_EVAL_WEIGHT: // Only on input fields
89       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
90       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
91       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
92                             NULL, qdata_alloc[iq]); CeedChk(ierr);
93       qdata[i] = qdata_alloc[iq];
94       indata[i] = qdata[i];
95       iq++;
96       break;
97     case CEED_EVAL_DIV:
98       break; // Not implimented
99     case CEED_EVAL_CURL:
100       break; // Not implimented
101     }
102   }
103   return 0;
104 }
105 
106 /*
107   CeedOperator needs to connect all the named fields (be they active or passive)
108   to the named inputs and outputs of its CeedQFunction.
109  */
110 static int CeedOperatorSetup_Ref(CeedOperator op) {
111   int ierr;
112   bool setupdone;
113   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
114   if (setupdone) return 0;
115   CeedOperator_Ref *impl;
116   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
117   CeedQFunction qf;
118   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
119   CeedInt Q, numinputfields, numoutputfields;
120   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
121   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
122   CeedChk(ierr);
123   CeedOperatorField *opinputfields, *opoutputfields;
124   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
125   CeedChk(ierr);
126   CeedQFunctionField *qfinputfields, *qfoutputfields;
127   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
128   CeedChk(ierr);
129   CeedEvalMode emode;
130 
131   // Count infield and outfield array sizes and evectors
132   impl->numein = numinputfields;
133   for (CeedInt i=0; i<numinputfields; i++) {
134     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
135     CeedChk(ierr);
136     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
137                     !!(emode & CEED_EVAL_WEIGHT);
138   }
139   impl->numeout = numoutputfields;
140   for (CeedInt i=0; i<numoutputfields; i++) {
141     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
142     CeedChk(ierr);
143     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
144   }
145 
146   // Allocate
147   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs); CeedChk(ierr);
148   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
149   CeedChk(ierr);
150 
151   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
152   CeedChk(ierr);
153   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
154   CeedChk(ierr);
155 
156   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
157   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
158 
159   // Set up infield and outfield pointer arrays
160   // Infields
161   ierr = CeedOperatorSetupFields_Ref(qfinputfields, opinputfields,
162                                      impl->evecs, impl->qdata, impl->qdata_alloc,
163                                      impl->indata, 0, 0,
164                                      numinputfields, Q); CeedChk(ierr);
165 
166   // Outfields
167   ierr = CeedOperatorSetupFields_Ref(qfoutputfields, opoutputfields,
168                                      impl->evecs, impl->qdata, impl->qdata_alloc,
169                                      impl->indata, numinputfields,
170                                      impl->numqin, numoutputfields, Q);
171   CeedChk(ierr);
172 
173   // Input Qvecs
174   for (CeedInt i=0; i<numinputfields; i++) {
175     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
176     CeedChk(ierr);
177     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
178       impl->indata[i] =  impl->qdata[i];
179   }
180   // Output Qvecs
181   for (CeedInt i=0; i<numoutputfields; i++) {
182     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
183     CeedChk(ierr);
184     if (emode != CEED_EVAL_NONE)
185       impl->outdata[i] =  impl->qdata[i + numinputfields];
186   }
187 
188   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
189 
190   return 0;
191 }
192 
193 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
194                                  CeedVector outvec, CeedRequest *request) {
195   int ierr;
196   CeedOperator_Ref *impl;
197   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
198   CeedQFunction qf;
199   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
200   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
201   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
202   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
203   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
204   CeedChk(ierr);
205   CeedTransposeMode lmode;
206   CeedOperatorField *opinputfields, *opoutputfields;
207   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
208   CeedChk(ierr);
209   CeedQFunctionField *qfinputfields, *qfoutputfields;
210   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
211   CeedChk(ierr);
212   CeedEvalMode emode;
213   CeedVector vec;
214   CeedBasis basis;
215   CeedElemRestriction Erestrict;
216 
217   // Setup
218   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
219 
220   // Input Evecs and Restriction
221   for (CeedInt i=0; i<numinputfields; i++) {
222     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
223     CeedChk(ierr);
224     if (emode == CEED_EVAL_WEIGHT) { // Skip
225     } else {
226       // Get input vector
227       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
228       if (vec == CEED_VECTOR_ACTIVE)
229         vec = invec;
230       // Restrict
231       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
232       CeedChk(ierr);
233       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
234       ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
235                                       lmode, vec, impl->evecs[i],
236                                       request); CeedChk(ierr);
237       // Get evec
238       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
239                                     (const CeedScalar **) &impl->edata[i]);
240       CeedChk(ierr);
241     }
242   }
243 
244   // Output Evecs
245   for (CeedInt i=0; i<numoutputfields; i++) {
246     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
247                               &impl->edata[i + numinputfields]); CeedChk(ierr);
248   }
249 
250   // Loop through elements
251   for (CeedInt e=0; e<numelements; e++) {
252     // Input basis apply if needed
253     for (CeedInt i=0; i<numinputfields; i++) {
254       // Get elemsize, emode, ncomp
255       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
256       CeedChk(ierr);
257       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
258       CeedChk(ierr);
259       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
260       CeedChk(ierr);
261       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
262       CeedChk(ierr);
263       // Basis action
264       switch(emode) {
265       case CEED_EVAL_NONE:
266         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
267         break;
268       case CEED_EVAL_INTERP:
269         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
270         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
271                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
272                               impl->qdata[i]); CeedChk(ierr);
273         break;
274       case CEED_EVAL_GRAD:
275         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
276         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
277                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
278                               impl->qdata[i]); CeedChk(ierr);
279         break;
280       case CEED_EVAL_WEIGHT:
281         break;  // No action
282       case CEED_EVAL_DIV:
283         break; // Not implimented
284       case CEED_EVAL_CURL:
285         break; // Not implimented
286       }
287     }
288     // Output pointers
289     for (CeedInt i=0; i<numoutputfields; i++) {
290       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
291       CeedChk(ierr);
292       if (emode == CEED_EVAL_NONE) {
293         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
294         CeedChk(ierr);
295         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
296       }
297     }
298     // Q function
299     ierr = CeedQFunctionApply(qf, Q, (const CeedScalar * const*) impl->indata,
300                               impl->outdata); CeedChk(ierr);
301 
302     // Output basis apply if needed
303     for (CeedInt i=0; i<numoutputfields; i++) {
304       // Get elemsize, emode, ncomp
305       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
306       CeedChk(ierr);
307       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
308       CeedChk(ierr);
309       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
310       CeedChk(ierr);
311       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
312       CeedChk(ierr);
313       // Basis action
314       switch(emode) {
315       case CEED_EVAL_NONE:
316         break; // No action
317       case CEED_EVAL_INTERP:
318         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
319         CeedChk(ierr);
320         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
321                               CEED_EVAL_INTERP, impl->outdata[i],
322                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
323         CeedChk(ierr);
324         break;
325       case CEED_EVAL_GRAD:
326         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
327         CeedChk(ierr);
328         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
329                               CEED_EVAL_GRAD, impl->outdata[i],
330                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
331         CeedChk(ierr);
332         break;
333       case CEED_EVAL_WEIGHT: {
334         Ceed ceed;
335         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
336         return CeedError(ceed, 1,
337                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
338         break; // Should not occur
339       }
340       case CEED_EVAL_DIV:
341         break; // Not implimented
342       case CEED_EVAL_CURL:
343         break; // Not implimented
344       }
345     }
346   }
347 
348   // Zero lvecs
349   for (CeedInt i=0; i<numoutputfields; i++) {
350     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
351     if (vec == CEED_VECTOR_ACTIVE)
352       vec = outvec;
353     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
354     }
355 
356   // Output restriction
357   for (CeedInt i=0; i<numoutputfields; i++) {
358     // Restore evec
359     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
360                                   &impl->edata[i + numinputfields]);
361     CeedChk(ierr);
362     // Get output vector
363     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
364     // Active
365     if (vec == CEED_VECTOR_ACTIVE)
366       vec = outvec;
367     // Restrict
368     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
369     CeedChk(ierr);
370     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
371     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
372                                     lmode, impl->evecs[i+impl->numein], vec,
373                                     request); CeedChk(ierr);
374   }
375 
376   // Restore input arrays
377   for (CeedInt i=0; i<numinputfields; i++) {
378     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
379     CeedChk(ierr);
380     if (emode == CEED_EVAL_WEIGHT) { // Skip
381     } else {
382       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
383                                         (const CeedScalar **) &impl->edata[i]);
384       CeedChk(ierr);
385     }
386   }
387 
388   return 0;
389 }
390 
391 int CeedOperatorCreate_Ref(CeedOperator op) {
392   int ierr;
393   CeedOperator_Ref *impl;
394 
395   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
396   op->data = impl;
397   op->Destroy = CeedOperatorDestroy_Ref;
398   op->Apply = CeedOperatorApply_Ref;
399   return 0;
400 }
401