xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 885ac19c71404eabce0bdaffc144734a44d512ec)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 CeedElemRestriction CEED_RESTRICTION_IDENTITY = NULL;
22 CeedBasis CEED_BASIS_COLOCATED = NULL;
23 CeedVector CEED_QDATA_NONE = NULL;
24 
25 static int CeedOperatorDestroy_Ref(CeedOperator op) {
26   CeedOperator_Ref *impl = op->data;
27   int ierr;
28 
29   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
30     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
31   }
32   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
33   ierr = CeedFree(&impl->edata); CeedChk(ierr);
34 
35   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
36     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
39   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
40 
41   ierr = CeedFree(&impl->indata); CeedChk(ierr);
42   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
43 
44   ierr = CeedFree(&op->data); CeedChk(ierr);
45   return 0;
46 }
47 
48 /*
49   Setup infields or outfields
50  */
51 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
52                                        struct CeedOperatorField ofields[16],
53                                        CeedVector *evecs, CeedScalar **qdata,
54                                        CeedScalar **qdata_alloc, CeedScalar **indata,
55                                        CeedInt starti, CeedInt starte,
56                                        CeedInt startq, CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, ie=starte, iq=startq, ncomp;
58 
59   // Loop over fields
60   for (CeedInt i=0; i<numfields; i++) {
61     if (ofields[i].Erestrict) {
62       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[ie]);
63       CeedChk(ierr);
64       ie++;
65     }
66     CeedEvalMode emode = qfields[i].emode;
67     switch(emode) {
68     case CEED_EVAL_NONE:
69       break; // No action
70     case CEED_EVAL_INTERP:
71       ncomp = qfields[i].ncomp;
72       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
73       qdata[i + starti] = qdata_alloc[iq];
74       iq++;
75       break;
76     case CEED_EVAL_GRAD:
77       ncomp = qfields[i].ncomp;
78       dim = ofields[i].basis->dim;
79       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
80       qdata[i + starti] = qdata_alloc[iq];
81       iq++;
82       break;
83     case CEED_EVAL_WEIGHT: // Only on input fields
84       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
85       ierr = CeedBasisApply(ofields[iq].basis, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
86                             NULL, qdata_alloc[iq]); CeedChk(ierr);
87       qdata[i] = qdata_alloc[iq];
88       indata[i] = qdata[i];
89       iq++;
90       break;
91     case CEED_EVAL_DIV:
92       break; // Not implimented
93     case CEED_EVAL_CURL:
94       break; // Not implimented
95     }
96   }
97   return 0;
98 }
99 
100 /*
101   CeedOperator needs to connect all the named fields (be they active or passive)
102   to the named inputs and outputs of its CeedQFunction.
103  */
104 static int CeedOperatorSetup_Ref(CeedOperator op) {
105   if (op->setupdone) return 0;
106   CeedOperator_Ref *opref = op->data;
107   CeedQFunction qf = op->qf;
108   CeedInt Q = op->numqpoints;
109   int ierr;
110 
111   // Count infield and outfield array sizes and evectors
112   for (CeedInt i=0; i<qf->numinputfields; i++) {
113     CeedEvalMode emode = qf->inputfields[i].emode;
114     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
115                      (emode & CEED_EVAL_WEIGHT);
116     opref->numein +=
117       !!op->inputfields[i].Erestrict; // Need E-vector when restriction exists
118   }
119   for (CeedInt i=0; i<qf->numoutputfields; i++) {
120     CeedEvalMode emode = qf->outputfields[i].emode;
121     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
122     opref->numeout += !!op->outputfields[i].Erestrict;
123   }
124 
125   // Allocate
126   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
127   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->edata);
128   CeedChk(ierr);
129 
130   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
131   CeedChk(ierr);
132   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
133   CeedChk(ierr);
134 
135   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
136   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
137 
138   // Set up infield and outfield pointer arrays
139   // Infields
140   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
141                                      opref->evecs, opref->qdata, opref->qdata_alloc,
142                                      opref->indata, 0, 0, 0,
143                                      qf->numinputfields, Q); CeedChk(ierr);
144 
145   // Outfields
146   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
147                                      opref->evecs, opref->qdata, opref->qdata_alloc,
148                                      opref->indata, qf->numinputfields, opref->numein,
149                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
150 
151   op->setupdone = 1;
152 
153   return 0;
154 }
155 
156 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
157                                  CeedVector outvec, CeedRequest *request) {
158   CeedOperator_Ref *opref = op->data;
159   CeedInt Q = op->numqpoints, elemsize;
160   int ierr;
161   CeedQFunction qf = op->qf;
162   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
163 
164   // Setup
165   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
166 
167   // Input Evecs and Restriction
168   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
169     // Restriction
170     if (op->inputfields[i].Erestrict) {
171       // Passive
172       if (op->inputfields[i].vec) {
173         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
174                                         lmode, op->inputfields[i].vec, opref->evecs[iein],
175                                         request); CeedChk(ierr);
176         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
177                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
178         iein++;
179       } else {
180         // Active
181         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
182                                         lmode, invec, opref->evecs[iein], request); CeedChk(ierr);
183         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
184                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
185         iein++;
186       }
187     } else {
188       // No restriction
189       CeedEvalMode emode = qf->inputfields[i].emode;
190       if (emode & CEED_EVAL_WEIGHT) {
191       } else {
192         ierr = CeedVectorGetArrayRead(op->inputfields[i].vec, CEED_MEM_HOST,
193                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
194       }
195     }
196   }
197 
198   // Output Evecs
199   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
200     // Restriction
201     if (op->outputfields[i].Erestrict) {
202       ierr = CeedVectorGetArray(opref->evecs[ieout], CEED_MEM_HOST,
203                                 &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
204       ieout++;
205     } else {
206       // No restriction
207       // Passive
208       if (op->inputfields[i].vec) {
209         ierr = CeedVectorGetArray(op->inputfields[i].vec, CEED_MEM_HOST,
210                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
211       } else {
212         // Active
213         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST,
214                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
215       }
216     }
217   }
218 
219   // Output Qvecs
220   for (CeedInt i=0; i<qf->numoutputfields; i++) {
221     CeedEvalMode emode = qf->outputfields[i].emode;
222     if (emode != CEED_EVAL_NONE) {
223       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
224     }
225   }
226 
227   // Loop through elements
228   for (CeedInt e=0; e<op->numelements; e++) {
229     // Input basis apply if needed
230     for (CeedInt i=0; i<qf->numinputfields; i++) {
231       // Get elemsize
232       if (op->inputfields[i].Erestrict) {
233         elemsize = op->inputfields[i].Erestrict->elemsize;
234       } else {
235         elemsize = Q;
236       }
237       // Get emode, ncomp
238       CeedEvalMode emode = qf->inputfields[i].emode;
239       CeedInt ncomp = qf->inputfields[i].ncomp;
240       // Basis action
241       switch(emode) {
242       case CEED_EVAL_NONE:
243         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
244         break;
245       case CEED_EVAL_INTERP:
246         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
247                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
248         CeedChk(ierr);
249         opref->indata[i] = opref->qdata[i];
250         break;
251       case CEED_EVAL_GRAD:
252         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
253                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
254         CeedChk(ierr);
255         opref->indata[i] = opref->qdata[i];
256         break;
257       case CEED_EVAL_WEIGHT:
258         break;  // No action
259       case CEED_EVAL_DIV:
260         break; // Not implimented
261       case CEED_EVAL_CURL:
262         break; // Not implimented
263       }
264     }
265     // Output pointers
266     for (CeedInt i=0; i<qf->numoutputfields; i++) {
267       CeedEvalMode emode = qf->outputfields[i].emode;
268       if (emode == CEED_EVAL_NONE) {
269         CeedInt ncomp = qf->outputfields[i].ncomp;
270         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
271       }
272     }
273     // Q function
274     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
275                               opref->outdata); CeedChk(ierr);
276 
277     // Output basis apply if needed
278     for (CeedInt i=0; i<qf->numoutputfields; i++) {
279       // Get elemsize
280       if (op->outputfields[i].Erestrict) {
281         elemsize = op->outputfields[i].Erestrict->elemsize;
282       } else {
283         elemsize = Q;
284       }
285       // Get emode, ncomp
286       CeedInt ncomp = qf->outputfields[i].ncomp;
287       CeedEvalMode emode = qf->outputfields[i].emode;
288       // Basis action
289       switch(emode) {
290       case CEED_EVAL_NONE:
291         break; // No action
292       case CEED_EVAL_INTERP:
293         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE,
294                               CEED_EVAL_INTERP, opref->outdata[i],
295                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
296         break;
297       case CEED_EVAL_GRAD:
298         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE, CEED_EVAL_GRAD,
299                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
300         CeedChk(ierr);
301         break;
302       case CEED_EVAL_WEIGHT:
303         break; // Should not occur
304       case CEED_EVAL_DIV:
305         break; // Not implimented
306       case CEED_EVAL_CURL:
307         break; // Not implimented
308       }
309     }
310   }
311 
312   // Output restriction
313   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
314     // Restriction
315     if (op->outputfields[i].Erestrict) {
316       // Passive
317       if (op->outputfields[i].vec) {
318         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
319                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
320         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
321                                         lmode, opref->evecs[ieout], op->outputfields[i].vec, request); CeedChk(ierr);
322         ieout++;
323       } else {
324         // Active
325         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
326                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
327         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
328                                         lmode, opref->evecs[ieout], outvec, request); CeedChk(ierr);
329         ieout++;
330       }
331     } else {
332       // No Restriction
333       // Passive
334       if (op->outputfields[i].vec) {
335         ierr = CeedVectorRestoreArray(op->outputfields[i].vec,
336                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
337       } else {
338         // Active
339         ierr = CeedVectorRestoreArray(outvec, &opref->edata[i + qf->numinputfields]);
340         CeedChk(ierr);
341       }
342     }
343   }
344 
345   return 0;
346 }
347 
348 int CeedOperatorCreate_Ref(CeedOperator op) {
349   CeedOperator_Ref *impl;
350   int ierr;
351 
352   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
353   op->data = impl;
354   op->Destroy = CeedOperatorDestroy_Ref;
355   op->Apply = CeedOperatorApply_Ref;
356   return 0;
357 }
358