xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision aedaa0e5fd3a3e03ad33ad8a6308ac527f4f900e)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numein; i++) {
35     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
36   }
37   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
38 
39   for (CeedInt i=0; i<impl->numeout; i++) {
40     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
41   }
42   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
43 
44   ierr = CeedFree(&impl); CeedChk(ierr);
45   return 0;
46 }
47 
48 /*
49   Setup infields or outfields
50  */
51 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
52     bool inOrOut,
53     CeedElemRestriction *blkrestr,
54     CeedVector *evecs,
55     CeedVector *qvecs, CeedInt starte,
56     CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, ncomp;
58   Ceed ceed;
59   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
60   CeedBasis basis;
61   CeedElemRestriction r;
62   CeedOperatorField *opfields;
63   CeedQFunctionField *qffields;
64   if (inOrOut) {
65     ierr = CeedOperatorGetFields(op, NULL, &opfields);
66     CeedChk(ierr);
67     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
68     CeedChk(ierr);
69   } else {
70     ierr = CeedOperatorGetFields(op, &opfields, NULL);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
73     CeedChk(ierr);
74   }
75   const CeedInt blksize = 8;
76 
77   // Loop over fields
78   for (CeedInt i=0; i<numfields; i++) {
79     CeedEvalMode emode;
80     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
81 
82     if (emode != CEED_EVAL_WEIGHT) {
83       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
84       CeedChk(ierr);
85       CeedElemRestriction_Ref *data;
86       ierr = CeedElemRestrictionGetData(r, (void *)&data);
87       Ceed ceed;
88       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
89       CeedInt nelem, elemsize, ndof, ncomp;
90       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
91       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
92       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
93       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
94       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
95                                               blksize, ndof, ncomp,
96                                               CEED_MEM_HOST, CEED_COPY_VALUES,
97                                               data->indices, &blkrestr[i+starte]);
98       CeedChk(ierr);
99       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
100                                              &evecs[i+starte]);
101       CeedChk(ierr);
102     }
103 
104     switch(emode) {
105     case CEED_EVAL_NONE:
106       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
107       CeedChk(ierr);
108       ierr = CeedVectorCreate(ceed, Q*ncomp, &qvecs[i]); CeedChk(ierr);
109       break;
110     case CEED_EVAL_INTERP:
111       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
112       CeedChk(ierr);
113       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
114       break;
115     case CEED_EVAL_GRAD:
116       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
117       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
118       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
119       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
120       break;
121     case CEED_EVAL_WEIGHT: // Only on input fields
122       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
123       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
124       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
125                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
126 
127       break;
128     case CEED_EVAL_DIV:
129       break; // Not implimented
130     case CEED_EVAL_CURL:
131       break; // Not implimented
132     }
133   }
134   return 0;
135 }
136 
137 /*
138   CeedOperator needs to connect all the named fields (be they active or passive)
139   to the named inputs and outputs of its CeedQFunction.
140  */
141 static int CeedOperatorSetup_Blocked(CeedOperator op) {
142   int ierr;
143   bool setupdone;
144   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
145   if (setupdone) return 0;
146   Ceed ceed;
147   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
148   CeedOperator_Blocked *impl;
149   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
150   CeedQFunction qf;
151   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
152   CeedInt Q, numinputfields, numoutputfields;
153   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
154   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
155   CeedChk(ierr);
156   CeedOperatorField *opinputfields, *opoutputfields;
157   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
158   CeedChk(ierr);
159   CeedQFunctionField *qfinputfields, *qfoutputfields;
160   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
161   CeedChk(ierr);
162 
163   // Allocate
164   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
165   CeedChk(ierr);
166   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
167   CeedChk(ierr);
168   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
169   CeedChk(ierr);
170 
171   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
172   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
173 
174   impl->numein = numinputfields; impl->numeout = numoutputfields;
175 
176   // Set up infield and outfield pointer arrays
177   // Infields
178   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
179                                          impl->evecs, impl->qvecsin, 0,
180                                          numinputfields, Q);
181   CeedChk(ierr);
182   // Outfields
183   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
184                                          impl->evecs, impl->qvecsout,
185                                          numinputfields, numoutputfields, Q);
186   CeedChk(ierr);
187 
188   // Temporary Vector
189   ierr = CeedVectorCreate(ceed, 0, &impl->tempvec); CeedChk(ierr);
190 
191   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
192 
193   return 0;
194 }
195 
196 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
197                                      CeedVector outvec, CeedRequest *request) {
198   int ierr;
199   CeedOperator_Blocked *impl;
200   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
201   const CeedInt blksize = 8;
202   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
203   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
204   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
205   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
206   CeedQFunction qf;
207   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
208   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
209   CeedChk(ierr);
210   CeedTransposeMode lmode;
211   CeedOperatorField *opinputfields, *opoutputfields;
212   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
213   CeedChk(ierr);
214   CeedQFunctionField *qfinputfields, *qfoutputfields;
215   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
216   CeedChk(ierr);
217   CeedEvalMode emode;
218   CeedVector vec;
219   CeedBasis basis;
220   CeedElemRestriction Erestrict;
221 
222   // Setup
223   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
224 
225   // Input Evecs and Restriction
226   for (CeedInt i=0; i<numinputfields; i++) {
227     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
228     CeedChk(ierr);
229     if (emode == CEED_EVAL_WEIGHT) { // Skip
230     } else {
231       // Get input vector
232       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
233       if (vec == CEED_VECTOR_ACTIVE)
234         vec = invec;
235       // Restrict
236       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
237       ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
238                                       lmode, vec, impl->evecs[i],
239                                       request); CeedChk(ierr); CeedChk(ierr);
240       // Get evec
241       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
242                                     (const CeedScalar **) &impl->edata[i]);
243       CeedChk(ierr);
244     }
245   }
246 
247   // Output Evecs
248   for (CeedInt i=0; i<numoutputfields; i++) {
249     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
250                               &impl->edata[i + numinputfields]); CeedChk(ierr);
251   }
252 
253   // Loop through elements
254   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
255     // Input basis apply if needed
256     for (CeedInt i=0; i<numinputfields; i++) {
257       // Get elemsize, emode, ncomp
258       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
259       CeedChk(ierr);
260       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
261       CeedChk(ierr);
262       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
263       CeedChk(ierr);
264       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
265       CeedChk(ierr);
266       // Basis action
267       switch(emode) {
268       case CEED_EVAL_NONE:
269         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
270                                   CEED_USE_POINTER,
271                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
272         break;
273       case CEED_EVAL_INTERP:
274         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
275         ierr = CeedVectorSetArray(impl->tempvec, CEED_MEM_HOST,
276                                   CEED_USE_POINTER,
277                                   &impl->edata[i][e*elemsize*ncomp]);
278         CeedChk(ierr);
279         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
280                               CEED_EVAL_INTERP, impl->tempvec,
281                               impl->qvecsin[i]); CeedChk(ierr);
282         break;
283       case CEED_EVAL_GRAD:
284         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
285         ierr = CeedVectorSetArray(impl->tempvec, CEED_MEM_HOST,
286                                   CEED_USE_POINTER,
287                                   &impl->edata[i][e*elemsize*ncomp]);
288         CeedChk(ierr);
289         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
290                               CEED_EVAL_GRAD, impl->tempvec,
291                               impl->qvecsin[i]); CeedChk(ierr);
292         break;
293       case CEED_EVAL_WEIGHT:
294         break;  // No action
295       case CEED_EVAL_DIV:
296         break; // Not implimented
297       case CEED_EVAL_CURL:
298         break; // Not implimented
299       }
300     }
301 
302     // Output pointers
303     for (CeedInt i=0; i<numoutputfields; i++) {
304       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
305       CeedChk(ierr);
306       if (emode == CEED_EVAL_NONE) {
307         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
308         CeedChk(ierr);
309         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
310                                   CEED_USE_POINTER,
311                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
312         CeedChk(ierr);
313       }
314     }
315     // Q function
316     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
317     CeedChk(ierr);
318 
319     // Output basis apply if needed
320     for (CeedInt i=0; i<numoutputfields; i++) {
321       // Get elemsize, emode, ncomp
322       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
323       CeedChk(ierr);
324       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
325       CeedChk(ierr);
326       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
327       CeedChk(ierr);
328       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
329       CeedChk(ierr);
330       // Basis action
331       switch(emode) {
332       case CEED_EVAL_NONE:
333         break; // No action
334       case CEED_EVAL_INTERP:
335         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
336         CeedChk(ierr);
337         ierr = CeedVectorSetArray(impl->tempvec, CEED_MEM_HOST,
338                                   CEED_USE_POINTER,
339                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
340         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
341                               CEED_EVAL_INTERP, impl->qvecsout[i],
342                               impl->tempvec); CeedChk(ierr);
343         break;
344       case CEED_EVAL_GRAD:
345         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
346         CeedChk(ierr);
347         ierr = CeedVectorSetArray(impl->tempvec, CEED_MEM_HOST,
348                                   CEED_USE_POINTER,
349                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
350         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
351                               CEED_EVAL_GRAD, impl->qvecsout[i],
352                               impl->tempvec); CeedChk(ierr);
353         break;
354       case CEED_EVAL_WEIGHT: {
355         Ceed ceed;
356         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
357         return CeedError(ceed, 1,
358                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
359         break; // Should not occur
360       }
361       case CEED_EVAL_DIV:
362         break; // Not implimented
363       case CEED_EVAL_CURL:
364         break; // Not implimented
365       }
366     }
367   }
368 
369   // Zero lvecs
370   for (CeedInt i=0; i<numoutputfields; i++) {
371     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
372     if (vec == CEED_VECTOR_ACTIVE)
373       vec = outvec;
374     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
375   }
376 
377   // Output restriction
378   for (CeedInt i=0; i<numoutputfields; i++) {
379     // Restore evec
380     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
381                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
382     // Get output vector
383     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
384     // Active
385     if (vec == CEED_VECTOR_ACTIVE)
386       vec = outvec;
387     // Restrict
388     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
389     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
390                                     lmode, impl->evecs[i+impl->numein], vec,
391                                     request); CeedChk(ierr);
392 
393   }
394 
395   // Restore input arrays
396   for (CeedInt i=0; i<numinputfields; i++) {
397     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
398     CeedChk(ierr);
399     if (emode == CEED_EVAL_WEIGHT) { // Skip
400     } else {
401       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
402                                         (const CeedScalar **) &impl->edata[i]);
403       CeedChk(ierr);
404     }
405   }
406 
407   return 0;
408 }
409 
410 int CeedOperatorCreate_Blocked(CeedOperator op) {
411   int ierr;
412   Ceed ceed;
413   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
414   CeedOperator_Blocked *impl;
415 
416   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
417   ierr = CeedOperatorSetData(op, (void *)&impl);
418 
419   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
420                                 CeedOperatorApply_Blocked); CeedChk(ierr);
421   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
422                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
423   return 0;
424 }
425