xref: /libCEED/backends/hip-ref/ceed-hip-ref-operator.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <assert.h>
12 #include <stdbool.h>
13 #include <string.h>
14 #include "ceed-hip-ref.h"
15 #include "../hip/ceed-hip-compile.h"
16 
17 //------------------------------------------------------------------------------
18 // Destroy operator
19 //------------------------------------------------------------------------------
20 static int CeedOperatorDestroy_Hip(CeedOperator op) {
21   int ierr;
22   CeedOperator_Hip *impl;
23   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
24 
25   // Apply data
26   for (CeedInt i = 0; i < impl->numein + impl->numeout; i++) {
27     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChkBackend(ierr);
28   }
29   ierr = CeedFree(&impl->evecs); CeedChkBackend(ierr);
30 
31   for (CeedInt i = 0; i < impl->numein; i++) {
32     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChkBackend(ierr);
33   }
34   ierr = CeedFree(&impl->qvecsin); CeedChkBackend(ierr);
35 
36   for (CeedInt i = 0; i < impl->numeout; i++) {
37     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChkBackend(ierr);
38   }
39   ierr = CeedFree(&impl->qvecsout); CeedChkBackend(ierr);
40 
41   // QFunction diagonal assembly data
42   for (CeedInt i=0; i<impl->qfnumactivein; i++) {
43     ierr = CeedVectorDestroy(&impl->qfactivein[i]); CeedChkBackend(ierr);
44   }
45   ierr = CeedFree(&impl->qfactivein); CeedChkBackend(ierr);
46 
47   // Diag data
48   if (impl->diag) {
49     Ceed ceed;
50     ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
51     CeedChk_Hip(ceed, hipModuleUnload(impl->diag->module));
52     ierr = CeedFree(&impl->diag->h_emodein); CeedChkBackend(ierr);
53     ierr = CeedFree(&impl->diag->h_emodeout); CeedChkBackend(ierr);
54     ierr = hipFree(impl->diag->d_emodein); CeedChk_Hip(ceed, ierr);
55     ierr = hipFree(impl->diag->d_emodeout); CeedChk_Hip(ceed, ierr);
56     ierr = hipFree(impl->diag->d_identity); CeedChk_Hip(ceed, ierr);
57     ierr = hipFree(impl->diag->d_interpin); CeedChk_Hip(ceed, ierr);
58     ierr = hipFree(impl->diag->d_interpout); CeedChk_Hip(ceed, ierr);
59     ierr = hipFree(impl->diag->d_gradin); CeedChk_Hip(ceed, ierr);
60     ierr = hipFree(impl->diag->d_gradout); CeedChk_Hip(ceed, ierr);
61     ierr = CeedElemRestrictionDestroy(&impl->diag->pbdiagrstr);
62     CeedChkBackend(ierr);
63     ierr = CeedVectorDestroy(&impl->diag->elemdiag); CeedChkBackend(ierr);
64     ierr = CeedVectorDestroy(&impl->diag->pbelemdiag); CeedChkBackend(ierr);
65   }
66   ierr = CeedFree(&impl->diag); CeedChkBackend(ierr);
67 
68   ierr = CeedFree(&impl); CeedChkBackend(ierr);
69   return CEED_ERROR_SUCCESS;
70 }
71 
72 //------------------------------------------------------------------------------
73 // Setup infields or outfields
74 //------------------------------------------------------------------------------
75 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op,
76                                        bool isinput, CeedVector *evecs,
77                                        CeedVector *qvecs, CeedInt starte,
78                                        CeedInt numfields, CeedInt Q,
79                                        CeedInt numelements) {
80   CeedInt dim, ierr, size;
81   Ceed ceed;
82   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
83   CeedBasis basis;
84   CeedElemRestriction Erestrict;
85   CeedOperatorField *opfields;
86   CeedQFunctionField *qffields;
87   CeedVector fieldvec;
88   bool strided;
89   bool skiprestrict;
90 
91   if (isinput) {
92     ierr = CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL);
93     CeedChkBackend(ierr);
94     ierr = CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL);
95     CeedChkBackend(ierr);
96   } else {
97     ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields);
98     CeedChkBackend(ierr);
99     ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields);
100     CeedChkBackend(ierr);
101   }
102 
103   // Loop over fields
104   for (CeedInt i = 0; i < numfields; i++) {
105     CeedEvalMode emode;
106     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChkBackend(ierr);
107 
108     strided = false;
109     skiprestrict = false;
110     if (emode != CEED_EVAL_WEIGHT) {
111       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict);
112       CeedChkBackend(ierr);
113 
114       // Check whether this field can skip the element restriction:
115       // must be passive input, with emode NONE, and have a strided restriction with
116       // CEED_STRIDES_BACKEND.
117 
118       // First, check whether the field is input or output:
119       if (isinput) {
120         // Check for passive input:
121         ierr = CeedOperatorFieldGetVector(opfields[i], &fieldvec); CeedChkBackend(ierr);
122         if (fieldvec != CEED_VECTOR_ACTIVE) {
123           // Check emode
124           if (emode == CEED_EVAL_NONE) {
125             // Check for strided restriction
126             ierr = CeedElemRestrictionIsStrided(Erestrict, &strided);
127             CeedChkBackend(ierr);
128             if (strided) {
129               // Check if vector is already in preferred backend ordering
130               ierr = CeedElemRestrictionHasBackendStrides(Erestrict,
131                      &skiprestrict); CeedChkBackend(ierr);
132             }
133           }
134         }
135       }
136       if (skiprestrict) {
137         // We do not need an E-Vector, but will use the input field vector's data
138         // directly in the operator application.
139         evecs[i + starte] = NULL;
140       } else {
141         ierr = CeedElemRestrictionCreateVector(Erestrict, NULL,
142                                                &evecs[i + starte]);
143         CeedChkBackend(ierr);
144       }
145     }
146 
147     switch (emode) {
148     case CEED_EVAL_NONE:
149       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChkBackend(ierr);
150       ierr = CeedVectorCreate(ceed, numelements * Q * size, &qvecs[i]);
151       CeedChkBackend(ierr);
152       break;
153     case CEED_EVAL_INTERP:
154       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChkBackend(ierr);
155       ierr = CeedVectorCreate(ceed, numelements * Q * size, &qvecs[i]);
156       CeedChkBackend(ierr);
157       break;
158     case CEED_EVAL_GRAD:
159       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChkBackend(ierr);
160       ierr = CeedQFunctionFieldGetSize(qffields[i], &size); CeedChkBackend(ierr);
161       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
162       ierr = CeedVectorCreate(ceed, numelements * Q * size, &qvecs[i]);
163       CeedChkBackend(ierr);
164       break;
165     case CEED_EVAL_WEIGHT: // Only on input fields
166       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChkBackend(ierr);
167       ierr = CeedVectorCreate(ceed, numelements * Q, &qvecs[i]); CeedChkBackend(ierr);
168       ierr = CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE,
169                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChkBackend(ierr);
170       break;
171     case CEED_EVAL_DIV:
172       break; // TODO: Not implemented
173     case CEED_EVAL_CURL:
174       break; // TODO: Not implemented
175     }
176   }
177   return CEED_ERROR_SUCCESS;
178 }
179 
180 //------------------------------------------------------------------------------
181 // CeedOperator needs to connect all the named fields (be they active or passive)
182 //   to the named inputs and outputs of its CeedQFunction.
183 //------------------------------------------------------------------------------
184 static int CeedOperatorSetup_Hip(CeedOperator op) {
185   int ierr;
186   bool setupdone;
187   ierr = CeedOperatorIsSetupDone(op, &setupdone); CeedChkBackend(ierr);
188   if (setupdone)
189     return CEED_ERROR_SUCCESS;
190   Ceed ceed;
191   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
192   CeedOperator_Hip *impl;
193   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
194   CeedQFunction qf;
195   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
196   CeedInt Q, numelements, numinputfields, numoutputfields;
197   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
198   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChkBackend(ierr);
199   CeedOperatorField *opinputfields, *opoutputfields;
200   ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields,
201                                &numoutputfields, &opoutputfields);
202   CeedChkBackend(ierr);
203   CeedQFunctionField *qfinputfields, *qfoutputfields;
204   ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields);
205   CeedChkBackend(ierr);
206 
207   // Allocate
208   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
209   CeedChkBackend(ierr);
210 
211   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->qvecsin); CeedChkBackend(ierr);
212   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->qvecsout); CeedChkBackend(ierr);
213 
214   impl->numein = numinputfields; impl->numeout = numoutputfields;
215 
216   // Set up infield and outfield evecs and qvecs
217   // Infields
218   ierr = CeedOperatorSetupFields_Hip(qf, op, true,
219                                      impl->evecs, impl->qvecsin, 0,
220                                      numinputfields, Q, numelements);
221   CeedChkBackend(ierr);
222 
223   // Outfields
224   ierr = CeedOperatorSetupFields_Hip(qf, op, false,
225                                      impl->evecs, impl->qvecsout,
226                                      numinputfields, numoutputfields, Q,
227                                      numelements); CeedChkBackend(ierr);
228 
229   ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr);
230   return CEED_ERROR_SUCCESS;
231 }
232 
233 //------------------------------------------------------------------------------
234 // Setup Operator Inputs
235 //------------------------------------------------------------------------------
236 static inline int CeedOperatorSetupInputs_Hip(CeedInt numinputfields,
237     CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields,
238     CeedVector invec, const bool skipactive, CeedScalar *edata[2*CEED_FIELD_MAX],
239     CeedOperator_Hip *impl, CeedRequest *request) {
240   CeedInt ierr;
241   CeedEvalMode emode;
242   CeedVector vec;
243   CeedElemRestriction Erestrict;
244 
245   for (CeedInt i = 0; i < numinputfields; i++) {
246     // Get input vector
247     ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
248     if (vec == CEED_VECTOR_ACTIVE) {
249       if (skipactive)
250         continue;
251       else
252         vec = invec;
253     }
254 
255     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
256     CeedChkBackend(ierr);
257     if (emode == CEED_EVAL_WEIGHT) { // Skip
258     } else {
259       // Get input vector
260       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
261       // Get input element restriction
262       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
263       CeedChkBackend(ierr);
264       if (vec == CEED_VECTOR_ACTIVE)
265         vec = invec;
266       // Restrict, if necessary
267       if (!impl->evecs[i]) {
268         // No restriction for this field; read data directly from vec.
269         ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE,
270                                       (const CeedScalar **) &edata[i]);
271         CeedChkBackend(ierr);
272       } else {
273         ierr = CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE, vec,
274                                         impl->evecs[i], request); CeedChkBackend(ierr);
275         // Get evec
276         ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_DEVICE,
277                                       (const CeedScalar **) &edata[i]);
278         CeedChkBackend(ierr);
279       }
280     }
281   }
282   return CEED_ERROR_SUCCESS;
283 }
284 
285 //------------------------------------------------------------------------------
286 // Input Basis Action
287 //------------------------------------------------------------------------------
288 static inline int CeedOperatorInputBasis_Hip(CeedInt numelements,
289     CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields,
290     CeedInt numinputfields, const bool skipactive,
291     CeedScalar *edata[2*CEED_FIELD_MAX], CeedOperator_Hip *impl) {
292   CeedInt ierr;
293   CeedInt elemsize, size;
294   CeedElemRestriction Erestrict;
295   CeedEvalMode emode;
296   CeedBasis basis;
297 
298   for (CeedInt i=0; i<numinputfields; i++) {
299     // Skip active input
300     if (skipactive) {
301       CeedVector vec;
302       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
303       if (vec == CEED_VECTOR_ACTIVE)
304         continue;
305     }
306     // Get elemsize, emode, size
307     ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
308     CeedChkBackend(ierr);
309     ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
310     CeedChkBackend(ierr);
311     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
312     CeedChkBackend(ierr);
313     ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChkBackend(ierr);
314     // Basis action
315     switch (emode) {
316     case CEED_EVAL_NONE:
317       ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_DEVICE,
318                                 CEED_USE_POINTER, edata[i]); CeedChkBackend(ierr);
319       break;
320     case CEED_EVAL_INTERP:
321       ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
322       CeedChkBackend(ierr);
323       ierr = CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE,
324                             CEED_EVAL_INTERP, impl->evecs[i],
325                             impl->qvecsin[i]); CeedChkBackend(ierr);
326       break;
327     case CEED_EVAL_GRAD:
328       ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
329       CeedChkBackend(ierr);
330       ierr = CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE,
331                             CEED_EVAL_GRAD, impl->evecs[i],
332                             impl->qvecsin[i]); CeedChkBackend(ierr);
333       break;
334     case CEED_EVAL_WEIGHT:
335       break; // No action
336     case CEED_EVAL_DIV:
337       break; // TODO: Not implemented
338     case CEED_EVAL_CURL:
339       break; // TODO: Not implemented
340     }
341   }
342   return CEED_ERROR_SUCCESS;
343 }
344 
345 //------------------------------------------------------------------------------
346 // Restore Input Vectors
347 //------------------------------------------------------------------------------
348 static inline int CeedOperatorRestoreInputs_Hip(CeedInt numinputfields,
349     CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields,
350     const bool skipactive, CeedScalar *edata[2*CEED_FIELD_MAX],
351     CeedOperator_Hip *impl) {
352   CeedInt ierr;
353   CeedEvalMode emode;
354   CeedVector vec;
355 
356   for (CeedInt i = 0; i < numinputfields; i++) {
357     // Skip active input
358     if (skipactive) {
359       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
360       if (vec == CEED_VECTOR_ACTIVE)
361         continue;
362     }
363     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
364     CeedChkBackend(ierr);
365     if (emode == CEED_EVAL_WEIGHT) { // Skip
366     } else {
367       if (!impl->evecs[i]) {  // This was a skiprestrict case
368         ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
369         ierr = CeedVectorRestoreArrayRead(vec,
370                                           (const CeedScalar **)&edata[i]);
371         CeedChkBackend(ierr);
372       } else {
373         ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
374                                           (const CeedScalar **) &edata[i]);
375         CeedChkBackend(ierr);
376       }
377     }
378   }
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Apply and add to output
384 //------------------------------------------------------------------------------
385 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector invec,
386                                     CeedVector outvec, CeedRequest *request) {
387   int ierr;
388   CeedOperator_Hip *impl;
389   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
390   CeedQFunction qf;
391   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
392   CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, size;
393   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
394   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChkBackend(ierr);
395   CeedOperatorField *opinputfields, *opoutputfields;
396   ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields,
397                                &numoutputfields, &opoutputfields);
398   CeedChkBackend(ierr);
399   CeedQFunctionField *qfinputfields, *qfoutputfields;
400   ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields);
401   CeedChkBackend(ierr);
402   CeedEvalMode emode;
403   CeedVector vec;
404   CeedBasis basis;
405   CeedElemRestriction Erestrict;
406   CeedScalar *edata[2*CEED_FIELD_MAX];
407 
408   // Setup
409   ierr = CeedOperatorSetup_Hip(op); CeedChkBackend(ierr);
410 
411   // Input Evecs and Restriction
412   ierr = CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields,
413                                      opinputfields, invec, false, edata,
414                                      impl, request); CeedChkBackend(ierr);
415 
416   // Input basis apply if needed
417   ierr = CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields,
418                                     numinputfields, false, edata, impl);
419   CeedChkBackend(ierr);
420 
421   // Output pointers, as necessary
422   for (CeedInt i = 0; i < numoutputfields; i++) {
423     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
424     CeedChkBackend(ierr);
425     if (emode == CEED_EVAL_NONE) {
426       // Set the output Q-Vector to use the E-Vector data directly.
427       ierr = CeedVectorGetArrayWrite(impl->evecs[i + impl->numein], CEED_MEM_DEVICE,
428                                      &edata[i + numinputfields]); CeedChkBackend(ierr);
429       ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_DEVICE,
430                                 CEED_USE_POINTER, edata[i + numinputfields]);
431       CeedChkBackend(ierr);
432     }
433   }
434 
435   // Q function
436   ierr = CeedQFunctionApply(qf, numelements * Q, impl->qvecsin, impl->qvecsout);
437   CeedChkBackend(ierr);
438 
439   // Output basis apply if needed
440   for (CeedInt i = 0; i < numoutputfields; i++) {
441     // Get elemsize, emode, size
442     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
443     CeedChkBackend(ierr);
444     ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
445     CeedChkBackend(ierr);
446     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
447     CeedChkBackend(ierr);
448     ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size);
449     CeedChkBackend(ierr);
450     // Basis action
451     switch (emode) {
452     case CEED_EVAL_NONE:
453       break;
454     case CEED_EVAL_INTERP:
455       ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
456       CeedChkBackend(ierr);
457       ierr = CeedBasisApply(basis, numelements, CEED_TRANSPOSE,
458                             CEED_EVAL_INTERP, impl->qvecsout[i],
459                             impl->evecs[i + impl->numein]); CeedChkBackend(ierr);
460       break;
461     case CEED_EVAL_GRAD:
462       ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
463       CeedChkBackend(ierr);
464       ierr = CeedBasisApply(basis, numelements, CEED_TRANSPOSE,
465                             CEED_EVAL_GRAD, impl->qvecsout[i],
466                             impl->evecs[i + impl->numein]); CeedChkBackend(ierr);
467       break;
468     // LCOV_EXCL_START
469     case CEED_EVAL_WEIGHT: {
470       Ceed ceed;
471       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
472       return CeedError(ceed, CEED_ERROR_BACKEND,
473                        "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
474       break; // Should not occur
475     }
476     case CEED_EVAL_DIV:
477       break; // TODO: Not implemented
478     case CEED_EVAL_CURL:
479       break; // TODO: Not implemented
480       // LCOV_EXCL_STOP
481     }
482   }
483 
484   // Output restriction
485   for (CeedInt i = 0; i < numoutputfields; i++) {
486     // Restore evec
487     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
488     CeedChkBackend(ierr);
489     if (emode == CEED_EVAL_NONE) {
490       ierr = CeedVectorRestoreArray(impl->evecs[i+impl->numein],
491                                     &edata[i + numinputfields]);
492       CeedChkBackend(ierr);
493     }
494     // Get output vector
495     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
496     CeedChkBackend(ierr);
497     // Restrict
498     ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
499     CeedChkBackend(ierr);
500     // Active
501     if (vec == CEED_VECTOR_ACTIVE)
502       vec = outvec;
503 
504     ierr = CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE,
505                                     impl->evecs[i + impl->numein], vec,
506                                     request); CeedChkBackend(ierr);
507   }
508 
509   // Restore input arrays
510   ierr = CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields,
511                                        opinputfields, false, edata, impl);
512   CeedChkBackend(ierr);
513   return CEED_ERROR_SUCCESS;
514 }
515 
516 //------------------------------------------------------------------------------
517 // Core code for assembling linear QFunction
518 //------------------------------------------------------------------------------
519 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op,
520     bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
521     CeedRequest *request) {
522   int ierr;
523   CeedOperator_Hip *impl;
524   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
525   CeedQFunction qf;
526   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
527   CeedInt Q, numelements, numinputfields, numoutputfields, size;
528   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
529   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChkBackend(ierr);
530   CeedOperatorField *opinputfields, *opoutputfields;
531   ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields,
532                                &numoutputfields, &opoutputfields);
533   CeedChkBackend(ierr);
534   CeedQFunctionField *qfinputfields, *qfoutputfields;
535   ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields);
536   CeedChkBackend(ierr);
537   CeedVector vec;
538   CeedInt numactivein = impl->qfnumactivein, numactiveout = impl->qfnumactiveout;
539   CeedVector *activein = impl->qfactivein;
540   CeedScalar *a, *tmp;
541   Ceed ceed, ceedparent;
542   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
543   ierr = CeedGetOperatorFallbackParentCeed(ceed, &ceedparent);
544   CeedChkBackend(ierr);
545   ceedparent = ceedparent ? ceedparent : ceed;
546   CeedScalar *edata[2*CEED_FIELD_MAX];
547 
548   // Setup
549   ierr = CeedOperatorSetup_Hip(op); CeedChkBackend(ierr);
550 
551   // Check for identity
552   bool identityqf;
553   ierr = CeedQFunctionIsIdentity(qf, &identityqf); CeedChkBackend(ierr);
554   if (identityqf)
555     // LCOV_EXCL_START
556     return CeedError(ceed, CEED_ERROR_BACKEND,
557                      "Assembling identity QFunctions not supported");
558   // LCOV_EXCL_STOP
559 
560   // Input Evecs and Restriction
561   ierr = CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields,
562                                      opinputfields, NULL, true, edata,
563                                      impl, request); CeedChkBackend(ierr);
564 
565   // Count number of active input fields
566   if (!numactivein) {
567     for (CeedInt i=0; i<numinputfields; i++) {
568       // Get input vector
569       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
570       // Check if active input
571       if (vec == CEED_VECTOR_ACTIVE) {
572         ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChkBackend(ierr);
573         ierr = CeedVectorSetValue(impl->qvecsin[i], 0.0); CeedChkBackend(ierr);
574         ierr = CeedVectorGetArray(impl->qvecsin[i], CEED_MEM_DEVICE, &tmp);
575         CeedChkBackend(ierr);
576         ierr = CeedRealloc(numactivein + size, &activein); CeedChkBackend(ierr);
577         for (CeedInt field = 0; field < size; field++) {
578           ierr = CeedVectorCreate(ceed, Q*numelements,
579                                   &activein[numactivein+field]); CeedChkBackend(ierr);
580           ierr = CeedVectorSetArray(activein[numactivein+field], CEED_MEM_DEVICE,
581                                     CEED_USE_POINTER, &tmp[field*Q*numelements]);
582           CeedChkBackend(ierr);
583         }
584         numactivein += size;
585         ierr = CeedVectorRestoreArray(impl->qvecsin[i], &tmp); CeedChkBackend(ierr);
586       }
587     }
588     impl->qfnumactivein = numactivein;
589     impl->qfactivein = activein;
590   }
591 
592   // Count number of active output fields
593   if (!numactiveout) {
594     for (CeedInt i=0; i<numoutputfields; i++) {
595       // Get output vector
596       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
597       CeedChkBackend(ierr);
598       // Check if active output
599       if (vec == CEED_VECTOR_ACTIVE) {
600         ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size);
601         CeedChkBackend(ierr);
602         numactiveout += size;
603       }
604     }
605     impl->qfnumactiveout = numactiveout;
606   }
607 
608   // Check sizes
609   if (!numactivein || !numactiveout)
610     // LCOV_EXCL_START
611     return CeedError(ceed, CEED_ERROR_BACKEND,
612                      "Cannot assemble QFunction without active inputs "
613                      "and outputs");
614   // LCOV_EXCL_STOP
615 
616   // Build objects if needed
617   if (build_objects) {
618     // Create output restriction
619     CeedInt strides[3] = {1, numelements*Q, Q}; /* *NOPAD* */
620     ierr = CeedElemRestrictionCreateStrided(ceedparent, numelements, Q,
621                                             numactivein*numactiveout,
622                                             numactivein*numactiveout*numelements*Q,
623                                             strides, rstr); CeedChkBackend(ierr);
624     // Create assembled vector
625     ierr = CeedVectorCreate(ceedparent, numelements*Q*numactivein*numactiveout,
626                             assembled); CeedChkBackend(ierr);
627   }
628   ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr);
629   ierr = CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &a);
630   CeedChkBackend(ierr);
631 
632   // Input basis apply
633   ierr = CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields,
634                                     numinputfields, true, edata, impl);
635   CeedChkBackend(ierr);
636 
637   // Assemble QFunction
638   for (CeedInt in=0; in<numactivein; in++) {
639     // Set Inputs
640     ierr = CeedVectorSetValue(activein[in], 1.0); CeedChkBackend(ierr);
641     if (numactivein > 1) {
642       ierr = CeedVectorSetValue(activein[(in+numactivein-1)%numactivein],
643                                 0.0); CeedChkBackend(ierr);
644     }
645     // Set Outputs
646     for (CeedInt out=0; out<numoutputfields; out++) {
647       // Get output vector
648       ierr = CeedOperatorFieldGetVector(opoutputfields[out], &vec);
649       CeedChkBackend(ierr);
650       // Check if active output
651       if (vec == CEED_VECTOR_ACTIVE) {
652         CeedVectorSetArray(impl->qvecsout[out], CEED_MEM_DEVICE,
653                            CEED_USE_POINTER, a); CeedChkBackend(ierr);
654         ierr = CeedQFunctionFieldGetSize(qfoutputfields[out], &size);
655         CeedChkBackend(ierr);
656         a += size*Q*numelements; // Advance the pointer by the size of the output
657       }
658     }
659     // Apply QFunction
660     ierr = CeedQFunctionApply(qf, Q*numelements, impl->qvecsin, impl->qvecsout);
661     CeedChkBackend(ierr);
662   }
663 
664   // Un-set output Qvecs to prevent accidental overwrite of Assembled
665   for (CeedInt out=0; out<numoutputfields; out++) {
666     // Get output vector
667     ierr = CeedOperatorFieldGetVector(opoutputfields[out], &vec);
668     CeedChkBackend(ierr);
669     // Check if active output
670     if (vec == CEED_VECTOR_ACTIVE) {
671       ierr = CeedVectorTakeArray(impl->qvecsout[out], CEED_MEM_DEVICE, NULL);
672       CeedChkBackend(ierr);
673     }
674   }
675 
676   // Restore input arrays
677   ierr = CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields,
678                                        opinputfields, true, edata, impl);
679   CeedChkBackend(ierr);
680 
681   // Restore output
682   ierr = CeedVectorRestoreArray(*assembled, &a); CeedChkBackend(ierr);
683 
684   return CEED_ERROR_SUCCESS;
685 }
686 
687 //------------------------------------------------------------------------------
688 // Assemble Linear QFunction
689 //------------------------------------------------------------------------------
690 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op,
691     CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
692   return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr,
693          request);
694 }
695 
696 //------------------------------------------------------------------------------
697 // Assemble Linear QFunction
698 //------------------------------------------------------------------------------
699 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op,
700     CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
701   return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr,
702          request);
703 }
704 
705 //------------------------------------------------------------------------------
706 // Diagonal assembly kernels
707 //------------------------------------------------------------------------------
708 // *INDENT-OFF*
709 static const char *diagonalkernels = QUOTE(
710 
711 typedef enum {
712   /// Perform no evaluation (either because there is no data or it is already at
713   /// quadrature points)
714   CEED_EVAL_NONE   = 0,
715   /// Interpolate from nodes to quadrature points
716   CEED_EVAL_INTERP = 1,
717   /// Evaluate gradients at quadrature points from input in a nodal basis
718   CEED_EVAL_GRAD   = 2,
719   /// Evaluate divergence at quadrature points from input in a nodal basis
720   CEED_EVAL_DIV    = 4,
721   /// Evaluate curl at quadrature points from input in a nodal basis
722   CEED_EVAL_CURL   = 8,
723   /// Using no input, evaluate quadrature weights on the reference element
724   CEED_EVAL_WEIGHT = 16,
725 } CeedEvalMode;
726 
727 //------------------------------------------------------------------------------
728 // Get Basis Emode Pointer
729 //------------------------------------------------------------------------------
730 extern "C" __device__ void CeedOperatorGetBasisPointer_Hip(const CeedScalar **basisptr,
731     CeedEvalMode emode, const CeedScalar *identity, const CeedScalar *interp,
732     const CeedScalar *grad) {
733   switch (emode) {
734   case CEED_EVAL_NONE:
735     *basisptr = identity;
736     break;
737   case CEED_EVAL_INTERP:
738     *basisptr = interp;
739     break;
740   case CEED_EVAL_GRAD:
741     *basisptr = grad;
742     break;
743   case CEED_EVAL_WEIGHT:
744   case CEED_EVAL_DIV:
745   case CEED_EVAL_CURL:
746     break; // Caught by QF Assembly
747   }
748 }
749 
750 //------------------------------------------------------------------------------
751 // Core code for diagonal assembly
752 //------------------------------------------------------------------------------
753 __device__ void diagonalCore(const CeedInt nelem,
754     const CeedScalar maxnorm, const bool pointBlock,
755     const CeedScalar *identity,
756     const CeedScalar *interpin, const CeedScalar *gradin,
757     const CeedScalar *interpout, const CeedScalar *gradout,
758     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
759     const CeedScalar *__restrict__ assembledqfarray,
760     CeedScalar *__restrict__ elemdiagarray) {
761   const int tid = threadIdx.x; // running with P threads, tid is evec node
762   const CeedScalar qfvaluebound = maxnorm*1e-12;
763 
764   // Compute the diagonal of B^T D B
765   // Each element
766   for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < nelem;
767        e += gridDim.x*blockDim.z) {
768     CeedInt dout = -1;
769     // Each basis eval mode pair
770     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
771       const CeedScalar *bt = NULL;
772       if (emodeout[eout] == CEED_EVAL_GRAD)
773         dout += 1;
774       CeedOperatorGetBasisPointer_Hip(&bt, emodeout[eout], identity, interpout,
775                                       &gradout[dout*NQPTS*NNODES]);
776       CeedInt din = -1;
777       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
778         const CeedScalar *b = NULL;
779         if (emodein[ein] == CEED_EVAL_GRAD)
780           din += 1;
781         CeedOperatorGetBasisPointer_Hip(&b, emodein[ein], identity, interpin,
782                                         &gradin[din*NQPTS*NNODES]);
783         // Each component
784         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
785           // Each qpoint/node pair
786           if (pointBlock) {
787             // Point Block Diagonal
788             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
789               CeedScalar evalue = 0.;
790               for (CeedInt q = 0; q < NQPTS; q++) {
791                 const CeedScalar qfvalue =
792                   assembledqfarray[((((ein*NCOMP+compIn)*NUMEMODEOUT+eout)*
793                                      NCOMP+compOut)*nelem+e)*NQPTS+q];
794                 if (abs(qfvalue) > qfvaluebound)
795                   evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
796               }
797               elemdiagarray[((compOut*NCOMP+compIn)*nelem+e)*NNODES+tid] += evalue;
798             }
799           } else {
800             // Diagonal Only
801             CeedScalar evalue = 0.;
802             for (CeedInt q = 0; q < NQPTS; q++) {
803               const CeedScalar qfvalue =
804                 assembledqfarray[((((ein*NCOMP+compOut)*NUMEMODEOUT+eout)*
805                                    NCOMP+compOut)*nelem+e)*NQPTS+q];
806               if (abs(qfvalue) > qfvaluebound)
807                 evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
808             }
809             elemdiagarray[(compOut*nelem+e)*NNODES+tid] += evalue;
810           }
811         }
812       }
813     }
814   }
815 }
816 
817 //------------------------------------------------------------------------------
818 // Linear diagonal
819 //------------------------------------------------------------------------------
820 extern "C" __global__ void linearDiagonal(const CeedInt nelem,
821     const CeedScalar maxnorm, const CeedScalar *identity,
822     const CeedScalar *interpin, const CeedScalar *gradin,
823     const CeedScalar *interpout, const CeedScalar *gradout,
824     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
825     const CeedScalar *__restrict__ assembledqfarray,
826     CeedScalar *__restrict__ elemdiagarray) {
827   diagonalCore(nelem, maxnorm, false, identity, interpin, gradin, interpout,
828                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
829 }
830 
831 //------------------------------------------------------------------------------
832 // Linear point block diagonal
833 //------------------------------------------------------------------------------
834 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem,
835     const CeedScalar maxnorm, const CeedScalar *identity,
836     const CeedScalar *interpin, const CeedScalar *gradin,
837     const CeedScalar *interpout, const CeedScalar *gradout,
838     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
839     const CeedScalar *__restrict__ assembledqfarray,
840     CeedScalar *__restrict__ elemdiagarray) {
841   diagonalCore(nelem, maxnorm, true, identity, interpin, gradin, interpout,
842                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
843 }
844 
845 );
846 // *INDENT-ON*
847 
848 //------------------------------------------------------------------------------
849 // Create point block restriction
850 //------------------------------------------------------------------------------
851 static int CreatePBRestriction(CeedElemRestriction rstr,
852                                CeedElemRestriction *pbRstr) {
853   int ierr;
854   Ceed ceed;
855   ierr = CeedElemRestrictionGetCeed(rstr, &ceed); CeedChkBackend(ierr);
856   const CeedInt *offsets;
857   ierr = CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets);
858   CeedChkBackend(ierr);
859 
860   // Expand offsets
861   CeedInt nelem, ncomp, elemsize, compstride, max = 1, *pbOffsets;
862   ierr = CeedElemRestrictionGetNumElements(rstr, &nelem); CeedChkBackend(ierr);
863   ierr = CeedElemRestrictionGetNumComponents(rstr, &ncomp); CeedChkBackend(ierr);
864   ierr = CeedElemRestrictionGetElementSize(rstr, &elemsize); CeedChkBackend(ierr);
865   ierr = CeedElemRestrictionGetCompStride(rstr, &compstride);
866   CeedChkBackend(ierr);
867   CeedInt shift = ncomp;
868   if (compstride != 1)
869     shift *= ncomp;
870   ierr = CeedCalloc(nelem*elemsize, &pbOffsets); CeedChkBackend(ierr);
871   for (CeedInt i = 0; i < nelem*elemsize; i++) {
872     pbOffsets[i] = offsets[i]*shift;
873     if (pbOffsets[i] > max)
874       max = pbOffsets[i];
875   }
876 
877   // Create new restriction
878   ierr = CeedElemRestrictionCreate(ceed, nelem, elemsize, ncomp*ncomp, 1,
879                                    max + ncomp*ncomp, CEED_MEM_HOST,
880                                    CEED_OWN_POINTER, pbOffsets, pbRstr);
881   CeedChkBackend(ierr);
882 
883   // Cleanup
884   ierr = CeedElemRestrictionRestoreOffsets(rstr, &offsets); CeedChkBackend(ierr);
885 
886   return CEED_ERROR_SUCCESS;
887 }
888 
889 //------------------------------------------------------------------------------
890 // Assemble diagonal setup
891 //------------------------------------------------------------------------------
892 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op,
893     const bool pointBlock) {
894   int ierr;
895   Ceed ceed;
896   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
897   CeedQFunction qf;
898   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
899   CeedInt numinputfields, numoutputfields;
900   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
901   CeedChkBackend(ierr);
902 
903   // Determine active input basis
904   CeedOperatorField *opfields;
905   CeedQFunctionField *qffields;
906   ierr = CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL);
907   CeedChkBackend(ierr);
908   ierr = CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL);
909   CeedChkBackend(ierr);
910   CeedInt numemodein = 0, ncomp = 0, dim = 1;
911   CeedEvalMode *emodein = NULL;
912   CeedBasis basisin = NULL;
913   CeedElemRestriction rstrin = NULL;
914   for (CeedInt i = 0; i < numinputfields; i++) {
915     CeedVector vec;
916     ierr = CeedOperatorFieldGetVector(opfields[i], &vec); CeedChkBackend(ierr);
917     if (vec == CEED_VECTOR_ACTIVE) {
918       CeedElemRestriction rstr;
919       ierr = CeedOperatorFieldGetBasis(opfields[i], &basisin); CeedChkBackend(ierr);
920       ierr = CeedBasisGetNumComponents(basisin, &ncomp); CeedChkBackend(ierr);
921       ierr = CeedBasisGetDimension(basisin, &dim); CeedChkBackend(ierr);
922       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &rstr);
923       CeedChkBackend(ierr);
924       if (rstrin && rstrin != rstr)
925         // LCOV_EXCL_START
926         return CeedError(ceed, CEED_ERROR_BACKEND,
927                          "Multi-field non-composite operator diagonal assembly not supported");
928       // LCOV_EXCL_STOP
929       rstrin = rstr;
930       CeedEvalMode emode;
931       ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode);
932       CeedChkBackend(ierr);
933       switch (emode) {
934       case CEED_EVAL_NONE:
935       case CEED_EVAL_INTERP:
936         ierr = CeedRealloc(numemodein + 1, &emodein); CeedChkBackend(ierr);
937         emodein[numemodein] = emode;
938         numemodein += 1;
939         break;
940       case CEED_EVAL_GRAD:
941         ierr = CeedRealloc(numemodein + dim, &emodein); CeedChkBackend(ierr);
942         for (CeedInt d = 0; d < dim; d++)
943           emodein[numemodein+d] = emode;
944         numemodein += dim;
945         break;
946       case CEED_EVAL_WEIGHT:
947       case CEED_EVAL_DIV:
948       case CEED_EVAL_CURL:
949         break; // Caught by QF Assembly
950       }
951     }
952   }
953 
954   // Determine active output basis
955   ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields);
956   CeedChkBackend(ierr);
957   ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields);
958   CeedChkBackend(ierr);
959   CeedInt numemodeout = 0;
960   CeedEvalMode *emodeout = NULL;
961   CeedBasis basisout = NULL;
962   CeedElemRestriction rstrout = NULL;
963   for (CeedInt i = 0; i < numoutputfields; i++) {
964     CeedVector vec;
965     ierr = CeedOperatorFieldGetVector(opfields[i], &vec); CeedChkBackend(ierr);
966     if (vec == CEED_VECTOR_ACTIVE) {
967       CeedElemRestriction rstr;
968       ierr = CeedOperatorFieldGetBasis(opfields[i], &basisout); CeedChkBackend(ierr);
969       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &rstr);
970       CeedChkBackend(ierr);
971       if (rstrout && rstrout != rstr)
972         // LCOV_EXCL_START
973         return CeedError(ceed, CEED_ERROR_BACKEND,
974                          "Multi-field non-composite operator diagonal assembly not supported");
975       // LCOV_EXCL_STOP
976       rstrout = rstr;
977       CeedEvalMode emode;
978       ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChkBackend(ierr);
979       switch (emode) {
980       case CEED_EVAL_NONE:
981       case CEED_EVAL_INTERP:
982         ierr = CeedRealloc(numemodeout + 1, &emodeout); CeedChkBackend(ierr);
983         emodeout[numemodeout] = emode;
984         numemodeout += 1;
985         break;
986       case CEED_EVAL_GRAD:
987         ierr = CeedRealloc(numemodeout + dim, &emodeout); CeedChkBackend(ierr);
988         for (CeedInt d = 0; d < dim; d++)
989           emodeout[numemodeout+d] = emode;
990         numemodeout += dim;
991         break;
992       case CEED_EVAL_WEIGHT:
993       case CEED_EVAL_DIV:
994       case CEED_EVAL_CURL:
995         break; // Caught by QF Assembly
996       }
997     }
998   }
999 
1000   // Operator data struct
1001   CeedOperator_Hip *impl;
1002   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
1003   ierr = CeedCalloc(1, &impl->diag); CeedChkBackend(ierr);
1004   CeedOperatorDiag_Hip *diag = impl->diag;
1005   diag->basisin = basisin;
1006   diag->basisout = basisout;
1007   diag->h_emodein = emodein;
1008   diag->h_emodeout = emodeout;
1009   diag->numemodein = numemodein;
1010   diag->numemodeout = numemodeout;
1011 
1012   // Assemble kernel
1013   CeedInt nnodes, nqpts;
1014   ierr = CeedBasisGetNumNodes(basisin, &nnodes); CeedChkBackend(ierr);
1015   ierr = CeedBasisGetNumQuadraturePoints(basisin, &nqpts); CeedChkBackend(ierr);
1016   diag->nnodes = nnodes;
1017   ierr = CeedCompileHip(ceed, diagonalkernels, &diag->module, 5,
1018                         "NUMEMODEIN", numemodein,
1019                         "NUMEMODEOUT", numemodeout,
1020                         "NNODES", nnodes,
1021                         "NQPTS", nqpts,
1022                         "NCOMP", ncomp
1023                        ); CeedChk_Hip(ceed, ierr);
1024   ierr = CeedGetKernelHip(ceed, diag->module, "linearDiagonal",
1025                           &diag->linearDiagonal); CeedChk_Hip(ceed, ierr);
1026   ierr = CeedGetKernelHip(ceed, diag->module, "linearPointBlockDiagonal",
1027                           &diag->linearPointBlock);
1028   CeedChk_Hip(ceed, ierr);
1029 
1030   // Basis matrices
1031   const CeedInt qBytes = nqpts * sizeof(CeedScalar);
1032   const CeedInt iBytes = qBytes * nnodes;
1033   const CeedInt gBytes = qBytes * nnodes * dim;
1034   const CeedInt eBytes = sizeof(CeedEvalMode);
1035   const CeedScalar *interpin, *interpout, *gradin, *gradout;
1036 
1037   // CEED_EVAL_NONE
1038   CeedScalar *identity = NULL;
1039   bool evalNone = false;
1040   for (CeedInt i=0; i<numemodein; i++)
1041     evalNone = evalNone || (emodein[i] == CEED_EVAL_NONE);
1042   for (CeedInt i=0; i<numemodeout; i++)
1043     evalNone = evalNone || (emodeout[i] == CEED_EVAL_NONE);
1044   if (evalNone) {
1045     ierr = CeedCalloc(nqpts*nnodes, &identity); CeedChkBackend(ierr);
1046     for (CeedInt i=0; i<(nnodes<nqpts?nnodes:nqpts); i++)
1047       identity[i*nnodes+i] = 1.0;
1048     ierr = hipMalloc((void **)&diag->d_identity, iBytes); CeedChk_Hip(ceed, ierr);
1049     ierr = hipMemcpy(diag->d_identity, identity, iBytes,
1050                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1051   }
1052 
1053   // CEED_EVAL_INTERP
1054   ierr = CeedBasisGetInterp(basisin, &interpin); CeedChkBackend(ierr);
1055   ierr = hipMalloc((void **)&diag->d_interpin, iBytes); CeedChk_Hip(ceed, ierr);
1056   ierr = hipMemcpy(diag->d_interpin, interpin, iBytes,
1057                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1058   ierr = CeedBasisGetInterp(basisout, &interpout); CeedChkBackend(ierr);
1059   ierr = hipMalloc((void **)&diag->d_interpout, iBytes); CeedChk_Hip(ceed, ierr);
1060   ierr = hipMemcpy(diag->d_interpout, interpout, iBytes,
1061                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1062 
1063   // CEED_EVAL_GRAD
1064   ierr = CeedBasisGetGrad(basisin, &gradin); CeedChkBackend(ierr);
1065   ierr = hipMalloc((void **)&diag->d_gradin, gBytes); CeedChk_Hip(ceed, ierr);
1066   ierr = hipMemcpy(diag->d_gradin, gradin, gBytes,
1067                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1068   ierr = CeedBasisGetGrad(basisout, &gradout); CeedChkBackend(ierr);
1069   ierr = hipMalloc((void **)&diag->d_gradout, gBytes); CeedChk_Hip(ceed, ierr);
1070   ierr = hipMemcpy(diag->d_gradout, gradout, gBytes,
1071                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1072 
1073   // Arrays of emodes
1074   ierr = hipMalloc((void **)&diag->d_emodein, numemodein * eBytes);
1075   CeedChk_Hip(ceed, ierr);
1076   ierr = hipMemcpy(diag->d_emodein, emodein, numemodein * eBytes,
1077                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1078   ierr = hipMalloc((void **)&diag->d_emodeout, numemodeout * eBytes);
1079   CeedChk_Hip(ceed, ierr);
1080   ierr = hipMemcpy(diag->d_emodeout, emodeout, numemodeout * eBytes,
1081                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1082 
1083   // Restriction
1084   diag->diagrstr = rstrout;
1085 
1086   return CEED_ERROR_SUCCESS;
1087 }
1088 
1089 //------------------------------------------------------------------------------
1090 // Assemble diagonal common code
1091 //------------------------------------------------------------------------------
1092 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op,
1093     CeedVector assembled, CeedRequest *request, const bool pointBlock) {
1094   int ierr;
1095   Ceed ceed;
1096   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
1097   CeedOperator_Hip *impl;
1098   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
1099 
1100   // Assemble QFunction
1101   CeedVector assembledqf;
1102   CeedElemRestriction rstr;
1103   ierr = CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembledqf,
1104          &rstr, request); CeedChkBackend(ierr);
1105   ierr = CeedElemRestrictionDestroy(&rstr); CeedChkBackend(ierr);
1106   CeedScalar maxnorm = 0;
1107   ierr = CeedVectorNorm(assembledqf, CEED_NORM_MAX, &maxnorm);
1108   CeedChkBackend(ierr);
1109 
1110   // Setup
1111   if (!impl->diag) {
1112     ierr = CeedOperatorAssembleDiagonalSetup_Hip(op, pointBlock);
1113     CeedChkBackend(ierr);
1114   }
1115   CeedOperatorDiag_Hip *diag = impl->diag;
1116   assert(diag != NULL);
1117 
1118   // Restriction
1119   if (pointBlock && !diag->pbdiagrstr) {
1120     CeedElemRestriction pbdiagrstr;
1121     ierr = CreatePBRestriction(diag->diagrstr, &pbdiagrstr); CeedChkBackend(ierr);
1122     diag->pbdiagrstr = pbdiagrstr;
1123   }
1124   CeedElemRestriction diagrstr = pointBlock ? diag->pbdiagrstr : diag->diagrstr;
1125 
1126   // Create diagonal vector
1127   CeedVector elemdiag = pointBlock ? diag->pbelemdiag : diag->elemdiag;
1128   if (!elemdiag) {
1129     // Element diagonal vector
1130     ierr = CeedElemRestrictionCreateVector(diagrstr, NULL, &elemdiag);
1131     CeedChkBackend(ierr);
1132     if (pointBlock)
1133       diag->pbelemdiag = elemdiag;
1134     else
1135       diag->elemdiag = elemdiag;
1136   }
1137   ierr = CeedVectorSetValue(elemdiag, 0.0); CeedChkBackend(ierr);
1138 
1139   // Assemble element operator diagonals
1140   CeedScalar *elemdiagarray;
1141   const CeedScalar *assembledqfarray;
1142   ierr = CeedVectorGetArray(elemdiag, CEED_MEM_DEVICE, &elemdiagarray);
1143   CeedChkBackend(ierr);
1144   ierr = CeedVectorGetArrayRead(assembledqf, CEED_MEM_DEVICE, &assembledqfarray);
1145   CeedChkBackend(ierr);
1146   CeedInt nelem;
1147   ierr = CeedElemRestrictionGetNumElements(diagrstr, &nelem);
1148   CeedChkBackend(ierr);
1149 
1150   // Compute the diagonal of B^T D B
1151   int elemsPerBlock = 1;
1152   int grid = nelem/elemsPerBlock+((nelem/elemsPerBlock*elemsPerBlock<nelem)?1:0);
1153   void *args[] = {(void *) &nelem, (void *) &maxnorm, &diag->d_identity,
1154                   &diag->d_interpin, &diag->d_gradin, &diag->d_interpout,
1155                   &diag->d_gradout, &diag->d_emodein, &diag->d_emodeout,
1156                   &assembledqfarray, &elemdiagarray
1157                  };
1158   if (pointBlock) {
1159     ierr = CeedRunKernelDimHip(ceed, diag->linearPointBlock, grid,
1160                                diag->nnodes, 1, elemsPerBlock, args);
1161     CeedChkBackend(ierr);
1162   } else {
1163     ierr = CeedRunKernelDimHip(ceed, diag->linearDiagonal, grid,
1164                                diag->nnodes, 1, elemsPerBlock, args);
1165     CeedChkBackend(ierr);
1166   }
1167 
1168   // Restore arrays
1169   ierr = CeedVectorRestoreArray(elemdiag, &elemdiagarray); CeedChkBackend(ierr);
1170   ierr = CeedVectorRestoreArrayRead(assembledqf, &assembledqfarray);
1171   CeedChkBackend(ierr);
1172 
1173   // Assemble local operator diagonal
1174   ierr = CeedElemRestrictionApply(diagrstr, CEED_TRANSPOSE, elemdiag,
1175                                   assembled, request); CeedChkBackend(ierr);
1176 
1177   // Cleanup
1178   ierr = CeedVectorDestroy(&assembledqf); CeedChkBackend(ierr);
1179 
1180   return CEED_ERROR_SUCCESS;
1181 }
1182 
1183 //------------------------------------------------------------------------------
1184 // Assemble composite diagonal common code
1185 //------------------------------------------------------------------------------
1186 static inline int CeedOperatorLinearAssembleAddDiagonalCompositeCore_Hip(
1187   CeedOperator op, CeedVector assembled, CeedRequest *request,
1188   const bool pointBlock) {
1189   int ierr;
1190   CeedInt numSub;
1191   CeedOperator *subOperators;
1192   ierr = CeedOperatorGetNumSub(op, &numSub); CeedChkBackend(ierr);
1193   ierr = CeedOperatorGetSubList(op, &subOperators); CeedChkBackend(ierr);
1194   for (CeedInt i = 0; i < numSub; i++) {
1195     ierr = CeedOperatorAssembleDiagonalCore_Hip(subOperators[i], assembled,
1196            request, pointBlock); CeedChkBackend(ierr);
1197   }
1198   return CEED_ERROR_SUCCESS;
1199 }
1200 
1201 //------------------------------------------------------------------------------
1202 // Assemble Linear Diagonal
1203 //------------------------------------------------------------------------------
1204 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op,
1205     CeedVector assembled, CeedRequest *request) {
1206   int ierr;
1207   bool isComposite;
1208   ierr = CeedOperatorIsComposite(op, &isComposite); CeedChkBackend(ierr);
1209   if (isComposite) {
1210     return CeedOperatorLinearAssembleAddDiagonalCompositeCore_Hip(op, assembled,
1211            request, false);
1212   } else {
1213     return CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false);
1214   }
1215 }
1216 
1217 //------------------------------------------------------------------------------
1218 // Assemble Linear Point Block Diagonal
1219 //------------------------------------------------------------------------------
1220 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op,
1221     CeedVector assembled, CeedRequest *request) {
1222   int ierr;
1223   bool isComposite;
1224   ierr = CeedOperatorIsComposite(op, &isComposite); CeedChkBackend(ierr);
1225   if (isComposite) {
1226     return CeedOperatorLinearAssembleAddDiagonalCompositeCore_Hip(op, assembled,
1227            request, true);
1228   } else {
1229     return CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true);
1230   }
1231 }
1232 
1233 
1234 //------------------------------------------------------------------------------
1235 // Create operator
1236 //------------------------------------------------------------------------------
1237 int CeedOperatorCreate_Hip(CeedOperator op) {
1238   int ierr;
1239   Ceed ceed;
1240   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
1241   CeedOperator_Hip *impl;
1242 
1243   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
1244   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
1245 
1246   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction",
1247                                 CeedOperatorLinearAssembleQFunction_Hip);
1248   CeedChkBackend(ierr);
1249   ierr = CeedSetBackendFunction(ceed, "Operator", op,
1250                                 "LinearAssembleQFunctionUpdate",
1251                                 CeedOperatorLinearAssembleQFunctionUpdate_Hip);
1252   CeedChkBackend(ierr);
1253   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal",
1254                                 CeedOperatorLinearAssembleAddDiagonal_Hip);
1255   CeedChkBackend(ierr);
1256   ierr = CeedSetBackendFunction(ceed, "Operator", op,
1257                                 "LinearAssembleAddPointBlockDiagonal",
1258                                 CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip);
1259   CeedChkBackend(ierr);
1260   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
1261                                 CeedOperatorApplyAdd_Hip); CeedChkBackend(ierr);
1262   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
1263                                 CeedOperatorDestroy_Hip); CeedChkBackend(ierr);
1264   return CEED_ERROR_SUCCESS;
1265 }
1266 
1267 //------------------------------------------------------------------------------
1268 // Composite Operator Create
1269 //------------------------------------------------------------------------------
1270 int CeedCompositeOperatorCreate_Hip(CeedOperator op) {
1271   int ierr;
1272   Ceed ceed;
1273   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
1274 
1275   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal",
1276                                 CeedOperatorLinearAssembleAddDiagonal_Hip);
1277   CeedChkBackend(ierr);
1278   ierr = CeedSetBackendFunction(ceed, "Operator", op,
1279                                 "LinearAssembleAddPointBlockDiagonal",
1280                                 CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip);
1281   CeedChkBackend(ierr);
1282   return CEED_ERROR_SUCCESS;
1283 }
1284 //------------------------------------------------------------------------------
1285