xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 7ca8db16b1f641ec2f47938672bae2706e69c374)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 CeedElemRestriction CEED_RESTRICTION_IDENTITY = NULL;
22 CeedBasis CEED_BASIS_COLOCATED = NULL;
23 CeedVector CEED_VECTOR_ACTIVE = NULL;
24 CeedVector CEED_VECTOR_NONE = NULL;
25 
26 static int CeedOperatorDestroy_Ref(CeedOperator op) {
27   CeedOperator_Ref *impl = op->data;
28   int ierr;
29 
30   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
31     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
32   }
33   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
34   ierr = CeedFree(&impl->edata); CeedChk(ierr);
35 
36   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
37     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
40   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
41 
42   ierr = CeedFree(&impl->indata); CeedChk(ierr);
43   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
44 
45   ierr = CeedFree(&op->data); CeedChk(ierr);
46   return 0;
47 }
48 
49 /*
50   Setup infields or outfields
51  */
52 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
53                                        struct CeedOperatorField ofields[16],
54                                        CeedVector *evecs, CeedScalar **qdata,
55                                        CeedScalar **qdata_alloc, CeedScalar **indata,
56                                        CeedInt starti, CeedInt starte,
57                                        CeedInt startq, CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, ie=starte, iq=startq, ncomp;
59 
60   // Loop over fields
61   for (CeedInt i=0; i<numfields; i++) {
62     if (ofields[i].Erestrict) {
63       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[ie]);
64       CeedChk(ierr);
65       ie++;
66     }
67     CeedEvalMode emode = qfields[i].emode;
68     switch(emode) {
69     case CEED_EVAL_NONE:
70       break; // No action
71     case CEED_EVAL_INTERP:
72       ncomp = qfields[i].ncomp;
73       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
74       qdata[i + starti] = qdata_alloc[iq];
75       iq++;
76       break;
77     case CEED_EVAL_GRAD:
78       ncomp = qfields[i].ncomp;
79       dim = ofields[i].basis->dim;
80       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
81       qdata[i + starti] = qdata_alloc[iq];
82       iq++;
83       break;
84     case CEED_EVAL_WEIGHT: // Only on input fields
85       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
86       ierr = CeedBasisApply(ofields[iq].basis, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
87                             NULL, qdata_alloc[iq]); CeedChk(ierr);
88       qdata[i] = qdata_alloc[iq];
89       indata[i] = qdata[i];
90       iq++;
91       break;
92     case CEED_EVAL_DIV:
93       break; // Not implimented
94     case CEED_EVAL_CURL:
95       break; // Not implimented
96     }
97   }
98   return 0;
99 }
100 
101 /*
102   CeedOperator needs to connect all the named fields (be they active or passive)
103   to the named inputs and outputs of its CeedQFunction.
104  */
105 static int CeedOperatorSetup_Ref(CeedOperator op) {
106   if (op->setupdone) return 0;
107   CeedOperator_Ref *opref = op->data;
108   CeedQFunction qf = op->qf;
109   CeedInt Q = op->numqpoints;
110   int ierr;
111 
112   // Count infield and outfield array sizes and evectors
113   for (CeedInt i=0; i<qf->numinputfields; i++) {
114     CeedEvalMode emode = qf->inputfields[i].emode;
115     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
116                      (emode & CEED_EVAL_WEIGHT);
117     opref->numein +=
118       !!op->inputfields[i].Erestrict; // Need E-vector when restriction exists
119   }
120   for (CeedInt i=0; i<qf->numoutputfields; i++) {
121     CeedEvalMode emode = qf->outputfields[i].emode;
122     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
123     opref->numeout += !!op->outputfields[i].Erestrict;
124   }
125 
126   // Allocate
127   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
128   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->edata);
129   CeedChk(ierr);
130 
131   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
132   CeedChk(ierr);
133   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
134   CeedChk(ierr);
135 
136   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
137   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
138 
139   // Set up infield and outfield pointer arrays
140   // Infields
141   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
142                                      opref->evecs, opref->qdata, opref->qdata_alloc,
143                                      opref->indata, 0, 0, 0,
144                                      qf->numinputfields, Q); CeedChk(ierr);
145 
146   // Outfields
147   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
148                                      opref->evecs, opref->qdata, opref->qdata_alloc,
149                                      opref->indata, qf->numinputfields, opref->numein,
150                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
151 
152   // Output Qvecs
153   for (CeedInt i=0; i<qf->numoutputfields; i++) {
154     CeedEvalMode emode = qf->outputfields[i].emode;
155     if (emode != CEED_EVAL_NONE) {
156       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
157     }
158   }
159 
160   op->setupdone = 1;
161 
162   return 0;
163 }
164 
165 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
166                                  CeedVector outvec, CeedRequest *request) {
167   CeedOperator_Ref *opref = op->data;
168   CeedInt Q = op->numqpoints, elemsize;
169   int ierr;
170   CeedQFunction qf = op->qf;
171   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
172   CeedScalar *vec_temp;
173 
174   // Setup
175   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
176 
177   // Input Evecs and Restriction
178   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
179     // Restriction
180     if (op->inputfields[i].Erestrict) {
181       // Zero evec
182       ierr = CeedVectorGetArray(opref->evecs[iein], CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
183       for (CeedInt j=0; j<opref->evecs[iein]->length; j++)
184         vec_temp[j] = 0.;
185       ierr = CeedVectorRestoreArray(opref->evecs[iein], &vec_temp); CeedChk(ierr);
186       // Passive
187       if (op->inputfields[i].vec) {
188         // Restrict
189         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
190                                         lmode, op->inputfields[i].vec, opref->evecs[iein],
191                                         request); CeedChk(ierr);
192         // Get evec
193         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
194                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
195         iein++;
196       } else {
197         // Active
198         // Restrict
199 
200         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
201                                         lmode, invec, opref->evecs[iein], request); CeedChk(ierr);
202         // Get evec
203         ierr = CeedVectorGetArrayRead(opref->evecs[iein], CEED_MEM_HOST,
204                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
205         iein++;
206       }
207     } else {
208       // No restriction
209       CeedEvalMode emode = qf->inputfields[i].emode;
210       if (emode & CEED_EVAL_WEIGHT) {
211       } else {
212         // Passive
213         if (op->inputfields[i].vec) {
214           ierr = CeedVectorGetArrayRead(op->inputfields[i].vec, CEED_MEM_HOST,
215                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
216         // Active
217         } else {
218           ierr = CeedVectorGetArrayRead(invec, CEED_MEM_HOST,
219                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
220         }
221       }
222     }
223   }
224 
225   // Output Evecs
226   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
227     // Restriction
228     if (op->outputfields[i].Erestrict) {
229       ierr = CeedVectorGetArray(opref->evecs[ieout], CEED_MEM_HOST,
230                                 &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
231       ieout++;
232     } else {
233       // No restriction
234       // Passive
235       if (op->inputfields[i].vec) {
236         ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST,
237                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
238       } else {
239         // Active
240         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST,
241                                   &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
242       }
243     }
244   }
245 
246   // Loop through elements
247   for (CeedInt e=0; e<op->numelements; e++) {
248     // Input basis apply if needed
249     for (CeedInt i=0; i<qf->numinputfields; i++) {
250       // Get elemsize
251       if (op->inputfields[i].Erestrict) {
252         elemsize = op->inputfields[i].Erestrict->elemsize;
253       } else {
254         elemsize = Q;
255       }
256       // Get emode, ncomp
257       CeedEvalMode emode = qf->inputfields[i].emode;
258       CeedInt ncomp = qf->inputfields[i].ncomp;
259       // Basis action
260       switch(emode) {
261       case CEED_EVAL_NONE:
262         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
263         break;
264       case CEED_EVAL_INTERP:
265         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
266                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
267         CeedChk(ierr);
268         opref->indata[i] = opref->qdata[i];
269         break;
270       case CEED_EVAL_GRAD:
271         ierr = CeedBasisApply(op->inputfields[i].basis, CEED_NOTRANSPOSE,
272                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
273         CeedChk(ierr);
274         opref->indata[i] = opref->qdata[i];
275         break;
276       case CEED_EVAL_WEIGHT:
277         break;  // No action
278       case CEED_EVAL_DIV:
279         break; // Not implimented
280       case CEED_EVAL_CURL:
281         break; // Not implimented
282       }
283     }
284     // Output pointers
285     for (CeedInt i=0; i<qf->numoutputfields; i++) {
286       CeedEvalMode emode = qf->outputfields[i].emode;
287       if (emode == CEED_EVAL_NONE) {
288         CeedInt ncomp = qf->outputfields[i].ncomp;
289         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
290       }
291     }
292     // Q function
293     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
294                               opref->outdata); CeedChk(ierr);
295 
296     // Output basis apply if needed
297     for (CeedInt i=0; i<qf->numoutputfields; i++) {
298       // Get elemsize
299       if (op->outputfields[i].Erestrict) {
300         elemsize = op->outputfields[i].Erestrict->elemsize;
301       } else {
302         elemsize = Q;
303       }
304       // Get emode, ncomp
305       CeedInt ncomp = qf->outputfields[i].ncomp;
306       CeedEvalMode emode = qf->outputfields[i].emode;
307       // Basis action
308       switch(emode) {
309       case CEED_EVAL_NONE:
310         break; // No action
311       case CEED_EVAL_INTERP:
312         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE,
313                               CEED_EVAL_INTERP, opref->outdata[i],
314                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
315         break;
316       case CEED_EVAL_GRAD:
317         ierr = CeedBasisApply(op->outputfields[i].basis, CEED_TRANSPOSE, CEED_EVAL_GRAD,
318                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
319         CeedChk(ierr);
320         break;
321       case CEED_EVAL_WEIGHT:
322         break; // Should not occur
323       case CEED_EVAL_DIV:
324         break; // Not implimented
325       case CEED_EVAL_CURL:
326         break; // Not implimented
327       }
328     }
329   }
330 
331   // Output restriction
332   for (CeedInt i=0,ieout=opref->numein; i<qf->numoutputfields; i++) {
333     // Restriction
334     if (op->outputfields[i].Erestrict) {
335       // Passive
336       if (op->outputfields[i].vec) {
337         // Restore evec
338         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
339                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
340         // Zero lvec
341         ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
342         for (CeedInt j=0; j<op->outputfields[i].vec->length; j++)
343           vec_temp[j] = 0.;
344         ierr = CeedVectorRestoreArray(op->outputfields[i].vec, &vec_temp); CeedChk(ierr);
345         // Restrict
346         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
347                                         lmode, opref->evecs[ieout], op->outputfields[i].vec, request); CeedChk(ierr);
348         ieout++;
349       } else {
350         // Active
351         // Restore evec
352         ierr = CeedVectorRestoreArray(opref->evecs[ieout],
353                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
354         // Zero lvec
355         ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
356         for (CeedInt j=0; j<outvec->length; j++)
357           vec_temp[j] = 0.;
358         ierr = CeedVectorRestoreArray(outvec, &vec_temp); CeedChk(ierr);
359         // Restrict
360         ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
361                                         lmode, opref->evecs[ieout], outvec, request); CeedChk(ierr);
362         ieout++;
363       }
364     } else {
365       // No Restriction
366       // Passive
367       if (op->outputfields[i].vec) {
368         ierr = CeedVectorRestoreArray(op->outputfields[i].vec,
369                                       &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
370       } else {
371         // Active
372         ierr = CeedVectorRestoreArray(outvec, &opref->edata[i + qf->numinputfields]);
373         CeedChk(ierr);
374       }
375     }
376   }
377 
378   // Restore input arrays
379   for (CeedInt i=0,iein=0; i<qf->numinputfields; i++) {
380     // Restriction
381     if (op->inputfields[i].Erestrict) {
382       ierr = CeedVectorRestoreArrayRead(opref->evecs[iein],
383                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
384         iein++;
385     } else {
386       // No restriction
387       CeedEvalMode emode = qf->inputfields[i].emode;
388       if (emode & CEED_EVAL_WEIGHT) {
389       } else {
390         // Passive
391         if (op->inputfields[i].vec) {
392           ierr = CeedVectorRestoreArrayRead(op->inputfields[i].vec,
393                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
394         // Active
395         } else {
396           ierr = CeedVectorRestoreArrayRead(invec,
397                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
398 
399         }
400       }
401     }
402   }
403 
404   return 0;
405 }
406 
407 int CeedOperatorCreate_Ref(CeedOperator op) {
408   CeedOperator_Ref *impl;
409   int ierr;
410 
411   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
412   op->data = impl;
413   op->Destroy = CeedOperatorDestroy_Ref;
414   op->Apply = CeedOperatorApply_Ref;
415   return 0;
416 }
417