xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 389b3d932a124a0aac28799927fddea4f2230f2a)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 CeedElemRestriction CEED_RESTRICTION_IDENTITY = NULL;
22 CeedBasis CEED_BASIS_COLOCATED = NULL;
23 CeedVector CEED_VECTOR_ACTIVE = NULL;
24 CeedVector CEED_VECTOR_NONE = NULL;
25 
26 static int CeedOperatorDestroy_Ref(CeedOperator op) {
27   CeedOperator_Ref *impl = op->data;
28   int ierr;
29 
30   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
31     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
32   }
33   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
34   ierr = CeedFree(&impl->edata); CeedChk(ierr);
35 
36   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
37     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
40   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
41 
42   ierr = CeedFree(&impl->indata); CeedChk(ierr);
43   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
44 
45   ierr = CeedFree(&op->data); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
53                                        struct CeedOperatorField ofields[16],
54                                        CeedVector *evecs, CeedScalar **qdata,
55                                        CeedScalar **qdata_alloc, CeedScalar **indata,
56                                        CeedInt starti, CeedInt starte,
57                                        CeedInt startq, CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, ie=starte, iq=startq, ncomp;
59 
60   // Loop over fields
61   for (CeedInt i=0; i<numfields; i++) {
62     if (ofields[i].Erestrict) {
63       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[ie]);
64       CeedChk(ierr);
65       ie++;
66     }
67     CeedEvalMode emode = qfields[i].emode;
68     switch(emode) {
69     case CEED_EVAL_NONE:
70       break; // No action
71     case CEED_EVAL_INTERP:
72       ncomp = qfields[i].ncomp;
73       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
74       qdata[i + starti] = qdata_alloc[iq];
75       iq++;
76       break;
77     case CEED_EVAL_GRAD:
78       ncomp = qfields[i].ncomp;
79       dim = ofields[i].basis->dim;
80       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
81       qdata[i + starti] = qdata_alloc[iq];
82       iq++;
83       break;
84     case CEED_EVAL_WEIGHT: // Only on input fields
85       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
86       ierr = CeedBasisApply(ofields[iq].basis, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
87                             NULL, qdata_alloc[iq]); CeedChk(ierr);
88       qdata[i] = qdata_alloc[iq];
89       indata[i] = qdata[i];
90       iq++;
91       break;
92     case CEED_EVAL_DIV:
93       break; // Not implimented
94     case CEED_EVAL_CURL:
95       break; // Not implimented
96     }
97   }
98   return 0;
99 }
100 
101 /*
102   CeedOperator needs to connect all the named fields (be they active or passive)
103   to the named inputs and outputs of its CeedQFunction.
104  */
105 static int CeedOperatorSetup_Ref(CeedOperator op) {
106   if (op->setupdone) return 0;
107   CeedOperator_Ref *opref = op->data;
108   CeedQFunction qf = op->qf;
109   CeedInt Q = op->numqpoints;
110   int ierr;
111 
112   // Count infield and outfield array sizes and evectors
113   for (CeedInt i=0; i<qf->numinputfields; i++) {
114     CeedEvalMode emode = qf->inputfields[i].emode;
115     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
116                      (emode & CEED_EVAL_WEIGHT);
117     opref->numein +=
118       !!op->inputfields[i].Erestrict; // Need E-vector when restriction exists
119   }
120   for (CeedInt i=0; i<qf->numoutputfields; i++) {
121     CeedEvalMode emode = qf->outputfields[i].emode;
122     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
123     opref->numeout += !!op->outputfields[i].Erestrict;
124   }
125 
126   // Allocate
127   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
128   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->edata);
129   CeedChk(ierr);
130 
131   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
132   CeedChk(ierr);
133   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
134   CeedChk(ierr);
135 
136   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
137   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
138 
139   // Set up infield and outfield pointer arrays
140   // Infields
141   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
142                                      opref->evecs, opref->qdata, opref->qdata_alloc,
143                                      opref->indata, 0, 0, 0,
144                                      qf->numinputfields, Q); CeedChk(ierr);
145 
146   // Outfields
147   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
148                                      opref->evecs, opref->qdata, opref->qdata_alloc,
149                                      opref->indata, qf->numinputfields, opref->numein,
150                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
151 
152   // Output Qvecs
153   for (CeedInt i=0; i<qf->numoutputfields; i++) {
154     CeedEvalMode emode = qf->outputfields[i].emode;
155     if (emode != CEED_EVAL_NONE) {
156       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
157     }
158   }
159 
160   op->setupdone = 1;
161 
162   return 0;
163 }
164 
165 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
166                                  CeedVector outvec, CeedRequest *request) {
167   CeedOperator_Ref *opref = op->data;
168   CeedInt Q = op->numqpoints, elemsize;
169   int ierr;
170   CeedQFunction qf = op->qf;
171   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
172   CeedScalar *vec_temp;
173 
174   // Setup
175   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
176 
177   // Input Evecs and Restriction
178   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
179     // Restriction
180     if (op->inputfields[i].Erestrict) {
181       // Zero evec
182       ierr = CeedVectorGetArray(opref->evecs[iein], CEED_MEM_HOST, &vec_temp);
183       CeedChk(ierr);
184       for (CeedInt j=0; j<opref->evecs[iein]->length; j++)
185         vec_temp[j] = 0.;
186       ierr = CeedVectorRestoreArray(opref->evecs[iein], &vec_temp); CeedChk(ierr);
187       // Passive
188       if (op->inputfields[i].vec) {
189         // Restrict
190         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
191                                         lmode, op->inputfields[i].vec, opref->evecs[iein],
192                                         request); CeedChk(ierr);
193         // Get evec
194         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
195                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
196         iein++;
197       } else {
198         // Active
199         // Restrict
200 
201         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
202                                         lmode, invec, opref->evecs[iein], request); CeedChk(ierr);
203         // Get evec
204         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
205                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
206         iein++;
207       }
208     } else {
209       // No restriction
210       CeedEvalMode emode = qf->inputfields[i].emode;
211       if (emode & CEED_EVAL_WEIGHT) {
212       } else {
213         // Passive
214         if (op->inputfields[i].vec) {
215           ierr = CeedVectorGetArrayRead(op->inputfields[i].vec, CEED_MEM_HOST,
216                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
217           // Active
218         } else {
219           ierr = CeedVectorGetArrayRead(invec, CEED_MEM_HOST,
220                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
221         }
222       }
223     }
224   }
225 
226   // Output Evecs
227   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
228     // Restriction
229     if (op->outputfields[i].Erestrict) {
230       ierr = CeedVectorGetArray(opref->evecs[ieout], CEED_MEM_HOST,
231                                 &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
232       ieout++;
233     } else {
234       // No restriction
235       // Passive
236       if (op->outputfields[i].vec) {
237         ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST,
238                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
239       } else {
240         // Active
241         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST,
242                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
243       }
244     }
245   }
246 
247   // Loop through elements
248   for (CeedInt e=0; e<op->numelements; e++) {
249     // Input basis apply if needed
250     for (CeedInt i=0; i<qf->numinputfields; i++) {
251       // Get elemsize
252       if (op->inputfields[i].Erestrict) {
253         elemsize = op->inputfields[i].Erestrict->elemsize;
254       } else {
255         elemsize = Q;
256       }
257       // Get emode, ncomp
258       CeedEvalMode emode = qf->inputfields[i].emode;
259       CeedInt ncomp = qf->inputfields[i].ncomp;
260       // Basis action
261       switch(emode) {
262       case CEED_EVAL_NONE:
263         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
264         break;
265       case CEED_EVAL_INTERP:
266         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
267                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
268         CeedChk(ierr);
269         opref->indata[i] = opref->qdata[i];
270         break;
271       case CEED_EVAL_GRAD:
272         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
273                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
274         CeedChk(ierr);
275         opref->indata[i] = opref->qdata[i];
276         break;
277       case CEED_EVAL_WEIGHT:
278         break;  // No action
279       case CEED_EVAL_DIV:
280         break; // Not implimented
281       case CEED_EVAL_CURL:
282         break; // Not implimented
283       }
284     }
285     // Output pointers
286     for (CeedInt i=0; i<qf->numoutputfields; i++) {
287       CeedEvalMode emode = qf->outputfields[i].emode;
288       if (emode == CEED_EVAL_NONE) {
289         CeedInt ncomp = qf->outputfields[i].ncomp;
290         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
291       }
292     }
293     // Q function
294     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
295                               opref->outdata); CeedChk(ierr);
296 
297     // Output basis apply if needed
298     for (CeedInt i=0; i<qf->numoutputfields; i++) {
299       // Get elemsize
300       if (op->outputfields[i].Erestrict) {
301         elemsize = op->outputfields[i].Erestrict->elemsize;
302       } else {
303         elemsize = Q;
304       }
305       // Get emode, ncomp
306       CeedInt ncomp = qf->outputfields[i].ncomp;
307       CeedEvalMode emode = qf->outputfields[i].emode;
308       // Basis action
309       switch(emode) {
310       case CEED_EVAL_NONE:
311         break; // No action
312       case CEED_EVAL_INTERP:
313         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE,
314                               CEED_EVAL_INTERP, opref->outdata[i],
315                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
316         break;
317       case CEED_EVAL_GRAD:
318         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE, CEED_EVAL_GRAD,
319                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
320         CeedChk(ierr);
321         break;
322       case CEED_EVAL_WEIGHT:
323         break; // Should not occur
324       case CEED_EVAL_DIV:
325         break; // Not implimented
326       case CEED_EVAL_CURL:
327         break; // Not implimented
328       }
329     }
330   }
331 
332   // Output restriction
333   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
334     // Restriction
335     if (op->outputfields[i].Erestrict) {
336       // Passive
337       if (op->outputfields[i].vec) {
338         // Restore evec
339         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
340                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
341         // Zero lvec
342         ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST, &vec_temp);
343         CeedChk(ierr);
344         for (CeedInt j=0; j<op->outputfields[i].vec->length; j++)
345           vec_temp[j] = 0.;
346         ierr = CeedVectorRestoreArray(op->outputfields[i].vec, &vec_temp);
347         CeedChk(ierr);
348         // Restrict
349         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
350                                         lmode, opref->evecs[ieout], op->outputfields[i].vec, request); CeedChk(ierr);
351         ieout++;
352       } else {
353         // Active
354         // Restore evec
355         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
356                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
357         // Zero lvec
358         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
359         for (CeedInt j=0; j<outvec->length; j++)
360           vec_temp[j] = 0.;
361         ierr = CeedVectorRestoreArray(outvec, &vec_temp); CeedChk(ierr);
362         // Restrict
363         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
364                                         lmode, opref->evecs[ieout], outvec, request); CeedChk(ierr);
365         ieout++;
366       }
367     } else {
368       // No Restriction
369       // Passive
370       if (op->outputfields[i].vec) {
371         ierr = CeedVectorRestoreArray(op->outputfields[i].vec,
372                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
373       } else {
374         // Active
375         ierr = CeedVectorRestoreArray(outvec, &opref->edata[i + qf->numinputfields]);
376         CeedChk(ierr);
377       }
378     }
379   }
380 
381   // Restore input arrays
382   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
383     // Restriction
384     if (op->inputfields[i].Erestrict) {
385       ierr = CeedVectorRestoreArrayRead(opref->evecs[iein],
386                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
387       iein++;
388     } else {
389       // No restriction
390       CeedEvalMode emode = qf->inputfields[i].emode;
391       if (emode & CEED_EVAL_WEIGHT) {
392       } else {
393         // Passive
394         if (op->inputfields[i].vec) {
395           ierr = CeedVectorRestoreArrayRead(op->inputfields[i].vec,
396                                             (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
397           // Active
398         } else {
399           ierr = CeedVectorRestoreArrayRead(invec,
400                                             (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
401 
402         }
403       }
404     }
405   }
406 
407   return 0;
408 }
409 
410 int CeedOperatorCreate_Ref(CeedOperator op) {
411   CeedOperator_Ref *impl;
412   int ierr;
413 
414   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
415   op->data = impl;
416   op->Destroy = CeedOperatorDestroy_Ref;
417   op->Apply = CeedOperatorApply_Ref;
418   return 0;
419 }
420