xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision d1bcdac99090efedac6b43fd437855ade149b9cc)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
35     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
36   }
37   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
38   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&impl->indata); CeedChk(ierr);
41   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
42 
43   ierr = CeedFree(&op->data); CeedChk(ierr);
44   return 0;
45 }
46 
47 /*
48   Setup infields or outfields
49  */
50 static int CeedOperatorSetupFields_Blocked(CeedQFunctionField qfields[16],
51                                        CeedOperatorField ofields[16],
52                                        CeedElemRestriction *blkrestr,
53                                        CeedVector *evecs, CeedScalar **qdata,
54                                        CeedScalar **qdata_alloc, CeedScalar **indata,
55                                        CeedInt starti, CeedInt startq,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, iq=startq, ncomp;
58   CeedBasis basis;
59   CeedElemRestriction r;
60   const CeedInt blksize = 8;
61 
62   // Loop over fields
63   for (CeedInt i=0; i<numfields; i++) {
64     CeedEvalMode emode;
65     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
66 
67     if (emode != CEED_EVAL_WEIGHT) {
68       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &r);
69       CeedChk(ierr);
70       CeedElemRestriction_Ref *data = r->data;
71       Ceed ceed;
72       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
73       CeedInt nelem, elemsize, ndof, ncomp;
74       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
75       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
76       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
77       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
78       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
79                                               blksize, ndof, ncomp,
80                                               CEED_MEM_HOST, CEED_COPY_VALUES,
81                                               data->indices, &blkrestr[i+starti]);
82       CeedChk(ierr);
83       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starti], NULL,
84                                              &evecs[i+starti]);
85       CeedChk(ierr);
86     }
87 
88     switch(emode) {
89     case CEED_EVAL_NONE:
90       break; // No action
91     case CEED_EVAL_INTERP:
92       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
93       CeedChk(ierr);
94       ierr = CeedMalloc(Q*ncomp*blksize, &qdata_alloc[iq]); CeedChk(ierr);
95       qdata[i + starti] = qdata_alloc[iq];
96       iq++;
97       break;
98     case CEED_EVAL_GRAD:
99       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
100       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
101       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
102       ierr = CeedMalloc(Q*ncomp*dim*blksize, &qdata_alloc[iq]); CeedChk(ierr);
103       qdata[i + starti] = qdata_alloc[iq];
104       iq++;
105       break;
106     case CEED_EVAL_WEIGHT: // Only on input fields
107       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
108       ierr = CeedMalloc(Q*blksize, &qdata_alloc[iq]); CeedChk(ierr);
109       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
110                             CEED_EVAL_WEIGHT, NULL, qdata_alloc[iq]); CeedChk(ierr);
111       qdata[i] = qdata_alloc[iq];
112       indata[i] = qdata[i];
113       iq++;
114       break;
115     case CEED_EVAL_DIV:
116       break; // Not implimented
117     case CEED_EVAL_CURL:
118       break; // Not implimented
119     }
120   }
121   return 0;
122 }
123 
124 /*
125   CeedOperator needs to connect all the named fields (be they active or passive)
126   to the named inputs and outputs of its CeedQFunction.
127  */
128 static int CeedOperatorSetup_Blocked(CeedOperator op) {
129   int ierr;
130   bool setupdone;
131   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
132   if (setupdone) return 0;
133   CeedOperator_Blocked *impl;
134   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
135   CeedQFunction qf;
136   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
137   CeedInt Q, numinputfields, numoutputfields;
138   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
139   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
140   CeedChk(ierr);
141   CeedOperatorField *opinputfields, *opoutputfields;
142   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
143   CeedChk(ierr);
144   CeedQFunctionField *qfinputfields, *qfoutputfields;
145   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
146   CeedChk(ierr);
147   CeedEvalMode emode;
148 
149   // Count infield and outfield array sizes and evectors
150   impl->numein = numinputfields;
151   for (CeedInt i=0; i<numinputfields; i++) {
152     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
153     CeedChk(ierr);
154     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
155                     !!(emode & CEED_EVAL_WEIGHT);
156   }
157   impl->numeout = numoutputfields;
158   for (CeedInt i=0; i<numoutputfields; i++) {
159     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
160     CeedChk(ierr);
161     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
162   }
163 
164   // Allocate
165   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->blkrestr);
166   CeedChk(ierr);
167   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs);
168   CeedChk(ierr);
169   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
170   CeedChk(ierr);
171 
172   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
173   CeedChk(ierr);
174   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
175   CeedChk(ierr);
176 
177   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
178   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
179   // Set up infield and outfield pointer arrays
180   // Infields
181   ierr = CeedOperatorSetupFields_Blocked(qfinputfields, opinputfields,
182                                      impl->blkrestr, impl->evecs,
183                                      impl->qdata, impl->qdata_alloc,
184                                      impl->indata, 0,
185                                      0, numinputfields, Q);
186   CeedChk(ierr);
187   // Outfields
188   ierr = CeedOperatorSetupFields_Blocked(qfoutputfields, opoutputfields,
189                                      impl->blkrestr, impl->evecs,
190                                      impl->qdata, impl->qdata_alloc,
191                                      impl->indata, numinputfields,
192                                      impl->numqin, numoutputfields, Q);
193   CeedChk(ierr);
194   // Input Qvecs
195   for (CeedInt i=0; i<numinputfields; i++) {
196     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
197     CeedChk(ierr);
198     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
199       impl->indata[i] =  impl->qdata[i];
200   }
201   // Output Qvecs
202   for (CeedInt i=0; i<numoutputfields; i++) {
203     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
204     CeedChk(ierr);
205     if (emode != CEED_EVAL_NONE)
206       impl->outdata[i] =  impl->qdata[i + numinputfields];
207   }
208 
209   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
210 
211   return 0;
212 }
213 
214 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
215                                  CeedVector outvec, CeedRequest *request) {
216   int ierr;
217   CeedOperator_Blocked *impl;
218   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
219   const CeedInt blksize = 8;
220   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
221   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
222   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
223   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
224   CeedQFunction qf;
225   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
226   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
227   CeedChk(ierr);
228   CeedTransposeMode lmode = CEED_NOTRANSPOSE;
229   CeedOperatorField *opinputfields, *opoutputfields;
230   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
231   CeedChk(ierr);
232   CeedQFunctionField *qfinputfields, *qfoutputfields;
233   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
234   CeedChk(ierr);
235   CeedEvalMode emode;
236   CeedVector vec;
237   CeedBasis basis;
238   CeedElemRestriction Erestrict;
239 
240   // Setup
241   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
242 
243   // Input Evecs and Restriction
244   for (CeedInt i=0; i<numinputfields; i++) {
245     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
246     CeedChk(ierr);
247     if (emode == CEED_EVAL_WEIGHT) { // Skip
248     } else {
249       // Get input vector
250       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
251       if (vec == CEED_VECTOR_ACTIVE)
252         vec = invec;
253       // Restrict
254       ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
255                                       lmode, vec, impl->evecs[i],
256                                       request); CeedChk(ierr); CeedChk(ierr);
257       // Get evec
258       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
259                                     (const CeedScalar **) &impl->edata[i]);
260       CeedChk(ierr);
261     }
262   }
263 
264   // Output Evecs
265   for (CeedInt i=0; i<numoutputfields; i++) {
266     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
267                               &impl->edata[i + numinputfields]); CeedChk(ierr);
268   }
269 
270   // Loop through elements
271   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
272     // Input basis apply if needed
273     for (CeedInt i=0; i<numinputfields; i++) {
274       // Get elemsize, emode, ncomp
275       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
276       CeedChk(ierr);
277       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
278       CeedChk(ierr);
279       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
280       CeedChk(ierr);
281       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
282       CeedChk(ierr);
283       // Basis action
284       switch(emode) {
285       case CEED_EVAL_NONE:
286         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
287         break;
288       case CEED_EVAL_INTERP:
289         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
290         CeedChk(ierr);
291         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
292                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
293                               impl->qdata[i]); CeedChk(ierr);
294         break;
295       case CEED_EVAL_GRAD:
296         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
297         CeedChk(ierr);
298         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
299                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
300                               impl->qdata[i]); CeedChk(ierr);
301         break;
302       case CEED_EVAL_WEIGHT:
303         break;  // No action
304       case CEED_EVAL_DIV:
305         break; // Not implimented
306       case CEED_EVAL_CURL:
307         break; // Not implimented
308       }
309     }
310 
311     // Output pointers
312     for (CeedInt i=0; i<numoutputfields; i++) {
313       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
314       CeedChk(ierr);
315       if (emode == CEED_EVAL_NONE) {
316         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
317         CeedChk(ierr);
318         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
319       }
320     }
321     // Q function
322     ierr = CeedQFunctionApply(qf, Q*blksize,
323                               (const CeedScalar * const*) impl->indata,
324                               impl->outdata); CeedChk(ierr);
325 
326     // Output basis apply if needed
327     for (CeedInt i=0; i<numoutputfields; i++) {
328       // Get elemsize, emode, ncomp
329       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
330       CeedChk(ierr);
331       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
332       CeedChk(ierr);
333       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
334       CeedChk(ierr);
335       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
336       CeedChk(ierr);
337       // Basis action
338       switch(emode) {
339       case CEED_EVAL_NONE:
340         break; // No action
341       case CEED_EVAL_INTERP:
342         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
343         CeedChk(ierr);
344         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
345                               CEED_EVAL_INTERP, impl->outdata[i],
346                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
347         CeedChk(ierr);
348         break;
349       case CEED_EVAL_GRAD:
350         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
351         CeedChk(ierr);
352         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
353                               CEED_EVAL_GRAD,
354                               impl->outdata[i], &impl->edata[i + numinputfields][e*elemsize*ncomp]);
355         CeedChk(ierr);
356         break;
357       case CEED_EVAL_WEIGHT: {
358         Ceed ceed;
359         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
360         return CeedError(ceed, 1,
361                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
362         break; // Should not occur
363       }
364       case CEED_EVAL_DIV:
365         break; // Not implimented
366       case CEED_EVAL_CURL:
367         break; // Not implimented
368       }
369     }
370   }
371 
372   // Zero lvecs
373   for (CeedInt i=0; i<numoutputfields; i++) {
374     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
375     if (vec == CEED_VECTOR_ACTIVE)
376       vec = outvec;
377     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
378     }
379 
380   // Output restriction
381   for (CeedInt i=0; i<numoutputfields; i++) {
382     // Restore evec
383     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
384                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
385     // Get output vector
386     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
387     // Active
388     if (vec == CEED_VECTOR_ACTIVE)
389       vec = outvec;
390     // Restrict
391     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
392                                       lmode, impl->evecs[i+impl->numein], vec,
393                                       request); CeedChk(ierr);
394 
395   }
396 
397   // Restore input arrays
398   for (CeedInt i=0; i<numinputfields; i++) {
399     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
400     CeedChk(ierr);
401     if (emode == CEED_EVAL_WEIGHT) { // Skip
402     } else {
403       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
404                                         (const CeedScalar **) &impl->edata[i]);
405       CeedChk(ierr);
406     }
407   }
408 
409   return 0;
410 }
411 
412 int CeedOperatorCreate_Blocked(CeedOperator op) {
413   int ierr;
414   CeedOperator_Blocked *impl;
415 
416   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
417   op->data = impl;
418   op->Destroy = CeedOperatorDestroy_Blocked;
419   op->Apply = CeedOperatorApply_Blocked;
420   return 0;
421 }
422