xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision 4dccadb61a9bb3ddd06b933b05f6f28773cf32d8)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-blocked.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
22   int ierr;
23   CeedOperator_Blocked *impl;
24   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numqin+impl->numqout; i++) {
35     ierr = CeedFree(&impl->qdata_alloc[i]); CeedChk(ierr);
36   }
37   ierr = CeedFree(&impl->qdata_alloc); CeedChk(ierr);
38   ierr = CeedFree(&impl->qdata); CeedChk(ierr);
39 
40   ierr = CeedFree(&impl->indata); CeedChk(ierr);
41   ierr = CeedFree(&impl->outdata); CeedChk(ierr);
42 
43   ierr = CeedFree(&op->data); CeedChk(ierr);
44   return 0;
45 }
46 
47 /*
48   Setup infields or outfields
49  */
50 static int CeedOperatorSetupFields_Blocked(CeedQFunctionField qfields[16],
51                                        CeedOperatorField ofields[16],
52                                        CeedElemRestriction *blkrestr,
53                                        CeedVector *evecs, CeedScalar **qdata,
54                                        CeedScalar **qdata_alloc, CeedScalar **indata,
55                                        CeedInt starti, CeedInt startq,
56                                        CeedInt numfields, CeedInt Q) {
57   CeedInt dim, ierr, iq=startq, ncomp;
58   CeedBasis basis;
59   CeedElemRestriction r;
60   const CeedInt blksize = 8;
61 
62   // Loop over fields
63   for (CeedInt i=0; i<numfields; i++) {
64     CeedEvalMode emode;
65     ierr = CeedQFunctionFieldGetEvalMode(qfields[i], &emode); CeedChk(ierr);
66 
67     if (emode != CEED_EVAL_WEIGHT) {
68       ierr = CeedOperatorFieldGetElemRestriction(ofields[i], &r);
69       CeedChk(ierr);
70       CeedElemRestriction_Ref *data = r->data;
71       Ceed ceed;
72       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
73       CeedInt nelem, elemsize, ndof, ncomp;
74       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
75       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
76       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
77       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
78       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
79                                               blksize, ndof, ncomp,
80                                               CEED_MEM_HOST, CEED_COPY_VALUES,
81                                               data->indices, &blkrestr[i+starti]);
82       CeedChk(ierr);
83       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starti], NULL,
84                                              &evecs[i+starti]);
85       CeedChk(ierr);
86     }
87 
88     switch(emode) {
89     case CEED_EVAL_NONE:
90       break; // No action
91     case CEED_EVAL_INTERP:
92       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
93       CeedChk(ierr);
94       ierr = CeedMalloc(Q*ncomp*blksize, &qdata_alloc[iq]); CeedChk(ierr);
95       qdata[i + starti] = qdata_alloc[iq];
96       iq++;
97       break;
98     case CEED_EVAL_GRAD:
99       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
100       ierr = CeedQFunctionFieldGetNumComponents(qfields[i], &ncomp);
101       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
102       ierr = CeedMalloc(Q*ncomp*dim*blksize, &qdata_alloc[iq]); CeedChk(ierr);
103       qdata[i + starti] = qdata_alloc[iq];
104       iq++;
105       break;
106     case CEED_EVAL_WEIGHT: // Only on input fields
107       ierr = CeedOperatorFieldGetBasis(ofields[i], &basis); CeedChk(ierr);
108       ierr = CeedMalloc(Q*blksize, &qdata_alloc[iq]); CeedChk(ierr);
109       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
110                             CEED_EVAL_WEIGHT, NULL, qdata_alloc[iq]); CeedChk(ierr);
111       qdata[i] = qdata_alloc[iq];
112       indata[i] = qdata[i];
113       iq++;
114       break;
115     case CEED_EVAL_DIV:
116       break; // Not implimented
117     case CEED_EVAL_CURL:
118       break; // Not implimented
119     }
120   }
121   return 0;
122 }
123 
124 /*
125   CeedOperator needs to connect all the named fields (be they active or passive)
126   to the named inputs and outputs of its CeedQFunction.
127  */
128 static int CeedOperatorSetup_Blocked(CeedOperator op) {
129   int ierr;
130   bool setupdone;
131   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
132   if (setupdone) return 0;
133   CeedOperator_Blocked *impl;
134   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
135   CeedQFunction qf;
136   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
137   CeedInt Q, numinputfields, numoutputfields;
138   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
139   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
140   CeedChk(ierr);
141   CeedOperatorField *opinputfields, *opoutputfields;
142   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
143   CeedChk(ierr);
144   CeedQFunctionField *qfinputfields, *qfoutputfields;
145   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
146   CeedChk(ierr);
147   CeedEvalMode emode;
148 
149   // Count infield and outfield array sizes and evectors
150   impl->numein = numinputfields;
151   for (CeedInt i=0; i<numinputfields; i++) {
152     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
153     CeedChk(ierr);
154     impl->numqin += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD) +
155                     !!(emode & CEED_EVAL_WEIGHT);
156   }
157   impl->numeout = numoutputfields;
158   for (CeedInt i=0; i<numoutputfields; i++) {
159     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
160     CeedChk(ierr);
161     impl->numqout += !!(emode & CEED_EVAL_INTERP) + !!(emode & CEED_EVAL_GRAD);
162   }
163 
164   // Allocate
165   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->blkrestr);
166   CeedChk(ierr);
167   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->evecs);
168   CeedChk(ierr);
169   ierr = CeedCalloc(impl->numein + impl->numeout, &impl->edata);
170   CeedChk(ierr);
171 
172   ierr = CeedCalloc(impl->numqin + impl->numqout, &impl->qdata_alloc);
173   CeedChk(ierr);
174   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->qdata);
175   CeedChk(ierr);
176 
177   ierr = CeedCalloc(16, &impl->indata); CeedChk(ierr);
178   ierr = CeedCalloc(16, &impl->outdata); CeedChk(ierr);
179   // Set up infield and outfield pointer arrays
180   // Infields
181   ierr = CeedOperatorSetupFields_Blocked(qfinputfields, opinputfields,
182                                      impl->blkrestr, impl->evecs,
183                                      impl->qdata, impl->qdata_alloc,
184                                      impl->indata, 0,
185                                      0, numinputfields, Q);
186   CeedChk(ierr);
187   // Outfields
188   ierr = CeedOperatorSetupFields_Blocked(qfoutputfields, opoutputfields,
189                                      impl->blkrestr, impl->evecs,
190                                      impl->qdata, impl->qdata_alloc,
191                                      impl->indata, numinputfields,
192                                      impl->numqin, numoutputfields, Q);
193   CeedChk(ierr);
194   // Input Qvecs
195   for (CeedInt i=0; i<numinputfields; i++) {
196     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
197     CeedChk(ierr);
198     if ((emode != CEED_EVAL_NONE) && (emode != CEED_EVAL_WEIGHT))
199       impl->indata[i] =  impl->qdata[i];
200   }
201   // Output Qvecs
202   for (CeedInt i=0; i<numoutputfields; i++) {
203     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
204     CeedChk(ierr);
205     if (emode != CEED_EVAL_NONE)
206       impl->outdata[i] =  impl->qdata[i + numinputfields];
207   }
208 
209   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
210 
211   return 0;
212 }
213 
214 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
215                                  CeedVector outvec, CeedRequest *request) {
216   int ierr;
217   CeedOperator_Blocked *impl;
218   ierr = CeedOperatorGetData(op, (void*)&impl); CeedChk(ierr);
219   const CeedInt blksize = 8;
220   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
221   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
222   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
223   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
224   CeedQFunction qf;
225   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
226   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
227   CeedChk(ierr);
228   CeedTransposeMode lmode;
229   CeedOperatorField *opinputfields, *opoutputfields;
230   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
231   CeedChk(ierr);
232   CeedQFunctionField *qfinputfields, *qfoutputfields;
233   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
234   CeedChk(ierr);
235   CeedEvalMode emode;
236   CeedVector vec;
237   CeedBasis basis;
238   CeedElemRestriction Erestrict;
239 
240   // Setup
241   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
242 
243   // Input Evecs and Restriction
244   for (CeedInt i=0; i<numinputfields; i++) {
245     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
246     CeedChk(ierr);
247     if (emode == CEED_EVAL_WEIGHT) { // Skip
248     } else {
249       // Get input vector
250       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
251       if (vec == CEED_VECTOR_ACTIVE)
252         vec = invec;
253       // Restrict
254       ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode); CeedChk(ierr);
255       ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
256                                       lmode, vec, impl->evecs[i],
257                                       request); CeedChk(ierr); CeedChk(ierr);
258       // Get evec
259       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
260                                     (const CeedScalar **) &impl->edata[i]);
261       CeedChk(ierr);
262     }
263   }
264 
265   // Output Evecs
266   for (CeedInt i=0; i<numoutputfields; i++) {
267     ierr = CeedVectorGetArray(impl->evecs[i+impl->numein], CEED_MEM_HOST,
268                               &impl->edata[i + numinputfields]); CeedChk(ierr);
269   }
270 
271   // Loop through elements
272   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
273     // Input basis apply if needed
274     for (CeedInt i=0; i<numinputfields; i++) {
275       // Get elemsize, emode, ncomp
276       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
277       CeedChk(ierr);
278       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
279       CeedChk(ierr);
280       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
281       CeedChk(ierr);
282       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
283       CeedChk(ierr);
284       // Basis action
285       switch(emode) {
286       case CEED_EVAL_NONE:
287         impl->indata[i] = &impl->edata[i][e*Q*ncomp];
288         break;
289       case CEED_EVAL_INTERP:
290         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
291         CeedChk(ierr);
292         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
293                               CEED_EVAL_INTERP, &impl->edata[i][e*elemsize*ncomp],
294                               impl->qdata[i]); CeedChk(ierr);
295         break;
296       case CEED_EVAL_GRAD:
297         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
298         CeedChk(ierr);
299         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
300                               CEED_EVAL_GRAD, &impl->edata[i][e*elemsize*ncomp],
301                               impl->qdata[i]); CeedChk(ierr);
302         break;
303       case CEED_EVAL_WEIGHT:
304         break;  // No action
305       case CEED_EVAL_DIV:
306         break; // Not implimented
307       case CEED_EVAL_CURL:
308         break; // Not implimented
309       }
310     }
311 
312     // Output pointers
313     for (CeedInt i=0; i<numoutputfields; i++) {
314       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
315       CeedChk(ierr);
316       if (emode == CEED_EVAL_NONE) {
317         ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
318         CeedChk(ierr);
319         impl->outdata[i] = &impl->edata[i + numinputfields][e*Q*ncomp];
320       }
321     }
322     // Q function
323     ierr = CeedQFunctionApply(qf, Q*blksize,
324                               (const CeedScalar * const*) impl->indata,
325                               impl->outdata); CeedChk(ierr);
326 
327     // Output basis apply if needed
328     for (CeedInt i=0; i<numoutputfields; i++) {
329       // Get elemsize, emode, ncomp
330       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
331       CeedChk(ierr);
332       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
333       CeedChk(ierr);
334       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
335       CeedChk(ierr);
336       ierr = CeedQFunctionFieldGetNumComponents(qfoutputfields[i], &ncomp);
337       CeedChk(ierr);
338       // Basis action
339       switch(emode) {
340       case CEED_EVAL_NONE:
341         break; // No action
342       case CEED_EVAL_INTERP:
343         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
344         CeedChk(ierr);
345         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
346                               CEED_EVAL_INTERP, impl->outdata[i],
347                               &impl->edata[i + numinputfields][e*elemsize*ncomp]);
348         CeedChk(ierr);
349         break;
350       case CEED_EVAL_GRAD:
351         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
352         CeedChk(ierr);
353         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
354                               CEED_EVAL_GRAD,
355                               impl->outdata[i], &impl->edata[i + numinputfields][e*elemsize*ncomp]);
356         CeedChk(ierr);
357         break;
358       case CEED_EVAL_WEIGHT: {
359         Ceed ceed;
360         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
361         return CeedError(ceed, 1,
362                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
363         break; // Should not occur
364       }
365       case CEED_EVAL_DIV:
366         break; // Not implimented
367       case CEED_EVAL_CURL:
368         break; // Not implimented
369       }
370     }
371   }
372 
373   // Zero lvecs
374   for (CeedInt i=0; i<numoutputfields; i++) {
375     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
376     if (vec == CEED_VECTOR_ACTIVE)
377       vec = outvec;
378     ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
379     }
380 
381   // Output restriction
382   for (CeedInt i=0; i<numoutputfields; i++) {
383     // Restore evec
384     ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
385                                   &impl->edata[i + numinputfields]); CeedChk(ierr);
386     // Get output vector
387     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
388     // Active
389     if (vec == CEED_VECTOR_ACTIVE)
390       vec = outvec;
391     // Restrict
392     ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode); CeedChk(ierr);
393     ierr = CeedElemRestrictionApply(impl->blkrestr[i+impl->numein], CEED_TRANSPOSE,
394                                       lmode, impl->evecs[i+impl->numein], vec,
395                                       request); CeedChk(ierr);
396 
397   }
398 
399   // Restore input arrays
400   for (CeedInt i=0; i<numinputfields; i++) {
401     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
402     CeedChk(ierr);
403     if (emode == CEED_EVAL_WEIGHT) { // Skip
404     } else {
405       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
406                                         (const CeedScalar **) &impl->edata[i]);
407       CeedChk(ierr);
408     }
409   }
410 
411   return 0;
412 }
413 
414 int CeedOperatorCreate_Blocked(CeedOperator op) {
415   int ierr;
416   CeedOperator_Blocked *impl;
417 
418   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
419   op->data = impl;
420   op->Destroy = CeedOperatorDestroy_Blocked;
421   op->Apply = CeedOperatorApply_Blocked;
422   return 0;
423 }
424