xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 0219ea01e2c00bd70a330a05b50ef0218d6ddcb0)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-blocked.h"
18 #include "../ref/ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
21   int ierr;
22   CeedOperator_Blocked *impl;
23   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
27     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
28   }
29   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
30   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
31   ierr = CeedFree(&impl->edata); CeedChk(ierr);
32   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numein; i++) {
35     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
36     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
39   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
40 
41   for (CeedInt i=0; i<impl->numeout; i++) {
42     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
43     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
44   }
45   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
46   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
47 
48   ierr = CeedFree(&impl); CeedChk(ierr);
49   return 0;
50 }
51 
52 /*
53   Setup infields or outfields
54  */
55 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf,
56     CeedOperator op, bool inOrOut,
57     CeedElemRestriction *blkrestr,
58     CeedVector *fullevecs, CeedVector *evecs,
59     CeedVector *qvecs, CeedInt starte,
60     CeedInt numfields, CeedInt Q) {
61   CeedInt dim, ierr, ncomp, size, P;
62   Ceed ceed;
63   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
64   CeedBasis basis;
65   CeedElemRestriction r;
66   CeedOperatorField *opfields;
67   CeedQFunctionField *qffields;
68   if (inOrOut) {
69     ierr = CeedOperatorGetFields(op, NULL, &opfields);
70     CeedChk(ierr);
71     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
72     CeedChk(ierr);
73   } else {
74     ierr = CeedOperatorGetFields(op, &opfields, NULL);
75     CeedChk(ierr);
76     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
77     CeedChk(ierr);
78   }
79   const CeedInt blksize = 8;
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data); CeedChk(ierr);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, nnodes;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumNodes(r, &nnodes); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, nnodes, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
111       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
112       break;
113     case CEED_EVAL_INTERP:
114       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
115       ierr = CeedElemRestrictionGetElementSize(r, &P);
116       CeedChk(ierr);
117       ierr = CeedVectorCreate(ceed, P*size*blksize, &evecs[i]); CeedChk(ierr);
118       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
119       break;
120     case CEED_EVAL_GRAD:
121       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
122       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
123       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
124       ierr = CeedElemRestrictionGetElementSize(r, &P);
125       CeedChk(ierr);
126       ierr = CeedVectorCreate(ceed, P*size/dim*blksize, &evecs[i]); CeedChk(ierr);
127       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
128       break;
129     case CEED_EVAL_WEIGHT: // Only on input fields
130       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
131       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
132       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
133                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
134 
135       break;
136     case CEED_EVAL_DIV:
137       break; // Not implemented
138     case CEED_EVAL_CURL:
139       break; // Not implemented
140     }
141   }
142   return 0;
143 }
144 
145 /*
146   CeedOperator needs to connect all the named fields (be they active or passive)
147   to the named inputs and outputs of its CeedQFunction.
148  */
149 static int CeedOperatorSetup_Blocked(CeedOperator op) {
150   int ierr;
151   bool setupdone;
152   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
153   if (setupdone) return 0;
154   Ceed ceed;
155   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
156   CeedOperator_Blocked *impl;
157   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
158   CeedQFunction qf;
159   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
160   CeedInt Q, numinputfields, numoutputfields;
161   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
162   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
163   CeedChk(ierr);
164   CeedOperatorField *opinputfields, *opoutputfields;
165   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
166   CeedChk(ierr);
167   CeedQFunctionField *qfinputfields, *qfoutputfields;
168   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
169   CeedChk(ierr);
170 
171   // Allocate
172   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
173   CeedChk(ierr);
174   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
175   CeedChk(ierr);
176   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
177   CeedChk(ierr);
178 
179   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
180   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
181   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
182   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
183   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
184 
185   impl->numein = numinputfields; impl->numeout = numoutputfields;
186 
187   // Set up infield and outfield pointer arrays
188   // Infields
189   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
190                                          impl->evecs, impl->evecsin,
191                                          impl->qvecsin, 0,
192                                          numinputfields, Q);
193   CeedChk(ierr);
194   // Outfields
195   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
196                                          impl->evecs, impl->evecsout,
197                                          impl->qvecsout, numinputfields,
198                                          numoutputfields, Q);
199   CeedChk(ierr);
200 
201   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
202 
203   return 0;
204 }
205 
206 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
207                                      CeedVector outvec,
208                                      CeedRequest *request) {
209   int ierr;
210   CeedOperator_Blocked *impl;
211   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
212   const CeedInt blksize = 8;
213   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, size, dim;
214   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
215   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
216   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
217   CeedQFunction qf;
218   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
219   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
220   CeedChk(ierr);
221   CeedTransposeMode lmode;
222   CeedOperatorField *opinputfields, *opoutputfields;
223   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
224   CeedChk(ierr);
225   CeedQFunctionField *qfinputfields, *qfoutputfields;
226   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
227   CeedChk(ierr);
228   CeedEvalMode emode;
229   CeedVector vec;
230   CeedBasis basis;
231   CeedElemRestriction Erestrict;
232   uint64_t state;
233 
234   // Setup
235   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
236 
237   // Input Evecs and Restriction
238   for (CeedInt i=0; i<numinputfields; i++) {
239     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
240     CeedChk(ierr);
241     if (emode == CEED_EVAL_WEIGHT) { // Skip
242     } else {
243       // Get input vector
244       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
245       if (vec == CEED_VECTOR_ACTIVE)
246         vec = invec;
247       // Restrict
248       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
249       if (state != impl->inputstate[i] || vec == invec) {
250         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
251         ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
252                                         lmode, vec, impl->evecs[i],
253                                         request); CeedChk(ierr); CeedChk(ierr);
254         impl->inputstate[i] = state;
255       }
256       // Get evec
257       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
258                                     (const CeedScalar **) &impl->edata[i]);
259       CeedChk(ierr);
260     }
261   }
262 
263   // Output Evecs
264   for (CeedInt i=0; i<numoutputfields; i++) {
265     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
266                               &impl->edata[i + numinputfields]); CeedChk(ierr);
267   }
268 
269   // Loop through elements
270   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
271     // Input basis apply if needed
272     for (CeedInt i=0; i<numinputfields; i++) {
273       // Get elemsize, emode, size
274       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
275       CeedChk(ierr);
276       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
277       CeedChk(ierr);
278       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
279       CeedChk(ierr);
280       ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChk(ierr);
281       // Basis action
282       switch(emode) {
283       case CEED_EVAL_NONE:
284         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
285                                   CEED_USE_POINTER,
286                                   &impl->edata[i][e*Q*size]); CeedChk(ierr);
287         break;
288       case CEED_EVAL_INTERP:
289         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
290         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
291                                   CEED_USE_POINTER,
292                                   &impl->edata[i][e*elemsize*size]);
293         CeedChk(ierr);
294         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
295                               CEED_EVAL_INTERP, impl->evecsin[i],
296                               impl->qvecsin[i]); CeedChk(ierr);
297         break;
298       case CEED_EVAL_GRAD:
299         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
300         ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
301         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
302                                   CEED_USE_POINTER,
303                                   &impl->edata[i][e*elemsize*size/dim]);
304         CeedChk(ierr);
305         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
306                               CEED_EVAL_GRAD, impl->evecsin[i],
307                               impl->qvecsin[i]); CeedChk(ierr);
308         break;
309       case CEED_EVAL_WEIGHT:
310         break;  // No action
311       case CEED_EVAL_DIV:
312         break; // Not implemented
313       case CEED_EVAL_CURL:
314         break; // Not implemented
315       }
316     }
317 
318     // Output pointers
319     for (CeedInt i=0; i<numoutputfields; i++) {
320       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
321       CeedChk(ierr);
322       if (emode == CEED_EVAL_NONE) {
323         ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size);
324         CeedChk(ierr);
325         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
326                                   CEED_USE_POINTER,
327                                   &impl->edata[i + numinputfields][e*Q*size]);
328         CeedChk(ierr);
329       }
330     }
331     // Q function
332     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
333     CeedChk(ierr);
334 
335     // Output basis apply if needed
336     for (CeedInt i=0; i<numoutputfields; i++) {
337       // Get elemsize, emode, size
338       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
339       CeedChk(ierr);
340       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
341       CeedChk(ierr);
342       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
343       CeedChk(ierr);
344       ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size); CeedChk(ierr);
345       // Basis action
346       switch(emode) {
347       case CEED_EVAL_NONE:
348         break; // No action
349       case CEED_EVAL_INTERP:
350         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
351         CeedChk(ierr);
352         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
353                                   CEED_USE_POINTER,
354                                   &impl->edata[i + numinputfields][e*elemsize*size]);
355         CeedChk(ierr);
356         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
357                               CEED_EVAL_INTERP, impl->qvecsout[i],
358                               impl->evecsout[i]); CeedChk(ierr);
359         break;
360       case CEED_EVAL_GRAD:
361         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
362         CeedChk(ierr);
363         ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
364         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
365                                   CEED_USE_POINTER,
366                                   &impl->edata[i + numinputfields][e*elemsize*size/dim]);
367         CeedChk(ierr);
368         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
369                               CEED_EVAL_GRAD, impl->qvecsout[i],
370                               impl->evecsout[i]); CeedChk(ierr);
371         break;
372       case CEED_EVAL_WEIGHT: {
373         Ceed ceed;
374         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
375         return CeedError(ceed, 1,
376                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
377         break; // Should not occur
378       }
379       case CEED_EVAL_DIV:
380         break; // Not implemented
381       case CEED_EVAL_CURL:
382         break; // Not implemented
383       }
384     }
385   }
386 
387   // Zero lvecs
388   for (CeedInt i=0; i<numoutputfields; i++) {
389     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
390     if (vec == CEED_VECTOR_ACTIVE) {
391       if (!impl->add) {
392         vec = outvec;
393         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
394       }
395     } else {
396       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
397     }
398   }
399   impl->add = false;
400 
401   // Output restriction
402   for (CeedInt i=0; i<numoutputfields; i++) {
403     // Restore evec
404     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
405                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
406     // Get output vector
407     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
408     // Active
409     if (vec == CEED_VECTOR_ACTIVE)
410       vec = outvec;
411     // Restrict
412     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
413     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
414                                     lmode, impl->evecs[i+impl->numein], vec,
415                                     request); CeedChk(ierr);
416 
417   }
418 
419   // Restore input arrays
420   for (CeedInt i=0; i<numinputfields; i++) {
421     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
422     CeedChk(ierr);
423     if (emode == CEED_EVAL_WEIGHT) { // Skip
424     } else {
425       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
426                                         (const CeedScalar **) &impl->edata[i]);
427       CeedChk(ierr);
428     }
429   }
430 
431   return 0;
432 }
433 
434 int CeedOperatorCreate_Blocked(CeedOperator op) {
435   int ierr;
436   Ceed ceed;
437   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
438   CeedOperator_Blocked *impl;
439 
440   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
441   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
442 
443   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
444                                 CeedOperatorApply_Blocked); CeedChk(ierr);
445   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
446                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
447   return 0;
448 }
449