xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 5a0a874ce977dcfdcf940eedc0869ca9f0da1624)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Ref(CeedOperator op) {
22   CeedOperator_Ref *impl = op->data;
23   int ierr;
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
27   }
28   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
29   ierr = CeedFree(&impl->edata); CeedChk(ierr);
30 
31   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
32     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
33   }
34   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
35   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
36 
37   ierr = CeedFree(&impl->indata); CeedChk(ierr);
38   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&op->data); CeedChk(ierr);
41   return 0;
42 }
43 
44 /*
45   Setup infields or outfields
46  */
47 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
48                                        struct CeedOperatorField ofields[16],
49                                        CeedVector *evecs, CeedScalar **qdata,
50                                        CeedScalar **qdata_alloc, CeedScalar **indata,
51                                        CeedInt starti, CeedInt startq,
52                                        CeedInt numfields, CeedInt Q) {
53   CeedInt dim, ierr, iq=startq, ncomp;
54 
55   // Loop over fields
56   for (CeedInt i=0; i<numfields; i++) {
57     CeedEvalMode emode = qfields[i].emode;
58 
59     if (emode != CEED_EVAL_WEIGHT) {
60       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[i+starti]);
61       CeedChk(ierr);
62     }
63 
64     switch(emode) {
65     case CEED_EVAL_NONE:
66       break; // No action
67     case CEED_EVAL_INTERP:
68       ncomp = qfields[i].ncomp;
69       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
70       qdata[i + starti] = qdata_alloc[iq];
71       iq++;
72       break;
73     case CEED_EVAL_GRAD:
74       ncomp = qfields[i].ncomp;
75       dim = ofields[i].basis->dim;
76       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
77       qdata[i + starti] = qdata_alloc[iq];
78       iq++;
79       break;
80     case CEED_EVAL_WEIGHT: // Only on input fields
81       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
82       ierr = CeedBasisApply(ofields[iq].basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
83                             NULL, qdata_alloc[iq]); CeedChk(ierr);
84       qdata[i] = qdata_alloc[iq];
85       indata[i] = qdata[i];
86       iq++;
87       break;
88     case CEED_EVAL_DIV:
89       break; // Not implimented
90     case CEED_EVAL_CURL:
91       break; // Not implimented
92     }
93   }
94   return 0;
95 }
96 
97 /*
98   CeedOperator needs to connect all the named fields (be they active or passive)
99   to the named inputs and outputs of its CeedQFunction.
100  */
101 static int CeedOperatorSetup_Ref(CeedOperator op) {
102   if (op->setupdone) return 0;
103   CeedOperator_Ref *opref = op->data;
104   CeedQFunction qf = op->qf;
105   CeedInt Q = op->numqpoints;
106   int ierr;
107 
108   // Count infield and outfield array sizes and evectors
109   opref->numein = qf->numinputfields;
110   for (CeedInt i=0; i<qf->numinputfields; i++) {
111     CeedEvalMode emode = qf->inputfields[i].emode;
112     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
113                      (emode & CEED_EVAL_WEIGHT);
114   }
115   opref->numeout = qf->numoutputfields;
116   for (CeedInt i=0; i<qf->numoutputfields; i++) {
117     CeedEvalMode emode = qf->outputfields[i].emode;
118     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
119   }
120 
121   // Allocate
122   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
123   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->edata);
124   CeedChk(ierr);
125 
126   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
127   CeedChk(ierr);
128   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
129   CeedChk(ierr);
130 
131   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
132   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
133 
134   // Set up infield and outfield pointer arrays
135   // Infields
136   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
137                                      opref->evecs, opref->qdata, opref->qdata_alloc,
138                                      opref->indata, 0, 0,
139                                      qf->numinputfields, Q); CeedChk(ierr);
140 
141   // Outfields
142   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
143                                      opref->evecs, opref->qdata, opref->qdata_alloc,
144                                      opref->indata, qf->numinputfields,
145                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
146 
147   // Output Qvecs
148   for (CeedInt i=0; i<qf->numoutputfields; i++) {
149     CeedEvalMode emode = qf->outputfields[i].emode;
150     if (emode != CEED_EVAL_NONE) {
151       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
152     }
153   }
154 
155   op->setupdone = 1;
156 
157   return 0;
158 }
159 
160 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
161                                  CeedVector outvec, CeedRequest *request) {
162   CeedOperator_Ref *opref = op->data;
163   CeedInt Q = op->numqpoints, elemsize;
164   int ierr;
165   CeedQFunction qf = op->qf;
166   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
167   CeedScalar *vec_temp;
168 
169   // Setup
170   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
171 
172   // Input Evecs and Restriction
173   for (CeedInt i=0; i<qf->numinputfields; i++) {
174     CeedEvalMode emode = qf->inputfields[i].emode;
175     if (emode == CEED_EVAL_WEIGHT) { // Skip
176     } else {
177       // Zero evec
178       ierr = CeedVectorGetArray(opref->evecs[i], CEED_MEM_HOST, &vec_temp);
179       CeedChk(ierr);
180       for (CeedInt j=0; j<opref->evecs[i]->length; j++)
181         vec_temp[j] = 0.;
182       ierr = CeedVectorRestoreArray(opref->evecs[i], &vec_temp); CeedChk(ierr);
183       // Active
184       if (op->inputfields[i].vec == CEED_VECTOR_ACTIVE) {
185         // Restrict
186         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
187                                         lmode, invec, opref->evecs[i],
188                                         request); CeedChk(ierr);
189         // Get evec
190         ierr = CeedVectorGetArrayRead(opref->evecs[i], CEED_MEM_HOST,
191                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
192       } else {
193         // Passive
194         // Restrict
195         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
196                                         lmode, op->inputfields[i].vec, opref->evecs[i],
197                                         request); CeedChk(ierr);
198         // Get evec
199         ierr = CeedVectorGetArrayRead(opref->evecs[i], CEED_MEM_HOST,
200                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
201       }
202     }
203   }
204 
205   // Output Evecs
206   for (CeedInt i=0; i<qf->numoutputfields; i++) {
207     ierr = CeedVectorGetArray(opref->evecs[i+opref->numein], CEED_MEM_HOST,
208                               &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
209   }
210 
211   // Loop through elements
212   for (CeedInt e=0; e<op->numelements; e++) {
213     // Input basis apply if needed
214     for (CeedInt i=0; i<qf->numinputfields; i++) {
215       // Get elemsize, emode, ncomp
216       elemsize = op->inputfields[i].Erestrict->elemsize;
217       CeedEvalMode emode = qf->inputfields[i].emode;
218       CeedInt ncomp = qf->inputfields[i].ncomp;
219       // Basis action
220       switch(emode) {
221       case CEED_EVAL_NONE:
222         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
223         break;
224       case CEED_EVAL_INTERP:
225         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
226                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
227         CeedChk(ierr);
228         opref->indata[i] = opref->qdata[i];
229         break;
230       case CEED_EVAL_GRAD:
231         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
232                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
233         CeedChk(ierr);
234         opref->indata[i] = opref->qdata[i];
235         break;
236       case CEED_EVAL_WEIGHT:
237         break;  // No action
238       case CEED_EVAL_DIV:
239         break; // Not implimented
240       case CEED_EVAL_CURL:
241         break; // Not implimented
242       }
243     }
244     // Output pointers
245     for (CeedInt i=0; i<qf->numoutputfields; i++) {
246       CeedEvalMode emode = qf->outputfields[i].emode;
247       if (emode == CEED_EVAL_NONE) {
248         CeedInt ncomp = qf->outputfields[i].ncomp;
249         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
250       }
251     }
252     // Q function
253     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
254                               opref->outdata); CeedChk(ierr);
255 
256     // Output basis apply if needed
257     for (CeedInt i=0; i<qf->numoutputfields; i++) {
258       // Get elemsize, emode, ncomp
259       elemsize = op->outputfields[i].Erestrict->elemsize;
260       CeedInt ncomp = qf->outputfields[i].ncomp;
261       CeedEvalMode emode = qf->outputfields[i].emode;
262       // Basis action
263       switch(emode) {
264       case CEED_EVAL_NONE:
265         break; // No action
266       case CEED_EVAL_INTERP:
267         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
268                               CEED_EVAL_INTERP, opref->outdata[i],
269                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
270         break;
271       case CEED_EVAL_GRAD:
272         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
273                               CEED_EVAL_GRAD,
274                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
275         CeedChk(ierr);
276         break;
277       case CEED_EVAL_WEIGHT:
278         break; // Should not occur
279       case CEED_EVAL_DIV:
280         break; // Not implimented
281       case CEED_EVAL_CURL:
282         break; // Not implimented
283       }
284     }
285   }
286 
287   // Output restriction
288   for (CeedInt i=0; i<qf->numoutputfields; i++) {
289     // Active
290     if (op->outputfields[i].vec == CEED_VECTOR_ACTIVE) {
291       // Restore evec
292       ierr = CeedVectorRestoreArray(opref->evecs[i+opref->numein],
293                                     &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
294       // Zero lvec
295       ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
296       for (CeedInt j=0; j<outvec->length; j++)
297         vec_temp[j] = 0.;
298       ierr = CeedVectorRestoreArray(outvec, &vec_temp); CeedChk(ierr);
299       // Restrict
300       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
301                                       lmode, opref->evecs[i+opref->numein], outvec, request); CeedChk(ierr);
302     } else {
303       // Passive
304       // Restore evec
305       ierr = CeedVectorRestoreArray(opref->evecs[i+opref->numein],
306                                     &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
307       // Zero lvec
308       ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST, &vec_temp);
309       CeedChk(ierr);
310       for (CeedInt j=0; j<op->outputfields[i].vec->length; j++)
311         vec_temp[j] = 0.;
312       ierr = CeedVectorRestoreArray(op->outputfields[i].vec, &vec_temp);
313       CeedChk(ierr);
314       // Restrict
315       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
316                                       lmode, opref->evecs[i+opref->numein], op->outputfields[i].vec,
317                                       request); CeedChk(ierr);
318     }
319   }
320 
321   // Restore input arrays
322   for (CeedInt i=0; i<qf->numinputfields; i++) {
323     CeedEvalMode emode = qf->inputfields[i].emode;
324     if (emode == CEED_EVAL_WEIGHT) { // Skip
325     } else {
326       ierr = CeedVectorRestoreArrayRead(opref->evecs[i],
327                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
328     }
329   }
330 
331   return 0;
332 }
333 
334 int CeedOperatorCreate_Ref(CeedOperator op) {
335   CeedOperator_Ref *impl;
336   int ierr;
337 
338   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
339   op->data = impl;
340   op->Destroy = CeedOperatorDestroy_Ref;
341   op->Apply = CeedOperatorApply_Ref;
342   return 0;
343 }
344