xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision fe2413ff2a5f6e0d0808f191532abb277c4b8bc7)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
35     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
36   }
37   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
38   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&impl->indata); CeedChk(ierr);
41   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
42 
43   ierr = CeedFree(&impl); CeedChk(ierr);
44   return 0;
45 }
46 
47 /*
48   Setup infields or outfields
49  */
50 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
51                                        bool inOrOut,
52                                        CeedElemRestriction *blkrestr,
53                                        CeedVector *evecs, CeedScalar **qdata,
54                                        CeedScalar **qdata_alloc, CeedScalar **indata,
55                                        CeedInt starti, CeedInt startq,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, iq=startq, ncomp;
58   CeedBasis basis;
59   CeedElemRestriction r;
60   CeedOperatorField *ofields;
61   CeedQFunctionField *qfields;
62   if (inOrOut) {
63     ierr = CeedOperatorGetFields(op, NULL, &ofields);
64     CeedChk(ierr);
65     ierr = CeedQFunctionGetFields(qf, NULL, &qfields);
66     CeedChk(ierr);
67   } else {
68     ierr = CeedOperatorGetFields(op, &ofields, NULL);
69     CeedChk(ierr);
70     ierr = CeedQFunctionGetFields(qf, &qfields, NULL);
71     CeedChk(ierr);
72   }
73   const CeedInt blksize = 8;
74 
75   // Loop over fields
76   for (CeedInt i=0; i<numfields; i++) {
77     CeedEvalMode emode;
78     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
79 
80     if (emode != CEED_EVAL_WEIGHT) {
81       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &r);
82       CeedChk(ierr);
83       CeedElemRestriction_Ref *data;
84       ierr = CeedElemRestrictionGetData(r, (void *)&data);
85       Ceed ceed;
86       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
87       CeedInt nelem, elemsize, ndof, ncomp;
88       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
89       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
90       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
91       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
92       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
93                                               blksize, ndof, ncomp,
94                                               CEED_MEM_HOST, CEED_COPY_VALUES,
95                                               data->indices, &blkrestr[i+starti]);
96       CeedChk(ierr);
97       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starti], NULL,
98                                              &evecs[i+starti]);
99       CeedChk(ierr);
100     }
101 
102     switch(emode) {
103     case CEED_EVAL_NONE:
104       break; // No action
105     case CEED_EVAL_INTERP:
106       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
107       CeedChk(ierr);
108       ierr = CeedMalloc(Q*ncomp*blksize, &qdata_alloc[iq]); CeedChk(ierr);
109       qdata[i + starti] = qdata_alloc[iq];
110       iq++;
111       break;
112     case CEED_EVAL_GRAD:
113       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
114       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
115       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
116       ierr = CeedMalloc(Q*ncomp*dim*blksize, &qdata_alloc[iq]); CeedChk(ierr);
117       qdata[i + starti] = qdata_alloc[iq];
118       iq++;
119       break;
120     case CEED_EVAL_WEIGHT: // Only on input fields
121       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
122       ierr = CeedMalloc(Q*blksize, &qdata_alloc[iq]); CeedChk(ierr);
123       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
124                             CEED_EVAL_WEIGHT, NULL, qdata_alloc[iq]); CeedChk(ierr);
125       qdata[i] = qdata_alloc[iq];
126       indata[i] = qdata[i];
127       iq++;
128       break;
129     case CEED_EVAL_DIV:
130       break; // Not implimented
131     case CEED_EVAL_CURL:
132       break; // Not implimented
133     }
134   }
135   return 0;
136 }
137 
138 /*
139   CeedOperator needs to connect all the named fields (be they active or passive)
140   to the named inputs and outputs of its CeedQFunction.
141  */
142 static int CeedOperatorSetup_Blocked(CeedOperator op) {
143   int ierr;
144   bool setupdone;
145   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
146   if (setupdone) return 0;
147   CeedOperator_Blocked *impl;
148   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
149   CeedQFunction qf;
150   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
151   CeedInt Q, numinputfields, numoutputfields;
152   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
153   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
154   CeedChk(ierr);
155   CeedOperatorField *opinputfields, *opoutputfields;
156   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
157   CeedChk(ierr);
158   CeedQFunctionField *qfinputfields, *qfoutputfields;
159   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
160   CeedChk(ierr);
161   CeedEvalMode emode;
162 
163   // Count infield and outfield array sizes and evectors
164   impl->numein = numinputfields;
165   for (CeedInt i=0; i<numinputfields; i++) {
166     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
167     CeedChk(ierr);
168     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
169                     !!(emode & CEED_EVAL_WEIGHT);
170   }
171   impl->numeout = numoutputfields;
172   for (CeedInt i=0; i<numoutputfields; i++) {
173     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
174     CeedChk(ierr);
175     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
176   }
177 
178   // Allocate
179   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->blkrestr);
180   CeedChk(ierr);
181   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs);
182   CeedChk(ierr);
183   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
184   CeedChk(ierr);
185 
186   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
187   CeedChk(ierr);
188   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
189   CeedChk(ierr);
190 
191   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
192   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
193   // Set up infield and outfield pointer arrays
194   // Infields
195   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0,
196                                      impl->blkrestr, impl->evecs,
197                                      impl->qdata, impl->qdata_alloc,
198                                      impl->indata, 0,
199                                      0, numinputfields, Q);
200   CeedChk(ierr);
201   // Outfields
202   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1,
203                                      impl->blkrestr, impl->evecs,
204                                      impl->qdata, impl->qdata_alloc,
205                                      impl->indata, numinputfields,
206                                      impl->numqin, numoutputfields, Q);
207   CeedChk(ierr);
208   // Input Qvecs
209   for (CeedInt i=0; i<numinputfields; i++) {
210     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
211     CeedChk(ierr);
212     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
213       impl->indata[i] =  impl->qdata[i];
214   }
215   // Output Qvecs
216   for (CeedInt i=0; i<numoutputfields; i++) {
217     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
218     CeedChk(ierr);
219     if (emode != CEED_EVAL_NONE)
220       impl->outdata[i] =  impl->qdata[i + numinputfields];
221   }
222 
223   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
224 
225   return 0;
226 }
227 
228 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
229                                  CeedVector outvec, CeedRequest *request) {
230   int ierr;
231   CeedOperator_Blocked *impl;
232   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
233   const CeedInt blksize = 8;
234   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
235   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
236   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
237   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
238   CeedQFunction qf;
239   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
240   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
241   CeedChk(ierr);
242   CeedTransposeMode lmode;
243   CeedOperatorField *opinputfields, *opoutputfields;
244   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
245   CeedChk(ierr);
246   CeedQFunctionField *qfinputfields, *qfoutputfields;
247   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
248   CeedChk(ierr);
249   CeedEvalMode emode;
250   CeedVector vec;
251   CeedBasis basis;
252   CeedElemRestriction Erestrict;
253 
254   // Setup
255   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
256 
257   // Input Evecs and Restriction
258   for (CeedInt i=0; i<numinputfields; i++) {
259     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
260     CeedChk(ierr);
261     if (emode == CEED_EVAL_WEIGHT) { // Skip
262     } else {
263       // Get input vector
264       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
265       if (vec == CEED_VECTOR_ACTIVE)
266         vec = invec;
267       // Restrict
268       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
269       ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
270                                       lmode, vec, impl->evecs[i],
271                                       request); CeedChk(ierr); CeedChk(ierr);
272       // Get evec
273       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
274                                     (const CeedScalar **) &impl->edata[i]);
275       CeedChk(ierr);
276     }
277   }
278 
279   // Output Evecs
280   for (CeedInt i=0; i<numoutputfields; i++) {
281     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
282                               &impl->edata[i + numinputfields]); CeedChk(ierr);
283   }
284 
285   // Loop through elements
286   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
287     // Input basis apply if needed
288     for (CeedInt i=0; i<numinputfields; i++) {
289       // Get elemsize, emode, ncomp
290       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
291       CeedChk(ierr);
292       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
293       CeedChk(ierr);
294       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
295       CeedChk(ierr);
296       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
297       CeedChk(ierr);
298       // Basis action
299       switch(emode) {
300       case CEED_EVAL_NONE:
301         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
302         break;
303       case CEED_EVAL_INTERP:
304         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
305         CeedChk(ierr);
306         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
307                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
308                               impl->qdata[i]); CeedChk(ierr);
309         break;
310       case CEED_EVAL_GRAD:
311         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
312         CeedChk(ierr);
313         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
314                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
315                               impl->qdata[i]); CeedChk(ierr);
316         break;
317       case CEED_EVAL_WEIGHT:
318         break;  // No action
319       case CEED_EVAL_DIV:
320         break; // Not implimented
321       case CEED_EVAL_CURL:
322         break; // Not implimented
323       }
324     }
325 
326     // Output pointers
327     for (CeedInt i=0; i<numoutputfields; i++) {
328       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
329       CeedChk(ierr);
330       if (emode == CEED_EVAL_NONE) {
331         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
332         CeedChk(ierr);
333         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
334       }
335     }
336     // Q function
337     ierr = CeedQFunctionApply(qf, Q*blksize,
338                               (const CeedScalar * const*) impl->indata,
339                               impl->outdata); CeedChk(ierr);
340 
341     // Output basis apply if needed
342     for (CeedInt i=0; i<numoutputfields; i++) {
343       // Get elemsize, emode, ncomp
344       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
345       CeedChk(ierr);
346       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
347       CeedChk(ierr);
348       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
349       CeedChk(ierr);
350       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
351       CeedChk(ierr);
352       // Basis action
353       switch(emode) {
354       case CEED_EVAL_NONE:
355         break; // No action
356       case CEED_EVAL_INTERP:
357         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
358         CeedChk(ierr);
359         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
360                               CEED_EVAL_INTERP, impl->outdata[i],
361                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
362         CeedChk(ierr);
363         break;
364       case CEED_EVAL_GRAD:
365         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
366         CeedChk(ierr);
367         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
368                               CEED_EVAL_GRAD,
369                               impl->outdata[i], &impl->edata[i + numinputfields][e*elemsize*ncomp]);
370         CeedChk(ierr);
371         break;
372       case CEED_EVAL_WEIGHT: {
373         Ceed ceed;
374         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
375         return CeedError(ceed, 1,
376                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
377         break; // Should not occur
378       }
379       case CEED_EVAL_DIV:
380         break; // Not implimented
381       case CEED_EVAL_CURL:
382         break; // Not implimented
383       }
384     }
385   }
386 
387   // Zero lvecs
388   for (CeedInt i=0; i<numoutputfields; i++) {
389     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
390     if (vec == CEED_VECTOR_ACTIVE)
391       vec = outvec;
392     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
393     }
394 
395   // Output restriction
396   for (CeedInt i=0; i<numoutputfields; i++) {
397     // Restore evec
398     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
399                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
400     // Get output vector
401     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
402     // Active
403     if (vec == CEED_VECTOR_ACTIVE)
404       vec = outvec;
405     // Restrict
406     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
407     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
408                                       lmode, impl->evecs[i+impl->numein], vec,
409                                       request); CeedChk(ierr);
410 
411   }
412 
413   // Restore input arrays
414   for (CeedInt i=0; i<numinputfields; i++) {
415     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
416     CeedChk(ierr);
417     if (emode == CEED_EVAL_WEIGHT) { // Skip
418     } else {
419       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
420                                         (const CeedScalar **) &impl->edata[i]);
421       CeedChk(ierr);
422     }
423   }
424 
425   return 0;
426 }
427 
428 int CeedOperatorCreate_Blocked(CeedOperator op) {
429   int ierr;
430   Ceed ceed;
431   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
432   CeedOperator_Blocked *impl;
433 
434   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
435   ierr = CeedOperatorSetData(op, (void *)&impl);
436 
437   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
438                                 CeedOperatorApply_Blocked); CeedChk(ierr);
439   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
440                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
441   return 0;
442 }
443