xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp (revision 22070f9510d4ff69aa6119a59aa3c46d57cc1cc7)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #define CEED_DEBUG_COLOR 12
9 
10 #include <ceed/backend.h>
11 #include <ceed/ceed.h>
12 #include <ceed/jit-source/sycl/sycl-types.h>
13 #include <ceed/jit-tools.h>
14 
15 #include <iostream>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19 #include <vector>
20 
21 #include "../sycl-ref/ceed-sycl-ref.hpp"
22 #include "../sycl-shared/ceed-sycl-shared.hpp"
23 #include "../sycl/ceed-sycl-compile.hpp"
24 
25 #include "ceed-sycl-gen.hpp"
26 
27 //------------------------------------------------------------------------------
28 // Calculate the block size used for launching the operator kernel
29 //------------------------------------------------------------------------------
30 extern "C" int BlockGridCalculate_Sycl_gen(const CeedInt dim, const CeedInt P_1d, const CeedInt Q_1d, CeedInt *block_sizes) {
31   const CeedInt thread1d = CeedIntMax(Q_1d, P_1d);
32 
33   if (dim == 1) {
34     CeedInt elems_per_block = 64 * thread1d > 256 ? 256 / thread1d : 64;
35 
36     elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
37     block_sizes[0]  = thread1d;
38     block_sizes[1]  = 1;
39     block_sizes[2]  = elems_per_block;
40   } else if (dim == 2) {
41     const CeedInt elems_per_block = thread1d < 4 ? 16 : 2;
42 
43     block_sizes[0] = thread1d;
44     block_sizes[1] = thread1d;
45     block_sizes[2] = elems_per_block;
46   } else if (dim == 3) {
47     const CeedInt elems_per_block = thread1d < 6 ? 4 : (thread1d < 8 ? 2 : 1);
48 
49     block_sizes[0] = thread1d;
50     block_sizes[1] = thread1d;
51     block_sizes[2] = elems_per_block;
52   }
53   return CEED_ERROR_SUCCESS;
54 }
55 
56 //------------------------------------------------------------------------------
57 // Build single operator kernel
58 // - [ ] Check arguments to device functions reudsed from sycl-shared-basis are correct
59 // - [ ] Do kernel jitting!
60 //------------------------------------------------------------------------------
61 extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
62   Ceed                      ceed;
63   Ceed_Sycl                *sycl_data;
64   bool                      is_setup_done, is_identity_qf;
65   CeedSize                  l_size;
66   CeedInt                   Q, P_1d = 0, Q_1d = 0, elem_size, num_input_fields, num_output_fields, num_comp, dim = 1;
67   Fields_Sycl               h_B, h_G;
68   FieldsInt_Sycl            h_indices;
69   CeedEvalMode              eval_mode;
70   CeedElemRestriction       elem_rstr;
71   CeedElemRestriction_Sycl *rstr_impl;
72   CeedBasis                 basis;
73   CeedBasis_Sycl_shared    *basis_impl;
74   CeedQFunctionField       *qf_input_fields, *qf_output_fields;
75   CeedQFunction_Sycl_gen   *qf_impl;
76   CeedQFunction             qf;
77   CeedOperatorField        *op_input_fields, *op_output_fields;
78   CeedOperator_Sycl_gen    *impl;
79 
80   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
81   if (is_setup_done) return CEED_ERROR_SUCCESS;
82 
83   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
84   CeedCallBackend(CeedGetData(ceed, &sycl_data));
85 
86   CeedCallBackend(CeedOperatorGetData(op, &impl));
87   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
88   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
89   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
90   Q_1d = Q;
91 
92   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
93   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
94 
95   // Check for restriction only identity operator
96   CeedCallBackend(CeedQFunctionIsIdentity(qf, &is_identity_qf));
97   if (is_identity_qf) {
98     CeedEvalMode eval_mode_in, eval_mode_out;
99 
100     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[0], &eval_mode_in));
101     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[0], &eval_mode_out));
102     if (eval_mode_in == CEED_EVAL_NONE && eval_mode_out == CEED_EVAL_NONE) {
103       // LCOV_EXCL_START
104       return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement restriction only identity operators");
105       // LCOV_EXCL_STOP
106     }
107   }
108 
109   std::ostringstream code;
110   // TODO: generalize to accept different device functions?
111   {
112     char       *tensor_basis_code;
113     const char *tensor_basis_kernel_path;
114 
115     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h", &tensor_basis_kernel_path));
116     CeedDebug256(ceed, 2, "----- Loading Tensor Basis Kernel Source -----\n");
117     CeedCallBackend(CeedLoadSourceToBuffer(ceed, tensor_basis_kernel_path, &tensor_basis_code));
118     code << tensor_basis_code;
119     CeedCallBackend(CeedFree(&tensor_basis_kernel_path));
120     CeedCallBackend(CeedFree(&tensor_basis_code));
121   }
122   {
123     char       *sycl_gen_template_source;
124     const char *sycl_gen_template_path;
125 
126     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-gen-templates.h", &sycl_gen_template_path));
127     CeedDebug256(ceed, 2, "----- Loading Sycl-Gen Template Source -----\n");
128     CeedCallBackend(CeedLoadSourceToBuffer(ceed, sycl_gen_template_path, &sycl_gen_template_source));
129     code << sycl_gen_template_source;
130     CeedCallBackend(CeedFree(&sycl_gen_template_path));
131     CeedCallBackend(CeedFree(&sycl_gen_template_source));
132   }
133 
134   std::string_view  qfunction_source(qf_impl->qfunction_source);
135   std::string_view  qfunction_name(qf_impl->qfunction_name);
136   const std::string operator_name = "CeedKernelSyclGenOperator_" + std::string(qfunction_name);
137 
138   // Find dim, P_1d, Q_1d
139   impl->max_P_1d = 0;
140   for (CeedInt i = 0; i < num_input_fields; i++) {
141     CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
142     if (basis != CEED_BASIS_NONE) {
143       bool is_tensor;
144 
145       CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
146       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
147 
148       // Collect dim, P_1d, and Q_1d
149       CeedCallBackend(CeedBasisGetDimension(basis, &dim));
150       CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
151       if (is_tensor) {
152         CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
153         CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
154         if (P_1d > impl->max_P_1d) impl->max_P_1d = P_1d;
155       } else {
156         // LCOV_EXCL_START
157         return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
158         // LCOV_EXCL_STOP
159       }
160     }
161   }
162   // Check output bases for Q_1d, dim as well
163   //   The only input basis might be CEED_BASIS_NONE
164   for (CeedInt i = 0; i < num_output_fields; i++) {
165     CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
166 
167     if (basis != CEED_BASIS_NONE) {
168       bool is_tensor;
169 
170       CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
171       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
172 
173       // Collect Q_1d
174       CeedCallBackend(CeedBasisGetDimension(basis, &dim));
175       CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
176       if (is_tensor) {
177         CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
178       } else {
179         // LCOV_EXCL_START
180         return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
181         // LCOV_EXCL_STOP
182       }
183     }
184   }
185   impl->dim  = dim;
186   impl->Q_1d = Q_1d;
187 
188   // Only use 3D collocated gradient parallelization strategy when gradient is computed
189   // TODO: put in a function?
190   bool use_collograd_parallelization = false;
191 
192   if (dim == 3) {
193     bool was_grad_found = false;
194 
195     for (CeedInt i = 0; i < num_input_fields; i++) {
196       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
197       if (eval_mode == CEED_EVAL_GRAD) {
198         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
199         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
200         use_collograd_parallelization = basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
201         was_grad_found                = true;
202       }
203     }
204     for (CeedInt i = 0; i < num_output_fields; i++) {
205       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
206       if (eval_mode == CEED_EVAL_GRAD) {
207         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
208         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
209         use_collograd_parallelization = basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
210         was_grad_found                = true;
211       }
212     }
213   }
214 
215   CeedInt block_sizes[3];
216   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
217 
218   // Define CEED_Q_VLA
219   code << "\n#undef CEED_Q_VLA\n";
220   if (dim != 3 || use_collograd_parallelization) {
221     code << "#define CEED_Q_VLA 1\n\n";
222   } else {
223     code << "#define CEED_Q_VLA " << Q_1d << "\n\n";
224   }
225 
226   // Determine subgroup size based on supported sizes : Default : 16 (if supported)
227   std::vector allowed_sg_sizes  = sycl_data->sycl_device.get_info<sycl::info::device::sub_group_sizes>();
228   CeedInt     sub_group_size_op = allowed_sg_sizes[allowed_sg_sizes.size() - 1];
229   for (const auto &s : allowed_sg_sizes) {
230     if (s == 16) {
231       sub_group_size_op = s;
232       break;
233     }
234   }
235 
236   code << qfunction_source;
237 
238   // Kernel function
239   code << "\n// -----------------------------------------------------------------------------\n";
240   code << "__attribute__((reqd_work_group_size(GROUP_SIZE_X, GROUP_SIZE_Y, GROUP_SIZE_Z), intel_reqd_sub_group_size(" << sub_group_size_op << ")))\n";
241   code << "kernel void " << operator_name << "(";
242   code << "const CeedInt num_elem, ";
243   code << "global void* ctx, ";
244   code << "global const FieldsInt_Sycl* indices, ";
245   code << "global Fields_Sycl* fields, ";
246   code << "global const Fields_Sycl* B, ";
247   code << "global const Fields_Sycl* G, ";
248   code << "global const CeedScalar * restrict W";
249   code << ") {\n";
250 
251   for (CeedInt i = 0; i < num_input_fields; i++) {
252     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
253     if (eval_mode != CEED_EVAL_WEIGHT) {  // Skip CEED_EVAL_WEIGHT
254       code << "  global const CeedScalar* d_u_" << i << " = fields->inputs[" << i << "];\n";
255     }
256   }
257 
258   for (CeedInt i = 0; i < num_output_fields; i++) {
259     code << "  global CeedScalar* d_v_" << i << " = fields->outputs[" << i << "];\n";
260   }
261 
262   // TODO: Convert these to defined constants to save on GRF
263   code << "  const CeedInt DIM = " << dim << ";\n";
264   code << "  const CeedInt Q_1D = " << Q_1d << ";\n";
265 
266   const CeedInt scratch_size = block_sizes[0] * block_sizes[1] * block_sizes[2];
267   code << "  local CeedScalar scratch[" << scratch_size << "];\n";
268   code << "  local CeedScalar * elem_scratch = scratch + get_local_id(2) * T_1D" << (dim > 1 ? "*T_1D" : "") << ";\n";
269 
270   code << "\n  // -- Input field constants and basis data --\n";
271   // Initialize constants, and matrices B and G
272   for (CeedInt i = 0; i < num_input_fields; i++) {
273     code << "  // ---- Input field " << i << " ----\n";
274     // Get elem_size, eval_mode, num_comp
275     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
276     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
277     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
278     CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
279 
280     // Set field constants
281     if (eval_mode != CEED_EVAL_WEIGHT) {
282       CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
283       if (basis != CEED_BASIS_NONE) {
284         CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
285         code << "  const CeedInt P_in_" << i << " = " << P_1d << ";\n";
286       } else {
287         code << "  const CeedInt P_in_" << i << " = " << Q_1d << ";\n";
288       }
289       code << "  const CeedInt num_comp_in_" << i << " = " << num_comp << ";\n";
290     }
291 
292     // Load basis data
293     code << "  // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
294     switch (eval_mode) {
295       case CEED_EVAL_NONE:
296         break;
297       case CEED_EVAL_INTERP:
298         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
299         h_B.inputs[i] = basis_impl->d_interp_1d;
300         code << "  local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
301         code << "  loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
302         break;
303       case CEED_EVAL_GRAD:
304         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
305         h_B.inputs[i] = basis_impl->d_interp_1d;
306         code << "  local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
307         code << "  loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
308         if (use_collograd_parallelization) {
309           h_G.inputs[i] = basis_impl->d_collo_grad_1d;
310           code << "  local CeedScalar s_G_in_" << i << "[" << Q_1d * Q_1d << "];\n";
311           code << "  loadMatrix(Q_1D*Q_1D, G->inputs[" << i << "], s_G_in_" << i << ");\n";
312         } else {
313           bool has_collo_grad = basis_impl->d_collo_grad_1d;
314           h_G.inputs[i]       = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
315           code << "  local CeedScalar s_G_in_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
316           code << "  loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_in_" + std::to_string(i))) << "*Q_1D, G->inputs[" << i << "], s_G_in_" << i
317                << ");\n";
318         }
319         break;
320       case CEED_EVAL_WEIGHT:
321         break;  // No action
322       case CEED_EVAL_DIV:
323         break;  // TODO: Not implemented
324       case CEED_EVAL_CURL:
325         break;  // TODO: Not implemented
326     }
327   }
328 
329   code << "\n  // -- Output field constants and basis data --\n";
330   for (CeedInt i = 0; i < num_output_fields; i++) {
331     code << "  // ---- Output field " << i << " ----\n";
332     // Get elem_size, eval_mode, num_comp
333     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
334     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
335     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
336     CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
337 
338     // Set field constants
339     CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
340     if (basis != CEED_BASIS_NONE) {
341       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
342       code << "  const CeedInt P_out_" << i << " = " << P_1d << ";\n";
343     } else {
344       code << "  const CeedInt P_out_" << i << " = " << Q_1d << ";\n";
345     }
346     code << "  const CeedInt num_comp_out_" << i << " = " << num_comp << ";\n";
347 
348     // Load basis data
349     code << "  // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
350     switch (eval_mode) {
351       case CEED_EVAL_NONE:
352         break;  // No action
353       case CEED_EVAL_INTERP:
354         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
355         h_B.outputs[i] = basis_impl->d_interp_1d;
356         code << "  local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
357         code << "  loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
358         break;
359       case CEED_EVAL_GRAD:
360         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
361         h_B.outputs[i] = basis_impl->d_interp_1d;
362         code << "  local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
363         code << "  loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
364         if (use_collograd_parallelization) {
365           h_G.outputs[i] = basis_impl->d_collo_grad_1d;
366           code << "  local CeedScalar s_G_out_" << i << "[" << Q_1d * Q_1d << "];\n";
367           code << "  loadMatrix(Q_1D*Q_1D, G->outputs[" << i << "], s_G_out_" << i << ");\n";
368         } else {
369           bool has_collo_grad = basis_impl->d_collo_grad_1d;
370           h_G.outputs[i]      = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
371           code << "  local CeedScalar s_G_out_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
372           code << "  loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_out_" + std::to_string(i))) << "*Q_1D, G->outputs[" << i << "], s_G_out_" << i
373                << ");\n";
374         }
375         break;
376       // LCOV_EXCL_START
377       case CEED_EVAL_WEIGHT: {
378         Ceed ceed;
379         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
380         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
381         break;  // Should not occur
382       }
383       case CEED_EVAL_DIV:
384         break;  // TODO: Not implemented
385       case CEED_EVAL_CURL:
386         break;  // TODO: Not implemented
387                 // LCOV_EXCL_STOP
388     }
389   }
390   code << "\n  // -- Element loop --\n";
391   code << "  work_group_barrier(CLK_LOCAL_MEM_FENCE);\n";
392   code << "  {\n";
393   // Input basis apply if needed
394   // Generate the correct eval mode code for each input
395   code << "    // -- Input field restrictions and basis actions --\n";
396   for (CeedInt i = 0; i < num_input_fields; i++) {
397     code << "    // ---- Input field " << i << " ----\n";
398     // Get elem_size, eval_mode, num_comp
399     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
400     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
401     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
402     CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
403 
404     // Restriction
405     if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
406       bool is_strided;
407 
408       code << "    CeedScalar r_u_" << i << "[num_comp_in_" << i << "*P_in_" << i << "];\n";
409 
410       CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
411       if (!is_strided) {
412         CeedInt comp_stride;
413 
414         CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
415         code << "    const CeedInt l_size_in_" << i << " = " << l_size << ";\n";
416         CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
417         code << "    // CompStride: " << comp_stride << "\n";
418         CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
419         h_indices.inputs[i] = rstr_impl->d_ind;
420         code << "    readDofsOffset" << dim << "d(num_comp_in_" << i << ", " << comp_stride << ", P_in_" << i << ", num_elem, indices->inputs[" << i
421              << "], d_u_" << i << ", r_u_" << i << ");\n";
422       } else {
423         bool    has_backend_strides;
424         CeedInt num_elem;
425 
426         CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
427         CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
428         CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
429 
430         if (!has_backend_strides) {
431           CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
432         }
433         code << "    // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
434         code << "    readDofsStrided" << dim << "d(num_comp_in_" << i << ",P_in_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
435              << ", num_elem, d_u_" << i << ", r_u_" << i << ");\n";
436       }
437     }
438 
439     // Basis action
440     code << "    // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
441     switch (eval_mode) {
442       case CEED_EVAL_NONE:
443         if (!use_collograd_parallelization) {
444           code << "    private CeedScalar* r_t_" << i << " = r_u_" << i << ";\n";
445         }
446         break;
447       case CEED_EVAL_INTERP:
448         code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
449         code << "    Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_" << i
450              << ", r_t_" << i << ", elem_scratch);\n";
451         break;
452       case CEED_EVAL_GRAD:
453         if (use_collograd_parallelization) {
454           code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
455           code << "    Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_"
456                << i << ", r_t_" << i << ", elem_scratch);\n";
457         } else {
458           CeedInt P_1d;
459           CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
460           CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
461           code << "    CeedScalar r_t_" << i << "[num_comp_in_" << i << "*DIM*Q_1D];\n";
462           code << "    Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_in_" << i
463                << ", P_in_" << i << ", Q_1D, r_u_" << i << (dim > 1 ? ", s_B_in_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_in_" << i
464                << ", r_t_" << i << ", elem_scratch);\n";
465         }
466         break;
467       case CEED_EVAL_WEIGHT:
468         code << "    CeedScalar r_t_" << i << "[Q_1D];\n";
469         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
470         CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
471         impl->W = basis_impl->d_q_weight_1d;
472         code << "    Weight" << (dim > 1 ? "Tensor" : "") << dim << "d(Q_1D, W, r_t_" << i << ");\n";
473         break;  // No action
474       case CEED_EVAL_DIV:
475         break;  // TODO: Not implemented
476       case CEED_EVAL_CURL:
477         break;  // TODO: Not implemented
478     }
479   }
480 
481   // Q function
482   code << "\n    // -- Output field setup --\n";
483   for (CeedInt i = 0; i < num_output_fields; i++) {
484     code << "\n    // ---- Output field " << i << " ----\n";
485     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
486     if (eval_mode == CEED_EVAL_GRAD) {
487       if (use_collograd_parallelization) {
488         // Accumulator for gradient slices
489         code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
490         code << "    for (CeedInt i = 0; i < num_comp_out_" << i << "; i++) {\n";
491         code << "      for (CeedInt j = 0; j < Q_1D; ++j) {\n";
492         code << "        r_tt_" << i << "[j + i*Q_1D] = 0.0;\n";
493         code << "      }\n";
494         code << "    }\n";
495       } else {
496         code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*DIM*Q_1D];\n";
497       }
498     }
499     if (eval_mode == CEED_EVAL_NONE || eval_mode == CEED_EVAL_INTERP) {
500       code << "    CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
501     }
502   }
503   // We treat quadrature points per slice in 3d to save registers
504   if (use_collograd_parallelization) {
505     code << "\n    // Note: Using planes of 3D elements\n";
506     code << "    for (CeedInt q = 0; q < Q_1D; q++) {\n";
507     code << "      // -- Input fields --\n";
508     for (CeedInt i = 0; i < num_input_fields; i++) {
509       code << "      // ---- Input field " << i << " ----\n";
510       // Get elem_size, eval_mode, num_comp
511       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
512       // Basis action
513       code << "      // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
514       switch (eval_mode) {
515         case CEED_EVAL_NONE:
516           bool is_strided;
517 
518           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
519 
520           CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
521           CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
522           if (!is_strided) {
523             CeedInt comp_stride;
524 
525             CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
526             code << "      const CeedInt l_size_in_" << i << " = " << l_size << ";\n";
527             CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
528             code << "      // CompStride: " << comp_stride << "\n";
529             CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
530             h_indices.inputs[i] = rstr_impl->d_ind;
531             code << "      readSliceQuadsOffset"
532                  << "3d(num_comp_in_" << i << ", " << comp_stride << ", Q_1D, l_size_in_" << i << ", num_elem, q, indices->inputs[" << i << "], d_u_"
533                  << i << ", r_q_" << i << ");\n";
534           } else {
535             bool    has_backend_strides;
536             CeedInt num_elem;
537 
538             CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
539             CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
540             CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
541             CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
542 
543             if (!has_backend_strides) {
544               CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
545             }
546             code << "      // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
547             code << "      readSliceQuadsStrided"
548                  << "3d(num_comp_in_" << i << ", Q_1D," << strides[0] << ", " << strides[1] << ", " << strides[2] << ", num_elem, q, d_u_" << i
549                  << ", r_q_" << i << ");\n";
550           }
551           break;
552         case CEED_EVAL_INTERP:
553           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
554           code << "      for (CeedInt j = 0; j < num_comp_in_" << i << " ; ++j) {\n";
555           code << "        r_q_" << i << "[j] = r_t_" << i << "[q + j*Q_1D];\n";
556           code << "      }\n";
557           break;
558         case CEED_EVAL_GRAD:
559           code << "      CeedScalar r_q_" << i << "[num_comp_in_" << i << "*DIM];\n";
560           code << "      gradCollo3d(num_comp_in_" << i << ", Q_1D, q, r_t_" << i << ", s_G_in_" << i << ", r_q_" << i << ", elem_scratch);\n";
561           break;
562         case CEED_EVAL_WEIGHT:
563           code << "      CeedScalar r_q_" << i << "[1];\n";
564           code << "      r_q_" << i << "[0] = r_t_" << i << "[q];\n";
565           break;  // No action
566         case CEED_EVAL_DIV:
567           break;  // TODO: Not implemented
568         case CEED_EVAL_CURL:
569           break;  // TODO: Not implemented
570       }
571     }
572     code << "\n      // -- Output fields --\n";
573     for (CeedInt i = 0; i < num_output_fields; i++) {
574       code << "      // ---- Output field " << i << " ----\n";
575       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
576       // Basis action
577       switch (eval_mode) {
578         case CEED_EVAL_NONE:
579           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
580           break;  // No action
581         case CEED_EVAL_INTERP:
582           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
583           break;
584         case CEED_EVAL_GRAD:
585           code << "      CeedScalar r_qq_" << i << "[num_comp_out_" << i << "*DIM];\n";
586           break;
587         case CEED_EVAL_WEIGHT:
588           break;  // Should not occur
589         case CEED_EVAL_DIV:
590           break;  // TODO: Not implemented
591         case CEED_EVAL_CURL:
592           break;  // TODO: Not implemented
593       }
594     }
595   } else {
596     code << "\n      // Note: Using full elements\n";
597     code << "      // -- Input fields --\n";
598     for (CeedInt i = 0; i < num_input_fields; i++) {
599       code << "      // ---- Input field " << i << " ----\n";
600       code << "      private CeedScalar* r_q_" << i << " = r_t_" << i << ";\n";
601     }
602     code << "      // -- Output fields --\n";
603     for (CeedInt i = 0; i < num_output_fields; i++) {
604       code << "      // ---- Output field " << i << " ----\n";
605       code << "      private CeedScalar* r_qq_" << i << " = r_tt_" << i << ";\n";
606     }
607   }
608   //--------------------------------------------------
609   code << "\n      // -- QFunction Inputs and outputs --\n";
610   code << "      const CeedScalar * in[" << num_input_fields << "];\n";
611   for (CeedInt i = 0; i < num_input_fields; i++) {
612     code << "      // ---- Input field " << i << " ----\n";
613     code << "      in[" << i << "] = r_q_" << i << ";\n";
614   }
615   code << "      CeedScalar * out[" << num_output_fields << "];\n";
616   for (CeedInt i = 0; i < num_output_fields; i++) {
617     code << "      // ---- Output field " << i << " ----\n";
618     code << "      out[" << i << "] = r_qq_" << i << ";\n";
619   }
620 
621   code << "\n      // -- Apply QFunction --\n";
622   code << "      " << qfunction_name << "(ctx, ";
623   if (dim != 3 || use_collograd_parallelization) {
624     code << "1";
625   } else {
626     code << "Q_1D";
627   }
628   code << ", in, out);\n";
629   //--------------------------------------------------
630 
631   if (use_collograd_parallelization) {
632     code << "      // -- Output fields --\n";
633     for (CeedInt i = 0; i < num_output_fields; i++) {
634       code << "      // ---- Output field " << i << " ----\n";
635       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
636       // Basis action
637       code << "      // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
638       switch (eval_mode) {
639         case CEED_EVAL_NONE:
640           code << "      for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
641           code << "        r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
642           code << "      }\n";
643           break;  // No action
644         case CEED_EVAL_INTERP:
645           code << "      for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
646           code << "        r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
647           code << "      }\n";
648           break;
649         case CEED_EVAL_GRAD:
650           code << "      gradColloTranspose3d(num_comp_out_" << i << ",Q_1D, q, r_qq_" << i << ", s_G_out_" << i << ", r_tt_" << i
651                << ", elem_scratch);\n";
652           break;
653         case CEED_EVAL_WEIGHT:
654           break;  // Should not occur
655         case CEED_EVAL_DIV:
656           break;  // TODO: Not implemented
657         case CEED_EVAL_CURL:
658           break;  // TODO: Not implemented
659       }
660     }
661     code << "    }\n";
662   }
663 
664   // Output basis apply if needed
665   // Generate the correct eval mode code for each output
666   code << "\n    // -- Output field basis action and restrictions --\n";
667   for (CeedInt i = 0; i < num_output_fields; i++) {
668     code << "    // ---- Output field " << i << " ----\n";
669     // Get elem_size, eval_mode, num_comp
670     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
671     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
672     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
673     CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
674     // Basis action
675     code << "    // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
676     switch (eval_mode) {
677       case CEED_EVAL_NONE:
678         code << "    private CeedScalar* r_v_" << i << " = r_tt_" << i << ";\n";
679         break;  // No action
680       case CEED_EVAL_INTERP:
681         code << "    CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
682         code << "    InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
683              << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
684         break;
685       case CEED_EVAL_GRAD:
686         code << "    CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
687         if (use_collograd_parallelization) {
688           code << "    InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
689                << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
690         } else {
691           CeedInt P_1d;
692           CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
693           CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
694           code << "    GradTranspose" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_out_" << i
695                << ", P_out_" << i << ", Q_1D, r_tt_" << i << (dim > 1 ? ", s_B_out_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_out_" << i
696                << ", r_v_" << i << ", elem_scratch);\n";
697         }
698         break;
699       // LCOV_EXCL_START
700       case CEED_EVAL_WEIGHT: {
701         Ceed ceed;
702         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
703         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
704         break;  // Should not occur
705       }
706       case CEED_EVAL_DIV:
707         break;  // TODO: Not implemented
708       case CEED_EVAL_CURL:
709         break;  // TODO: Not implemented
710                 // LCOV_EXCL_STOP
711     }
712     // Restriction
713     bool is_strided;
714 
715     CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
716     if (!is_strided) {
717       CeedInt comp_stride;
718 
719       CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
720       code << "    const CeedInt l_size_out_" << i << " = " << l_size << ";\n";
721       CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
722       code << "    // CompStride: " << comp_stride << "\n";
723       CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
724       h_indices.outputs[i] = rstr_impl->d_ind;
725       code << "    writeDofsOffset" << dim << "d(num_comp_out_" << i << ", " << comp_stride << ", P_out_" << i << ", num_elem, indices->outputs[" << i
726            << "], r_v_" << i << ", d_v_" << i << ");\n";
727     } else {
728       bool    has_backend_strides;
729       CeedInt num_elem;
730 
731       CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
732       CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
733       CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
734 
735       if (!has_backend_strides) {
736         CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
737       }
738       code << "    // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
739       code << "    writeDofsStrided" << dim << "d(num_comp_out_" << i << ",P_out_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
740            << ", num_elem, r_v_" << i << ", d_v_" << i << ");\n";
741     }
742   }
743 
744   code << "  }\n";
745   code << "}\n";
746   code << "// -----------------------------------------------------------------------------\n\n";
747 
748   // Copy the struct (containing device addresses) from the host to the device
749   sycl::event copy_B       = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_B, impl->B, 1);
750   sycl::event copy_G       = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_G, impl->G, 1);
751   sycl::event copy_indices = sycl_data->sycl_queue.copy<FieldsInt_Sycl>(&h_indices, impl->indices, 1);
752   // These copies can happen while the JIT is being done
753   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_B, copy_G, copy_indices}));
754 
755   // View kernel for debugging
756   CeedDebug256(ceed, 2, "Generated Operator Kernels:\n");
757   CeedDebug(ceed, code.str().c_str());
758 
759   std::map<std::string, CeedInt> jit_constants;
760   jit_constants["T_1D"]         = block_sizes[0];
761   jit_constants["GROUP_SIZE_X"] = block_sizes[0];
762   jit_constants["GROUP_SIZE_Y"] = block_sizes[1];
763   jit_constants["GROUP_SIZE_Z"] = block_sizes[2];
764 
765   // Compile kernel into a kernel bundle
766   CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module, jit_constants));
767 
768   // Load kernel function
769   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, operator_name, &impl->op));
770 
771   CeedCallBackend(CeedOperatorSetSetupDone(op));
772   return CEED_ERROR_SUCCESS;
773 }
774 
775 //------------------------------------------------------------------------------
776