xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 583b1d4c382ee3afcd0623eaf06c15e7df28a1d2)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 CeedElemRestriction CEED_RESTRICTION_IDENTITY = NULL;
22 CeedBasis CEED_BASIS_COLOCATED = NULL;
23 CeedVector CEED_VECTOR_ACTIVE = NULL;
24 CeedVector CEED_VECTOR_NONE = NULL;
25 
26 static int CeedOperatorDestroy_Ref(CeedOperator op) {
27   CeedOperator_Ref *impl = op->data;
28   int ierr;
29 
30   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
31     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
32   }
33   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
34   ierr = CeedFree(&impl->edata); CeedChk(ierr);
35 
36   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
37     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
40   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
41 
42   ierr = CeedFree(&impl->indata); CeedChk(ierr);
43   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
44 
45   ierr = CeedFree(&op->data); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
53                                        struct CeedOperatorField ofields[16],
54                                        CeedVector *evecs, CeedScalar **qdata,
55                                        CeedScalar **qdata_alloc, CeedScalar **indata,
56                                        CeedInt starti, CeedInt starte,
57                                        CeedInt startq, CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, ie=starte, iq=startq, ncomp;
59 
60   // Loop over fields
61   for (CeedInt i=0; i<numfields; i++) {
62     if (ofields[i].Erestrict) {
63       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[ie]);
64       CeedChk(ierr);
65       ie++;
66     }
67     CeedEvalMode emode = qfields[i].emode;
68     switch(emode) {
69     case CEED_EVAL_NONE:
70       break; // No action
71     case CEED_EVAL_INTERP:
72       ncomp = qfields[i].ncomp;
73       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
74       qdata[i + starti] = qdata_alloc[iq];
75       iq++;
76       break;
77     case CEED_EVAL_GRAD:
78       ncomp = qfields[i].ncomp;
79       dim = ofields[i].basis->dim;
80       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
81       qdata[i + starti] = qdata_alloc[iq];
82       iq++;
83       break;
84     case CEED_EVAL_WEIGHT: // Only on input fields
85       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
86       ierr = CeedBasisApply(ofields[iq].basis, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
87                             NULL, qdata_alloc[iq]); CeedChk(ierr);
88       qdata[i] = qdata_alloc[iq];
89       indata[i] = qdata[i];
90       iq++;
91       break;
92     case CEED_EVAL_DIV:
93       break; // Not implimented
94     case CEED_EVAL_CURL:
95       break; // Not implimented
96     }
97   }
98   return 0;
99 }
100 
101 /*
102   CeedOperator needs to connect all the named fields (be they active or passive)
103   to the named inputs and outputs of its CeedQFunction.
104  */
105 static int CeedOperatorSetup_Ref(CeedOperator op) {
106   if (op->setupdone) return 0;
107   CeedOperator_Ref *opref = op->data;
108   CeedQFunction qf = op->qf;
109   CeedInt Q = op->numqpoints;
110   int ierr;
111 
112   // Count infield and outfield array sizes and evectors
113   for (CeedInt i=0; i<qf->numinputfields; i++) {
114     CeedEvalMode emode = qf->inputfields[i].emode;
115     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
116                      (emode & CEED_EVAL_WEIGHT);
117     opref->numein +=
118       !!op->inputfields[i].Erestrict; // Need E-vector when restriction exists
119   }
120   for (CeedInt i=0; i<qf->numoutputfields; i++) {
121     CeedEvalMode emode = qf->outputfields[i].emode;
122     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
123     opref->numeout += !!op->outputfields[i].Erestrict;
124   }
125 
126   // Allocate
127   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
128   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->edata);
129   CeedChk(ierr);
130 
131   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
132   CeedChk(ierr);
133   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
134   CeedChk(ierr);
135 
136   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
137   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
138 
139   // Set up infield and outfield pointer arrays
140   // Infields
141   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
142                                      opref->evecs, opref->qdata, opref->qdata_alloc,
143                                      opref->indata, 0, 0, 0,
144                                      qf->numinputfields, Q); CeedChk(ierr);
145 
146   // Outfields
147   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
148                                      opref->evecs, opref->qdata, opref->qdata_alloc,
149                                      opref->indata, qf->numinputfields, opref->numein,
150                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
151 
152   op->setupdone = 1;
153 
154   return 0;
155 }
156 
157 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
158                                  CeedVector outvec, CeedRequest *request) {
159   CeedOperator_Ref *opref = op->data;
160   CeedInt Q = op->numqpoints, elemsize;
161   int ierr;
162   CeedQFunction qf = op->qf;
163   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
164 
165   // Setup
166   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
167 
168   // Input Evecs and Restriction
169   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
170     // Restriction
171     if (op->inputfields[i].Erestrict) {
172       // Passive
173       if (op->inputfields[i].vec) {
174         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
175                                         lmode, op->inputfields[i].vec, opref->evecs[iein],
176                                         request); CeedChk(ierr);
177         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
178                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
179         iein++;
180       } else {
181         // Active
182         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
183                                         lmode, invec, opref->evecs[iein], request); CeedChk(ierr);
184         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
185                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
186         iein++;
187       }
188     } else {
189       // No restriction
190       CeedEvalMode emode = qf->inputfields[i].emode;
191       if (emode & CEED_EVAL_WEIGHT) {
192       } else {
193         ierr = CeedVectorGetArrayRead(op->inputfields[i].vec, CEED_MEM_HOST,
194                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
195       }
196     }
197   }
198 
199   // Output Evecs
200   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
201     // Restriction
202     if (op->outputfields[i].Erestrict) {
203       ierr = CeedVectorGetArray(opref->evecs[ieout], CEED_MEM_HOST,
204                                 &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
205       ieout++;
206     } else {
207       // No restriction
208       // Passive
209       if (op->inputfields[i].vec) {
210         ierr = CeedVectorGetArray(op->inputfields[i].vec, CEED_MEM_HOST,
211                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
212       } else {
213         // Active
214         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST,
215                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
216       }
217     }
218   }
219 
220   // Output Qvecs
221   for (CeedInt i=0; i<qf->numoutputfields; i++) {
222     CeedEvalMode emode = qf->outputfields[i].emode;
223     if (emode != CEED_EVAL_NONE) {
224       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
225     }
226   }
227 
228   // Loop through elements
229   for (CeedInt e=0; e<op->numelements; e++) {
230     // Input basis apply if needed
231     for (CeedInt i=0; i<qf->numinputfields; i++) {
232       // Get elemsize
233       if (op->inputfields[i].Erestrict) {
234         elemsize = op->inputfields[i].Erestrict->elemsize;
235       } else {
236         elemsize = Q;
237       }
238       // Get emode, ncomp
239       CeedEvalMode emode = qf->inputfields[i].emode;
240       CeedInt ncomp = qf->inputfields[i].ncomp;
241       // Basis action
242       switch(emode) {
243       case CEED_EVAL_NONE:
244         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
245         break;
246       case CEED_EVAL_INTERP:
247         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
248                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
249         CeedChk(ierr);
250         opref->indata[i] = opref->qdata[i];
251         break;
252       case CEED_EVAL_GRAD:
253         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
254                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
255         CeedChk(ierr);
256         opref->indata[i] = opref->qdata[i];
257         break;
258       case CEED_EVAL_WEIGHT:
259         break;  // No action
260       case CEED_EVAL_DIV:
261         break; // Not implimented
262       case CEED_EVAL_CURL:
263         break; // Not implimented
264       }
265     }
266     // Output pointers
267     for (CeedInt i=0; i<qf->numoutputfields; i++) {
268       CeedEvalMode emode = qf->outputfields[i].emode;
269       if (emode == CEED_EVAL_NONE) {
270         CeedInt ncomp = qf->outputfields[i].ncomp;
271         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
272       }
273     }
274     // Q function
275     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
276                               opref->outdata); CeedChk(ierr);
277 
278     // Output basis apply if needed
279     for (CeedInt i=0; i<qf->numoutputfields; i++) {
280       // Get elemsize
281       if (op->outputfields[i].Erestrict) {
282         elemsize = op->outputfields[i].Erestrict->elemsize;
283       } else {
284         elemsize = Q;
285       }
286       // Get emode, ncomp
287       CeedInt ncomp = qf->outputfields[i].ncomp;
288       CeedEvalMode emode = qf->outputfields[i].emode;
289       // Basis action
290       switch(emode) {
291       case CEED_EVAL_NONE:
292         break; // No action
293       case CEED_EVAL_INTERP:
294         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE,
295                               CEED_EVAL_INTERP, opref->outdata[i],
296                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
297         break;
298       case CEED_EVAL_GRAD:
299         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE, CEED_EVAL_GRAD,
300                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
301         CeedChk(ierr);
302         break;
303       case CEED_EVAL_WEIGHT:
304         break; // Should not occur
305       case CEED_EVAL_DIV:
306         break; // Not implimented
307       case CEED_EVAL_CURL:
308         break; // Not implimented
309       }
310     }
311   }
312 
313   // Output restriction
314   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
315     // Restriction
316     if (op->outputfields[i].Erestrict) {
317       // Passive
318       if (op->outputfields[i].vec) {
319         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
320                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
321         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
322                                         lmode, opref->evecs[ieout], op->outputfields[i].vec, request); CeedChk(ierr);
323         ieout++;
324       } else {
325         // Active
326         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
327                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
328         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
329                                         lmode, opref->evecs[ieout], outvec, request); CeedChk(ierr);
330         ieout++;
331       }
332     } else {
333       // No Restriction
334       // Passive
335       if (op->outputfields[i].vec) {
336         ierr = CeedVectorRestoreArray(op->outputfields[i].vec,
337                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
338       } else {
339         // Active
340         ierr = CeedVectorRestoreArray(outvec, &opref->edata[i + qf->numinputfields]);
341         CeedChk(ierr);
342       }
343     }
344   }
345 
346   return 0;
347 }
348 
349 int CeedOperatorCreate_Ref(CeedOperator op) {
350   CeedOperator_Ref *impl;
351   int ierr;
352 
353   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
354   op->data = impl;
355   op->Destroy = CeedOperatorDestroy_Ref;
356   op->Apply = CeedOperatorApply_Ref;
357   return 0;
358 }
359