xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 4ce2993fee6e7bc9b86526ee90098d0dc489fc60)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-blocked.h"
20 #include "../ref/ceed-ref.h"
21 
22 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
23   int ierr;
24   CeedOperator_Blocked *impl;
25   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
26 
27   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
28     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
29     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
30   }
31   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
32   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
33   ierr = CeedFree(&impl->edata); CeedChk(ierr);
34 
35   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
36     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
39   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
40 
41   ierr = CeedFree(&impl->indata); CeedChk(ierr);
42   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
43 
44   ierr = CeedFree(&op->data); CeedChk(ierr);
45   return 0;
46 }
47 
48 /*
49   Setup infields or outfields
50  */
51 static int CeedOperatorSetupFields_Blocked(struct CeedQFunctionField qfields[16],
52                                        struct CeedOperatorField ofields[16],
53                                        CeedElemRestriction *blkrestr,
54                                        CeedVector *evecs, CeedScalar **qdata,
55                                        CeedScalar **qdata_alloc, CeedScalar **indata,
56                                        CeedInt starti, CeedInt startq,
57                                        CeedInt numfields, CeedInt Q) {
58   CeedInt dim, ierr, iq=startq, ncomp;
59   const CeedInt blksize = 8;
60 
61   // Loop over fields
62   for (CeedInt i=0; i<numfields; i++) {
63     CeedEvalMode emode = qfields[i].emode;
64 
65     if (emode != CEED_EVAL_WEIGHT) {
66       CeedElemRestriction r = ofields[i].Erestrict;
67       CeedElemRestriction_Ref *data = r->data;
68       Ceed ceed;
69       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
70       CeedInt nelem, elemsize, ndof, ncomp;
71       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
72       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
73       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
74       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
75       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
76                                               blksize, ndof, ncomp,
77                                               CEED_MEM_HOST, CEED_COPY_VALUES,
78                                               data->indices, &blkrestr[i+starti]);
79       CeedChk(ierr);
80       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starti], NULL,
81                                              &evecs[i+starti]);
82       CeedChk(ierr);
83     }
84 
85     switch(emode) {
86     case CEED_EVAL_NONE:
87       break; // No action
88     case CEED_EVAL_INTERP:
89       ncomp = qfields[i].ncomp;
90       ierr = CeedMalloc(Q*ncomp*blksize, &qdata_alloc[iq]); CeedChk(ierr);
91       qdata[i + starti] = qdata_alloc[iq];
92       iq++;
93       break;
94     case CEED_EVAL_GRAD:
95       ncomp = qfields[i].ncomp;
96       ierr = CeedBasisGetDimension(ofields[i].basis, &dim); CeedChk(ierr);
97       ierr = CeedMalloc(Q*ncomp*dim*blksize, &qdata_alloc[iq]); CeedChk(ierr);
98       qdata[i + starti] = qdata_alloc[iq];
99       iq++;
100       break;
101     case CEED_EVAL_WEIGHT: // Only on input fields
102       ierr = CeedMalloc(Q*blksize, &qdata_alloc[iq]); CeedChk(ierr);
103       ierr = CeedBasisApply(ofields[iq].basis, blksize, CEED_NOTRANSPOSE,
104                             CEED_EVAL_WEIGHT, NULL, qdata_alloc[iq]); CeedChk(ierr);
105       qdata[i] = qdata_alloc[iq];
106       indata[i] = qdata[i];
107       iq++;
108       break;
109     case CEED_EVAL_DIV:
110       break; // Not implimented
111     case CEED_EVAL_CURL:
112       break; // Not implimented
113     }
114   }
115   return 0;
116 }
117 
118 /*
119   CeedOperator needs to connect all the named fields (be they active or passive)
120   to the named inputs and outputs of its CeedQFunction.
121  */
122 static int CeedOperatorSetup_Blocked(CeedOperator op) {
123   int ierr;
124   bool setupdone;
125   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
126   if (setupdone) return 0;
127   CeedOperator_Blocked *impl;
128   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
129   CeedQFunction qf;
130   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
131   CeedInt Q, numinputfields, numoutputfields;
132   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
133   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
134   CeedChk(ierr);
135 
136   // Count infield and outfield array sizes and evectors
137   impl->numein = numinputfields;
138   for (CeedInt i=0; i<numinputfields; i++) {
139     CeedEvalMode emode = qf->inputfields[i].emode;
140     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
141                     !!(emode & CEED_EVAL_WEIGHT);
142   }
143   impl->numeout = numoutputfields;
144   for (CeedInt i=0; i<numoutputfields; i++) {
145     CeedEvalMode emode = qf->outputfields[i].emode;
146     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
147   }
148 
149   // Allocate
150   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->blkrestr);
151   CeedChk(ierr);
152   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs);
153   CeedChk(ierr);
154   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
155   CeedChk(ierr);
156 
157   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
158   CeedChk(ierr);
159   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
160   CeedChk(ierr);
161 
162   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
163   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
164   // Set up infield and outfield pointer arrays
165   // Infields
166   ierr = CeedOperatorSetupFields_Blocked(qf->inputfields, op->inputfields,
167                                      impl->blkrestr, impl->evecs,
168                                      impl->qdata, impl->qdata_alloc,
169                                      impl->indata, 0,
170                                      0, numinputfields, Q);
171   CeedChk(ierr);
172   // Outfields
173   ierr = CeedOperatorSetupFields_Blocked(qf->outputfields, op->outputfields,
174                                      impl->blkrestr, impl->evecs,
175                                      impl->qdata, impl->qdata_alloc,
176                                      impl->indata, numinputfields,
177                                      impl->numqin, numoutputfields, Q);
178   CeedChk(ierr);
179   // Input Qvecs
180   for (CeedInt i=0; i<numinputfields; i++) {
181     CeedEvalMode emode = qf->inputfields[i].emode;
182     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
183       impl->indata[i] =  impl->qdata[i];
184   }
185   // Output Qvecs
186   for (CeedInt i=0; i<numoutputfields; i++) {
187     CeedEvalMode emode = qf->outputfields[i].emode;
188     if (emode != CEED_EVAL_NONE)
189       impl->outdata[i] =  impl->qdata[i + numinputfields];
190   }
191 
192   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
193 
194   return 0;
195 }
196 
197 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
198                                  CeedVector outvec, CeedRequest *request) {
199   int ierr;
200   CeedOperator_Blocked *impl;
201   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
202   const CeedInt blksize = 8;
203   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements;
204   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
205   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
206   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
207   CeedQFunction qf;
208   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
209   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
210   CeedChk(ierr);
211   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
212 
213   // Setup
214   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
215 
216   // Input Evecs and Restriction
217   for (CeedInt i=0; i<numinputfields; i++) {
218     CeedEvalMode emode = qf->inputfields[i].emode;
219     if (emode == CEED_EVAL_WEIGHT) { // Skip
220     } else {
221       // Active
222       // Restrict
223       if (op->inputfields[i].vec == CEED_VECTOR_ACTIVE) {
224         ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
225                                         lmode, invec, impl->evecs[i],
226                                         request); CeedChk(ierr); CeedChk(ierr);
227       } else {
228         // Passive
229         // Restrict
230         ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
231                                         lmode, op->inputfields[i].vec, impl->evecs[i],
232                                         request); CeedChk(ierr);
233       }
234       // Get evec
235       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
236                                     (const CeedScalar **) &impl->edata[i]);
237       CeedChk(ierr);
238     }
239   }
240 
241   // Output Evecs
242   for (CeedInt i=0; i<numoutputfields; i++) {
243     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
244                               &impl->edata[i + numinputfields]); CeedChk(ierr);
245   }
246 
247   // Loop through elements
248   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
249     // Input basis apply if needed
250     for (CeedInt i=0; i<numinputfields; i++) {
251       // Get elemsize, emode, ncomp
252       ierr = CeedElemRestrictionGetElementSize(
253                op->inputfields[i].Erestrict, &elemsize); CeedChk(ierr);
254       CeedEvalMode emode = qf->inputfields[i].emode;
255       CeedInt ncomp = qf->inputfields[i].ncomp;
256       // Basis action
257       switch(emode) {
258       case CEED_EVAL_NONE:
259         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
260         break;
261       case CEED_EVAL_INTERP:
262         ierr = CeedBasisApply(op->inputfields[i].basis, blksize, CEED_NOTRANSPOSE,
263                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp], impl->qdata[i]);
264         CeedChk(ierr);
265         break;
266       case CEED_EVAL_GRAD:
267         ierr = CeedBasisApply(op->inputfields[i].basis, blksize, CEED_NOTRANSPOSE,
268                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp], impl->qdata[i]);
269         CeedChk(ierr);
270         break;
271       case CEED_EVAL_WEIGHT:
272         break;  // No action
273       case CEED_EVAL_DIV:
274         break; // Not implimented
275       case CEED_EVAL_CURL:
276         break; // Not implimented
277       }
278     }
279 
280     // Output pointers
281     for (CeedInt i=0; i<numoutputfields; i++) {
282       CeedEvalMode emode = qf->outputfields[i].emode;
283       if (emode == CEED_EVAL_NONE) {
284         CeedInt ncomp = qf->outputfields[i].ncomp;
285         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
286       }
287     }
288     // Q function
289     ierr = CeedQFunctionApply(qf, Q*blksize,
290                               (const CeedScalar * const*) impl->indata,
291                               impl->outdata); CeedChk(ierr);
292 
293     // Output basis apply if needed
294     for (CeedInt i=0; i<numoutputfields; i++) {
295       // Get elemsize, emode, ncomp
296       ierr = CeedElemRestrictionGetElementSize(
297                op->outputfields[i].Erestrict, &elemsize); CeedChk(ierr);
298       CeedInt ncomp = qf->outputfields[i].ncomp;
299       CeedEvalMode emode = qf->outputfields[i].emode;
300       // Basis action
301       switch(emode) {
302       case CEED_EVAL_NONE:
303         break; // No action
304       case CEED_EVAL_INTERP:
305         ierr = CeedBasisApply(op->outputfields[i].basis, blksize, CEED_TRANSPOSE,
306                               CEED_EVAL_INTERP, impl->outdata[i],
307                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
308         CeedChk(ierr);
309         break;
310       case CEED_EVAL_GRAD:
311         ierr = CeedBasisApply(op->outputfields[i].basis, blksize, CEED_TRANSPOSE,
312                               CEED_EVAL_GRAD,
313                               impl->outdata[i], &impl->edata[i + numinputfields][e*elemsize*ncomp]);
314         CeedChk(ierr);
315         break;
316       case CEED_EVAL_WEIGHT: {
317         Ceed ceed;
318         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
319         return CeedError(ceed, 1,
320                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
321         break; // Should not occur
322       }
323       case CEED_EVAL_DIV:
324         break; // Not implimented
325       case CEED_EVAL_CURL:
326         break; // Not implimented
327       }
328     }
329   }
330 
331   // Zero lvecs
332   ierr = CeedVectorSetValue(outvec, 0.0); CeedChk(ierr);
333   for (CeedInt i=0; i<numoutputfields; i++)
334     if (op->outputfields[i].vec != CEED_VECTOR_ACTIVE) {
335       ierr = CeedVectorSetValue(op->outputfields[i].vec, 0.0); CeedChk(ierr);
336     }
337 
338   // Output restriction
339   for (CeedInt i=0; i<numoutputfields; i++) {
340     // Restore evec
341     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
342                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
343     // Active
344     if (op->outputfields[i].vec == CEED_VECTOR_ACTIVE) {
345       // Restrict
346       ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
347                                       lmode, impl->evecs[i+impl->numein], outvec, request); CeedChk(ierr);
348     } else {
349       // Passive
350       // Restrict
351       ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
352                                       lmode, impl->evecs[i+impl->numein], op->outputfields[i].vec,
353                                       request); CeedChk(ierr);
354     }
355   }
356 
357   // Restore input arrays
358   for (CeedInt i=0; i<numinputfields; i++) {
359     CeedEvalMode emode = qf->inputfields[i].emode;
360     if (emode == CEED_EVAL_WEIGHT) { // Skip
361     } else {
362       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
363                                         (const CeedScalar **) &impl->edata[i]); CeedChk(ierr);
364     }
365   }
366 
367   return 0;
368 }
369 
370 int CeedOperatorCreate_Blocked(CeedOperator op) {
371   int ierr;
372   CeedOperator_Blocked *impl;
373 
374   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
375   op->data = impl;
376   op->Destroy = CeedOperatorDestroy_Blocked;
377   op->Apply = CeedOperatorApply_Blocked;
378   return 0;
379 }
380