xref: /libCEED/backends/opt/ceed-opt-operator.c (revision 288c044332e33f37503f09b6484fec9d0a55fba1)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include "ceed-opt.h"
19 #include "../ref/ceed-ref.h"
20 
21 static int CeedOperatorDestroy_Opt(CeedOperator op) {
22   int ierr;
23   CeedOperator_Opt *impl;
24   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
25 
26   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
27     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
28     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
29   }
30   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
31   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
32   ierr = CeedFree(&impl->edata); CeedChk(ierr);
33   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
34 
35   for (CeedInt i=0; i<impl->numein; i++) {
36     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
37     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
38   }
39   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
40   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
41 
42   for (CeedInt i=0; i<impl->numeout; i++) {
43     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
44     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
45   }
46   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
47   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
48 
49   ierr = CeedFree(&impl); CeedChk(ierr);
50   return 0;
51 }
52 
53 /*
54   Setup infields or outfields
55  */
56 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op,
57                                        bool inOrOut, const CeedInt blksize,
58                                        CeedElemRestriction *blkrestr,
59                                        CeedVector *fullevecs, CeedVector *evecs,
60                                        CeedVector *qvecs, CeedInt starte,
61                                        CeedInt numfields, CeedInt Q) {
62   CeedInt dim, ierr, ncomp, size, P;
63   Ceed ceed;
64   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
65   CeedBasis basis;
66   CeedElemRestriction r;
67   CeedOperatorField *opfields;
68   CeedQFunctionField *qffields;
69   if (inOrOut) {
70     ierr = CeedOperatorGetFields(op, NULL, &opfields);
71     CeedChk(ierr);
72     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
73     CeedChk(ierr);
74   } else {
75     ierr = CeedOperatorGetFields(op, &opfields, NULL);
76     CeedChk(ierr);
77     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
78     CeedChk(ierr);
79   }
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data); CeedChk(ierr);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, nnodes;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumNodes(r, &nnodes); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, nnodes, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
111       ierr = CeedVectorCreate(ceed, Q*size*blksize, &evecs[i]); CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
113       break;
114     case CEED_EVAL_INTERP:
115       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
116       ierr = CeedElemRestrictionGetElementSize(r, &P);
117       CeedChk(ierr);
118       ierr = CeedVectorCreate(ceed, P*size*blksize, &evecs[i]); CeedChk(ierr);
119       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
120       break;
121     case CEED_EVAL_GRAD:
122       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
123       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChk(ierr);
124       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
125       ierr = CeedElemRestrictionGetElementSize(r, &P);
126       CeedChk(ierr);
127       ierr = CeedVectorCreate(ceed, P*size/dim*blksize, &evecs[i]); CeedChk(ierr);
128       ierr = CeedVectorCreate(ceed, Q*size*blksize, &qvecs[i]); CeedChk(ierr);
129       break;
130     case CEED_EVAL_WEIGHT: // Only on input fields
131       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
132       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
133       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
134                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
135 
136       break;
137     case CEED_EVAL_DIV:
138       break; // Not implimented
139     case CEED_EVAL_CURL:
140       break; // Not implimented
141     }
142   }
143   return 0;
144 }
145 
146 /*
147   CeedOperator needs to connect all the named fields (be they active or passive)
148   to the named inputs and outputs of its CeedQFunction.
149  */
150 static int CeedOperatorSetup_Opt(CeedOperator op) {
151   int ierr;
152   bool setupdone;
153   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
154   if (setupdone) return 0;
155   Ceed ceed;
156   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
157   Ceed_Opt *ceedimpl;
158   ierr = CeedGetData(ceed, (void *)&ceedimpl); CeedChk(ierr);
159   const CeedInt blksize = ceedimpl->blksize;
160   CeedOperator_Opt *impl;
161   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
162   CeedQFunction qf;
163   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
164   CeedInt Q, numinputfields, numoutputfields;
165   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
166   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
167   CeedChk(ierr);
168   CeedOperatorField *opinputfields, *opoutputfields;
169   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
170   CeedChk(ierr);
171   CeedQFunctionField *qfinputfields, *qfoutputfields;
172   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
173   CeedChk(ierr);
174 
175   // Allocate
176   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
177   CeedChk(ierr);
178   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
179   CeedChk(ierr);
180   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
181   CeedChk(ierr);
182 
183   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
184   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
185   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
186   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
187   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
188 
189   impl->numein = numinputfields; impl->numeout = numoutputfields;
190 
191   // Set up infield and outfield pointer arrays
192   // Infields
193   ierr = CeedOperatorSetupFields_Opt(qf, op, 0, blksize, impl->blkrestr,
194                                      impl->evecs, impl->evecsin,
195                                      impl->qvecsin, 0,
196                                      numinputfields, Q);
197   CeedChk(ierr);
198   // Outfields
199   ierr = CeedOperatorSetupFields_Opt(qf, op, 1, blksize, impl->blkrestr,
200                                      impl->evecs, impl->evecsout,
201                                      impl->qvecsout, numinputfields,
202                                      numoutputfields, Q);
203   CeedChk(ierr);
204 
205   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
206 
207   return 0;
208 }
209 
210 static inline int CeedOperatorApply_Opt(CeedOperator op,
211                                         const CeedInt blksize, CeedVector invec,
212                                         CeedVector outvec,
213                                         CeedRequest *request) {
214   int ierr;
215   CeedOperator_Opt *impl;
216   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
217   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, size, dim;
218   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
219   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
220   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
221   CeedQFunction qf;
222   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
223   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
224   CeedChk(ierr);
225   CeedTransposeMode lmode;
226   CeedOperatorField *opinputfields, *opoutputfields;
227   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
228   CeedChk(ierr);
229   CeedQFunctionField *qfinputfields, *qfoutputfields;
230   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
231   CeedChk(ierr);
232   CeedEvalMode emode;
233   CeedVector vec;
234   CeedBasis basis;
235   CeedElemRestriction Erestrict;
236   uint64_t state;
237 
238   // Setup
239   ierr = CeedOperatorSetup_Opt(op); CeedChk(ierr);
240 
241   // Input Evecs and Restriction
242   for (CeedInt i=0; i<numinputfields; i++) {
243     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
244     CeedChk(ierr);
245     if (emode == CEED_EVAL_WEIGHT) { // Skip
246     } else {
247       // Get input vector
248       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
249       if (vec != CEED_VECTOR_ACTIVE) {
250         // Restrict
251         ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
252         if (state != impl->inputstate[i]) {
253           ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
254           CeedChk(ierr);
255           ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
256                                           lmode, vec, impl->evecs[i], request);
257           CeedChk(ierr);
258           impl->inputstate[i] = state;
259         }
260       } else {
261         // Set Qvec for CEED_EVAL_NONE
262         if (emode == CEED_EVAL_NONE) {
263           ierr = CeedVectorGetArray(impl->evecsin[i], CEED_MEM_HOST,
264                                     &impl->edata[i]); CeedChk(ierr);
265           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
266                                     CEED_USE_POINTER,
267                                     impl->edata[i]); CeedChk(ierr);
268           ierr = CeedVectorRestoreArray(impl->evecsin[i],
269                                         &impl->edata[i]); CeedChk(ierr);
270         }
271       }
272       // Get evec
273       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
274                                     (const CeedScalar **) &impl->edata[i]);
275       CeedChk(ierr);
276     }
277   }
278 
279   // Output Lvecs, Evecs, and Qvecs
280   for (CeedInt i=0; i<numoutputfields; i++) {
281     // Zero Lvecs
282     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
283     if (vec == CEED_VECTOR_ACTIVE) {
284       if (!impl->add) {
285         vec = outvec;
286         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
287       }
288     } else {
289       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
290     }
291     // Set Qvec if needed
292     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
293     CeedChk(ierr);
294     if (emode == CEED_EVAL_NONE) {
295       // Set qvec to single block evec
296       ierr = CeedVectorGetArray(impl->evecsout[i], CEED_MEM_HOST,
297                                 &impl->edata[i + numinputfields]);
298       CeedChk(ierr);
299       ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
300                                 CEED_USE_POINTER,
301                                 impl->edata[i + numinputfields]); CeedChk(ierr);
302       ierr = CeedVectorRestoreArray(impl->evecsout[i],
303                                     &impl->edata[i + numinputfields]);
304       CeedChk(ierr);
305     }
306   }
307   impl->add = false;
308 
309   // Loop through elements
310   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
311     // Input basis apply if needed
312     for (CeedInt i=0; i<numinputfields; i++) {
313       CeedInt activein = 0;
314       // Get elemsize, emode, size
315       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
316       CeedChk(ierr);
317       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
318       CeedChk(ierr);
319       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
320       CeedChk(ierr);
321       ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChk(ierr);
322       // Restrict block active input
323       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
324       if (vec == CEED_VECTOR_ACTIVE) {
325         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
326         CeedChk(ierr);
327         ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i], e/blksize,
328                                              CEED_NOTRANSPOSE, lmode, invec,
329                                              impl->evecsin[i], request);
330         CeedChk(ierr);
331         activein = 1;
332       }
333       // Basis action
334       switch(emode) {
335       case CEED_EVAL_NONE:
336         if (!activein) {
337           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
338                                     CEED_USE_POINTER,
339                                     &impl->edata[i][e*Q*size]); CeedChk(ierr);
340         }
341         break;
342       case CEED_EVAL_INTERP:
343         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
344         CeedChk(ierr);
345         if (!activein) {
346           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
347                                     CEED_USE_POINTER,
348                                     &impl->edata[i][e*elemsize*size]);
349           CeedChk(ierr);
350         }
351         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
352                               CEED_EVAL_INTERP, impl->evecsin[i],
353                               impl->qvecsin[i]); CeedChk(ierr);
354         break;
355       case CEED_EVAL_GRAD:
356         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
357         CeedChk(ierr);
358         if (!activein) {
359           ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
360           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
361                                     CEED_USE_POINTER,
362                                     &impl->edata[i][e*elemsize*size/dim]);
363           CeedChk(ierr);
364         }
365         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
366                               CEED_EVAL_GRAD, impl->evecsin[i],
367                               impl->qvecsin[i]); CeedChk(ierr);
368         break;
369       case CEED_EVAL_WEIGHT:
370         break;  // No action
371       case CEED_EVAL_DIV:
372         break; // Not implimented
373       case CEED_EVAL_CURL:
374         break; // Not implimented
375       }
376     }
377 
378     // Q function
379     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
380     CeedChk(ierr);
381 
382     // Output basis apply and restrict
383     for (CeedInt i=0; i<numoutputfields; i++) {
384       // Get elemsize, emode, size
385       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
386       CeedChk(ierr);
387       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
388       CeedChk(ierr);
389       // Basis action
390       switch(emode) {
391       case CEED_EVAL_NONE:
392         break; // No action
393       case CEED_EVAL_INTERP:
394         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
395         CeedChk(ierr);
396         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
397                               CEED_EVAL_INTERP, impl->qvecsout[i],
398                               impl->evecsout[i]); CeedChk(ierr);
399         break;
400       case CEED_EVAL_GRAD:
401         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
402         CeedChk(ierr);
403         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
404                               CEED_EVAL_GRAD, impl->qvecsout[i],
405                               impl->evecsout[i]); CeedChk(ierr);
406         break;
407       case CEED_EVAL_WEIGHT: {
408         Ceed ceed;
409         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
410         return CeedError(ceed, 1,
411                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
412         break; // Should not occur
413       }
414       case CEED_EVAL_DIV:
415         break; // Not implimented
416       case CEED_EVAL_CURL:
417         break; // Not implimented
418       }
419       // Restrict output block
420       // Get output vector
421       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
422       if (vec == CEED_VECTOR_ACTIVE)
423         vec = outvec;
424       // Restrict
425       ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode);
426       CeedChk(ierr);
427       ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i+impl->numein],
428                                            e/blksize, CEED_TRANSPOSE,
429                                            lmode, impl->evecsout[i],
430                                            vec, request); CeedChk(ierr);
431     }
432   }
433 
434   // Restore input arrays
435   for (CeedInt i=0; i<numinputfields; i++) {
436     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
437     CeedChk(ierr);
438     if (emode == CEED_EVAL_WEIGHT) { // Skip
439     } else {
440       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
441                                         (const CeedScalar **) &impl->edata[i]);
442       CeedChk(ierr);
443     }
444   }
445 
446   return 0;
447 }
448 
449 int CeedOperatorApply_Opt_1(CeedOperator op, CeedVector invec,
450                             CeedVector outvec, CeedRequest *request) {
451   return CeedOperatorApply_Opt(op, 1, invec, outvec, request);
452 }
453 
454 int CeedOperatorApply_Opt_8(CeedOperator op, CeedVector invec,
455                             CeedVector outvec, CeedRequest *request) {
456   return CeedOperatorApply_Opt(op, 8, invec, outvec, request);
457 }
458 
459 int CeedOperatorCreate_Opt(CeedOperator op) {
460   int ierr;
461   Ceed ceed;
462   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
463   Ceed_Opt *ceedimpl;
464   ierr = CeedGetData(ceed, (void *)&ceedimpl); CeedChk(ierr);
465   CeedInt blksize = ceedimpl->blksize;
466   CeedOperator_Opt *impl;
467 
468   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
469   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
470 
471   if (blksize == 1) {
472     ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
473                                   CeedOperatorApply_Opt_1); CeedChk(ierr);
474   } else if (blksize == 8) {
475     ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
476                                   CeedOperatorApply_Opt_8); CeedChk(ierr);
477   } else {
478     return CeedError(ceed, 1, "Opt backend cannot use blocksize: %d", blksize);
479   }
480 
481   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
482                                 CeedOperatorDestroy_Opt); CeedChk(ierr);
483   return 0;
484 }
485