xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 418fb8c26cd03fc44256773f44bb9ece8ec63e5f)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
34 
35   for (CeedInt i=0; i<impl->numein; i++) {
36     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
37     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
40   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
41 
42   for (CeedInt i=0; i<impl->numeout; i++) {
43     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
44     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
45   }
46   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
47   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
48 
49   ierr = CeedFree(&impl); CeedChk(ierr);
50   return 0;
51 }
52 
53 /*
54   Setup infields or outfields
55  */
56 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
57     bool inOrOut,
58     CeedElemRestriction *blkrestr,
59     CeedVector *fullevecs, CeedVector *evecs,
60     CeedVector *qvecs, CeedInt starte,
61     CeedInt numfields, CeedInt Q) {
62   CeedInt dim, ierr, ncomp;
63   Ceed ceed;
64   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
65   CeedBasis basis;
66   CeedElemRestriction r;
67   CeedOperatorField *opfields;
68   CeedQFunctionField *qffields;
69   if (inOrOut) {
70     ierr = CeedOperatorGetFields(op, NULL, &opfields);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
73     CeedChk(ierr);
74   } else {
75     ierr = CeedOperatorGetFields(op, &opfields, NULL);
76     CeedChk(ierr);
77     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
78     CeedChk(ierr);
79   }
80   const CeedInt blksize = 8;
81 
82   // Loop over fields
83   for (CeedInt i=0; i<numfields; i++) {
84     CeedEvalMode emode;
85     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
86 
87     if (emode != CEED_EVAL_WEIGHT) {
88       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
89       CeedChk(ierr);
90       CeedElemRestriction_Ref *data;
91       ierr = CeedElemRestrictionGetData(r, (void *)&data);
92       Ceed ceed;
93       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
94       CeedInt nelem, elemsize, ndof, ncomp;
95       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
98       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
99       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
100                                               blksize, ndof, ncomp,
101                                               CEED_MEM_HOST, CEED_COPY_VALUES,
102                                               data->indices, &blkrestr[i+starte]);
103       CeedChk(ierr);
104       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
105                                              &fullevecs[i+starte]);
106       CeedChk(ierr);
107     }
108 
109     switch(emode) {
110     case CEED_EVAL_NONE:
111       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
112       CeedChk(ierr);
113       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
114       break;
115     case CEED_EVAL_INTERP:
116       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
117       CeedChk(ierr);
118       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
119       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
120       break;
121     case CEED_EVAL_GRAD:
122       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
123       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
124       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
125       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
126       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
127       break;
128     case CEED_EVAL_WEIGHT: // Only on input fields
129       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
130       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
131       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
132                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
133 
134       break;
135     case CEED_EVAL_DIV:
136       break; // Not implimented
137     case CEED_EVAL_CURL:
138       break; // Not implimented
139     }
140   }
141   return 0;
142 }
143 
144 /*
145   CeedOperator needs to connect all the named fields (be they active or passive)
146   to the named inputs and outputs of its CeedQFunction.
147  */
148 static int CeedOperatorSetup_Blocked(CeedOperator op) {
149   int ierr;
150   bool setupdone;
151   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
152   if (setupdone) return 0;
153   Ceed ceed;
154   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
155   CeedOperator_Blocked *impl;
156   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
157   CeedQFunction qf;
158   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
159   CeedInt Q, numinputfields, numoutputfields;
160   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
161   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
162   CeedChk(ierr);
163   CeedOperatorField *opinputfields, *opoutputfields;
164   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
165   CeedChk(ierr);
166   CeedQFunctionField *qfinputfields, *qfoutputfields;
167   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
168   CeedChk(ierr);
169 
170   // Allocate
171   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
172   CeedChk(ierr);
173   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
174   CeedChk(ierr);
175   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
176   CeedChk(ierr);
177 
178   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
179   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
180   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
181   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
182   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
183 
184   impl->numein = numinputfields; impl->numeout = numoutputfields;
185 
186   // Set up infield and outfield pointer arrays
187   // Infields
188   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
189                                          impl->evecs, impl->evecsin,
190                                          impl->qvecsin, 0,
191                                          numinputfields, Q);
192   CeedChk(ierr);
193   // Outfields
194   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
195                                          impl->evecs, impl->evecsout,
196                                          impl->qvecsout, numinputfields,
197                                          numoutputfields, Q);
198   CeedChk(ierr);
199 
200   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
201 
202   return 0;
203 }
204 
205 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
206                                      CeedVector outvec, CeedRequest *request) {
207   int ierr;
208   CeedOperator_Blocked *impl;
209   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
210   const CeedInt blksize = 8;
211   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
212   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
213   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
214   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
215   CeedQFunction qf;
216   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
217   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
218   CeedChk(ierr);
219   CeedTransposeMode lmode;
220   CeedOperatorField *opinputfields, *opoutputfields;
221   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
222   CeedChk(ierr);
223   CeedQFunctionField *qfinputfields, *qfoutputfields;
224   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
225   CeedChk(ierr);
226   CeedEvalMode emode;
227   CeedVector vec;
228   CeedBasis basis;
229   CeedElemRestriction Erestrict;
230   uint64_t state;
231 
232   // Setup
233   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
234 
235   // Input Evecs and Restriction
236   for (CeedInt i=0; i<numinputfields; i++) {
237     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
238     CeedChk(ierr);
239     if (emode == CEED_EVAL_WEIGHT) { // Skip
240     } else {
241       // Get input vector
242       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
243       if (vec == CEED_VECTOR_ACTIVE)
244         vec = invec;
245       // Restrict
246       ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
247       if (state != impl->inputstate[i]) {
248         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
249         ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
250                                         lmode, vec, impl->evecs[i],
251                                         request); CeedChk(ierr); CeedChk(ierr);
252         impl->inputstate[i] = state;
253       }
254       // Get evec
255       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
256                                     (const CeedScalar **) &impl->edata[i]);
257       CeedChk(ierr);
258     }
259   }
260 
261   // Output Evecs
262   for (CeedInt i=0; i<numoutputfields; i++) {
263     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
264                               &impl->edata[i + numinputfields]); CeedChk(ierr);
265   }
266 
267   // Loop through elements
268   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
269     // Input basis apply if needed
270     for (CeedInt i=0; i<numinputfields; i++) {
271       // Get elemsize, emode, ncomp
272       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
273       CeedChk(ierr);
274       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
275       CeedChk(ierr);
276       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
277       CeedChk(ierr);
278       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
279       CeedChk(ierr);
280       // Basis action
281       switch(emode) {
282       case CEED_EVAL_NONE:
283         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
284                                   CEED_USE_POINTER,
285                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
286         break;
287       case CEED_EVAL_INTERP:
288         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
289         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
290                                   CEED_USE_POINTER,
291                                   &impl->edata[i][e*elemsize*ncomp]);
292         CeedChk(ierr);
293         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
294                               CEED_EVAL_INTERP, impl->evecsin[i],
295                               impl->qvecsin[i]); CeedChk(ierr);
296         break;
297       case CEED_EVAL_GRAD:
298         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
299         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
300                                   CEED_USE_POINTER,
301                                   &impl->edata[i][e*elemsize*ncomp]);
302         CeedChk(ierr);
303         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
304                               CEED_EVAL_GRAD, impl->evecsin[i],
305                               impl->qvecsin[i]); CeedChk(ierr);
306         break;
307       case CEED_EVAL_WEIGHT:
308         break;  // No action
309       case CEED_EVAL_DIV:
310         break; // Not implimented
311       case CEED_EVAL_CURL:
312         break; // Not implimented
313       }
314     }
315 
316     // Output pointers
317     for (CeedInt i=0; i<numoutputfields; i++) {
318       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
319       CeedChk(ierr);
320       if (emode == CEED_EVAL_NONE) {
321         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
322         CeedChk(ierr);
323         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
324                                   CEED_USE_POINTER,
325                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
326         CeedChk(ierr);
327       }
328     }
329     // Q function
330     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
331     CeedChk(ierr);
332 
333     // Output basis apply if needed
334     for (CeedInt i=0; i<numoutputfields; i++) {
335       // Get elemsize, emode, ncomp
336       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
337       CeedChk(ierr);
338       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
339       CeedChk(ierr);
340       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
341       CeedChk(ierr);
342       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
343       CeedChk(ierr);
344       // Basis action
345       switch(emode) {
346       case CEED_EVAL_NONE:
347         break; // No action
348       case CEED_EVAL_INTERP:
349         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
350         CeedChk(ierr);
351         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
352                                   CEED_USE_POINTER,
353                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
354         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
355                               CEED_EVAL_INTERP, impl->qvecsout[i],
356                               impl->evecsout[i]); CeedChk(ierr);
357         break;
358       case CEED_EVAL_GRAD:
359         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
360         CeedChk(ierr);
361         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
362                                   CEED_USE_POINTER,
363                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
364         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
365                               CEED_EVAL_GRAD, impl->qvecsout[i],
366                               impl->evecsout[i]); CeedChk(ierr);
367         break;
368       case CEED_EVAL_WEIGHT: {
369         Ceed ceed;
370         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
371         return CeedError(ceed, 1,
372                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
373         break; // Should not occur
374       }
375       case CEED_EVAL_DIV:
376         break; // Not implimented
377       case CEED_EVAL_CURL:
378         break; // Not implimented
379       }
380     }
381   }
382 
383   // Zero lvecs
384   for (CeedInt i=0; i<numoutputfields; i++) {
385     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
386     if (vec == CEED_VECTOR_ACTIVE)
387       vec = outvec;
388     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
389   }
390 
391   // Output restriction
392   for (CeedInt i=0; i<numoutputfields; i++) {
393     // Restore evec
394     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
395                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
396     // Get output vector
397     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
398     // Active
399     if (vec == CEED_VECTOR_ACTIVE)
400       vec = outvec;
401     // Restrict
402     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
403     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
404                                     lmode, impl->evecs[i+impl->numein], vec,
405                                     request); CeedChk(ierr);
406 
407   }
408 
409   // Restore input arrays
410   for (CeedInt i=0; i<numinputfields; i++) {
411     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
412     CeedChk(ierr);
413     if (emode == CEED_EVAL_WEIGHT) { // Skip
414     } else {
415       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
416                                         (const CeedScalar **) &impl->edata[i]);
417       CeedChk(ierr);
418     }
419   }
420 
421   return 0;
422 }
423 
424 int CeedOperatorCreate_Blocked(CeedOperator op) {
425   int ierr;
426   Ceed ceed;
427   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
428   CeedOperator_Blocked *impl;
429 
430   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
431   ierr = CeedOperatorSetData(op, (void *)&impl);
432 
433   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
434                                 CeedOperatorApply_Blocked); CeedChk(ierr);
435   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
436                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
437   return 0;
438 }
439