xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 418fb8c26cd03fc44256773f44bb9ece8ec63e5f)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Ref(CeedOperator op) {
21   int ierr;
22   CeedOperator_Ref *impl;
23   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numein; i++) {
32     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
33     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
34   }
35   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
36   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
37 
38   for (CeedInt i=0; i<impl->numeout; i++) {
39     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
40     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
41   }
42   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
43   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
44 
45   ierr = CeedFree(&impl); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
53                                        bool inOrOut,
54                                        CeedVector *fullevecs, CeedVector *evecs,
55                                        CeedVector *qvecs, CeedInt starte,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, ncomp;
58   Ceed ceed;
59   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
60   CeedBasis basis;
61   CeedElemRestriction Erestrict;
62   CeedOperatorField *opfields;
63   CeedQFunctionField *qffields;
64   if (inOrOut) {
65     ierr = CeedOperatorGetFields(op, NULL, &opfields);
66     CeedChk(ierr);
67     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
68     CeedChk(ierr);
69   } else {
70     ierr = CeedOperatorGetFields(op, &opfields, NULL);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
73     CeedChk(ierr);
74   }
75 
76   // Loop over fields
77   for (CeedInt i=0; i<numfields; i++) {
78     CeedEvalMode emode;
79     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
80 
81     if (emode != CEED_EVAL_WEIGHT) {
82       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
83       CeedChk(ierr);
84       ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
85                                              &fullevecs[i+starte]);
86       CeedChk(ierr);
87     }
88 
89     switch(emode) {
90     case CEED_EVAL_NONE:
91       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
92       CeedChk(ierr);
93       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
94       break;
95     case CEED_EVAL_INTERP:
96       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
97       CeedChk(ierr);
98       ierr = CeedVectorCreate(ceed, Q*ncomp, &evecs[i]); CeedChk(ierr);
99       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
100       break;
101     case CEED_EVAL_GRAD:
102       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
103       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
104       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
105       ierr = CeedVectorCreate(ceed, Q*ncomp, &evecs[i]); CeedChk(ierr);
106       ierr = CeedVectorCreate(ceed, Q*ncomp*dim, &qvecs[i]); CeedChk(ierr);
107       break;
108     case CEED_EVAL_WEIGHT: // Only on input fields
109       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
110       ierr = CeedVectorCreate(ceed, Q, &qvecs[i]); CeedChk(ierr);
111       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
112                             NULL, qvecs[i]); CeedChk(ierr);
113       break;
114     case CEED_EVAL_DIV:
115       break; // Not implimented
116     case CEED_EVAL_CURL:
117       break; // Not implimented
118     }
119   }
120   return 0;
121 }
122 
123 /*
124   CeedOperator needs to connect all the named fields (be they active or passive)
125   to the named inputs and outputs of its CeedQFunction.
126  */
127 static int CeedOperatorSetup_Ref(CeedOperator op) {
128   int ierr;
129   bool setupdone;
130   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
131   if (setupdone) return 0;
132   Ceed ceed;
133   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
134   CeedOperator_Ref *impl;
135   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
136   CeedQFunction qf;
137   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
138   CeedInt Q, numinputfields, numoutputfields;
139   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
140   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
141   CeedChk(ierr);
142   CeedOperatorField *opinputfields, *opoutputfields;
143   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
144   CeedChk(ierr);
145   CeedQFunctionField *qfinputfields, *qfoutputfields;
146   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
147   CeedChk(ierr);
148 
149   // Allocate
150   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
151   CeedChk(ierr);
152   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
153   CeedChk(ierr);
154 
155   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
156   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
157   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
158   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
159 
160   impl->numein = numinputfields; impl->numeout = numoutputfields;
161 
162   // Set up infield and outfield evecs and qvecs
163   // Infields
164   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->evecs,
165                                      impl->evecsin, impl->qvecsin, 0,
166                                      numinputfields, Q);
167   CeedChk(ierr);
168 
169   // Outfields
170   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->evecs,
171                                      impl->evecsout, impl->qvecsout,
172                                      numinputfields, numoutputfields, Q);
173   CeedChk(ierr);
174 
175   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
176 
177   return 0;
178 }
179 
180 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
181                                  CeedVector outvec, CeedRequest *request) {
182   int ierr;
183   CeedOperator_Ref *impl;
184   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
185   CeedQFunction qf;
186   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
187   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, ncomp;
188   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
189   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
190   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
191   CeedChk(ierr);
192   CeedTransposeMode lmode;
193   CeedOperatorField *opinputfields, *opoutputfields;
194   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
195   CeedChk(ierr);
196   CeedQFunctionField *qfinputfields, *qfoutputfields;
197   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
198   CeedChk(ierr);
199   CeedEvalMode emode;
200   CeedVector vec;
201   CeedBasis basis;
202   CeedElemRestriction Erestrict;
203 
204   // Setup
205   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
206 
207   // Input Evecs and Restriction
208   for (CeedInt i=0; i<numinputfields; i++) {
209     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
210     CeedChk(ierr);
211     if (emode == CEED_EVAL_WEIGHT) { // Skip
212     } else {
213       // Get input vector
214       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
215       if (vec == CEED_VECTOR_ACTIVE)
216         vec = invec;
217       // Restrict
218       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
219       CeedChk(ierr);
220       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
221       ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE,
222                                       lmode, vec, impl->evecs[i],
223                                       request); CeedChk(ierr);
224       // Get evec
225       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
226                                     (const CeedScalar **) &impl->edata[i]);
227       CeedChk(ierr);
228     }
229   }
230 
231   // Output Evecs
232   for (CeedInt i=0; i<numoutputfields; i++) {
233     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
234                               &impl->edata[i + numinputfields]); CeedChk(ierr);
235   }
236 
237   // Loop through elements
238   for (CeedInt e=0; e<numelements; e++) {
239     // Input basis apply if needed
240     for (CeedInt i=0; i<numinputfields; i++) {
241       // Get elemsize, emode, ncomp
242       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
243       CeedChk(ierr);
244       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
245       CeedChk(ierr);
246       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
247       CeedChk(ierr);
248       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
249       CeedChk(ierr);
250       // Basis action
251       switch(emode) {
252       case CEED_EVAL_NONE:
253         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
254                                   CEED_USE_POINTER,
255                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
256         break;
257       case CEED_EVAL_INTERP:
258         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
259         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
260                                   CEED_USE_POINTER,
261                                   &impl->edata[i][e*elemsize*ncomp]);
262         CeedChk(ierr);
263         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
264                               CEED_EVAL_INTERP, impl->evecsin[i],
265                               impl->qvecsin[i]); CeedChk(ierr);
266         break;
267       case CEED_EVAL_GRAD:
268         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
269         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
270                                   CEED_USE_POINTER,
271                                   &impl->edata[i][e*elemsize*ncomp]);
272         CeedChk(ierr);
273         ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
274                               CEED_EVAL_GRAD, impl->evecsin[i],
275                               impl->qvecsin[i]); CeedChk(ierr);
276         break;
277       case CEED_EVAL_WEIGHT:
278         break;  // No action
279       case CEED_EVAL_DIV:
280         break; // Not implimented
281       case CEED_EVAL_CURL:
282         break; // Not implimented
283       }
284     }
285     // Output pointers
286     for (CeedInt i=0; i<numoutputfields; i++) {
287       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
288       CeedChk(ierr);
289       if (emode == CEED_EVAL_NONE) {
290         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
291         CeedChk(ierr);
292         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
293                                   CEED_USE_POINTER,
294                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
295         CeedChk(ierr);
296       }
297     }
298     // Q function
299     ierr = CeedQFunctionApply(qf, Q, impl->qvecsin, impl->qvecsout); CeedChk(ierr);
300 
301     // Output basis apply if needed
302     for (CeedInt i=0; i<numoutputfields; i++) {
303       // Get elemsize, emode, ncomp
304       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
305       CeedChk(ierr);
306       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
307       CeedChk(ierr);
308       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
309       CeedChk(ierr);
310       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
311       CeedChk(ierr);
312       // Basis action
313       switch(emode) {
314       case CEED_EVAL_NONE:
315         break; // No action
316       case CEED_EVAL_INTERP:
317         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
318         CeedChk(ierr);
319         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
320                                   CEED_USE_POINTER,
321                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
322         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
323                               CEED_EVAL_INTERP, impl->qvecsout[i],
324                               impl->evecsout[i]); CeedChk(ierr);
325         break;
326       case CEED_EVAL_GRAD:
327         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
328         CeedChk(ierr);
329         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
330                                   CEED_USE_POINTER,
331                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
332         ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
333                               CEED_EVAL_GRAD, impl->qvecsout[i],
334                               impl->evecsout[i]); CeedChk(ierr);
335         break;
336       case CEED_EVAL_WEIGHT: {
337         Ceed ceed;
338         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
339         return CeedError(ceed, 1,
340                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
341         break; // Should not occur
342       }
343       case CEED_EVAL_DIV:
344         break; // Not implimented
345       case CEED_EVAL_CURL:
346         break; // Not implimented
347       }
348     }
349   }
350 
351   // Zero lvecs
352   for (CeedInt i=0; i<numoutputfields; i++) {
353     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
354     if (vec == CEED_VECTOR_ACTIVE)
355       vec = outvec;
356     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
357   }
358 
359   // Output restriction
360   for (CeedInt i=0; i<numoutputfields; i++) {
361     // Restore evec
362     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
363                                   &impl->edata[i + numinputfields]);
364     CeedChk(ierr);
365     // Get output vector
366     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
367     // Active
368     if (vec == CEED_VECTOR_ACTIVE)
369       vec = outvec;
370     // Restrict
371     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
372     CeedChk(ierr);
373     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
374     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
375                                     lmode, impl->evecs[i+impl->numein], vec,
376                                     request); CeedChk(ierr);
377   }
378 
379   // Restore input arrays
380   for (CeedInt i=0; i<numinputfields; i++) {
381     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
382     CeedChk(ierr);
383     if (emode == CEED_EVAL_WEIGHT) { // Skip
384     } else {
385       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
386                                         (const CeedScalar **) &impl->edata[i]);
387       CeedChk(ierr);
388     }
389   }
390 
391   return 0;
392 }
393 
394 int CeedOperatorCreate_Ref(CeedOperator op) {
395   int ierr;
396   Ceed ceed;
397   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
398   CeedOperator_Ref *impl;
399 
400   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
401   ierr = CeedOperatorSetData(op, (void*)&impl);
402 
403   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
404                                 CeedOperatorApply_Ref); CeedChk(ierr);
405   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
406                                 CeedOperatorDestroy_Ref); CeedChk(ierr);
407   return 0;
408 }
409