xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 4ce2993fee6e7bc9b86526ee90098d0dc489fc60)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Ref(CeedOperator op) {
22   int ierr;
23   CeedOperator_Ref *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
28   }
29   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
30   ierr = CeedFree(&impl->edata); CeedChk(ierr);
31 
32   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
33     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
34   }
35   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
36   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
37 
38   ierr = CeedFree(&impl->indata); CeedChk(ierr);
39   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
40 
41   ierr = CeedFree(&op->data); CeedChk(ierr);
42   return 0;
43 }
44 
45 /*
46   Setup infields or outfields
47  */
48 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
49                                        struct CeedOperatorField ofields[16],
50                                        CeedVector *evecs, CeedScalar **qdata,
51                                        CeedScalar **qdata_alloc, CeedScalar **indata,
52                                        CeedInt starti, CeedInt startq,
53                                        CeedInt numfields, CeedInt Q) {
54   CeedInt dim, ierr, iq=startq, ncomp;
55 
56   // Loop over fields
57   for (CeedInt i=0; i<numfields; i++) {
58     CeedEvalMode emode = qfields[i].emode;
59 
60     if (emode != CEED_EVAL_WEIGHT) {
61       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL,
62                                              &evecs[i+starti]);
63       CeedChk(ierr);
64     }
65 
66     switch(emode) {
67     case CEED_EVAL_NONE:
68       break; // No action
69     case CEED_EVAL_INTERP:
70       ncomp = qfields[i].ncomp;
71       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
72       qdata[i + starti] = qdata_alloc[iq];
73       iq++;
74       break;
75     case CEED_EVAL_GRAD:
76       ncomp = qfields[i].ncomp;
77       ierr = CeedBasisGetDimension(ofields[i].basis, &dim); CeedChk(ierr);
78       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
79       qdata[i + starti] = qdata_alloc[iq];
80       iq++;
81       break;
82     case CEED_EVAL_WEIGHT: // Only on input fields
83       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
84       ierr = CeedBasisApply(ofields[iq].basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
85                             NULL, qdata_alloc[iq]); CeedChk(ierr);
86       qdata[i] = qdata_alloc[iq];
87       indata[i] = qdata[i];
88       iq++;
89       break;
90     case CEED_EVAL_DIV:
91       break; // Not implimented
92     case CEED_EVAL_CURL:
93       break; // Not implimented
94     }
95   }
96   return 0;
97 }
98 
99 /*
100   CeedOperator needs to connect all the named fields (be they active or passive)
101   to the named inputs and outputs of its CeedQFunction.
102  */
103 static int CeedOperatorSetup_Ref(CeedOperator op) {
104   int ierr;
105   bool setupdone;
106   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
107   if (setupdone) return 0;
108   CeedOperator_Ref *impl;
109   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
110   CeedQFunction qf;
111   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
112   CeedInt Q, numinputfields, numoutputfields;
113   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
114   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
115   CeedChk(ierr);
116 
117   // Count infield and outfield array sizes and evectors
118   impl->numein = numinputfields;
119   for (CeedInt i=0; i<numinputfields; i++) {
120     CeedEvalMode emode = qf->inputfields[i].emode;
121     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
122                     !!(emode & CEED_EVAL_WEIGHT);
123   }
124   impl->numeout = numoutputfields;
125   for (CeedInt i=0; i<numoutputfields; i++) {
126     CeedEvalMode emode = qf->outputfields[i].emode;
127     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
128   }
129 
130   // Allocate
131   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs); CeedChk(ierr);
132   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
133   CeedChk(ierr);
134 
135   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
136   CeedChk(ierr);
137   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
138   CeedChk(ierr);
139 
140   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
141   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
142 
143   // Set up infield and outfield pointer arrays
144   // Infields
145   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
146                                      impl->evecs, impl->qdata, impl->qdata_alloc,
147                                      impl->indata, 0, 0,
148                                      numinputfields, Q); CeedChk(ierr);
149 
150   // Outfields
151   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
152                                      impl->evecs, impl->qdata, impl->qdata_alloc,
153                                      impl->indata, numinputfields,
154                                      impl->numqin, numoutputfields, Q); CeedChk(ierr);
155 
156   // Input Qvecs
157   for (CeedInt i=0; i<numinputfields; i++) {
158     CeedEvalMode emode = qf->inputfields[i].emode;
159     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
160       impl->indata[i] =  impl->qdata[i];
161   }
162   // Output Qvecs
163   for (CeedInt i=0; i<numoutputfields; i++) {
164     CeedEvalMode emode = qf->outputfields[i].emode;
165     if (emode != CEED_EVAL_NONE)
166       impl->outdata[i] =  impl->qdata[i + numinputfields];
167   }
168 
169   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
170 
171   return 0;
172 }
173 
174 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
175                                  CeedVector outvec, CeedRequest *request) {
176   int ierr;
177   CeedOperator_Ref *impl;
178   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
179   CeedQFunction qf;
180   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
181   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields;
182   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
183   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
184   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
185   CeedChk(ierr);
186   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
187 
188   // Setup
189   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
190 
191   // Input Evecs and Restriction
192   for (CeedInt i=0; i<numinputfields; i++) {
193     CeedEvalMode emode = qf->inputfields[i].emode;
194     if (emode == CEED_EVAL_WEIGHT) { // Skip
195     } else {
196       // Active
197       if (op->inputfields[i].vec == CEED_VECTOR_ACTIVE) {
198         // Restrict
199         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
200                                         lmode, invec, impl->evecs[i],
201                                         request); CeedChk(ierr);
202         // Get evec
203         ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
204                                       (const CeedScalar **) &impl->edata[i]); CeedChk(ierr);
205       } else {
206         // Passive
207         // Restrict
208         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
209                                         lmode, op->inputfields[i].vec, impl->evecs[i],
210                                         request); CeedChk(ierr);
211         // Get evec
212         ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
213                                       (const CeedScalar **) &impl->edata[i]); CeedChk(ierr);
214       }
215     }
216   }
217 
218   // Output Evecs
219   for (CeedInt i=0; i<numoutputfields; i++) {
220     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
221                               &impl->edata[i + numinputfields]); CeedChk(ierr);
222   }
223 
224   // Loop through elements
225   for (CeedInt e=0; e<numelements; e++) {
226     // Input basis apply if needed
227     for (CeedInt i=0; i<numinputfields; i++) {
228       // Get elemsize, emode, ncomp
229       ierr = CeedElemRestrictionGetElementSize(
230                op->inputfields[i].Erestrict, &elemsize); CeedChk(ierr);
231       CeedEvalMode emode = qf->inputfields[i].emode;
232       CeedInt ncomp = qf->inputfields[i].ncomp;
233       // Basis action
234       switch(emode) {
235       case CEED_EVAL_NONE:
236         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
237         break;
238       case CEED_EVAL_INTERP:
239         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
240                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp], impl->qdata[i]);
241         CeedChk(ierr);
242         break;
243       case CEED_EVAL_GRAD:
244         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
245                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp], impl->qdata[i]);
246         CeedChk(ierr);
247         break;
248       case CEED_EVAL_WEIGHT:
249         break;  // No action
250       case CEED_EVAL_DIV:
251         break; // Not implimented
252       case CEED_EVAL_CURL:
253         break; // Not implimented
254       }
255     }
256     // Output pointers
257     for (CeedInt i=0; i<numoutputfields; i++) {
258       CeedEvalMode emode = qf->outputfields[i].emode;
259       if (emode == CEED_EVAL_NONE) {
260         CeedInt ncomp = qf->outputfields[i].ncomp;
261         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
262       }
263     }
264     // Q function
265     ierr = CeedQFunctionApply(qf, Q, (const CeedScalar * const*) impl->indata,
266                               impl->outdata); CeedChk(ierr);
267 
268     // Output basis apply if needed
269     for (CeedInt i=0; i<numoutputfields; i++) {
270       // Get elemsize, emode, ncomp
271       ierr = CeedElemRestrictionGetElementSize(
272                op->outputfields[i].Erestrict, &elemsize); CeedChk(ierr);
273       CeedInt ncomp = qf->outputfields[i].ncomp;
274       CeedEvalMode emode = qf->outputfields[i].emode;
275       // Basis action
276       switch(emode) {
277       case CEED_EVAL_NONE:
278         break; // No action
279       case CEED_EVAL_INTERP:
280         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
281                               CEED_EVAL_INTERP, impl->outdata[i],
282                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
283         CeedChk(ierr);
284         break;
285       case CEED_EVAL_GRAD:
286         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
287                               CEED_EVAL_GRAD,
288                               impl->outdata[i], &impl->edata[i + numinputfields][e*elemsize*ncomp]);
289         CeedChk(ierr);
290         break;
291       case CEED_EVAL_WEIGHT: {
292         Ceed ceed;
293         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
294         return CeedError(ceed, 1,
295                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
296         break; // Should not occur
297       }
298       case CEED_EVAL_DIV:
299         break; // Not implimented
300       case CEED_EVAL_CURL:
301         break; // Not implimented
302       }
303     }
304   }
305 
306   // Zero lvecs
307   ierr = CeedVectorSetValue(outvec, 0.0); CeedChk(ierr);
308   for (CeedInt i=0; i<numoutputfields; i++)
309     if (op->outputfields[i].vec != CEED_VECTOR_ACTIVE) {
310       ierr = CeedVectorSetValue(op->outputfields[i].vec, 0.0); CeedChk(ierr);
311     }
312 
313   // Output restriction
314   for (CeedInt i=0; i<numoutputfields; i++) {
315     // Restore evec
316     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
317                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
318     // Active
319     if (op->outputfields[i].vec == CEED_VECTOR_ACTIVE) {
320       // Restrict
321       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
322                                       lmode, impl->evecs[i+impl->numein], outvec, request); CeedChk(ierr);
323     } else {
324       // Passive
325       // Restrict
326       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
327                                       lmode, impl->evecs[i+impl->numein], op->outputfields[i].vec,
328                                       request); CeedChk(ierr);
329     }
330   }
331 
332   // Restore input arrays
333   for (CeedInt i=0; i<numinputfields; i++) {
334     CeedEvalMode emode = qf->inputfields[i].emode;
335     if (emode == CEED_EVAL_WEIGHT) { // Skip
336     } else {
337       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
338                                         (const CeedScalar **) &impl->edata[i]); CeedChk(ierr);
339     }
340   }
341 
342   return 0;
343 }
344 
345 int CeedOperatorCreate_Ref(CeedOperator op) {
346   int ierr;
347   CeedOperator_Ref *impl;
348 
349   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
350   op->data = impl;
351   op->Destroy = CeedOperatorDestroy_Ref;
352   op->Apply = CeedOperatorApply_Ref;
353   return 0;
354 }
355