xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 9ad453579b4f8e35becf925ceedf519a25efe885)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-blocked.h"
18 #include "../ref/ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
21   int ierr;
22   CeedOperator_Blocked *impl;
23   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
27     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
28   }
29   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
30   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
31   ierr = CeedFree(&impl->edata); CeedChk(ierr);
32   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numein; i++) {
35     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
36     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
39   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
40 
41   for (CeedInt i=0; i<impl->numeout; i++) {
42     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
43     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
44   }
45   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
46   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
47 
48   ierr = CeedFree(&impl); CeedChk(ierr);
49   return 0;
50 }
51 
52 /*
53   Setup infields or outfields
54  */
55 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
56     bool inOrOut,
57     CeedElemRestriction *blkrestr,
58     CeedVector *fullevecs, CeedVector *evecs,
59     CeedVector *qvecs, CeedInt starte,
60     CeedInt numfields, CeedInt Q) {
61   CeedInt dim, ierr, ncomp, P;
62   Ceed ceed;
63   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
64   CeedBasis basis;
65   CeedElemRestriction r;
66   CeedOperatorField *opfields;
67   CeedQFunctionField *qffields;
68   if (inOrOut) {
69     ierr = CeedOperatorGetFields(op, NULL, &opfields);
70     CeedChk(ierr);
71     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
72     CeedChk(ierr);
73   } else {
74     ierr = CeedOperatorGetFields(op, &opfields, NULL);
75     CeedChk(ierr);
76     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
77     CeedChk(ierr);
78   }
79   const CeedInt blksize = 8;
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data); CeedChk(ierr);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, ndof;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, ndof, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
111       CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
113       break;
114     case CEED_EVAL_INTERP:
115       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
116       CeedChk(ierr);
117       ierr = CeedElemRestrictionGetElementSize(r, &P);
118       CeedChk(ierr);
119       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
120       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
121       break;
122     case CEED_EVAL_GRAD:
123       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
124       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
125       CeedChk(ierr);
126       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
127       ierr = CeedElemRestrictionGetElementSize(r, &P);
128       CeedChk(ierr);
129       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
130       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
131       break;
132     case CEED_EVAL_WEIGHT: // Only on input fields
133       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
134       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
135       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
136                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
137 
138       break;
139     case CEED_EVAL_DIV:
140       break; // Not implimented
141     case CEED_EVAL_CURL:
142       break; // Not implimented
143     }
144   }
145   return 0;
146 }
147 
148 /*
149   CeedOperator needs to connect all the named fields (be they active or passive)
150   to the named inputs and outputs of its CeedQFunction.
151  */
152 static int CeedOperatorSetup_Blocked(CeedOperator op) {
153   int ierr;
154   bool setupdone;
155   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
156   if (setupdone) return 0;
157   Ceed ceed;
158   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
159   CeedOperator_Blocked *impl;
160   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
161   CeedQFunction qf;
162   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
163   CeedInt Q, numinputfields, numoutputfields;
164   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
165   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
166   CeedChk(ierr);
167   CeedOperatorField *opinputfields, *opoutputfields;
168   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
169   CeedChk(ierr);
170   CeedQFunctionField *qfinputfields, *qfoutputfields;
171   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
172   CeedChk(ierr);
173 
174   // Allocate
175   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
176   CeedChk(ierr);
177   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
178   CeedChk(ierr);
179   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
180   CeedChk(ierr);
181 
182   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
183   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
184   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
185   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
186   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
187 
188   impl->numein = numinputfields; impl->numeout = numoutputfields;
189 
190   // Set up infield and outfield pointer arrays
191   // Infields
192   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
193                                          impl->evecs, impl->evecsin,
194                                          impl->qvecsin, 0,
195                                          numinputfields, Q);
196   CeedChk(ierr);
197   // Outfields
198   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
199                                          impl->evecs, impl->evecsout,
200                                          impl->qvecsout, numinputfields,
201                                          numoutputfields, Q);
202   CeedChk(ierr);
203 
204   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
205 
206   return 0;
207 }
208 
209 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
210                                      CeedVector outvec, CeedRequest *request) {
211   int ierr;
212   CeedOperator_Blocked *impl;
213   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
214   const CeedInt blksize = 8;
215   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
216   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
217   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
218   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
219   CeedQFunction qf;
220   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
221   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
222   CeedChk(ierr);
223   CeedTransposeMode lmode;
224   CeedOperatorField *opinputfields, *opoutputfields;
225   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
226   CeedChk(ierr);
227   CeedQFunctionField *qfinputfields, *qfoutputfields;
228   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
229   CeedChk(ierr);
230   CeedEvalMode emode;
231   CeedVector vec;
232   CeedBasis basis;
233   CeedElemRestriction Erestrict;
234   uint64_t state;
235 
236   // Setup
237   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
238 
239   // Input Evecs and Restriction
240   for (CeedInt i=0; i<numinputfields; i++) {
241     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
242     CeedChk(ierr);
243     if (emode == CEED_EVAL_WEIGHT) { // Skip
244     } else {
245       // Get input vector
246       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
247       if (vec == CEED_VECTOR_ACTIVE)
248         vec = invec;
249       // Restrict
250       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
251       if (state != impl->inputstate[i] || vec == invec) {
252         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
253         ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
254                                         lmode, vec, impl->evecs[i],
255                                         request); CeedChk(ierr); CeedChk(ierr);
256         impl->inputstate[i] = state;
257       }
258       // Get evec
259       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
260                                     (const CeedScalar **) &impl->edata[i]);
261       CeedChk(ierr);
262     }
263   }
264 
265   // Output Evecs
266   for (CeedInt i=0; i<numoutputfields; i++) {
267     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
268                               &impl->edata[i + numinputfields]); CeedChk(ierr);
269   }
270 
271   // Loop through elements
272   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
273     // Input basis apply if needed
274     for (CeedInt i=0; i<numinputfields; i++) {
275       // Get elemsize, emode, ncomp
276       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
277       CeedChk(ierr);
278       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
279       CeedChk(ierr);
280       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
281       CeedChk(ierr);
282       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
283       CeedChk(ierr);
284       // Basis action
285       switch(emode) {
286       case CEED_EVAL_NONE:
287         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
288                                   CEED_USE_POINTER,
289                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
290         break;
291       case CEED_EVAL_INTERP:
292         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
293         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
294                                   CEED_USE_POINTER,
295                                   &impl->edata[i][e*elemsize*ncomp]);
296         CeedChk(ierr);
297         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
298                               CEED_EVAL_INTERP, impl->evecsin[i],
299                               impl->qvecsin[i]); CeedChk(ierr);
300         break;
301       case CEED_EVAL_GRAD:
302         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
303         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
304                                   CEED_USE_POINTER,
305                                   &impl->edata[i][e*elemsize*ncomp]);
306         CeedChk(ierr);
307         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
308                               CEED_EVAL_GRAD, impl->evecsin[i],
309                               impl->qvecsin[i]); CeedChk(ierr);
310         break;
311       case CEED_EVAL_WEIGHT:
312         break;  // No action
313       case CEED_EVAL_DIV:
314         break; // Not implimented
315       case CEED_EVAL_CURL:
316         break; // Not implimented
317       }
318     }
319 
320     // Output pointers
321     for (CeedInt i=0; i<numoutputfields; i++) {
322       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
323       CeedChk(ierr);
324       if (emode == CEED_EVAL_NONE) {
325         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
326         CeedChk(ierr);
327         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
328                                   CEED_USE_POINTER,
329                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
330         CeedChk(ierr);
331       }
332     }
333     // Q function
334     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
335     CeedChk(ierr);
336 
337     // Output basis apply if needed
338     for (CeedInt i=0; i<numoutputfields; i++) {
339       // Get elemsize, emode, ncomp
340       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
341       CeedChk(ierr);
342       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
343       CeedChk(ierr);
344       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
345       CeedChk(ierr);
346       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
347       CeedChk(ierr);
348       // Basis action
349       switch(emode) {
350       case CEED_EVAL_NONE:
351         break; // No action
352       case CEED_EVAL_INTERP:
353         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
354         CeedChk(ierr);
355         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
356                                   CEED_USE_POINTER,
357                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
358         CeedChk(ierr);
359         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
360                               CEED_EVAL_INTERP, impl->qvecsout[i],
361                               impl->evecsout[i]); CeedChk(ierr);
362         break;
363       case CEED_EVAL_GRAD:
364         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
365         CeedChk(ierr);
366         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
367                                   CEED_USE_POINTER,
368                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
369         CeedChk(ierr);
370         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
371                               CEED_EVAL_GRAD, impl->qvecsout[i],
372                               impl->evecsout[i]); CeedChk(ierr);
373         break;
374       case CEED_EVAL_WEIGHT: {
375         Ceed ceed;
376         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
377         return CeedError(ceed, 1,
378                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
379         break; // Should not occur
380       }
381       case CEED_EVAL_DIV:
382         break; // Not implimented
383       case CEED_EVAL_CURL:
384         break; // Not implimented
385       }
386     }
387   }
388 
389   // Zero lvecs
390   for (CeedInt i=0; i<numoutputfields; i++) {
391     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
392     if (vec == CEED_VECTOR_ACTIVE) {
393       if (!impl->add) {
394         vec = outvec;
395         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
396       }
397     } else {
398       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
399     }
400   }
401   impl->add = false;
402 
403   // Output restriction
404   for (CeedInt i=0; i<numoutputfields; i++) {
405     // Restore evec
406     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
407                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
408     // Get output vector
409     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
410     // Active
411     if (vec == CEED_VECTOR_ACTIVE)
412       vec = outvec;
413     // Restrict
414     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
415     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
416                                     lmode, impl->evecs[i+impl->numein], vec,
417                                     request); CeedChk(ierr);
418 
419   }
420 
421   // Restore input arrays
422   for (CeedInt i=0; i<numinputfields; i++) {
423     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
424     CeedChk(ierr);
425     if (emode == CEED_EVAL_WEIGHT) { // Skip
426     } else {
427       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
428                                         (const CeedScalar **) &impl->edata[i]);
429       CeedChk(ierr);
430     }
431   }
432 
433   return 0;
434 }
435 
436 int CeedOperatorCreate_Blocked(CeedOperator op) {
437   int ierr;
438   Ceed ceed;
439   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
440   CeedOperator_Blocked *impl;
441 
442   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
443   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
444 
445   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
446                                 CeedOperatorApply_Blocked); CeedChk(ierr);
447   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
448                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
449   return 0;
450 }
451