xref: /libCEED/backends/opt/ceed-opt-operator.c (revision 1f37b403ef08d286079519fbc830f486b0f4359b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-opt.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Opt(CeedOperator op) {
22   int ierr;
23   CeedOperator_Opt *impl;
24   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
34 
35   for (CeedInt i=0; i<impl->numein; i++) {
36     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
37     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
40   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
41 
42   for (CeedInt i=0; i<impl->numeout; i++) {
43     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
44     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
45   }
46   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
47   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
48 
49   ierr = CeedFree(&impl); CeedChk(ierr);
50   return 0;
51 }
52 
53 /*
54   Setup infields or outfields
55  */
56 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op,
57     bool inOrOut, const CeedInt blksize,
58     CeedElemRestriction *blkrestr,
59     CeedVector *fullevecs, CeedVector *evecs,
60     CeedVector *qvecs, CeedInt starte,
61     CeedInt numfields, CeedInt Q) {
62   CeedInt dim, ierr, ncomp, P;
63   Ceed ceed;
64   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
65   CeedBasis basis;
66   CeedElemRestriction r;
67   CeedOperatorField *opfields;
68   CeedQFunctionField *qffields;
69   if (inOrOut) {
70     ierr = CeedOperatorGetFields(op, NULL, &opfields);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
73     CeedChk(ierr);
74   } else {
75     ierr = CeedOperatorGetFields(op, &opfields, NULL);
76     CeedChk(ierr);
77     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
78     CeedChk(ierr);
79   }
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data); CeedChk(ierr);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, ndof;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, ndof, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
111       CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
113       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
114       break;
115     case CEED_EVAL_INTERP:
116       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
117       CeedChk(ierr);
118       ierr = CeedElemRestrictionGetElementSize(r, &P);
119       CeedChk(ierr);
120       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
121       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
122       break;
123     case CEED_EVAL_GRAD:
124       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
125       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
126       CeedChk(ierr);
127       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
128       ierr = CeedElemRestrictionGetElementSize(r, &P);
129       CeedChk(ierr);
130       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
131       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
132       break;
133     case CEED_EVAL_WEIGHT: // Only on input fields
134       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
135       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
136       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
137                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
138 
139       break;
140     case CEED_EVAL_DIV:
141       break; // Not implimented
142     case CEED_EVAL_CURL:
143       break; // Not implimented
144     }
145   }
146   return 0;
147 }
148 
149 /*
150   CeedOperator needs to connect all the named fields (be they active or passive)
151   to the named inputs and outputs of its CeedQFunction.
152  */
153 static int CeedOperatorSetup_Opt(CeedOperator op) {
154   int ierr;
155   bool setupdone;
156   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
157   if (setupdone) return 0;
158   Ceed ceed;
159   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
160   Ceed_Opt *ceedimpl;
161   ierr = CeedGetData(ceed, (void *)&ceedimpl); CeedChk(ierr);
162   const CeedInt blksize = ceedimpl->blksize;
163   CeedOperator_Opt *impl;
164   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
165   CeedQFunction qf;
166   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
167   CeedInt Q, numinputfields, numoutputfields;
168   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
169   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
170   CeedChk(ierr);
171   CeedOperatorField *opinputfields, *opoutputfields;
172   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
173   CeedChk(ierr);
174   CeedQFunctionField *qfinputfields, *qfoutputfields;
175   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
176   CeedChk(ierr);
177 
178   // Allocate
179   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
180   CeedChk(ierr);
181   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
182   CeedChk(ierr);
183   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
184   CeedChk(ierr);
185 
186   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
187   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
188   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
189   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
190   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
191 
192   impl->numein = numinputfields; impl->numeout = numoutputfields;
193 
194   // Set up infield and outfield pointer arrays
195   // Infields
196   ierr = CeedOperatorSetupFields_Opt(qf, op, 0, blksize, impl->blkrestr,
197                                      impl->evecs, impl->evecsin,
198                                      impl->qvecsin, 0,
199                                      numinputfields, Q);
200   CeedChk(ierr);
201   // Outfields
202   ierr = CeedOperatorSetupFields_Opt(qf, op, 1, blksize, impl->blkrestr,
203                                      impl->evecs, impl->evecsout,
204                                      impl->qvecsout, numinputfields,
205                                      numoutputfields, Q);
206   CeedChk(ierr);
207 
208   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
209 
210   return 0;
211 }
212 
213 static inline int CeedOperatorApply_Opt(CeedOperator op,
214                                         const CeedInt blksize, CeedVector invec,
215                                         CeedVector outvec,
216                                         CeedRequest *request) {
217   int ierr;
218   CeedOperator_Opt *impl;
219   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
220   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
221   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
222   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
223   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
224   CeedQFunction qf;
225   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
226   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
227   CeedChk(ierr);
228   CeedTransposeMode lmode;
229   CeedOperatorField *opinputfields, *opoutputfields;
230   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
231   CeedChk(ierr);
232   CeedQFunctionField *qfinputfields, *qfoutputfields;
233   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
234   CeedChk(ierr);
235   CeedEvalMode emode;
236   CeedVector vec;
237   CeedBasis basis;
238   CeedElemRestriction Erestrict;
239   uint64_t state;
240 
241   // Setup
242   ierr = CeedOperatorSetup_Opt(op); CeedChk(ierr);
243 
244   // Input Evecs and Restriction
245   for (CeedInt i=0; i<numinputfields; i++) {
246     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
247     CeedChk(ierr);
248     if (emode == CEED_EVAL_WEIGHT) { // Skip
249     } else {
250       // Get input vector
251       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
252       if (vec != CEED_VECTOR_ACTIVE) {
253         // Restrict
254         ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
255         if (state != impl->inputstate[i]) {
256           ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
257           CeedChk(ierr);
258           ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
259                                           lmode, vec, impl->evecs[i], request);
260           CeedChk(ierr);
261           impl->inputstate[i] = state;
262         }
263       } else {
264         // Set Qvec for CEED_EVAL_NONE
265         if (emode == CEED_EVAL_NONE) {
266           ierr = CeedVectorGetArray(impl->evecsin[i], CEED_MEM_HOST,
267                                     &impl->edata[i]); CeedChk(ierr);
268           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
269                                     CEED_USE_POINTER,
270                                     impl->edata[i]); CeedChk(ierr);
271           ierr = CeedVectorRestoreArray(impl->evecsin[i],
272                                         &impl->edata[i]); CeedChk(ierr);
273         }
274       }
275       // Get evec
276       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
277                                     (const CeedScalar **) &impl->edata[i]);
278       CeedChk(ierr);
279     }
280   }
281 
282   // Output Lvecs, Evecs, and Qvecs
283   for (CeedInt i=0; i<numoutputfields; i++) {
284     // Zero Lvecs
285     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
286     if (vec == CEED_VECTOR_ACTIVE) {
287       if (!impl->add) {
288         vec = outvec;
289         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
290       }
291     } else {
292       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
293     }
294     // Set Qvec if needed
295     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
296     CeedChk(ierr);
297     if (emode == CEED_EVAL_NONE) {
298       // Set qvec to single block evec
299       ierr = CeedVectorGetArray(impl->evecsout[i], CEED_MEM_HOST,
300                                 &impl->edata[i + numinputfields]);
301       CeedChk(ierr);
302       ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
303                                 CEED_USE_POINTER,
304                                 impl->edata[i + numinputfields]); CeedChk(ierr);
305       ierr = CeedVectorRestoreArray(impl->evecsout[i],
306                                     &impl->edata[i + numinputfields]);
307       CeedChk(ierr);
308     }
309   }
310   impl->add = false;
311 
312   // Loop through elements
313   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
314     // Input basis apply if needed
315     for (CeedInt i=0; i<numinputfields; i++) {
316       CeedInt activein = 0;
317       // Get elemsize, emode, ncomp
318       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
319       CeedChk(ierr);
320       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
321       CeedChk(ierr);
322       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
323       CeedChk(ierr);
324       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
325       CeedChk(ierr);
326       // Restrict block active input
327       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
328       if (vec == CEED_VECTOR_ACTIVE) {
329         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
330         CeedChk(ierr);
331         ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i], e/blksize,
332                                              CEED_NOTRANSPOSE, lmode, invec,
333                                              impl->evecsin[i], request);
334         CeedChk(ierr);
335         activein = 1;
336       }
337       // Basis action
338       switch(emode) {
339       case CEED_EVAL_NONE:
340         if (!activein) {
341           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
342                                     CEED_USE_POINTER,
343                                     &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
344         }
345         break;
346       case CEED_EVAL_INTERP:
347         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
348         CeedChk(ierr);
349         if (!activein) {
350           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
351                                     CEED_USE_POINTER,
352                                     &impl->edata[i][e*elemsize*ncomp]);
353           CeedChk(ierr);
354         }
355         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
356                               CEED_EVAL_INTERP, impl->evecsin[i],
357                               impl->qvecsin[i]); CeedChk(ierr);
358         break;
359       case CEED_EVAL_GRAD:
360         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
361         CeedChk(ierr);
362          if (!activein) {
363           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
364                                     CEED_USE_POINTER,
365                                     &impl->edata[i][e*elemsize*ncomp]);
366           CeedChk(ierr);
367         }
368         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
369                               CEED_EVAL_GRAD, impl->evecsin[i],
370                               impl->qvecsin[i]); CeedChk(ierr);
371         break;
372       case CEED_EVAL_WEIGHT:
373         break;  // No action
374       case CEED_EVAL_DIV:
375         break; // Not implimented
376       case CEED_EVAL_CURL:
377         break; // Not implimented
378       }
379     }
380 
381     // Q function
382     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
383     CeedChk(ierr);
384 
385     // Output basis apply and restrict
386     for (CeedInt i=0; i<numoutputfields; i++) {
387       // Get elemsize, emode, ncomp
388       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
389       CeedChk(ierr);
390       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
391       CeedChk(ierr);
392       // Basis action
393       switch(emode) {
394       case CEED_EVAL_NONE:
395         break; // No action
396       case CEED_EVAL_INTERP:
397         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
398         CeedChk(ierr);
399         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
400                               CEED_EVAL_INTERP, impl->qvecsout[i],
401                               impl->evecsout[i]); CeedChk(ierr);
402         break;
403       case CEED_EVAL_GRAD:
404         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
405         CeedChk(ierr);
406         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
407                               CEED_EVAL_GRAD, impl->qvecsout[i],
408                               impl->evecsout[i]); CeedChk(ierr);
409         break;
410       case CEED_EVAL_WEIGHT: {
411         Ceed ceed;
412         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
413         return CeedError(ceed, 1,
414                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
415         break; // Should not occur
416       }
417       case CEED_EVAL_DIV:
418         break; // Not implimented
419       case CEED_EVAL_CURL:
420         break; // Not implimented
421       }
422       // Restrict output block
423       // Get output vector
424       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
425       if (vec == CEED_VECTOR_ACTIVE)
426         vec = outvec;
427       // Restrict
428       ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode);
429       CeedChk(ierr);
430       ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i+impl->numein],
431                                            e/blksize, CEED_TRANSPOSE,
432                                            lmode, impl->evecsout[i],
433                                            vec, request); CeedChk(ierr);
434     }
435   }
436 
437   // Restore input arrays
438   for (CeedInt i=0; i<numinputfields; i++) {
439     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
440     CeedChk(ierr);
441     if (emode == CEED_EVAL_WEIGHT) { // Skip
442     } else {
443       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
444                                         (const CeedScalar **) &impl->edata[i]);
445       CeedChk(ierr);
446     }
447   }
448 
449   return 0;
450 }
451 
452 int CeedOperatorApply_Opt_1(CeedOperator op, CeedVector invec,
453                             CeedVector outvec, CeedRequest *request) {
454   return CeedOperatorApply_Opt(op, 1, invec, outvec, request);
455 }
456 
457 int CeedOperatorApply_Opt_8(CeedOperator op, CeedVector invec,
458                             CeedVector outvec, CeedRequest *request) {
459   return CeedOperatorApply_Opt(op, 8, invec, outvec, request);
460 }
461 
462 int CeedOperatorCreate_Opt(CeedOperator op) {
463   int ierr;
464   Ceed ceed;
465   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
466   Ceed_Opt *ceedimpl;
467   ierr = CeedGetData(ceed, (void*)&ceedimpl); CeedChk(ierr);
468   CeedInt blksize = ceedimpl->blksize;
469   CeedOperator_Opt *impl;
470 
471   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
472   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
473 
474   if (blksize == 1) {
475     ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
476                                   CeedOperatorApply_Opt_1); CeedChk(ierr);
477   } else if (blksize == 8) {
478     ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
479                                   CeedOperatorApply_Opt_8); CeedChk(ierr);
480   } else {
481     return CeedError(ceed, 1, "Opt backend cannot use blocksize: %d", blksize);
482   }
483 
484   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
485                                 CeedOperatorDestroy_Opt); CeedChk(ierr);
486   return 0;
487 }
488