xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 91703d3f6e6cf8ee4d2bfaa1ca49093a23af4439)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numein; i++) {
35     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
36     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
39   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
40 
41   for (CeedInt i=0; i<impl->numeout; i++) {
42     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
43     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
44   }
45   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
46   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
47 
48   ierr = CeedFree(&impl); CeedChk(ierr);
49   return 0;
50 }
51 
52 /*
53   Setup infields or outfields
54  */
55 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
56     bool inOrOut,
57     CeedElemRestriction *blkrestr,
58     CeedVector *fullevecs, CeedVector *evecs,
59     CeedVector *qvecs, CeedInt starte,
60     CeedInt numfields, CeedInt Q) {
61   CeedInt dim, ierr, ncomp;
62   Ceed ceed;
63   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
64   CeedBasis basis;
65   CeedElemRestriction r;
66   CeedOperatorField *opfields;
67   CeedQFunctionField *qffields;
68   if (inOrOut) {
69     ierr = CeedOperatorGetFields(op, NULL, &opfields);
70     CeedChk(ierr);
71     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
72     CeedChk(ierr);
73   } else {
74     ierr = CeedOperatorGetFields(op, &opfields, NULL);
75     CeedChk(ierr);
76     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
77     CeedChk(ierr);
78   }
79   const CeedInt blksize = 8;
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, ndof, ncomp;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, ndof, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
111       CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
113       break;
114     case CEED_EVAL_INTERP:
115       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
116       CeedChk(ierr);
117       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
118       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
119       break;
120     case CEED_EVAL_GRAD:
121       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
122       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
123       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
124       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
125       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
126       break;
127     case CEED_EVAL_WEIGHT: // Only on input fields
128       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
129       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
130       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
131                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
132 
133       break;
134     case CEED_EVAL_DIV:
135       break; // Not implimented
136     case CEED_EVAL_CURL:
137       break; // Not implimented
138     }
139   }
140   return 0;
141 }
142 
143 /*
144   CeedOperator needs to connect all the named fields (be they active or passive)
145   to the named inputs and outputs of its CeedQFunction.
146  */
147 static int CeedOperatorSetup_Blocked(CeedOperator op) {
148   int ierr;
149   bool setupdone;
150   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
151   if (setupdone) return 0;
152   Ceed ceed;
153   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
154   CeedOperator_Blocked *impl;
155   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
156   CeedQFunction qf;
157   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
158   CeedInt Q, numinputfields, numoutputfields;
159   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
160   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
161   CeedChk(ierr);
162   CeedOperatorField *opinputfields, *opoutputfields;
163   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
164   CeedChk(ierr);
165   CeedQFunctionField *qfinputfields, *qfoutputfields;
166   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
167   CeedChk(ierr);
168 
169   // Allocate
170   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
171   CeedChk(ierr);
172   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
173   CeedChk(ierr);
174   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
175   CeedChk(ierr);
176 
177   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
178   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
179   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
180   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
181 
182   impl->numein = numinputfields; impl->numeout = numoutputfields;
183 
184   // Set up infield and outfield pointer arrays
185   // Infields
186   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
187                                          impl->evecs, impl->evecsin,
188                                          impl->qvecsin, 0,
189                                          numinputfields, Q);
190   CeedChk(ierr);
191   // Outfields
192   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
193                                          impl->evecs, impl->evecsout,
194                                          impl->qvecsout, numinputfields,
195                                          numoutputfields, Q);
196   CeedChk(ierr);
197 
198   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
199 
200   return 0;
201 }
202 
203 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
204                                      CeedVector outvec, CeedRequest *request) {
205   int ierr;
206   CeedOperator_Blocked *impl;
207   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
208   const CeedInt blksize = 8;
209   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
210   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
211   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
212   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
213   CeedQFunction qf;
214   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
215   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
216   CeedChk(ierr);
217   CeedTransposeMode lmode;
218   CeedOperatorField *opinputfields, *opoutputfields;
219   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
220   CeedChk(ierr);
221   CeedQFunctionField *qfinputfields, *qfoutputfields;
222   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
223   CeedChk(ierr);
224   CeedEvalMode emode;
225   CeedVector vec;
226   CeedBasis basis;
227   CeedElemRestriction Erestrict;
228 
229   // Setup
230   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
231 
232   // Input Evecs and Restriction
233   for (CeedInt i=0; i<numinputfields; i++) {
234     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
235     CeedChk(ierr);
236     if (emode == CEED_EVAL_WEIGHT) { // Skip
237     } else {
238       // Get input vector
239       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
240       if (vec == CEED_VECTOR_ACTIVE)
241         vec = invec;
242       // Restrict
243       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
244       ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
245                                       lmode, vec, impl->evecs[i],
246                                       request); CeedChk(ierr); CeedChk(ierr);
247       // Get evec
248       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
249                                     (const CeedScalar **) &impl->edata[i]);
250       CeedChk(ierr);
251     }
252   }
253 
254   // Output Evecs
255   for (CeedInt i=0; i<numoutputfields; i++) {
256     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
257                               &impl->edata[i + numinputfields]); CeedChk(ierr);
258   }
259 
260   // Loop through elements
261   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
262     // Input basis apply if needed
263     for (CeedInt i=0; i<numinputfields; i++) {
264       // Get elemsize, emode, ncomp
265       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
266       CeedChk(ierr);
267       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
268       CeedChk(ierr);
269       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
270       CeedChk(ierr);
271       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
272       CeedChk(ierr);
273       // Basis action
274       switch(emode) {
275       case CEED_EVAL_NONE:
276         ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
277                                   CEED_USE_POINTER,
278                                   &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
279         break;
280       case CEED_EVAL_INTERP:
281         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
282         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
283                                   CEED_USE_POINTER,
284                                   &impl->edata[i][e*elemsize*ncomp]);
285         CeedChk(ierr);
286         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
287                               CEED_EVAL_INTERP, impl->evecsin[i],
288                               impl->qvecsin[i]); CeedChk(ierr);
289         break;
290       case CEED_EVAL_GRAD:
291         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis); CeedChk(ierr);
292         ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
293                                   CEED_USE_POINTER,
294                                   &impl->edata[i][e*elemsize*ncomp]);
295         CeedChk(ierr);
296         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
297                               CEED_EVAL_GRAD, impl->evecsin[i],
298                               impl->qvecsin[i]); CeedChk(ierr);
299         break;
300       case CEED_EVAL_WEIGHT:
301         break;  // No action
302       case CEED_EVAL_DIV:
303         break; // Not implimented
304       case CEED_EVAL_CURL:
305         break; // Not implimented
306       }
307     }
308 
309     // Output pointers
310     for (CeedInt i=0; i<numoutputfields; i++) {
311       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
312       CeedChk(ierr);
313       if (emode == CEED_EVAL_NONE) {
314         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
315         CeedChk(ierr);
316         ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
317                                   CEED_USE_POINTER,
318                                   &impl->edata[i + numinputfields][e*Q*ncomp]);
319         CeedChk(ierr);
320       }
321     }
322     // Q function
323     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
324     CeedChk(ierr);
325 
326     // Output basis apply if needed
327     for (CeedInt i=0; i<numoutputfields; i++) {
328       // Get elemsize, emode, ncomp
329       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
330       CeedChk(ierr);
331       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
332       CeedChk(ierr);
333       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
334       CeedChk(ierr);
335       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
336       CeedChk(ierr);
337       // Basis action
338       switch(emode) {
339       case CEED_EVAL_NONE:
340         break; // No action
341       case CEED_EVAL_INTERP:
342         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
343         CeedChk(ierr);
344         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
345                                   CEED_USE_POINTER,
346                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
347         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
348                               CEED_EVAL_INTERP, impl->qvecsout[i],
349                               impl->evecsout[i]); CeedChk(ierr);
350         break;
351       case CEED_EVAL_GRAD:
352         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
353         CeedChk(ierr);
354         ierr = CeedVectorSetArray(impl->evecsout[i], CEED_MEM_HOST,
355                                   CEED_USE_POINTER,
356                                   &impl->edata[i + numinputfields][e*elemsize*ncomp]);
357         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
358                               CEED_EVAL_GRAD, impl->qvecsout[i],
359                               impl->evecsout[i]); CeedChk(ierr);
360         break;
361       case CEED_EVAL_WEIGHT: {
362         Ceed ceed;
363         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
364         return CeedError(ceed, 1,
365                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
366         break; // Should not occur
367       }
368       case CEED_EVAL_DIV:
369         break; // Not implimented
370       case CEED_EVAL_CURL:
371         break; // Not implimented
372       }
373     }
374   }
375 
376   // Zero lvecs
377   for (CeedInt i=0; i<numoutputfields; i++) {
378     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
379     if (vec == CEED_VECTOR_ACTIVE)
380       vec = outvec;
381     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
382   }
383 
384   // Output restriction
385   for (CeedInt i=0; i<numoutputfields; i++) {
386     // Restore evec
387     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
388                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
389     // Get output vector
390     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
391     // Active
392     if (vec == CEED_VECTOR_ACTIVE)
393       vec = outvec;
394     // Restrict
395     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
396     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
397                                     lmode, impl->evecs[i+impl->numein], vec,
398                                     request); CeedChk(ierr);
399 
400   }
401 
402   // Restore input arrays
403   for (CeedInt i=0; i<numinputfields; i++) {
404     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
405     CeedChk(ierr);
406     if (emode == CEED_EVAL_WEIGHT) { // Skip
407     } else {
408       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
409                                         (const CeedScalar **) &impl->edata[i]);
410       CeedChk(ierr);
411     }
412   }
413 
414   return 0;
415 }
416 
417 int CeedOperatorCreate_Blocked(CeedOperator op) {
418   int ierr;
419   Ceed ceed;
420   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
421   CeedOperator_Blocked *impl;
422 
423   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
424   ierr = CeedOperatorSetData(op, (void *)&impl);
425 
426   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
427                                 CeedOperatorApply_Blocked); CeedChk(ierr);
428   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
429                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
430   return 0;
431 }
432