xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 135a076eadd049721da56d757e4ca648632e94c7)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Ref(CeedOperator op) {
22   CeedOperator_Ref *impl = op->data;
23   int ierr;
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     if (impl->evecs[i]) {
27       ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
28     }
29   }
30   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
31   ierr = CeedFree(&impl->edata); CeedChk(ierr);
32 
33   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
34     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
35   }
36   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
37   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
38 
39   ierr = CeedFree(&impl->indata); CeedChk(ierr);
40   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
41 
42   ierr = CeedFree(&op->data); CeedChk(ierr);
43   return 0;
44 }
45 
46 /*
47   Setup infields or outfields
48  */
49 static int CeedOperatorSetupFields_Ref(struct CeedQFunctionField qfields[16],
50                                        struct CeedOperatorField ofields[16],
51                                        CeedVector *evecs, CeedScalar **qdata,
52                                        CeedScalar **qdata_alloc, CeedScalar **indata,
53                                        CeedInt starti, CeedInt startq,
54                                        CeedInt numfields, CeedInt Q) {
55   CeedInt dim, ierr, iq=startq, ncomp;
56 
57   // Loop over fields
58   for (CeedInt i=0; i<numfields; i++) {
59     CeedEvalMode emode = qfields[i].emode;
60 
61     if (emode != CEED_EVAL_WEIGHT) {
62       ierr = CeedElemRestrictionCreateVector(ofields[i].Erestrict, NULL, &evecs[i+starti]);
63       CeedChk(ierr);
64     }
65 
66     switch(emode) {
67     case CEED_EVAL_NONE:
68       break; // No action
69     case CEED_EVAL_INTERP:
70       ncomp = qfields[i].ncomp;
71       ierr = CeedMalloc(Q*ncomp, &qdata_alloc[iq]); CeedChk(ierr);
72       qdata[i + starti] = qdata_alloc[iq];
73       iq++;
74       break;
75     case CEED_EVAL_GRAD:
76       ncomp = qfields[i].ncomp;
77       dim = ofields[i].basis->dim;
78       ierr = CeedMalloc(Q*ncomp*dim, &qdata_alloc[iq]); CeedChk(ierr);
79       qdata[i + starti] = qdata_alloc[iq];
80       iq++;
81       break;
82     case CEED_EVAL_WEIGHT: // Only on input fields
83       ierr = CeedMalloc(Q, &qdata_alloc[iq]); CeedChk(ierr);
84       ierr = CeedBasisApply(ofields[iq].basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
85                             NULL, qdata_alloc[iq]); CeedChk(ierr);
86       qdata[i] = qdata_alloc[iq];
87       indata[i] = qdata[i];
88       iq++;
89       break;
90     case CEED_EVAL_DIV:
91       break; // Not implimented
92     case CEED_EVAL_CURL:
93       break; // Not implimented
94     }
95   }
96   return 0;
97 }
98 
99 /*
100   CeedOperator needs to connect all the named fields (be they active or passive)
101   to the named inputs and outputs of its CeedQFunction.
102  */
103 static int CeedOperatorSetup_Ref(CeedOperator op) {
104   if (op->setupdone) return 0;
105   CeedOperator_Ref *opref = op->data;
106   CeedQFunction qf = op->qf;
107   CeedInt Q = op->numqpoints;
108   int ierr;
109 
110   // Count infield and outfield array sizes and evectors
111   opref->numein = qf->numinputfields;
112   for (CeedInt i=0; i<qf->numinputfields; i++) {
113     CeedEvalMode emode = qf->inputfields[i].emode;
114     opref->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) + !!
115                      (emode & CEED_EVAL_WEIGHT);
116   }
117   opref->numeout = qf->numoutputfields;
118   for (CeedInt i=0; i<qf->numoutputfields; i++) {
119     CeedEvalMode emode = qf->outputfields[i].emode;
120     opref->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
121   }
122 
123   // Allocate
124   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->evecs); CeedChk(ierr);
125   ierr = CeedCalloc(opref->numein + opref->numeout, &opref->edata);
126   CeedChk(ierr);
127 
128   ierr = CeedCalloc(opref->numqin + opref->numqout, &opref->qdata_alloc);
129   CeedChk(ierr);
130   ierr = CeedCalloc(qf->numinputfields + qf->numoutputfields, &opref->qdata);
131   CeedChk(ierr);
132 
133   ierr = CeedCalloc(16, &opref->indata); CeedChk(ierr);
134   ierr = CeedCalloc(16, &opref->outdata); CeedChk(ierr);
135 
136   // Set up infield and outfield pointer arrays
137   // Infields
138   ierr = CeedOperatorSetupFields_Ref(qf->inputfields, op->inputfields,
139                                      opref->evecs, opref->qdata, opref->qdata_alloc,
140                                      opref->indata, 0, 0,
141                                      qf->numinputfields, Q); CeedChk(ierr);
142 
143   // Outfields
144   ierr = CeedOperatorSetupFields_Ref(qf->outputfields, op->outputfields,
145                                      opref->evecs, opref->qdata, opref->qdata_alloc,
146                                      opref->indata, qf->numinputfields,
147                                      opref->numqin, qf->numoutputfields, Q); CeedChk(ierr);
148 
149   // Output Qvecs
150   for (CeedInt i=0; i<qf->numoutputfields; i++) {
151     CeedEvalMode emode = qf->outputfields[i].emode;
152     if (emode != CEED_EVAL_NONE) {
153       opref->outdata[i] =  opref->qdata[i + qf->numinputfields];
154     }
155   }
156 
157   op->setupdone = 1;
158 
159   return 0;
160 }
161 
162 static int CeedOperatorApply_Ref(CeedOperator op, CeedVector invec,
163                                  CeedVector outvec, CeedRequest *request) {
164   CeedOperator_Ref *opref = op->data;
165   CeedInt Q = op->numqpoints, elemsize;
166   int ierr;
167   CeedQFunction qf = op->qf;
168   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
169   CeedScalar *vec_temp;
170 
171   // Setup
172   ierr = CeedOperatorSetup_Ref(op); CeedChk(ierr);
173 
174   // Input Evecs and Restriction
175   for (CeedInt i=0; i<qf->numinputfields; i++) {
176     CeedEvalMode emode = qf->inputfields[i].emode;
177     if (emode == CEED_EVAL_WEIGHT) { // Skip
178     } else {
179       // Zero evec
180       ierr = CeedVectorGetArray(opref->evecs[i], CEED_MEM_HOST, &vec_temp);
181       CeedChk(ierr);
182       for (CeedInt j=0; j<opref->evecs[i]->length; j++)
183         vec_temp[j] = 0.;
184       ierr = CeedVectorRestoreArray(opref->evecs[i], &vec_temp); CeedChk(ierr);
185       // Active
186       if (op->inputfields[i].vec == CEED_VECTOR_ACTIVE) {
187         // Restrict
188         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
189                                         lmode, invec, opref->evecs[i],
190                                         request); CeedChk(ierr);
191         // Get evec
192         ierr = CeedVectorGetArrayRead(opref->evecs[i], CEED_MEM_HOST,
193                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
194       } else {
195         // Passive
196         // Restrict
197         ierr = CeedElemRestrictionApply(op->inputfields[i].Erestrict, CEED_NOTRANSPOSE,
198                                         lmode, op->inputfields[i].vec, opref->evecs[i],
199                                         request); CeedChk(ierr);
200         // Get evec
201         ierr = CeedVectorGetArrayRead(opref->evecs[i], CEED_MEM_HOST,
202                                       (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
203       }
204     }
205   }
206 
207   // Output Evecs
208   for (CeedInt i=0; i<qf->numoutputfields; i++) {
209     ierr = CeedVectorGetArray(opref->evecs[i+opref->numein], CEED_MEM_HOST,
210                               &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
211   }
212 
213   // Loop through elements
214   for (CeedInt e=0; e<op->numelements; e++) {
215     // Input basis apply if needed
216     for (CeedInt i=0; i<qf->numinputfields; i++) {
217       // Get elemsize, emode, ncomp
218       elemsize = op->inputfields[i].Erestrict->elemsize;
219       CeedEvalMode emode = qf->inputfields[i].emode;
220       CeedInt ncomp = qf->inputfields[i].ncomp;
221       // Basis action
222       switch(emode) {
223       case CEED_EVAL_NONE:
224         opref->indata[i] = &opref->edata[i][e*Q*ncomp];
225         break;
226       case CEED_EVAL_INTERP:
227         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
228                               CEED_EVAL_INTERP, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
229         CeedChk(ierr);
230         opref->indata[i] = opref->qdata[i];
231         break;
232       case CEED_EVAL_GRAD:
233         ierr = CeedBasisApply(op->inputfields[i].basis, 1, CEED_NOTRANSPOSE,
234                               CEED_EVAL_GRAD, &opref->edata[i][e*elemsize*ncomp], opref->qdata[i]);
235         CeedChk(ierr);
236         opref->indata[i] = opref->qdata[i];
237         break;
238       case CEED_EVAL_WEIGHT:
239         break;  // No action
240       case CEED_EVAL_DIV:
241         break; // Not implimented
242       case CEED_EVAL_CURL:
243         break; // Not implimented
244       }
245     }
246     // Output pointers
247     for (CeedInt i=0; i<qf->numoutputfields; i++) {
248       CeedEvalMode emode = qf->outputfields[i].emode;
249       if (emode == CEED_EVAL_NONE) {
250         CeedInt ncomp = qf->outputfields[i].ncomp;
251         opref->outdata[i] = &opref->edata[i + qf->numinputfields][e*Q*ncomp];
252       }
253     }
254     // Q function
255     ierr = CeedQFunctionApply(op->qf, Q, (const CeedScalar * const*) opref->indata,
256                               opref->outdata); CeedChk(ierr);
257 
258     // Output basis apply if needed
259     for (CeedInt i=0; i<qf->numoutputfields; i++) {
260       // Get elemsize, emode, ncomp
261       elemsize = op->outputfields[i].Erestrict->elemsize;
262       CeedInt ncomp = qf->outputfields[i].ncomp;
263       CeedEvalMode emode = qf->outputfields[i].emode;
264       // Basis action
265       switch(emode) {
266       case CEED_EVAL_NONE:
267         break; // No action
268       case CEED_EVAL_INTERP:
269         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
270                               CEED_EVAL_INTERP, opref->outdata[i],
271                               &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]); CeedChk(ierr);
272         break;
273       case CEED_EVAL_GRAD:
274         ierr = CeedBasisApply(op->outputfields[i].basis, 1, CEED_TRANSPOSE,
275                               CEED_EVAL_GRAD,
276                               opref->outdata[i], &opref->edata[i + qf->numinputfields][e*elemsize*ncomp]);
277         CeedChk(ierr);
278         break;
279       case CEED_EVAL_WEIGHT:
280         break; // Should not occur
281       case CEED_EVAL_DIV:
282         break; // Not implimented
283       case CEED_EVAL_CURL:
284         break; // Not implimented
285       }
286     }
287   }
288 
289   // Output restriction
290   for (CeedInt i=0; i<qf->numoutputfields; i++) {
291     // Active
292     if (op->outputfields[i].vec == CEED_VECTOR_ACTIVE) {
293       // Restore evec
294       ierr = CeedVectorRestoreArray(opref->evecs[i+opref->numein],
295                                     &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
296       // Zero lvec
297       ierr = CeedVectorGetArray(outvec, CEED_MEM_HOST, &vec_temp); CeedChk(ierr);
298       for (CeedInt j=0; j<outvec->length; j++)
299         vec_temp[j] = 0.;
300       ierr = CeedVectorRestoreArray(outvec, &vec_temp); CeedChk(ierr);
301       // Restrict
302       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
303                                       lmode, opref->evecs[i+opref->numein], outvec, request); CeedChk(ierr);
304     } else {
305       // Passive
306       // Restore evec
307       ierr = CeedVectorRestoreArray(opref->evecs[i+opref->numein],
308                                     &opref->edata[i + qf->numinputfields]); CeedChk(ierr);
309       // Zero lvec
310       ierr = CeedVectorGetArray(op->outputfields[i].vec, CEED_MEM_HOST, &vec_temp);
311       CeedChk(ierr);
312       for (CeedInt j=0; j<op->outputfields[i].vec->length; j++)
313         vec_temp[j] = 0.;
314       ierr = CeedVectorRestoreArray(op->outputfields[i].vec, &vec_temp);
315       CeedChk(ierr);
316       // Restrict
317       ierr = CeedElemRestrictionApply(op->outputfields[i].Erestrict, CEED_TRANSPOSE,
318                                       lmode, opref->evecs[i+opref->numein], op->outputfields[i].vec,
319                                       request); CeedChk(ierr);
320     }
321   }
322 
323   // Restore input arrays
324   for (CeedInt i=0; i<qf->numinputfields; i++) {
325     CeedEvalMode emode = qf->inputfields[i].emode;
326     if (emode == CEED_EVAL_WEIGHT) { // Skip
327     } else {
328       ierr = CeedVectorRestoreArrayRead(opref->evecs[i],
329                                         (const CeedScalar **) &opref->edata[i]); CeedChk(ierr);
330     }
331   }
332 
333   return 0;
334 }
335 
336 int CeedOperatorCreate_Ref(CeedOperator op) {
337   CeedOperator_Ref *impl;
338   int ierr;
339 
340   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
341   op->data = impl;
342   op->Destroy = CeedOperatorDestroy_Ref;
343   op->Apply = CeedOperatorApply_Ref;
344   return 0;
345 }
346