xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp (revision 6ca0f394dabdca92269b68ec74be8bebae3befa4)
1*6ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*6ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*6ca0f394SUmesh Unnikrishnan //
4*6ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
5*6ca0f394SUmesh Unnikrishnan //
6*6ca0f394SUmesh Unnikrishnan // This file is part of CEED:  http://github.com/ceed
7*6ca0f394SUmesh Unnikrishnan 
8*6ca0f394SUmesh Unnikrishnan #define CEED_DEBUG_COLOR 12
9*6ca0f394SUmesh Unnikrishnan 
10*6ca0f394SUmesh Unnikrishnan #include <ceed/backend.h>
11*6ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h>
12*6ca0f394SUmesh Unnikrishnan #include <ceed/jit-source/sycl/sycl-types.h>
13*6ca0f394SUmesh Unnikrishnan #include <ceed/jit-tools.h>
14*6ca0f394SUmesh Unnikrishnan 
15*6ca0f394SUmesh Unnikrishnan #include <iostream>
16*6ca0f394SUmesh Unnikrishnan #include <sstream>
17*6ca0f394SUmesh Unnikrishnan #include <string>
18*6ca0f394SUmesh Unnikrishnan #include <string_view>
19*6ca0f394SUmesh Unnikrishnan #include <vector>
20*6ca0f394SUmesh Unnikrishnan 
21*6ca0f394SUmesh Unnikrishnan #include "../sycl-ref/ceed-sycl-ref.hpp"
22*6ca0f394SUmesh Unnikrishnan #include "../sycl-shared/ceed-sycl-shared.hpp"
23*6ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp"
24*6ca0f394SUmesh Unnikrishnan 
25*6ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp"
26*6ca0f394SUmesh Unnikrishnan 
27*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
28*6ca0f394SUmesh Unnikrishnan // Calculate the block size used for launching the operator kernel
29*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
30*6ca0f394SUmesh Unnikrishnan extern "C" int BlockGridCalculate_Sycl_gen(const CeedInt dim, const CeedInt P_1d, const CeedInt Q_1d, CeedInt *block_sizes) {
31*6ca0f394SUmesh Unnikrishnan   const CeedInt thread1d = CeedIntMax(Q_1d, P_1d);
32*6ca0f394SUmesh Unnikrishnan   if (dim == 1) {
33*6ca0f394SUmesh Unnikrishnan     CeedInt elems_per_block = 64 * thread1d > 256 ? 256 / thread1d : 64;
34*6ca0f394SUmesh Unnikrishnan     elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
35*6ca0f394SUmesh Unnikrishnan     block_sizes[0]          = thread1d;
36*6ca0f394SUmesh Unnikrishnan     block_sizes[1]          = 1;
37*6ca0f394SUmesh Unnikrishnan     block_sizes[2]          = elems_per_block;
38*6ca0f394SUmesh Unnikrishnan   } else if (dim == 2) {
39*6ca0f394SUmesh Unnikrishnan     const CeedInt elems_per_block = thread1d < 4 ? 16 : 2;
40*6ca0f394SUmesh Unnikrishnan     block_sizes[0]                = thread1d;
41*6ca0f394SUmesh Unnikrishnan     block_sizes[1]                = thread1d;
42*6ca0f394SUmesh Unnikrishnan     block_sizes[2]                = elems_per_block;
43*6ca0f394SUmesh Unnikrishnan   } else if (dim == 3) {
44*6ca0f394SUmesh Unnikrishnan     const CeedInt elems_per_block = thread1d < 6 ? 4 : (thread1d < 8 ? 2 : 1);
45*6ca0f394SUmesh Unnikrishnan     block_sizes[0]                = thread1d;
46*6ca0f394SUmesh Unnikrishnan     block_sizes[1]                = thread1d;
47*6ca0f394SUmesh Unnikrishnan     block_sizes[2]                = elems_per_block;
48*6ca0f394SUmesh Unnikrishnan   }
49*6ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
50*6ca0f394SUmesh Unnikrishnan }
51*6ca0f394SUmesh Unnikrishnan 
52*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
53*6ca0f394SUmesh Unnikrishnan // Build single operator kernel
54*6ca0f394SUmesh Unnikrishnan // - [ ] Check arguments to device functions reudsed from sycl-shared-basis are correct
55*6ca0f394SUmesh Unnikrishnan // - [ ] Do kernel jitting!
56*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
57*6ca0f394SUmesh Unnikrishnan extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
58*6ca0f394SUmesh Unnikrishnan   bool is_setup_done;
59*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
60*6ca0f394SUmesh Unnikrishnan   if (is_setup_done) return CEED_ERROR_SUCCESS;
61*6ca0f394SUmesh Unnikrishnan 
62*6ca0f394SUmesh Unnikrishnan   Ceed ceed;
63*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
64*6ca0f394SUmesh Unnikrishnan   Ceed_Sycl *sycl_data;
65*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedGetData(ceed, &sycl_data));
66*6ca0f394SUmesh Unnikrishnan 
67*6ca0f394SUmesh Unnikrishnan   CeedOperator_Sycl_gen *impl;
68*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetData(op, &impl));
69*6ca0f394SUmesh Unnikrishnan   Fields_Sycl             h_B, h_G;
70*6ca0f394SUmesh Unnikrishnan   FieldsInt_Sycl          h_indices;
71*6ca0f394SUmesh Unnikrishnan   CeedQFunction           qf;
72*6ca0f394SUmesh Unnikrishnan   CeedQFunction_Sycl_gen *qf_impl;
73*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
74*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
75*6ca0f394SUmesh Unnikrishnan   CeedSize lsize;
76*6ca0f394SUmesh Unnikrishnan   CeedInt  Q, P_1d = 0, Q_1d = 0, elem_size, num_input_fields, num_output_fields, num_comp, dim = 1;
77*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
78*6ca0f394SUmesh Unnikrishnan   Q_1d = Q;
79*6ca0f394SUmesh Unnikrishnan 
80*6ca0f394SUmesh Unnikrishnan   CeedOperatorField *op_input_fields, *op_output_fields;
81*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
82*6ca0f394SUmesh Unnikrishnan   CeedQFunctionField *qf_input_fields, *qf_output_fields;
83*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
84*6ca0f394SUmesh Unnikrishnan 
85*6ca0f394SUmesh Unnikrishnan   CeedEvalMode              eval_mode;
86*6ca0f394SUmesh Unnikrishnan   CeedBasis                 basis;
87*6ca0f394SUmesh Unnikrishnan   CeedBasis_Sycl_shared    *basis_impl;
88*6ca0f394SUmesh Unnikrishnan   CeedElemRestriction       Erestrict;
89*6ca0f394SUmesh Unnikrishnan   CeedElemRestriction_Sycl *restr_impl;
90*6ca0f394SUmesh Unnikrishnan 
91*6ca0f394SUmesh Unnikrishnan   // Check for restriction only identity operator
92*6ca0f394SUmesh Unnikrishnan   bool is_identity_qf;
93*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionIsIdentity(qf, &is_identity_qf));
94*6ca0f394SUmesh Unnikrishnan   if (is_identity_qf) {
95*6ca0f394SUmesh Unnikrishnan     CeedEvalMode eval_mode_in, eval_mode_out;
96*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[0], &eval_mode_in));
97*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[0], &eval_mode_out));
98*6ca0f394SUmesh Unnikrishnan     if (eval_mode_in == CEED_EVAL_NONE && eval_mode_out == CEED_EVAL_NONE) {
99*6ca0f394SUmesh Unnikrishnan       // LCOV_EXCL_START
100*6ca0f394SUmesh Unnikrishnan       return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement restriction only identity operators");
101*6ca0f394SUmesh Unnikrishnan       // LCOV_EXCL_STOP
102*6ca0f394SUmesh Unnikrishnan     }
103*6ca0f394SUmesh Unnikrishnan   }
104*6ca0f394SUmesh Unnikrishnan 
105*6ca0f394SUmesh Unnikrishnan   std::ostringstream code;
106*6ca0f394SUmesh Unnikrishnan   // TODO: generalize to accept different device functions?
107*6ca0f394SUmesh Unnikrishnan   {
108*6ca0f394SUmesh Unnikrishnan     char *tensor_basis_kernel_path, *tensor_basis_code;
109*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h", &tensor_basis_kernel_path));
110*6ca0f394SUmesh Unnikrishnan     CeedDebug256(ceed, 2, "----- Loading Tensor Basis Kernel Source -----\n");
111*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedLoadSourceToBuffer(ceed, tensor_basis_kernel_path, &tensor_basis_code));
112*6ca0f394SUmesh Unnikrishnan     code << tensor_basis_code;
113*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedFree(&tensor_basis_kernel_path));
114*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedFree(&tensor_basis_code));
115*6ca0f394SUmesh Unnikrishnan   }
116*6ca0f394SUmesh Unnikrishnan   {
117*6ca0f394SUmesh Unnikrishnan     char *sycl_gen_template_path, *sycl_gen_template_source;
118*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-gen-templates.h", &sycl_gen_template_path));
119*6ca0f394SUmesh Unnikrishnan     CeedDebug256(ceed, 2, "----- Loading Sycl-Gen Template Source -----\n");
120*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedLoadSourceToBuffer(ceed, sycl_gen_template_path, &sycl_gen_template_source));
121*6ca0f394SUmesh Unnikrishnan     code << sycl_gen_template_source;
122*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedFree(&sycl_gen_template_path));
123*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedFree(&sycl_gen_template_source));
124*6ca0f394SUmesh Unnikrishnan   }
125*6ca0f394SUmesh Unnikrishnan 
126*6ca0f394SUmesh Unnikrishnan   std::string_view  q_function_source(qf_impl->q_function_source);
127*6ca0f394SUmesh Unnikrishnan   std::string_view  q_function_name(qf_impl->q_function_name);
128*6ca0f394SUmesh Unnikrishnan   const std::string operator_name = "CeedKernelSyclGenOperator_" + std::string(q_function_name);
129*6ca0f394SUmesh Unnikrishnan 
130*6ca0f394SUmesh Unnikrishnan   // Find dim, P_1d, Q_1d
131*6ca0f394SUmesh Unnikrishnan   impl->max_P_1d = 0;
132*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
133*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
134*6ca0f394SUmesh Unnikrishnan     if (basis != CEED_BASIS_COLLOCATED) {
135*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
136*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
137*6ca0f394SUmesh Unnikrishnan 
138*6ca0f394SUmesh Unnikrishnan       // Collect dim, P_1d, and Q_1d
139*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisGetDimension(basis, &dim));
140*6ca0f394SUmesh Unnikrishnan       bool isTensor;
141*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisIsTensor(basis, &isTensor));
142*6ca0f394SUmesh Unnikrishnan       if (isTensor) {
143*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
144*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
145*6ca0f394SUmesh Unnikrishnan         if (P_1d > impl->max_P_1d) impl->max_P_1d = P_1d;
146*6ca0f394SUmesh Unnikrishnan       } else {
147*6ca0f394SUmesh Unnikrishnan         // LCOV_EXCL_START
148*6ca0f394SUmesh Unnikrishnan         return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
149*6ca0f394SUmesh Unnikrishnan         // LCOV_EXCL_STOP
150*6ca0f394SUmesh Unnikrishnan       }
151*6ca0f394SUmesh Unnikrishnan     }
152*6ca0f394SUmesh Unnikrishnan   }
153*6ca0f394SUmesh Unnikrishnan   // Check output bases for Q_1d, dim as well
154*6ca0f394SUmesh Unnikrishnan   //   The only input basis might be CEED_BASIS_COLLOCATED
155*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
156*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
157*6ca0f394SUmesh Unnikrishnan 
158*6ca0f394SUmesh Unnikrishnan     if (basis != CEED_BASIS_COLLOCATED) {
159*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
160*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
161*6ca0f394SUmesh Unnikrishnan 
162*6ca0f394SUmesh Unnikrishnan       // Collect Q_1d
163*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisGetDimension(basis, &dim));
164*6ca0f394SUmesh Unnikrishnan       bool isTensor;
165*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisIsTensor(basis, &isTensor));
166*6ca0f394SUmesh Unnikrishnan       if (isTensor) {
167*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
168*6ca0f394SUmesh Unnikrishnan       } else {
169*6ca0f394SUmesh Unnikrishnan         // LCOV_EXCL_START
170*6ca0f394SUmesh Unnikrishnan         return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
171*6ca0f394SUmesh Unnikrishnan         // LCOV_EXCL_STOP
172*6ca0f394SUmesh Unnikrishnan       }
173*6ca0f394SUmesh Unnikrishnan     }
174*6ca0f394SUmesh Unnikrishnan   }
175*6ca0f394SUmesh Unnikrishnan   impl->dim  = dim;
176*6ca0f394SUmesh Unnikrishnan   impl->Q_1d = Q_1d;
177*6ca0f394SUmesh Unnikrishnan 
178*6ca0f394SUmesh Unnikrishnan   // Only use 3D collocated gradient parallelization strategy when gradient is computed
179*6ca0f394SUmesh Unnikrishnan   // TODO: put in a function?
180*6ca0f394SUmesh Unnikrishnan   bool use_collograd_parallelization = false;
181*6ca0f394SUmesh Unnikrishnan   if (dim == 3) {
182*6ca0f394SUmesh Unnikrishnan     bool was_grad_found = false;
183*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_input_fields; i++) {
184*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
185*6ca0f394SUmesh Unnikrishnan       if (eval_mode == CEED_EVAL_GRAD) {
186*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
187*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
188*6ca0f394SUmesh Unnikrishnan         use_collograd_parallelization = !!basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
189*6ca0f394SUmesh Unnikrishnan         was_grad_found                = true;
190*6ca0f394SUmesh Unnikrishnan       }
191*6ca0f394SUmesh Unnikrishnan     }
192*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_output_fields; i++) {
193*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
194*6ca0f394SUmesh Unnikrishnan       if (eval_mode == CEED_EVAL_GRAD) {
195*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
196*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
197*6ca0f394SUmesh Unnikrishnan         use_collograd_parallelization = !!basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
198*6ca0f394SUmesh Unnikrishnan         was_grad_found                = true;
199*6ca0f394SUmesh Unnikrishnan       }
200*6ca0f394SUmesh Unnikrishnan     }
201*6ca0f394SUmesh Unnikrishnan   }
202*6ca0f394SUmesh Unnikrishnan 
203*6ca0f394SUmesh Unnikrishnan   CeedInt block_sizes[3];
204*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
205*6ca0f394SUmesh Unnikrishnan 
206*6ca0f394SUmesh Unnikrishnan   // Define CEED_Q_VLA
207*6ca0f394SUmesh Unnikrishnan   code << "\n#undef CEED_Q_VLA\n";
208*6ca0f394SUmesh Unnikrishnan   if (dim != 3 || use_collograd_parallelization) {
209*6ca0f394SUmesh Unnikrishnan     code << "#define CEED_Q_VLA 1\n\n";
210*6ca0f394SUmesh Unnikrishnan   } else {
211*6ca0f394SUmesh Unnikrishnan     code << "#define CEED_Q_VLA " << Q_1d << "\n\n";
212*6ca0f394SUmesh Unnikrishnan   }
213*6ca0f394SUmesh Unnikrishnan 
214*6ca0f394SUmesh Unnikrishnan   // Determine subgroup size based on supported sizes : Default : 16 (if supported)
215*6ca0f394SUmesh Unnikrishnan   std::vector allowed_sg_sizes  = sycl_data->sycl_device.get_info<sycl::info::device::sub_group_sizes>();
216*6ca0f394SUmesh Unnikrishnan   CeedInt     sub_group_size_op = allowed_sg_sizes[allowed_sg_sizes.size() - 1];
217*6ca0f394SUmesh Unnikrishnan   for (const auto &s : allowed_sg_sizes) {
218*6ca0f394SUmesh Unnikrishnan     if (s == 16) {
219*6ca0f394SUmesh Unnikrishnan       sub_group_size_op = s;
220*6ca0f394SUmesh Unnikrishnan       break;
221*6ca0f394SUmesh Unnikrishnan     }
222*6ca0f394SUmesh Unnikrishnan   }
223*6ca0f394SUmesh Unnikrishnan 
224*6ca0f394SUmesh Unnikrishnan   code << q_function_source;
225*6ca0f394SUmesh Unnikrishnan 
226*6ca0f394SUmesh Unnikrishnan   // Kernel function
227*6ca0f394SUmesh Unnikrishnan   code << "\n// -----------------------------------------------------------------------------\n";
228*6ca0f394SUmesh Unnikrishnan   code << "__attribute__((reqd_work_group_size(GROUP_SIZE_X, GROUP_SIZE_Y, GROUP_SIZE_Z), intel_reqd_sub_group_size(" << sub_group_size_op << ")))\n";
229*6ca0f394SUmesh Unnikrishnan   code << "kernel void " << operator_name << "(";
230*6ca0f394SUmesh Unnikrishnan   code << "const CeedInt num_elem, ";
231*6ca0f394SUmesh Unnikrishnan   code << "global void* ctx, ";
232*6ca0f394SUmesh Unnikrishnan   code << "global const FieldsInt_Sycl* indices, ";
233*6ca0f394SUmesh Unnikrishnan   code << "global Fields_Sycl* fields, ";
234*6ca0f394SUmesh Unnikrishnan   code << "global const Fields_Sycl* B, ";
235*6ca0f394SUmesh Unnikrishnan   code << "global const Fields_Sycl* G, ";
236*6ca0f394SUmesh Unnikrishnan   code << "global const CeedScalar * restrict W";
237*6ca0f394SUmesh Unnikrishnan   code << ") {\n";
238*6ca0f394SUmesh Unnikrishnan 
239*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
240*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
241*6ca0f394SUmesh Unnikrishnan     if (eval_mode != CEED_EVAL_WEIGHT) {  // Skip CEED_EVAL_WEIGHT
242*6ca0f394SUmesh Unnikrishnan       code << "  global const CeedScalar* d_u_" << i << " = fields->inputs[" << i << "];\n";
243*6ca0f394SUmesh Unnikrishnan     }
244*6ca0f394SUmesh Unnikrishnan   }
245*6ca0f394SUmesh Unnikrishnan 
246*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
247*6ca0f394SUmesh Unnikrishnan     code << "  global CeedScalar* d_v_" << i << " = fields->outputs[" << i << "];\n";
248*6ca0f394SUmesh Unnikrishnan   }
249*6ca0f394SUmesh Unnikrishnan 
250*6ca0f394SUmesh Unnikrishnan   // TODO: Convert these to defined constants to save on GRF
251*6ca0f394SUmesh Unnikrishnan   code << "  const CeedInt DIM = " << dim << ";\n";
252*6ca0f394SUmesh Unnikrishnan   code << "  const CeedInt Q_1D = " << Q_1d << ";\n";
253*6ca0f394SUmesh Unnikrishnan 
254*6ca0f394SUmesh Unnikrishnan   const CeedInt scratch_size = block_sizes[0] * block_sizes[1] * block_sizes[2];
255*6ca0f394SUmesh Unnikrishnan   code << "  local CeedScalar scratch[" << scratch_size << "];\n";
256*6ca0f394SUmesh Unnikrishnan   code << "  local CeedScalar * elem_scratch = scratch + get_local_id(2) * T_1D" << (dim > 1 ? "*T_1D" : "") << ";\n";
257*6ca0f394SUmesh Unnikrishnan 
258*6ca0f394SUmesh Unnikrishnan   code << "\n  // -- Input field constants and basis data --\n";
259*6ca0f394SUmesh Unnikrishnan   // Initialize constants, and matrices B and G
260*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
261*6ca0f394SUmesh Unnikrishnan     code << "  // ---- Input field " << i << " ----\n";
262*6ca0f394SUmesh Unnikrishnan     // Get elem_size, eval_mode, num_comp
263*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict));
264*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size));
265*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
266*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp));
267*6ca0f394SUmesh Unnikrishnan 
268*6ca0f394SUmesh Unnikrishnan     // Set field constants
269*6ca0f394SUmesh Unnikrishnan     if (eval_mode != CEED_EVAL_WEIGHT) {
270*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
271*6ca0f394SUmesh Unnikrishnan       if (basis != CEED_BASIS_COLLOCATED) {
272*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
273*6ca0f394SUmesh Unnikrishnan         code << "  const CeedInt P_in_" << i << " = " << P_1d << ";\n";
274*6ca0f394SUmesh Unnikrishnan       } else {
275*6ca0f394SUmesh Unnikrishnan         code << "  const CeedInt P_in_" << i << " = " << Q_1d << ";\n";
276*6ca0f394SUmesh Unnikrishnan       }
277*6ca0f394SUmesh Unnikrishnan       code << "  const CeedInt num_comp_in_" << i << " = " << num_comp << ";\n";
278*6ca0f394SUmesh Unnikrishnan     }
279*6ca0f394SUmesh Unnikrishnan 
280*6ca0f394SUmesh Unnikrishnan     // Load basis data
281*6ca0f394SUmesh Unnikrishnan     code << "  // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
282*6ca0f394SUmesh Unnikrishnan     switch (eval_mode) {
283*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_NONE:
284*6ca0f394SUmesh Unnikrishnan         break;
285*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_INTERP:
286*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
287*6ca0f394SUmesh Unnikrishnan         h_B.inputs[i] = basis_impl->d_interp_1d;
288*6ca0f394SUmesh Unnikrishnan         code << "  local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
289*6ca0f394SUmesh Unnikrishnan         code << "  loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
290*6ca0f394SUmesh Unnikrishnan         break;
291*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_GRAD:
292*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
293*6ca0f394SUmesh Unnikrishnan         h_B.inputs[i] = basis_impl->d_interp_1d;
294*6ca0f394SUmesh Unnikrishnan         code << "  local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
295*6ca0f394SUmesh Unnikrishnan         code << "  loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
296*6ca0f394SUmesh Unnikrishnan         if (use_collograd_parallelization) {
297*6ca0f394SUmesh Unnikrishnan           h_G.inputs[i] = basis_impl->d_collo_grad_1d;
298*6ca0f394SUmesh Unnikrishnan           code << "  local CeedScalar s_G_in_" << i << "[" << Q_1d * Q_1d << "];\n";
299*6ca0f394SUmesh Unnikrishnan           code << "  loadMatrix(Q_1D*Q_1D, G->inputs[" << i << "], s_G_in_" << i << ");\n";
300*6ca0f394SUmesh Unnikrishnan         } else {
301*6ca0f394SUmesh Unnikrishnan           bool has_collo_grad = !!basis_impl->d_collo_grad_1d;
302*6ca0f394SUmesh Unnikrishnan           h_G.inputs[i]       = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
303*6ca0f394SUmesh Unnikrishnan           code << "  local CeedScalar s_G_in_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
304*6ca0f394SUmesh Unnikrishnan           code << "  loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_in_" + std::to_string(i))) << "*Q_1D, G->inputs[" << i << "], s_G_in_" << i
305*6ca0f394SUmesh Unnikrishnan                << ");\n";
306*6ca0f394SUmesh Unnikrishnan         }
307*6ca0f394SUmesh Unnikrishnan         break;
308*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_WEIGHT:
309*6ca0f394SUmesh Unnikrishnan         break;  // No action
310*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_DIV:
311*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
312*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_CURL:
313*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
314*6ca0f394SUmesh Unnikrishnan     }
315*6ca0f394SUmesh Unnikrishnan   }
316*6ca0f394SUmesh Unnikrishnan 
317*6ca0f394SUmesh Unnikrishnan   code << "\n  // -- Output field constants and basis data --\n";
318*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
319*6ca0f394SUmesh Unnikrishnan     code << "  // ---- Output field " << i << " ----\n";
320*6ca0f394SUmesh Unnikrishnan     // Get elem_size, eval_mode, num_comp
321*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &Erestrict));
322*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size));
323*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
324*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp));
325*6ca0f394SUmesh Unnikrishnan 
326*6ca0f394SUmesh Unnikrishnan     // Set field constants
327*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
328*6ca0f394SUmesh Unnikrishnan     if (basis != CEED_BASIS_COLLOCATED) {
329*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
330*6ca0f394SUmesh Unnikrishnan       code << "  const CeedInt P_out_" << i << " = " << P_1d << ";\n";
331*6ca0f394SUmesh Unnikrishnan     } else {
332*6ca0f394SUmesh Unnikrishnan       code << "  const CeedInt P_out_" << i << " = " << Q_1d << ";\n";
333*6ca0f394SUmesh Unnikrishnan     }
334*6ca0f394SUmesh Unnikrishnan     code << "  const CeedInt num_comp_out_" << i << " = " << num_comp << ";\n";
335*6ca0f394SUmesh Unnikrishnan 
336*6ca0f394SUmesh Unnikrishnan     // Load basis data
337*6ca0f394SUmesh Unnikrishnan     code << "  // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
338*6ca0f394SUmesh Unnikrishnan     switch (eval_mode) {
339*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_NONE:
340*6ca0f394SUmesh Unnikrishnan         break;  // No action
341*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_INTERP:
342*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
343*6ca0f394SUmesh Unnikrishnan         h_B.outputs[i] = basis_impl->d_interp_1d;
344*6ca0f394SUmesh Unnikrishnan         code << "  local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
345*6ca0f394SUmesh Unnikrishnan         code << "  loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
346*6ca0f394SUmesh Unnikrishnan         break;
347*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_GRAD:
348*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
349*6ca0f394SUmesh Unnikrishnan         h_B.outputs[i] = basis_impl->d_interp_1d;
350*6ca0f394SUmesh Unnikrishnan         code << "  local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
351*6ca0f394SUmesh Unnikrishnan         code << "  loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
352*6ca0f394SUmesh Unnikrishnan         if (use_collograd_parallelization) {
353*6ca0f394SUmesh Unnikrishnan           h_G.outputs[i] = basis_impl->d_collo_grad_1d;
354*6ca0f394SUmesh Unnikrishnan           code << "  local CeedScalar s_G_out_" << i << "[" << Q_1d * Q_1d << "];\n";
355*6ca0f394SUmesh Unnikrishnan           code << "  loadMatrix(Q_1D*Q_1D, G->outputs[" << i << "], s_G_out_" << i << ");\n";
356*6ca0f394SUmesh Unnikrishnan         } else {
357*6ca0f394SUmesh Unnikrishnan           bool has_collo_grad = !!basis_impl->d_collo_grad_1d;
358*6ca0f394SUmesh Unnikrishnan           h_G.outputs[i]      = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
359*6ca0f394SUmesh Unnikrishnan           code << "  local CeedScalar s_G_out_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
360*6ca0f394SUmesh Unnikrishnan           code << "  loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_out_" + std::to_string(i))) << "*Q_1D, G->outputs[" << i << "], s_G_out_" << i
361*6ca0f394SUmesh Unnikrishnan                << ");\n";
362*6ca0f394SUmesh Unnikrishnan         }
363*6ca0f394SUmesh Unnikrishnan         break;
364*6ca0f394SUmesh Unnikrishnan       // LCOV_EXCL_START
365*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_WEIGHT: {
366*6ca0f394SUmesh Unnikrishnan         Ceed ceed;
367*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
368*6ca0f394SUmesh Unnikrishnan         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
369*6ca0f394SUmesh Unnikrishnan         break;  // Should not occur
370*6ca0f394SUmesh Unnikrishnan       }
371*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_DIV:
372*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
373*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_CURL:
374*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
375*6ca0f394SUmesh Unnikrishnan                 // LCOV_EXCL_STOP
376*6ca0f394SUmesh Unnikrishnan     }
377*6ca0f394SUmesh Unnikrishnan   }
378*6ca0f394SUmesh Unnikrishnan   code << "\n  // -- Element loop --\n";
379*6ca0f394SUmesh Unnikrishnan   code << "  work_group_barrier(CLK_LOCAL_MEM_FENCE);\n";
380*6ca0f394SUmesh Unnikrishnan   code << "  {\n";
381*6ca0f394SUmesh Unnikrishnan   // Input basis apply if needed
382*6ca0f394SUmesh Unnikrishnan   // Generate the correct eval mode code for each input
383*6ca0f394SUmesh Unnikrishnan   code << "    // -- Input field restrictions and basis actions --\n";
384*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
385*6ca0f394SUmesh Unnikrishnan     code << "    // ---- Input field " << i << " ----\n";
386*6ca0f394SUmesh Unnikrishnan     // Get elem_size, eval_mode, num_comp
387*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict));
388*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size));
389*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
390*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp));
391*6ca0f394SUmesh Unnikrishnan 
392*6ca0f394SUmesh Unnikrishnan     // Restriction
393*6ca0f394SUmesh Unnikrishnan     if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
394*6ca0f394SUmesh Unnikrishnan       code << "    CeedScalar r_u_" << i << "[num_comp_in_" << i << "*P_in_" << i << "];\n";
395*6ca0f394SUmesh Unnikrishnan 
396*6ca0f394SUmesh Unnikrishnan       bool is_strided;
397*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided));
398*6ca0f394SUmesh Unnikrishnan       if (!is_strided) {
399*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize));
400*6ca0f394SUmesh Unnikrishnan         code << "    const CeedInt lsize_in_" << i << " = " << lsize << ";\n";
401*6ca0f394SUmesh Unnikrishnan         CeedInt comp_stride;
402*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride));
403*6ca0f394SUmesh Unnikrishnan         code << "    // CompStride: " << comp_stride << "\n";
404*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl));
405*6ca0f394SUmesh Unnikrishnan         h_indices.inputs[i] = restr_impl->d_ind;
406*6ca0f394SUmesh Unnikrishnan         code << "    readDofsOffset" << dim << "d(num_comp_in_" << i << ", " << comp_stride << ", P_in_" << i << ", num_elem, indices->inputs[" << i
407*6ca0f394SUmesh Unnikrishnan              << "], d_u_" << i << ", r_u_" << i << ");\n";
408*6ca0f394SUmesh Unnikrishnan       } else {
409*6ca0f394SUmesh Unnikrishnan         bool has_backend_strides;
410*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides));
411*6ca0f394SUmesh Unnikrishnan         CeedInt num_elem;
412*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem));
413*6ca0f394SUmesh Unnikrishnan         CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
414*6ca0f394SUmesh Unnikrishnan         if (!has_backend_strides) {
415*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides));
416*6ca0f394SUmesh Unnikrishnan         }
417*6ca0f394SUmesh Unnikrishnan         code << "    // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
418*6ca0f394SUmesh Unnikrishnan         code << "    readDofsStrided" << dim << "d(num_comp_in_" << i << ",P_in_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
419*6ca0f394SUmesh Unnikrishnan              << ", num_elem, d_u_" << i << ", r_u_" << i << ");\n";
420*6ca0f394SUmesh Unnikrishnan       }
421*6ca0f394SUmesh Unnikrishnan     }
422*6ca0f394SUmesh Unnikrishnan 
423*6ca0f394SUmesh Unnikrishnan     // Basis action
424*6ca0f394SUmesh Unnikrishnan     code << "    // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
425*6ca0f394SUmesh Unnikrishnan     switch (eval_mode) {
426*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_NONE:
427*6ca0f394SUmesh Unnikrishnan         if (!use_collograd_parallelization) {
428*6ca0f394SUmesh Unnikrishnan           code << "    private CeedScalar* r_t_" << i << " = r_u_" << i << ";\n";
429*6ca0f394SUmesh Unnikrishnan         }
430*6ca0f394SUmesh Unnikrishnan         break;
431*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_INTERP:
432*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
433*6ca0f394SUmesh Unnikrishnan         code << "    Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_" << i
434*6ca0f394SUmesh Unnikrishnan              << ", r_t_" << i << ", elem_scratch);\n";
435*6ca0f394SUmesh Unnikrishnan         break;
436*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_GRAD:
437*6ca0f394SUmesh Unnikrishnan         if (use_collograd_parallelization) {
438*6ca0f394SUmesh Unnikrishnan           code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
439*6ca0f394SUmesh Unnikrishnan           code << "    Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_"
440*6ca0f394SUmesh Unnikrishnan                << i << ", r_t_" << i << ", elem_scratch);\n";
441*6ca0f394SUmesh Unnikrishnan         } else {
442*6ca0f394SUmesh Unnikrishnan           CeedInt P_1d;
443*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
444*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
445*6ca0f394SUmesh Unnikrishnan           code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*DIM*Q_1D];\n";
446*6ca0f394SUmesh Unnikrishnan           code << "    Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_in_" << i
447*6ca0f394SUmesh Unnikrishnan                << ", P_in_" << i << ", Q_1D, r_u_" << i << (dim > 1 ? ", s_B_in_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_in_" << i
448*6ca0f394SUmesh Unnikrishnan                << ", r_t_" << i << ", elem_scratch);\n";
449*6ca0f394SUmesh Unnikrishnan         }
450*6ca0f394SUmesh Unnikrishnan         break;
451*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_WEIGHT:
452*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_t_" << i << "[Q_1D];\n";
453*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
454*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
455*6ca0f394SUmesh Unnikrishnan         impl->W = basis_impl->d_q_weight_1d;
456*6ca0f394SUmesh Unnikrishnan         code << "    Weight" << (dim > 1 ? "Tensor" : "") << dim << "d(Q_1D, W, r_t_" << i << ");\n";
457*6ca0f394SUmesh Unnikrishnan         break;  // No action
458*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_DIV:
459*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
460*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_CURL:
461*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
462*6ca0f394SUmesh Unnikrishnan     }
463*6ca0f394SUmesh Unnikrishnan   }
464*6ca0f394SUmesh Unnikrishnan 
465*6ca0f394SUmesh Unnikrishnan   // Q function
466*6ca0f394SUmesh Unnikrishnan   code << "\n    // -- Output field setup --\n";
467*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
468*6ca0f394SUmesh Unnikrishnan     code << "\n    // ---- Output field " << i << " ----\n";
469*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
470*6ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_GRAD) {
471*6ca0f394SUmesh Unnikrishnan       if (use_collograd_parallelization) {
472*6ca0f394SUmesh Unnikrishnan         // Accumulator for gradient slices
473*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
474*6ca0f394SUmesh Unnikrishnan         code << "    for (CeedInt i = 0; i < num_comp_out_" << i << "; i++) {\n";
475*6ca0f394SUmesh Unnikrishnan         code << "      for (CeedInt j = 0; j < Q_1D; ++j) {\n";
476*6ca0f394SUmesh Unnikrishnan         code << "        r_tt_" << i << "[j + i*Q_1D] = 0.0;\n";
477*6ca0f394SUmesh Unnikrishnan         code << "      }\n";
478*6ca0f394SUmesh Unnikrishnan         code << "    }\n";
479*6ca0f394SUmesh Unnikrishnan       } else {
480*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*DIM*Q_1D];\n";
481*6ca0f394SUmesh Unnikrishnan       }
482*6ca0f394SUmesh Unnikrishnan     }
483*6ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_NONE || eval_mode == CEED_EVAL_INTERP) {
484*6ca0f394SUmesh Unnikrishnan       code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
485*6ca0f394SUmesh Unnikrishnan     }
486*6ca0f394SUmesh Unnikrishnan   }
487*6ca0f394SUmesh Unnikrishnan   // We treat quadrature points per slice in 3d to save registers
488*6ca0f394SUmesh Unnikrishnan   if (use_collograd_parallelization) {
489*6ca0f394SUmesh Unnikrishnan     code << "\n    // Note: Using planes of 3D elements\n";
490*6ca0f394SUmesh Unnikrishnan     code << "    for (CeedInt q = 0; q < Q_1D; q++) {\n";
491*6ca0f394SUmesh Unnikrishnan     code << "      // -- Input fields --\n";
492*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_input_fields; i++) {
493*6ca0f394SUmesh Unnikrishnan       code << "      // ---- Input field " << i << " ----\n";
494*6ca0f394SUmesh Unnikrishnan       // Get elem_size, eval_mode, num_comp
495*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
496*6ca0f394SUmesh Unnikrishnan       // Basis action
497*6ca0f394SUmesh Unnikrishnan       code << "      // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
498*6ca0f394SUmesh Unnikrishnan       switch (eval_mode) {
499*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_NONE:
500*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
501*6ca0f394SUmesh Unnikrishnan 
502*6ca0f394SUmesh Unnikrishnan           bool is_strided;
503*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict));
504*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided));
505*6ca0f394SUmesh Unnikrishnan           if (!is_strided) {
506*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize));
507*6ca0f394SUmesh Unnikrishnan             code << "      const CeedInt lsize_in_" << i << " = " << lsize << ";\n";
508*6ca0f394SUmesh Unnikrishnan             CeedInt comp_stride;
509*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride));
510*6ca0f394SUmesh Unnikrishnan             code << "      // CompStride: " << comp_stride << "\n";
511*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl));
512*6ca0f394SUmesh Unnikrishnan             h_indices.inputs[i] = restr_impl->d_ind;
513*6ca0f394SUmesh Unnikrishnan             code << "      readSliceQuadsOffset"
514*6ca0f394SUmesh Unnikrishnan                  << "3d(num_comp_in_" << i << ", " << comp_stride << ", Q_1D, lsize_in_" << i << ", num_elem, q, indices->inputs[" << i << "], d_u_"
515*6ca0f394SUmesh Unnikrishnan                  << i << ", r_q_" << i << ");\n";
516*6ca0f394SUmesh Unnikrishnan           } else {
517*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size));
518*6ca0f394SUmesh Unnikrishnan             bool has_backend_strides;
519*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides));
520*6ca0f394SUmesh Unnikrishnan             CeedInt num_elem;
521*6ca0f394SUmesh Unnikrishnan             CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem));
522*6ca0f394SUmesh Unnikrishnan             CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
523*6ca0f394SUmesh Unnikrishnan             if (!has_backend_strides) {
524*6ca0f394SUmesh Unnikrishnan               CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides));
525*6ca0f394SUmesh Unnikrishnan             }
526*6ca0f394SUmesh Unnikrishnan             code << "      // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
527*6ca0f394SUmesh Unnikrishnan             code << "      readSliceQuadsStrided"
528*6ca0f394SUmesh Unnikrishnan                  << "3d(num_comp_in_" << i << ", Q_1D," << strides[0] << ", " << strides[1] << ", " << strides[2] << ", num_elem, q, d_u_" << i
529*6ca0f394SUmesh Unnikrishnan                  << ", r_q_" << i << ");\n";
530*6ca0f394SUmesh Unnikrishnan           }
531*6ca0f394SUmesh Unnikrishnan           break;
532*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_INTERP:
533*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
534*6ca0f394SUmesh Unnikrishnan           code << "      for (CeedInt j = 0; j < num_comp_in_" << i << " ; ++j) {\n";
535*6ca0f394SUmesh Unnikrishnan           code << "        r_q_" << i << "[j] = r_t_" << i << "[q + j*Q_1D];\n";
536*6ca0f394SUmesh Unnikrishnan           code << "      }\n";
537*6ca0f394SUmesh Unnikrishnan           break;
538*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_GRAD:
539*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "*DIM];\n";
540*6ca0f394SUmesh Unnikrishnan           code << "      gradCollo3d(num_comp_in_" << i << ", Q_1D, q, r_t_" << i << ", s_G_in_" << i << ", r_q_" << i << ", elem_scratch);\n";
541*6ca0f394SUmesh Unnikrishnan           break;
542*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_WEIGHT:
543*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_q_" << i << "[1];\n";
544*6ca0f394SUmesh Unnikrishnan           code << "      r_q_" << i << "[0] = r_t_" << i << "[q];\n";
545*6ca0f394SUmesh Unnikrishnan           break;  // No action
546*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_DIV:
547*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
548*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_CURL:
549*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
550*6ca0f394SUmesh Unnikrishnan       }
551*6ca0f394SUmesh Unnikrishnan     }
552*6ca0f394SUmesh Unnikrishnan     code << "\n      // -- Output fields --\n";
553*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_output_fields; i++) {
554*6ca0f394SUmesh Unnikrishnan       code << "      // ---- Output field " << i << " ----\n";
555*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
556*6ca0f394SUmesh Unnikrishnan       // Basis action
557*6ca0f394SUmesh Unnikrishnan       switch (eval_mode) {
558*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_NONE:
559*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
560*6ca0f394SUmesh Unnikrishnan           break;  // No action
561*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_INTERP:
562*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
563*6ca0f394SUmesh Unnikrishnan           break;
564*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_GRAD:
565*6ca0f394SUmesh Unnikrishnan           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "*DIM];\n";
566*6ca0f394SUmesh Unnikrishnan           break;
567*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_WEIGHT:
568*6ca0f394SUmesh Unnikrishnan           break;  // Should not occur
569*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_DIV:
570*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
571*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_CURL:
572*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
573*6ca0f394SUmesh Unnikrishnan       }
574*6ca0f394SUmesh Unnikrishnan     }
575*6ca0f394SUmesh Unnikrishnan   } else {
576*6ca0f394SUmesh Unnikrishnan     code << "\n      // Note: Using full elements\n";
577*6ca0f394SUmesh Unnikrishnan     code << "      // -- Input fields --\n";
578*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_input_fields; i++) {
579*6ca0f394SUmesh Unnikrishnan       code << "      // ---- Input field " << i << " ----\n";
580*6ca0f394SUmesh Unnikrishnan       code << "      private CeedScalar* r_q_" << i << " = r_t_" << i << ";\n";
581*6ca0f394SUmesh Unnikrishnan     }
582*6ca0f394SUmesh Unnikrishnan     code << "      // -- Output fields --\n";
583*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_output_fields; i++) {
584*6ca0f394SUmesh Unnikrishnan       code << "      // ---- Output field " << i << " ----\n";
585*6ca0f394SUmesh Unnikrishnan       code << "      private CeedScalar* r_qq_" << i << " = r_tt_" << i << ";\n";
586*6ca0f394SUmesh Unnikrishnan     }
587*6ca0f394SUmesh Unnikrishnan   }
588*6ca0f394SUmesh Unnikrishnan   //--------------------------------------------------
589*6ca0f394SUmesh Unnikrishnan   code << "\n      // -- QFunction Inputs and outputs --\n";
590*6ca0f394SUmesh Unnikrishnan   code << "      const CeedScalar * in[" << num_input_fields << "];\n";
591*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
592*6ca0f394SUmesh Unnikrishnan     code << "      // ---- Input field " << i << " ----\n";
593*6ca0f394SUmesh Unnikrishnan     code << "      in[" << i << "] = r_q_" << i << ";\n";
594*6ca0f394SUmesh Unnikrishnan   }
595*6ca0f394SUmesh Unnikrishnan   code << "      CeedScalar * out[" << num_output_fields << "];\n";
596*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
597*6ca0f394SUmesh Unnikrishnan     code << "      // ---- Output field " << i << " ----\n";
598*6ca0f394SUmesh Unnikrishnan     code << "      out[" << i << "] = r_qq_" << i << ";\n";
599*6ca0f394SUmesh Unnikrishnan   }
600*6ca0f394SUmesh Unnikrishnan 
601*6ca0f394SUmesh Unnikrishnan   code << "\n      // -- Apply QFunction --\n";
602*6ca0f394SUmesh Unnikrishnan   code << "      " << q_function_name << "(ctx, ";
603*6ca0f394SUmesh Unnikrishnan   if (dim != 3 || use_collograd_parallelization) {
604*6ca0f394SUmesh Unnikrishnan     code << "1";
605*6ca0f394SUmesh Unnikrishnan   } else {
606*6ca0f394SUmesh Unnikrishnan     code << "Q_1D";
607*6ca0f394SUmesh Unnikrishnan   }
608*6ca0f394SUmesh Unnikrishnan   code << ", in, out);\n";
609*6ca0f394SUmesh Unnikrishnan   //--------------------------------------------------
610*6ca0f394SUmesh Unnikrishnan 
611*6ca0f394SUmesh Unnikrishnan   if (use_collograd_parallelization) {
612*6ca0f394SUmesh Unnikrishnan     code << "      // -- Output fields --\n";
613*6ca0f394SUmesh Unnikrishnan     for (CeedInt i = 0; i < num_output_fields; i++) {
614*6ca0f394SUmesh Unnikrishnan       code << "      // ---- Output field " << i << " ----\n";
615*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
616*6ca0f394SUmesh Unnikrishnan       // Basis action
617*6ca0f394SUmesh Unnikrishnan       code << "      // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
618*6ca0f394SUmesh Unnikrishnan       switch (eval_mode) {
619*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_NONE:
620*6ca0f394SUmesh Unnikrishnan           code << "      for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
621*6ca0f394SUmesh Unnikrishnan           code << "        r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
622*6ca0f394SUmesh Unnikrishnan           code << "      }\n";
623*6ca0f394SUmesh Unnikrishnan           break;  // No action
624*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_INTERP:
625*6ca0f394SUmesh Unnikrishnan           code << "      for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
626*6ca0f394SUmesh Unnikrishnan           code << "        r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
627*6ca0f394SUmesh Unnikrishnan           code << "      }\n";
628*6ca0f394SUmesh Unnikrishnan           break;
629*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_GRAD:
630*6ca0f394SUmesh Unnikrishnan           code << "      gradColloTranspose3d(num_comp_out_" << i << ",Q_1D, q, r_qq_" << i << ", s_G_out_" << i << ", r_tt_" << i
631*6ca0f394SUmesh Unnikrishnan                << ", elem_scratch);\n";
632*6ca0f394SUmesh Unnikrishnan           break;
633*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_WEIGHT:
634*6ca0f394SUmesh Unnikrishnan           break;  // Should not occur
635*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_DIV:
636*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
637*6ca0f394SUmesh Unnikrishnan         case CEED_EVAL_CURL:
638*6ca0f394SUmesh Unnikrishnan           break;  // TODO: Not implemented
639*6ca0f394SUmesh Unnikrishnan       }
640*6ca0f394SUmesh Unnikrishnan     }
641*6ca0f394SUmesh Unnikrishnan     code << "    }\n";
642*6ca0f394SUmesh Unnikrishnan   }
643*6ca0f394SUmesh Unnikrishnan 
644*6ca0f394SUmesh Unnikrishnan   // Output basis apply if needed
645*6ca0f394SUmesh Unnikrishnan   // Generate the correct eval mode code for each output
646*6ca0f394SUmesh Unnikrishnan   code << "\n    // -- Output field basis action and restrictions --\n";
647*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
648*6ca0f394SUmesh Unnikrishnan     code << "    // ---- Output field " << i << " ----\n";
649*6ca0f394SUmesh Unnikrishnan     // Get elem_size, eval_mode, num_comp
650*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &Erestrict));
651*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size));
652*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
653*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp));
654*6ca0f394SUmesh Unnikrishnan     // Basis action
655*6ca0f394SUmesh Unnikrishnan     code << "    // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
656*6ca0f394SUmesh Unnikrishnan     switch (eval_mode) {
657*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_NONE:
658*6ca0f394SUmesh Unnikrishnan         code << "    private CeedScalar* r_v_" << i << " = r_tt_" << i << ";\n";
659*6ca0f394SUmesh Unnikrishnan         break;  // No action
660*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_INTERP:
661*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
662*6ca0f394SUmesh Unnikrishnan         code << "    InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
663*6ca0f394SUmesh Unnikrishnan              << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
664*6ca0f394SUmesh Unnikrishnan         break;
665*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_GRAD:
666*6ca0f394SUmesh Unnikrishnan         code << "    CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
667*6ca0f394SUmesh Unnikrishnan         if (use_collograd_parallelization) {
668*6ca0f394SUmesh Unnikrishnan           code << "    InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
669*6ca0f394SUmesh Unnikrishnan                << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
670*6ca0f394SUmesh Unnikrishnan         } else {
671*6ca0f394SUmesh Unnikrishnan           CeedInt P_1d;
672*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
673*6ca0f394SUmesh Unnikrishnan           CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
674*6ca0f394SUmesh Unnikrishnan           code << "    GradTranspose" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_out_" << i
675*6ca0f394SUmesh Unnikrishnan                << ", P_out_" << i << ", Q_1D, r_tt_" << i << (dim > 1 ? ", s_B_out_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_out_" << i
676*6ca0f394SUmesh Unnikrishnan                << ", r_v_" << i << ", elem_scratch);\n";
677*6ca0f394SUmesh Unnikrishnan         }
678*6ca0f394SUmesh Unnikrishnan         break;
679*6ca0f394SUmesh Unnikrishnan       // LCOV_EXCL_START
680*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_WEIGHT: {
681*6ca0f394SUmesh Unnikrishnan         Ceed ceed;
682*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
683*6ca0f394SUmesh Unnikrishnan         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
684*6ca0f394SUmesh Unnikrishnan         break;  // Should not occur
685*6ca0f394SUmesh Unnikrishnan       }
686*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_DIV:
687*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
688*6ca0f394SUmesh Unnikrishnan       case CEED_EVAL_CURL:
689*6ca0f394SUmesh Unnikrishnan         break;  // TODO: Not implemented
690*6ca0f394SUmesh Unnikrishnan                 // LCOV_EXCL_STOP
691*6ca0f394SUmesh Unnikrishnan     }
692*6ca0f394SUmesh Unnikrishnan     // Restriction
693*6ca0f394SUmesh Unnikrishnan     bool is_strided;
694*6ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided));
695*6ca0f394SUmesh Unnikrishnan     if (!is_strided) {
696*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize));
697*6ca0f394SUmesh Unnikrishnan       code << "    const CeedInt lsize_out_" << i << " = " << lsize << ";\n";
698*6ca0f394SUmesh Unnikrishnan       CeedInt comp_stride;
699*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride));
700*6ca0f394SUmesh Unnikrishnan       code << "    // CompStride: " << comp_stride << "\n";
701*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl));
702*6ca0f394SUmesh Unnikrishnan       h_indices.outputs[i] = restr_impl->d_ind;
703*6ca0f394SUmesh Unnikrishnan       code << "    writeDofsOffset" << dim << "d(num_comp_out_" << i << ", " << comp_stride << ", P_out_" << i << ", num_elem, indices->outputs[" << i
704*6ca0f394SUmesh Unnikrishnan            << "], r_v_" << i << ", d_v_" << i << ");\n";
705*6ca0f394SUmesh Unnikrishnan     } else {
706*6ca0f394SUmesh Unnikrishnan       bool has_backend_strides;
707*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides));
708*6ca0f394SUmesh Unnikrishnan       CeedInt num_elem;
709*6ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem));
710*6ca0f394SUmesh Unnikrishnan       CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
711*6ca0f394SUmesh Unnikrishnan       if (!has_backend_strides) {
712*6ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides));
713*6ca0f394SUmesh Unnikrishnan       }
714*6ca0f394SUmesh Unnikrishnan       code << "    // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
715*6ca0f394SUmesh Unnikrishnan       code << "    writeDofsStrided" << dim << "d(num_comp_out_" << i << ",P_out_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
716*6ca0f394SUmesh Unnikrishnan            << ", num_elem, r_v_" << i << ", d_v_" << i << ");\n";
717*6ca0f394SUmesh Unnikrishnan     }
718*6ca0f394SUmesh Unnikrishnan   }
719*6ca0f394SUmesh Unnikrishnan 
720*6ca0f394SUmesh Unnikrishnan   code << "  }\n";
721*6ca0f394SUmesh Unnikrishnan   code << "}\n";
722*6ca0f394SUmesh Unnikrishnan   code << "// -----------------------------------------------------------------------------\n\n";
723*6ca0f394SUmesh Unnikrishnan 
724*6ca0f394SUmesh Unnikrishnan   // Copy the struct (containing device addresses) from the host to the device
725*6ca0f394SUmesh Unnikrishnan   sycl::event copy_B       = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_B, impl->B, 1);
726*6ca0f394SUmesh Unnikrishnan   sycl::event copy_G       = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_G, impl->G, 1);
727*6ca0f394SUmesh Unnikrishnan   sycl::event copy_indices = sycl_data->sycl_queue.copy<FieldsInt_Sycl>(&h_indices, impl->indices, 1);
728*6ca0f394SUmesh Unnikrishnan   // These copies can happen while the JIT is being done
729*6ca0f394SUmesh Unnikrishnan   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_B, copy_G, copy_indices}));
730*6ca0f394SUmesh Unnikrishnan 
731*6ca0f394SUmesh Unnikrishnan   // View kernel for debugging
732*6ca0f394SUmesh Unnikrishnan   CeedDebug256(ceed, 2, "Generated Operator Kernels:\n");
733*6ca0f394SUmesh Unnikrishnan   CeedDebug(ceed, code.str().c_str());
734*6ca0f394SUmesh Unnikrishnan 
735*6ca0f394SUmesh Unnikrishnan   std::map<std::string, CeedInt> jit_constants;
736*6ca0f394SUmesh Unnikrishnan   jit_constants["T_1D"]         = block_sizes[0];
737*6ca0f394SUmesh Unnikrishnan   jit_constants["GROUP_SIZE_X"] = block_sizes[0];
738*6ca0f394SUmesh Unnikrishnan   jit_constants["GROUP_SIZE_Y"] = block_sizes[1];
739*6ca0f394SUmesh Unnikrishnan   jit_constants["GROUP_SIZE_Z"] = block_sizes[2];
740*6ca0f394SUmesh Unnikrishnan 
741*6ca0f394SUmesh Unnikrishnan   // Compile kernel into a kernel bundle
742*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module, jit_constants));
743*6ca0f394SUmesh Unnikrishnan 
744*6ca0f394SUmesh Unnikrishnan   // Load kernel function
745*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, operator_name, &impl->op));
746*6ca0f394SUmesh Unnikrishnan 
747*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorSetSetupDone(op));
748*6ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
749*6ca0f394SUmesh Unnikrishnan }
750*6ca0f394SUmesh Unnikrishnan 
751*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
752