xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-operator.c (revision b7453713e95c1c6eb59ce174cbcb87227e92884e)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <assert.h>
12 #include <cuda.h>
13 #include <cuda_runtime.h>
14 #include <stdbool.h>
15 #include <string.h>
16 
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #include "ceed-cuda-ref.h"
20 
21 //------------------------------------------------------------------------------
22 // Destroy operator
23 //------------------------------------------------------------------------------
24 static int CeedOperatorDestroy_Cuda(CeedOperator op) {
25   CeedOperator_Cuda *impl;
26 
27   CeedCallBackend(CeedOperatorGetData(op, &impl));
28 
29   // Apply data
30   for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
31     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i]));
32   }
33   CeedCallBackend(CeedFree(&impl->e_vecs));
34 
35   for (CeedInt i = 0; i < impl->num_inputs; i++) {
36     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i]));
37   }
38   CeedCallBackend(CeedFree(&impl->q_vecs_in));
39 
40   for (CeedInt i = 0; i < impl->num_outputs; i++) {
41     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
42   }
43   CeedCallBackend(CeedFree(&impl->q_vecs_out));
44 
45   // QFunction assembly data
46   for (CeedInt i = 0; i < impl->num_active_in; i++) {
47     CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i]));
48   }
49   CeedCallBackend(CeedFree(&impl->qf_active_in));
50 
51   // Diag data
52   if (impl->diag) {
53     Ceed ceed;
54 
55     CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
56     CeedCallCuda(ceed, cuModuleUnload(impl->diag->module));
57     CeedCallBackend(CeedFree(&impl->diag->h_e_mode_in));
58     CeedCallBackend(CeedFree(&impl->diag->h_e_mode_out));
59     CeedCallCuda(ceed, cudaFree(impl->diag->d_e_mode_in));
60     CeedCallCuda(ceed, cudaFree(impl->diag->d_e_mode_out));
61     CeedCallCuda(ceed, cudaFree(impl->diag->d_identity));
62     CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_in));
63     CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_out));
64     CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_in));
65     CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_out));
66     CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_rstr));
67     CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
68     CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
69   }
70   CeedCallBackend(CeedFree(&impl->diag));
71 
72   if (impl->asmb) {
73     Ceed ceed;
74 
75     CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
76     CeedCallCuda(ceed, cuModuleUnload(impl->asmb->module));
77     CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_in));
78     CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out));
79   }
80   CeedCallBackend(CeedFree(&impl->asmb));
81 
82   CeedCallBackend(CeedFree(&impl));
83   return CEED_ERROR_SUCCESS;
84 }
85 
86 //------------------------------------------------------------------------------
87 // Setup infields or outfields
88 //------------------------------------------------------------------------------
89 static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt e_start,
90                                         CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
91   Ceed                ceed;
92   bool                is_strided, skip_restriction;
93   CeedSize            q_size;
94   CeedInt             dim, size;
95   CeedQFunctionField *qf_fields;
96   CeedOperatorField  *op_fields;
97 
98   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
99 
100   if (is_input) {
101     CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
102     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
103   } else {
104     CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
105     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
106   }
107 
108   // Loop over fields
109   for (CeedInt i = 0; i < num_fields; i++) {
110     CeedEvalMode e_mode;
111     CeedBasis    basis;
112 
113     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &e_mode));
114 
115     is_strided       = false;
116     skip_restriction = false;
117     if (e_mode != CEED_EVAL_WEIGHT) {
118       CeedElemRestriction elem_restr;
119 
120       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_restr));
121 
122       // Check whether this field can skip the element restriction:
123       // must be passive input, with e_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND.
124 
125       // First, check whether the field is input or output:
126       if (is_input) {
127         CeedVector vec;
128 
129         // Check for passive input:
130         CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
131         if (vec != CEED_VECTOR_ACTIVE) {
132           // Check e_mode
133           if (e_mode == CEED_EVAL_NONE) {
134             // Check for strided restriction
135             CeedCallBackend(CeedElemRestrictionIsStrided(elem_restr, &is_strided));
136             if (is_strided) {
137               // Check if vector is already in preferred backend ordering
138               CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_restr, &skip_restriction));
139             }
140           }
141         }
142       }
143       if (skip_restriction) {
144         // We do not need an E-Vector, but will use the input field vector's data directly in the operator application.
145         e_vecs[i + e_start] = NULL;
146       } else {
147         CeedCallBackend(CeedElemRestrictionCreateVector(elem_restr, NULL, &e_vecs[i + e_start]));
148       }
149     }
150 
151     switch (e_mode) {
152       case CEED_EVAL_NONE:
153         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
154         q_size = (CeedSize)num_elem * Q * size;
155         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
156         break;
157       case CEED_EVAL_INTERP:
158         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
159         q_size = (CeedSize)num_elem * Q * size;
160         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
161         break;
162       case CEED_EVAL_GRAD:
163         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
164         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
165         CeedCallBackend(CeedBasisGetDimension(basis, &dim));
166         q_size = (CeedSize)num_elem * Q * size;
167         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
168         break;
169       case CEED_EVAL_WEIGHT:  // Only on input fields
170         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
171         q_size = (CeedSize)num_elem * Q;
172         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
173         CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
174         break;
175       case CEED_EVAL_DIV:
176         break;  // TODO: Not implemented
177       case CEED_EVAL_CURL:
178         break;  // TODO: Not implemented
179     }
180   }
181   return CEED_ERROR_SUCCESS;
182 }
183 
184 //------------------------------------------------------------------------------
185 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction.
186 //------------------------------------------------------------------------------
187 static int CeedOperatorSetup_Cuda(CeedOperator op) {
188   Ceed                ceed;
189   bool                is_setup_done;
190   CeedInt             Q, num_elem, num_input_fields, num_output_fields;
191   CeedQFunctionField *qf_input_fields, *qf_output_fields;
192   CeedQFunction       qf;
193   CeedOperatorField  *op_input_fields, *op_output_fields;
194   CeedOperator_Cuda  *impl;
195 
196   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
197   if (is_setup_done) return CEED_ERROR_SUCCESS;
198 
199   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
200   CeedCallBackend(CeedOperatorGetData(op, &impl));
201   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
202   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
203   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
204   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
205   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
206 
207   // Allocate
208   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
209 
210   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
211   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
212 
213   impl->num_inputs  = num_input_fields;
214   impl->num_outputs = num_output_fields;
215 
216   // Set up infield and outfield e_vecs and q_vecs
217   // Infields
218   CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
219   // Outfields
220   CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem));
221 
222   CeedCallBackend(CeedOperatorSetSetupDone(op));
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Setup Operator Inputs
228 //------------------------------------------------------------------------------
229 static inline int CeedOperatorSetupInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
230                                                CeedVector in_vec, const bool skip_active_in, CeedScalar *e_data[2 * CEED_FIELD_MAX],
231                                                CeedOperator_Cuda *impl, CeedRequest *request) {
232   for (CeedInt i = 0; i < num_input_fields; i++) {
233     CeedEvalMode        e_mode;
234     CeedVector          vec;
235     CeedElemRestriction elem_restr;
236 
237     // Get input vector
238     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
239     if (vec == CEED_VECTOR_ACTIVE) {
240       if (skip_active_in) continue;
241       else vec = in_vec;
242     }
243 
244     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &e_mode));
245     if (e_mode == CEED_EVAL_WEIGHT) {  // Skip
246     } else {
247       // Get input element restriction
248       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr));
249       if (vec == CEED_VECTOR_ACTIVE) vec = in_vec;
250       // Restrict, if necessary
251       if (!impl->e_vecs[i]) {
252         // No restriction for this field; read data directly from vec.
253         CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i]));
254       } else {
255         CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request));
256         // Get evec
257         CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i]));
258       }
259     }
260   }
261   return CEED_ERROR_SUCCESS;
262 }
263 
264 //------------------------------------------------------------------------------
265 // Input Basis Action
266 //------------------------------------------------------------------------------
267 static inline int CeedOperatorInputBasis_Cuda(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
268                                               CeedInt num_input_fields, const bool skip_active_in, CeedScalar *e_data[2 * CEED_FIELD_MAX],
269                                               CeedOperator_Cuda *impl) {
270   for (CeedInt i = 0; i < num_input_fields; i++) {
271     CeedInt             elem_size, size;
272     CeedEvalMode        e_mode;
273     CeedElemRestriction elem_restr;
274     CeedBasis           basis;
275 
276     // Skip active input
277     if (skip_active_in) {
278       CeedVector vec;
279 
280       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
281       if (vec == CEED_VECTOR_ACTIVE) continue;
282     }
283     // Get elem_size, e_mode, size
284     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr));
285     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size));
286     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &e_mode));
287     CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
288     // Basis action
289     switch (e_mode) {
290       case CEED_EVAL_NONE:
291         CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i]));
292         break;
293       case CEED_EVAL_INTERP:
294         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
295         CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, impl->e_vecs[i], impl->q_vecs_in[i]));
296         break;
297       case CEED_EVAL_GRAD:
298         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
299         CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, impl->e_vecs[i], impl->q_vecs_in[i]));
300         break;
301       case CEED_EVAL_WEIGHT:
302         break;  // No action
303       case CEED_EVAL_DIV:
304         break;  // TODO: Not implemented
305       case CEED_EVAL_CURL:
306         break;  // TODO: Not implemented
307     }
308   }
309   return CEED_ERROR_SUCCESS;
310 }
311 
312 //------------------------------------------------------------------------------
313 // Restore Input Vectors
314 //------------------------------------------------------------------------------
315 static inline int CeedOperatorRestoreInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
316                                                  const bool skip_active_in, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) {
317   for (CeedInt i = 0; i < num_input_fields; i++) {
318     CeedEvalMode e_mode;
319     CeedVector   vec;
320 
321     // Skip active input
322     if (skip_active_in) {
323       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
324       if (vec == CEED_VECTOR_ACTIVE) continue;
325     }
326     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &e_mode));
327     if (e_mode == CEED_EVAL_WEIGHT) {  // Skip
328     } else {
329       if (!impl->e_vecs[i]) {  // This was a skip_restriction case
330         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
331         CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i]));
332       } else {
333         CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i]));
334       }
335     }
336   }
337   return CEED_ERROR_SUCCESS;
338 }
339 
340 //------------------------------------------------------------------------------
341 // Apply and add to output
342 //------------------------------------------------------------------------------
343 static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
344   CeedOperator_Cuda  *impl;
345   CeedInt             Q, num_elem, elem_size, num_input_fields, num_output_fields, size;
346   CeedEvalMode        e_mode;
347   CeedScalar         *e_data[2 * CEED_FIELD_MAX] = {NULL};
348   CeedOperatorField  *op_input_fields, *op_output_fields;
349   CeedQFunctionField *qf_input_fields, *qf_output_fields;
350   CeedQFunction       qf;
351 
352   CeedCallBackend(CeedOperatorGetData(op, &impl));
353   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
354   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
355   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
356   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
357   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
358 
359   // Setup
360   CeedCallBackend(CeedOperatorSetup_Cuda(op));
361 
362   // Input e_vecs and Restriction
363   CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request));
364 
365   // Input basis apply if needed
366   CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl));
367 
368   // Output pointers, as necessary
369   for (CeedInt i = 0; i < num_output_fields; i++) {
370     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &e_mode));
371     if (e_mode == CEED_EVAL_NONE) {
372       // Set the output Q-Vector to use the E-Vector data directly.
373       CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields]));
374       CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields]));
375     }
376   }
377 
378   // Q function
379   CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out));
380 
381   // Output basis apply if needed
382   for (CeedInt i = 0; i < num_output_fields; i++) {
383     CeedElemRestriction elem_restr;
384     CeedBasis           basis;
385 
386     // Get elem_size, e_mode, size
387     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr));
388     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size));
389     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &e_mode));
390     CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
391     // Basis action
392     switch (e_mode) {
393       case CEED_EVAL_NONE:
394         break;
395       case CEED_EVAL_INTERP:
396         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
397         CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, CEED_EVAL_INTERP, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
398         break;
399       case CEED_EVAL_GRAD:
400         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
401         CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, CEED_EVAL_GRAD, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
402         break;
403       // LCOV_EXCL_START
404       case CEED_EVAL_WEIGHT: {
405         Ceed ceed;
406         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
407         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
408         break;  // Should not occur
409       }
410       case CEED_EVAL_DIV:
411         break;  // TODO: Not implemented
412       case CEED_EVAL_CURL:
413         break;  // TODO: Not implemented
414                 // LCOV_EXCL_STOP
415     }
416   }
417 
418   // Output restriction
419   for (CeedInt i = 0; i < num_output_fields; i++) {
420     CeedVector          vec;
421     CeedElemRestriction elem_restr;
422 
423     // Restore evec
424     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &e_mode));
425     if (e_mode == CEED_EVAL_NONE) {
426       CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
427     }
428     // Get output vector
429     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
430     // Restrict
431     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr));
432     // Active
433     if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
434 
435     CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
436   }
437 
438   // Restore input arrays
439   CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));
440   return CEED_ERROR_SUCCESS;
441 }
442 
443 //------------------------------------------------------------------------------
444 // Core code for assembling linear QFunction
445 //------------------------------------------------------------------------------
446 static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
447                                                                CeedRequest *request) {
448   Ceed                ceed, ceed_parent;
449   bool                is_identity_qf;
450   CeedInt             num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size;
451   CeedSize            q_size;
452   CeedScalar         *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL};
453   CeedVector         *active_inputs;
454   CeedQFunctionField *qf_input_fields, *qf_output_fields;
455   CeedQFunction       qf;
456   CeedOperatorField  *op_input_fields, *op_output_fields;
457   CeedOperator_Cuda  *impl;
458 
459   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
460   CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent));
461   CeedCallBackend(CeedOperatorGetData(op, &impl));
462   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
463   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
464   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
465   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
466   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
467   active_inputs = impl->qf_active_in;
468   num_active_in = impl->num_active_in, num_active_out = impl->num_active_out;
469 
470   // Setup
471   CeedCallBackend(CeedOperatorSetup_Cuda(op));
472 
473   // Check for identity
474   CeedCallBackend(CeedQFunctionIsIdentity(qf, &is_identity_qf));
475   CeedCheck(!is_identity_qf, ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported");
476 
477   // Input e_vecs and Restriction
478   CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request));
479 
480   // Count number of active input fields
481   if (!num_active_in) {
482     for (CeedInt i = 0; i < num_input_fields; i++) {
483       CeedScalar *q_vec_array;
484       CeedVector  vec;
485 
486       // Get input vector
487       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
488       // Check if active input
489       if (vec == CEED_VECTOR_ACTIVE) {
490         CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
491         CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
492         CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array));
493         CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs));
494         for (CeedInt field = 0; field < size; field++) {
495           q_size = (CeedSize)Q * num_elem;
496           CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field]));
497           CeedCallBackend(
498               CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem]));
499         }
500         num_active_in += size;
501         CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array));
502       }
503     }
504     impl->num_active_in = num_active_in;
505     impl->qf_active_in  = active_inputs;
506   }
507 
508   // Count number of active output fields
509   if (!num_active_out) {
510     for (CeedInt i = 0; i < num_output_fields; i++) {
511       CeedVector vec;
512 
513       // Get output vector
514       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
515       // Check if active output
516       if (vec == CEED_VECTOR_ACTIVE) {
517         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
518         num_active_out += size;
519       }
520     }
521     impl->num_active_out = num_active_out;
522   }
523 
524   // Check sizes
525   CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
526 
527   // Build objects if needed
528   if (build_objects) {
529     // Create output restriction
530     CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
531     CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out,
532                                                      num_active_in * num_active_out * num_elem * Q, strides, rstr));
533     // Create assembled vector
534     CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out;
535     CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled));
536   }
537   CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
538   CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array));
539 
540   // Input basis apply
541   CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl));
542 
543   // Assemble QFunction
544   for (CeedInt in = 0; in < num_active_in; in++) {
545     // Set Inputs
546     CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0));
547     if (num_active_in > 1) {
548       CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0));
549     }
550     // Set Outputs
551     for (CeedInt out = 0; out < num_output_fields; out++) {
552       CeedVector vec;
553 
554       // Get output vector
555       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
556       // Check if active output
557       if (vec == CEED_VECTOR_ACTIVE) {
558         CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array));
559         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size));
560         assembled_array += size * Q * num_elem;  // Advance the pointer by the size of the output
561       }
562     }
563     // Apply QFunction
564     CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out));
565   }
566 
567   // Un-set output q_vecs to prevent accidental overwrite of Assembled
568   for (CeedInt out = 0; out < num_output_fields; out++) {
569     CeedVector vec;
570 
571     // Get output vector
572     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
573     // Check if active output
574     if (vec == CEED_VECTOR_ACTIVE) {
575       CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL));
576     }
577   }
578 
579   // Restore input arrays
580   CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl));
581 
582   // Restore output
583   CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array));
584   return CEED_ERROR_SUCCESS;
585 }
586 
587 //------------------------------------------------------------------------------
588 // Assemble Linear QFunction
589 //------------------------------------------------------------------------------
590 static int CeedOperatorLinearAssembleQFunction_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
591   return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, true, assembled, rstr, request);
592 }
593 
594 //------------------------------------------------------------------------------
595 // Update Assembled Linear QFunction
596 //------------------------------------------------------------------------------
597 static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
598   return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, false, &assembled, &rstr, request);
599 }
600 
601 //------------------------------------------------------------------------------
602 // Create point block restriction
603 //------------------------------------------------------------------------------
604 static int CreatePointBlockRestriction(CeedElemRestriction rstr, CeedElemRestriction *point_block_rstr) {
605   Ceed           ceed;
606   CeedSize       l_size;
607   CeedInt        num_elem, num_comp, elem_size, comp_stride, *point_block_offsets;
608   const CeedInt *offsets;
609 
610   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
611   CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
612 
613   // Expand offsets
614   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
615   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
616   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
617   CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
618   CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
619   CeedInt shift = num_comp;
620 
621   if (comp_stride != 1) shift *= num_comp;
622   CeedCallBackend(CeedCalloc(num_elem * elem_size, &point_block_offsets));
623   for (CeedInt i = 0; i < num_elem * elem_size; i++) {
624     point_block_offsets[i] = offsets[i] * shift;
625   }
626 
627   // Create new restriction
628   CeedCallBackend(CeedElemRestrictionCreate(ceed, num_elem, elem_size, num_comp * num_comp, 1, l_size * num_comp, CEED_MEM_HOST, CEED_OWN_POINTER,
629                                             point_block_offsets, point_block_rstr));
630 
631   // Cleanup
632   CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
633   return CEED_ERROR_SUCCESS;
634 }
635 
636 //------------------------------------------------------------------------------
637 // Assemble diagonal setup
638 //------------------------------------------------------------------------------
639 static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op, const bool is_point_block, CeedInt use_ceedsize_idx) {
640   Ceed                ceed;
641   char               *diagonal_kernel_path, *diagonal_kernel_source;
642   CeedInt             num_input_fields, num_output_fields, num_e_mode_in = 0, num_comp = 0, dim = 1, num_e_mode_out = 0, num_nodes, num_qpts;
643   CeedEvalMode       *e_mode_in = NULL, *e_mode_out = NULL;
644   CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
645   CeedBasis           basis_in = NULL, basis_out = NULL;
646   CeedQFunctionField *qf_fields;
647   CeedQFunction       qf;
648   CeedOperatorField  *op_fields;
649   CeedOperator_Cuda  *impl;
650 
651   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
652   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
653   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
654 
655   // Determine active input basis
656   CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
657   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
658   for (CeedInt i = 0; i < num_input_fields; i++) {
659     CeedVector vec;
660 
661     CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
662     if (vec == CEED_VECTOR_ACTIVE) {
663       CeedEvalMode        e_mode;
664       CeedElemRestriction rstr;
665 
666       CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in));
667       CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
668       CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
669       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
670       CeedCheck(!rstr_in || rstr_in == rstr, ceed, CEED_ERROR_BACKEND,
671                 "Backend does not implement multi-field non-composite operator diagonal assembly");
672       rstr_in = rstr;
673       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &e_mode));
674       switch (e_mode) {
675         case CEED_EVAL_NONE:
676         case CEED_EVAL_INTERP:
677           CeedCallBackend(CeedRealloc(num_e_mode_in + 1, &e_mode_in));
678           e_mode_in[num_e_mode_in] = e_mode;
679           num_e_mode_in += 1;
680           break;
681         case CEED_EVAL_GRAD:
682           CeedCallBackend(CeedRealloc(num_e_mode_in + dim, &e_mode_in));
683           for (CeedInt d = 0; d < dim; d++) e_mode_in[num_e_mode_in + d] = e_mode;
684           num_e_mode_in += dim;
685           break;
686         case CEED_EVAL_WEIGHT:
687         case CEED_EVAL_DIV:
688         case CEED_EVAL_CURL:
689           break;  // Caught by QF Assembly
690       }
691     }
692   }
693 
694   // Determine active output basis
695   CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
696   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
697   for (CeedInt i = 0; i < num_output_fields; i++) {
698     CeedVector vec;
699 
700     CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
701     if (vec == CEED_VECTOR_ACTIVE) {
702       CeedEvalMode        e_mode;
703       CeedElemRestriction rstr;
704 
705       CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out));
706       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
707       CeedCheck(!rstr_out || rstr_out == rstr, ceed, CEED_ERROR_BACKEND,
708                 "Backend does not implement multi-field non-composite operator diagonal assembly");
709       rstr_out = rstr;
710       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &e_mode));
711       switch (e_mode) {
712         case CEED_EVAL_NONE:
713         case CEED_EVAL_INTERP:
714           CeedCallBackend(CeedRealloc(num_e_mode_out + 1, &e_mode_out));
715           e_mode_out[num_e_mode_out] = e_mode;
716           num_e_mode_out += 1;
717           break;
718         case CEED_EVAL_GRAD:
719           CeedCallBackend(CeedRealloc(num_e_mode_out + dim, &e_mode_out));
720           for (CeedInt d = 0; d < dim; d++) e_mode_out[num_e_mode_out + d] = e_mode;
721           num_e_mode_out += dim;
722           break;
723         case CEED_EVAL_WEIGHT:
724         case CEED_EVAL_DIV:
725         case CEED_EVAL_CURL:
726           break;  // Caught by QF Assembly
727       }
728     }
729   }
730 
731   // Operator data struct
732   CeedCallBackend(CeedOperatorGetData(op, &impl));
733   CeedCallBackend(CeedCalloc(1, &impl->diag));
734   CeedOperatorDiag_Cuda *diag = impl->diag;
735 
736   diag->basis_in       = basis_in;
737   diag->basis_out      = basis_out;
738   diag->h_e_mode_in    = e_mode_in;
739   diag->h_e_mode_out   = e_mode_out;
740   diag->num_e_mode_in  = num_e_mode_in;
741   diag->num_e_mode_out = num_e_mode_out;
742 
743   // Assemble kernel
744   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path));
745   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n");
746   CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source));
747   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n");
748   CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
749   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
750   diag->num_nodes = num_nodes;
751   CeedCallCuda(ceed,
752                CeedCompile_Cuda(ceed, diagonal_kernel_source, &diag->module, 6, "NUM_E_MODE_IN", num_e_mode_in, "NUM_E_MODE_OUT", num_e_mode_out,
753                                 "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "NUM_COMP", num_comp, "USE_CEEDSIZE", use_ceedsize_idx));
754   CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, diag->module, "linearDiagonal", &diag->linearDiagonal));
755   CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, diag->module, "linearPointBlockDiagonal", &diag->linearPointBlock));
756   CeedCallBackend(CeedFree(&diagonal_kernel_path));
757   CeedCallBackend(CeedFree(&diagonal_kernel_source));
758 
759   // Basis matrices
760   const CeedInt     q_bytes      = num_qpts * sizeof(CeedScalar);
761   const CeedInt     interp_bytes = q_bytes * num_nodes;
762   const CeedInt     grad_bytes   = q_bytes * num_nodes * dim;
763   const CeedInt     e_mode_bytes = sizeof(CeedEvalMode);
764   const CeedScalar *interp_in, *interp_out, *grad_in, *grad_out;
765 
766   // CEED_EVAL_NONE
767   CeedScalar *identity     = NULL;
768   bool        is_eval_none = false;
769 
770   for (CeedInt i = 0; i < num_e_mode_in; i++) is_eval_none = is_eval_none || (e_mode_in[i] == CEED_EVAL_NONE);
771   for (CeedInt i = 0; i < num_e_mode_out; i++) is_eval_none = is_eval_none || (e_mode_out[i] == CEED_EVAL_NONE);
772   if (is_eval_none) {
773     CeedCallBackend(CeedCalloc(num_qpts * num_nodes, &identity));
774     for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0;
775     CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_identity, interp_bytes));
776     CeedCallCuda(ceed, cudaMemcpy(diag->d_identity, identity, interp_bytes, cudaMemcpyHostToDevice));
777   }
778 
779   // CEED_EVAL_INTERP
780   CeedCallBackend(CeedBasisGetInterp(basis_in, &interp_in));
781   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_interp_in, interp_bytes));
782   CeedCallCuda(ceed, cudaMemcpy(diag->d_interp_in, interp_in, interp_bytes, cudaMemcpyHostToDevice));
783   CeedCallBackend(CeedBasisGetInterp(basis_out, &interp_out));
784   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_interp_out, interp_bytes));
785   CeedCallCuda(ceed, cudaMemcpy(diag->d_interp_out, interp_out, interp_bytes, cudaMemcpyHostToDevice));
786 
787   // CEED_EVAL_GRAD
788   CeedCallBackend(CeedBasisGetGrad(basis_in, &grad_in));
789   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_grad_in, grad_bytes));
790   CeedCallCuda(ceed, cudaMemcpy(diag->d_grad_in, grad_in, grad_bytes, cudaMemcpyHostToDevice));
791   CeedCallBackend(CeedBasisGetGrad(basis_out, &grad_out));
792   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_grad_out, grad_bytes));
793   CeedCallCuda(ceed, cudaMemcpy(diag->d_grad_out, grad_out, grad_bytes, cudaMemcpyHostToDevice));
794 
795   // Arrays of e_modes
796   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_e_mode_in, num_e_mode_in * e_mode_bytes));
797   CeedCallCuda(ceed, cudaMemcpy(diag->d_e_mode_in, e_mode_in, num_e_mode_in * e_mode_bytes, cudaMemcpyHostToDevice));
798   CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_e_mode_out, num_e_mode_out * e_mode_bytes));
799   CeedCallCuda(ceed, cudaMemcpy(diag->d_e_mode_out, e_mode_out, num_e_mode_out * e_mode_bytes, cudaMemcpyHostToDevice));
800 
801   // Restriction
802   diag->diag_rstr = rstr_out;
803   return CEED_ERROR_SUCCESS;
804 }
805 
806 //------------------------------------------------------------------------------
807 // Assemble diagonal common code
808 //------------------------------------------------------------------------------
809 static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) {
810   Ceed                ceed;
811   CeedSize            assembled_length = 0, assembled_qf_length = 0;
812   CeedInt             use_ceedsize_idx = 0, num_elem;
813   CeedScalar         *elem_diag_array;
814   const CeedScalar   *assembled_qf_array;
815   CeedVector          assembled_qf = NULL;
816   CeedElemRestriction rstr         = NULL;
817   CeedOperator_Cuda  *impl;
818 
819   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
820   CeedCallBackend(CeedOperatorGetData(op, &impl));
821 
822   // Assemble QFunction
823   CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr, request));
824   CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
825 
826   CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length));
827   CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
828   if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
829 
830   // Setup
831   if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op, is_point_block, use_ceedsize_idx));
832   CeedOperatorDiag_Cuda *diag = impl->diag;
833 
834   assert(diag != NULL);
835 
836   // Restriction
837   if (is_point_block && !diag->point_block_rstr) {
838     CeedElemRestriction point_block_rstr;
839 
840     CeedCallBackend(CreatePointBlockRestriction(diag->diag_rstr, &point_block_rstr));
841     diag->point_block_rstr = point_block_rstr;
842   }
843   CeedElemRestriction diag_rstr = is_point_block ? diag->point_block_rstr : diag->diag_rstr;
844 
845   // Create diagonal vector
846   CeedVector elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag;
847 
848   if (!elem_diag) {
849     CeedCallBackend(CeedElemRestrictionCreateVector(diag_rstr, NULL, &elem_diag));
850     if (is_point_block) diag->point_block_elem_diag = elem_diag;
851     else diag->elem_diag = elem_diag;
852   }
853   CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0));
854 
855   // Assemble element operator diagonals
856   CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
857   CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
858   CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));
859 
860   // Compute the diagonal of B^T D B
861   int   elem_per_block = 1;
862   int   grid           = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
863   void *args[]         = {(void *)&num_elem, &diag->d_identity,  &diag->d_interp_in,  &diag->d_grad_in,    &diag->d_interp_out,
864                           &diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array};
865   if (is_point_block) {
866     CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearPointBlock, grid, diag->num_nodes, 1, elem_per_block, args));
867   } else {
868     CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearDiagonal, grid, diag->num_nodes, 1, elem_per_block, args));
869   }
870 
871   // Restore arrays
872   CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
873   CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
874 
875   // Assemble local operator diagonal
876   CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));
877 
878   // Cleanup
879   CeedCallBackend(CeedVectorDestroy(&assembled_qf));
880   return CEED_ERROR_SUCCESS;
881 }
882 
883 //------------------------------------------------------------------------------
884 // Assemble Linear Diagonal
885 //------------------------------------------------------------------------------
886 static int CeedOperatorLinearAssembleAddDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) {
887   CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, false));
888   return CEED_ERROR_SUCCESS;
889 }
890 
891 //------------------------------------------------------------------------------
892 // Assemble Linear Point Block Diagonal
893 //------------------------------------------------------------------------------
894 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) {
895   CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, true));
896   return CEED_ERROR_SUCCESS;
897 }
898 
899 //------------------------------------------------------------------------------
900 // Single operator assembly setup
901 //------------------------------------------------------------------------------
902 static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op, CeedInt use_ceedsize_idx) {
903   Ceed    ceed;
904   char   *assembly_kernel_path, *assembly_kernel_source;
905   CeedInt num_input_fields, num_output_fields, num_e_mode_in = 0, dim = 1, num_B_in_mats_to_load = 0, size_B_in = 0, num_qpts = 0, elem_size = 0,
906                                                num_e_mode_out = 0, num_B_out_mats_to_load = 0, size_B_out = 0, num_elem, num_comp;
907   CeedEvalMode       *eval_mode_in = NULL, *eval_mode_out = NULL;
908   CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
909   CeedBasis           basis_in = NULL, basis_out = NULL;
910   CeedQFunctionField *qf_fields;
911   CeedQFunction       qf;
912   CeedOperatorField  *input_fields, *output_fields;
913   CeedOperator_Cuda  *impl;
914 
915   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
916   CeedCallBackend(CeedOperatorGetData(op, &impl));
917 
918   // Get intput and output fields
919   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
920 
921   // Determine active input basis eval mode
922   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
923   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
924   // Note that the kernel will treat each dimension of a gradient action separately;
925   // i.e., when an active input has a CEED_EVAL_GRAD mode, num_e_mode_in will increment by dim.
926   // However, for the purposes of loading the B matrices, it will be treated as one mode, and we will load/copy the entire gradient matrix at once, so
927   // num_B_in_mats_to_load will be incremented by 1.
928   for (CeedInt i = 0; i < num_input_fields; i++) {
929     CeedVector vec;
930 
931     CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec));
932     if (vec == CEED_VECTOR_ACTIVE) {
933       CeedEvalMode eval_mode;
934 
935       CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis_in));
936       CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
937       CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
938       CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in));
939       CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size));
940       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
941       if (eval_mode != CEED_EVAL_NONE) {
942         CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in));
943         eval_mode_in[num_B_in_mats_to_load] = eval_mode;
944         num_B_in_mats_to_load += 1;
945         if (eval_mode == CEED_EVAL_GRAD) {
946           num_e_mode_in += dim;
947           size_B_in += dim * elem_size * num_qpts;
948         } else {
949           num_e_mode_in += 1;
950           size_B_in += elem_size * num_qpts;
951         }
952       }
953     }
954   }
955 
956   // Determine active output basis; basis_out and rstr_out only used if same as input, TODO
957   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
958   for (CeedInt i = 0; i < num_output_fields; i++) {
959     CeedVector vec;
960 
961     CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec));
962     if (vec == CEED_VECTOR_ACTIVE) {
963       CeedEvalMode eval_mode;
964 
965       CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis_out));
966       CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out));
967       CeedCheck(!rstr_out || rstr_out == rstr_in, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly");
968       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
969       if (eval_mode != CEED_EVAL_NONE) {
970         CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out));
971         eval_mode_out[num_B_out_mats_to_load] = eval_mode;
972         num_B_out_mats_to_load += 1;
973         if (eval_mode == CEED_EVAL_GRAD) {
974           num_e_mode_out += dim;
975           size_B_out += dim * elem_size * num_qpts;
976         } else {
977           num_e_mode_out += 1;
978           size_B_out += elem_size * num_qpts;
979         }
980       }
981     }
982   }
983   CeedCheck(num_e_mode_in > 0 && num_e_mode_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs");
984 
985   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem));
986   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp));
987 
988   CeedCallBackend(CeedCalloc(1, &impl->asmb));
989   CeedOperatorAssemble_Cuda *asmb = impl->asmb;
990   asmb->num_elem                  = num_elem;
991 
992   // Compile kernels
993   int elem_per_block    = 1;
994   asmb->elem_per_block  = elem_per_block;
995   CeedInt    block_size = elem_size * elem_size * elem_per_block;
996   Ceed_Cuda *cuda_data;
997 
998   CeedCallBackend(CeedGetData(ceed, &cuda_data));
999   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble.h", &assembly_kernel_path));
1000   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n");
1001   CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source));
1002   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n");
1003   bool fallback = block_size > cuda_data->device_prop.maxThreadsPerBlock;
1004 
1005   if (fallback) {
1006     // Use fallback kernel with 1D threadblock
1007     block_size         = elem_size * elem_per_block;
1008     asmb->block_size_x = elem_size;
1009     asmb->block_size_y = 1;
1010   } else {  // Use kernel with 2D threadblock
1011     asmb->block_size_x = elem_size;
1012     asmb->block_size_y = elem_size;
1013   }
1014   CeedCallBackend(CeedCompile_Cuda(ceed, assembly_kernel_source, &asmb->module, 8, "NUM_ELEM", num_elem, "NUM_E_MODE_IN", num_e_mode_in,
1015                                    "NUM_E_MODE_OUT", num_e_mode_out, "NUM_QPTS", num_qpts, "NUM_NODES", elem_size, "BLOCK_SIZE", block_size,
1016                                    "NUM_COMP", num_comp, "USE_CEEDSIZE", use_ceedsize_idx));
1017   CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, fallback ? "linearAssembleFallback" : "linearAssemble", &asmb->linearAssemble));
1018   CeedCallBackend(CeedFree(&assembly_kernel_path));
1019   CeedCallBackend(CeedFree(&assembly_kernel_source));
1020 
1021   // Build 'full' B matrices (not 1D arrays used for tensor-product matrices)
1022   const CeedScalar *interp_in, *grad_in;
1023 
1024   CeedCallBackend(CeedBasisGetInterp(basis_in, &interp_in));
1025   CeedCallBackend(CeedBasisGetGrad(basis_in, &grad_in));
1026 
1027   // Load into B_in, in order that they will be used in eval_mode
1028   const CeedInt inBytes   = size_B_in * sizeof(CeedScalar);
1029   CeedInt       mat_start = 0;
1030 
1031   CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, inBytes));
1032   for (int i = 0; i < num_B_in_mats_to_load; i++) {
1033     CeedEvalMode eval_mode = eval_mode_in[i];
1034 
1035     if (eval_mode == CEED_EVAL_INTERP) {
1036       CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[mat_start], interp_in, elem_size * num_qpts * sizeof(CeedScalar), cudaMemcpyHostToDevice));
1037       mat_start += elem_size * num_qpts;
1038     } else if (eval_mode == CEED_EVAL_GRAD) {
1039       CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[mat_start], grad_in, dim * elem_size * num_qpts * sizeof(CeedScalar), cudaMemcpyHostToDevice));
1040       mat_start += dim * elem_size * num_qpts;
1041     }
1042   }
1043 
1044   const CeedScalar *interp_out, *grad_out;
1045 
1046   // Note that this function currently assumes 1 basis, so this should always be true for now
1047   if (basis_out == basis_in) {
1048     interp_out = interp_in;
1049     grad_out   = grad_in;
1050   } else {
1051     CeedCallBackend(CeedBasisGetInterp(basis_out, &interp_out));
1052     CeedCallBackend(CeedBasisGetGrad(basis_out, &grad_out));
1053   }
1054 
1055   // Load into B_out, in order that they will be used in eval_mode
1056   const CeedInt outBytes = size_B_out * sizeof(CeedScalar);
1057   mat_start              = 0;
1058 
1059   CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, outBytes));
1060   for (int i = 0; i < num_B_out_mats_to_load; i++) {
1061     CeedEvalMode eval_mode = eval_mode_out[i];
1062 
1063     if (eval_mode == CEED_EVAL_INTERP) {
1064       CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[mat_start], interp_out, elem_size * num_qpts * sizeof(CeedScalar), cudaMemcpyHostToDevice));
1065       mat_start += elem_size * num_qpts;
1066     } else if (eval_mode == CEED_EVAL_GRAD) {
1067       CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[mat_start], grad_out, dim * elem_size * num_qpts * sizeof(CeedScalar), cudaMemcpyHostToDevice));
1068       mat_start += dim * elem_size * num_qpts;
1069     }
1070   }
1071   return CEED_ERROR_SUCCESS;
1072 }
1073 
1074 //------------------------------------------------------------------------------
1075 // Assemble matrix data for COO matrix of assembled operator.
1076 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1077 //
1078 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval
1079 // modes).
1080 // TODO: allow multiple active input restrictions/basis objects
1081 //------------------------------------------------------------------------------
1082 static int CeedSingleOperatorAssemble_Cuda(CeedOperator op, CeedInt offset, CeedVector values) {
1083   Ceed                ceed;
1084   CeedSize            values_length = 0, assembled_qf_length = 0;
1085   CeedInt             use_ceedsize_idx = 0;
1086   CeedScalar         *values_array;
1087   const CeedScalar   *qf_array;
1088   CeedVector          assembled_qf = NULL;
1089   CeedElemRestriction rstr_q       = NULL;
1090   CeedOperator_Cuda  *impl;
1091 
1092   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1093   CeedCallBackend(CeedOperatorGetData(op, &impl));
1094 
1095   // Assemble QFunction
1096   CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr_q, CEED_REQUEST_IMMEDIATE));
1097   CeedCallBackend(CeedElemRestrictionDestroy(&rstr_q));
1098   CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array));
1099   values_array += offset;
1100   CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &qf_array));
1101 
1102   CeedCallBackend(CeedVectorGetLength(values, &values_length));
1103   CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1104   if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1105   // Setup
1106   if (!impl->asmb) {
1107     CeedCallBackend(CeedSingleOperatorAssembleSetup_Cuda(op, use_ceedsize_idx));
1108     assert(impl->asmb != NULL);
1109   }
1110 
1111   // Compute B^T D B
1112   const CeedInt num_elem       = impl->asmb->num_elem;
1113   const CeedInt elem_per_block = impl->asmb->elem_per_block;
1114   const CeedInt grid           = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0);
1115   void         *args[]         = {&impl->asmb->d_B_in, &impl->asmb->d_B_out, &qf_array, &values_array};
1116 
1117   CeedCallBackend(
1118       CeedRunKernelDim_Cuda(ceed, impl->asmb->linearAssemble, grid, impl->asmb->block_size_x, impl->asmb->block_size_y, elem_per_block, args));
1119 
1120   // Restore arrays
1121   CeedCallBackend(CeedVectorRestoreArray(values, &values_array));
1122   CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &qf_array));
1123 
1124   // Cleanup
1125   CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1126   return CEED_ERROR_SUCCESS;
1127 }
1128 
1129 //------------------------------------------------------------------------------
1130 // Create operator
1131 //------------------------------------------------------------------------------
1132 int CeedOperatorCreate_Cuda(CeedOperator op) {
1133   Ceed               ceed;
1134   CeedOperator_Cuda *impl;
1135 
1136   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1137   CeedCallBackend(CeedCalloc(1, &impl));
1138   CeedCallBackend(CeedOperatorSetData(op, impl));
1139 
1140   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Cuda));
1141   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Cuda));
1142   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Cuda));
1143   CeedCallBackend(
1144       CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda));
1145   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Cuda));
1146   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda));
1147   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda));
1148   return CEED_ERROR_SUCCESS;
1149 }
1150 
1151 //------------------------------------------------------------------------------
1152