1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
26ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
36ca0f394SUmesh Unnikrishnan //
46ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
56ca0f394SUmesh Unnikrishnan //
66ca0f394SUmesh Unnikrishnan // This file is part of CEED: http://github.com/ceed
76ca0f394SUmesh Unnikrishnan
86ca0f394SUmesh Unnikrishnan #define CEED_DEBUG_COLOR 12
96ca0f394SUmesh Unnikrishnan
106ca0f394SUmesh Unnikrishnan #include <ceed/backend.h>
116ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h>
126ca0f394SUmesh Unnikrishnan #include <ceed/jit-source/sycl/sycl-types.h>
136ca0f394SUmesh Unnikrishnan #include <ceed/jit-tools.h>
146ca0f394SUmesh Unnikrishnan
156ca0f394SUmesh Unnikrishnan #include <iostream>
166ca0f394SUmesh Unnikrishnan #include <sstream>
176ca0f394SUmesh Unnikrishnan #include <string>
186ca0f394SUmesh Unnikrishnan #include <string_view>
196ca0f394SUmesh Unnikrishnan #include <vector>
206ca0f394SUmesh Unnikrishnan
216ca0f394SUmesh Unnikrishnan #include "../sycl-ref/ceed-sycl-ref.hpp"
226ca0f394SUmesh Unnikrishnan #include "../sycl-shared/ceed-sycl-shared.hpp"
236ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp"
246ca0f394SUmesh Unnikrishnan
256ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp"
266ca0f394SUmesh Unnikrishnan
276ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
286ca0f394SUmesh Unnikrishnan // Calculate the block size used for launching the operator kernel
296ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
BlockGridCalculate_Sycl_gen(const CeedInt dim,const CeedInt P_1d,const CeedInt Q_1d,CeedInt * block_sizes)306ca0f394SUmesh Unnikrishnan extern "C" int BlockGridCalculate_Sycl_gen(const CeedInt dim, const CeedInt P_1d, const CeedInt Q_1d, CeedInt *block_sizes) {
316ca0f394SUmesh Unnikrishnan const CeedInt thread1d = CeedIntMax(Q_1d, P_1d);
32dd64fc84SJeremy L Thompson
336ca0f394SUmesh Unnikrishnan if (dim == 1) {
346ca0f394SUmesh Unnikrishnan CeedInt elems_per_block = 64 * thread1d > 256 ? 256 / thread1d : 64;
35dd64fc84SJeremy L Thompson
366ca0f394SUmesh Unnikrishnan elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
376ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d;
386ca0f394SUmesh Unnikrishnan block_sizes[1] = 1;
396ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block;
406ca0f394SUmesh Unnikrishnan } else if (dim == 2) {
416ca0f394SUmesh Unnikrishnan const CeedInt elems_per_block = thread1d < 4 ? 16 : 2;
42dd64fc84SJeremy L Thompson
436ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d;
446ca0f394SUmesh Unnikrishnan block_sizes[1] = thread1d;
456ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block;
466ca0f394SUmesh Unnikrishnan } else if (dim == 3) {
476ca0f394SUmesh Unnikrishnan const CeedInt elems_per_block = thread1d < 6 ? 4 : (thread1d < 8 ? 2 : 1);
48dd64fc84SJeremy L Thompson
496ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d;
506ca0f394SUmesh Unnikrishnan block_sizes[1] = thread1d;
516ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block;
526ca0f394SUmesh Unnikrishnan }
536ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
546ca0f394SUmesh Unnikrishnan }
556ca0f394SUmesh Unnikrishnan
566ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
576ca0f394SUmesh Unnikrishnan // Build single operator kernel
586ca0f394SUmesh Unnikrishnan // - [ ] Check arguments to device functions reudsed from sycl-shared-basis are correct
596ca0f394SUmesh Unnikrishnan // - [ ] Do kernel jitting!
606ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
CeedOperatorBuildKernel_Sycl_gen(CeedOperator op)616ca0f394SUmesh Unnikrishnan extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
62dd64fc84SJeremy L Thompson Ceed ceed;
63dd64fc84SJeremy L Thompson Ceed_Sycl *sycl_data;
64dd64fc84SJeremy L Thompson bool is_setup_done, is_identity_qf;
65dd64fc84SJeremy L Thompson CeedSize l_size;
66dd64fc84SJeremy L Thompson CeedInt Q, P_1d = 0, Q_1d = 0, elem_size, num_input_fields, num_output_fields, num_comp, dim = 1;
67dd64fc84SJeremy L Thompson Fields_Sycl h_B, h_G;
68dd64fc84SJeremy L Thompson FieldsInt_Sycl h_indices;
69dd64fc84SJeremy L Thompson CeedEvalMode eval_mode;
70dd64fc84SJeremy L Thompson CeedElemRestriction elem_rstr;
71edb2538eSJeremy L Thompson CeedElemRestriction_Sycl *rstr_impl;
72dd64fc84SJeremy L Thompson CeedBasis basis;
73dd64fc84SJeremy L Thompson CeedBasis_Sycl_shared *basis_impl;
74dd64fc84SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields;
75dd64fc84SJeremy L Thompson CeedQFunction_Sycl_gen *qf_impl;
76dd64fc84SJeremy L Thompson CeedQFunction qf;
77dd64fc84SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields;
78dd64fc84SJeremy L Thompson CeedOperator_Sycl_gen *impl;
79dd64fc84SJeremy L Thompson
806ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
816ca0f394SUmesh Unnikrishnan if (is_setup_done) return CEED_ERROR_SUCCESS;
826ca0f394SUmesh Unnikrishnan
836ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
846ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &sycl_data));
856ca0f394SUmesh Unnikrishnan
866ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetData(op, &impl));
876ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
886ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
896ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
906ca0f394SUmesh Unnikrishnan Q_1d = Q;
916ca0f394SUmesh Unnikrishnan
926ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
936ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
946ca0f394SUmesh Unnikrishnan
956ca0f394SUmesh Unnikrishnan // Check for restriction only identity operator
966ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionIsIdentity(qf, &is_identity_qf));
976ca0f394SUmesh Unnikrishnan if (is_identity_qf) {
986ca0f394SUmesh Unnikrishnan CeedEvalMode eval_mode_in, eval_mode_out;
99dd64fc84SJeremy L Thompson
1006ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[0], &eval_mode_in));
1016ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[0], &eval_mode_out));
1024e3038a5SJeremy L Thompson CeedCheck(eval_mode_in != CEED_EVAL_NONE || eval_mode_out != CEED_EVAL_NONE, ceed, CEED_ERROR_BACKEND,
1034e3038a5SJeremy L Thompson "Backend does not implement restriction only identity operators");
1046ca0f394SUmesh Unnikrishnan }
1056ca0f394SUmesh Unnikrishnan
1066ca0f394SUmesh Unnikrishnan std::ostringstream code;
1076ca0f394SUmesh Unnikrishnan // TODO: generalize to accept different device functions?
1086ca0f394SUmesh Unnikrishnan {
10922070f95SJeremy L Thompson char *tensor_basis_code;
11022070f95SJeremy L Thompson const char *tensor_basis_kernel_path;
111dd64fc84SJeremy L Thompson
1126ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h", &tensor_basis_kernel_path));
1136ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "----- Loading Tensor Basis Kernel Source -----\n");
1146ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedLoadSourceToBuffer(ceed, tensor_basis_kernel_path, &tensor_basis_code));
1156ca0f394SUmesh Unnikrishnan code << tensor_basis_code;
1166ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&tensor_basis_kernel_path));
1176ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&tensor_basis_code));
1186ca0f394SUmesh Unnikrishnan }
1196ca0f394SUmesh Unnikrishnan {
12022070f95SJeremy L Thompson char *sycl_gen_template_source;
12122070f95SJeremy L Thompson const char *sycl_gen_template_path;
122dd64fc84SJeremy L Thompson
1236ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-gen-templates.h", &sycl_gen_template_path));
1246ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "----- Loading Sycl-Gen Template Source -----\n");
1256ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedLoadSourceToBuffer(ceed, sycl_gen_template_path, &sycl_gen_template_source));
1266ca0f394SUmesh Unnikrishnan code << sycl_gen_template_source;
1276ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&sycl_gen_template_path));
1286ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&sycl_gen_template_source));
1296ca0f394SUmesh Unnikrishnan }
1306ca0f394SUmesh Unnikrishnan
13109095acaSJeremy L Thompson std::string_view qfunction_source(qf_impl->qfunction_source);
13209095acaSJeremy L Thompson std::string_view qfunction_name(qf_impl->qfunction_name);
13309095acaSJeremy L Thompson const std::string operator_name = "CeedKernelSyclGenOperator_" + std::string(qfunction_name);
1346ca0f394SUmesh Unnikrishnan
1356ca0f394SUmesh Unnikrishnan // Find dim, P_1d, Q_1d
1366ca0f394SUmesh Unnikrishnan impl->max_P_1d = 0;
1376ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
1386ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
139356036faSJeremy L Thompson if (basis != CEED_BASIS_NONE) {
140dd64fc84SJeremy L Thompson bool is_tensor;
141dd64fc84SJeremy L Thompson
1426ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
1436ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1446ca0f394SUmesh Unnikrishnan
1456ca0f394SUmesh Unnikrishnan // Collect dim, P_1d, and Q_1d
1466ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetDimension(basis, &dim));
147dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
148dd64fc84SJeremy L Thompson if (is_tensor) {
1496ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
1506ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
1516ca0f394SUmesh Unnikrishnan if (P_1d > impl->max_P_1d) impl->max_P_1d = P_1d;
1526ca0f394SUmesh Unnikrishnan } else {
1536ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START
1546ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
1556ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP
1566ca0f394SUmesh Unnikrishnan }
1576ca0f394SUmesh Unnikrishnan }
158681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
1596ca0f394SUmesh Unnikrishnan }
1606ca0f394SUmesh Unnikrishnan // Check output bases for Q_1d, dim as well
161356036faSJeremy L Thompson // The only input basis might be CEED_BASIS_NONE
1626ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
1636ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
164356036faSJeremy L Thompson if (basis != CEED_BASIS_NONE) {
165dd64fc84SJeremy L Thompson bool is_tensor;
166dd64fc84SJeremy L Thompson
1676ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
1686ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
1696ca0f394SUmesh Unnikrishnan
1706ca0f394SUmesh Unnikrishnan // Collect Q_1d
1716ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetDimension(basis, &dim));
172dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
173dd64fc84SJeremy L Thompson if (is_tensor) {
1746ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
1756ca0f394SUmesh Unnikrishnan } else {
1766ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START
1776ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis");
1786ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP
1796ca0f394SUmesh Unnikrishnan }
1806ca0f394SUmesh Unnikrishnan }
181681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
1826ca0f394SUmesh Unnikrishnan }
1836ca0f394SUmesh Unnikrishnan impl->dim = dim;
1846ca0f394SUmesh Unnikrishnan impl->Q_1d = Q_1d;
1856ca0f394SUmesh Unnikrishnan
1866ca0f394SUmesh Unnikrishnan // Only use 3D collocated gradient parallelization strategy when gradient is computed
1876ca0f394SUmesh Unnikrishnan // TODO: put in a function?
1886ca0f394SUmesh Unnikrishnan bool use_collograd_parallelization = false;
189dd64fc84SJeremy L Thompson
1906ca0f394SUmesh Unnikrishnan if (dim == 3) {
1916ca0f394SUmesh Unnikrishnan bool was_grad_found = false;
192dd64fc84SJeremy L Thompson
1936ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
1946ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1956ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) {
1966ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
1976ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
1981c66c397SJeremy L Thompson use_collograd_parallelization = basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
1996ca0f394SUmesh Unnikrishnan was_grad_found = true;
200681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
2016ca0f394SUmesh Unnikrishnan }
2026ca0f394SUmesh Unnikrishnan }
2036ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
2046ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
2056ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) {
2066ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
2076ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
2081c66c397SJeremy L Thompson use_collograd_parallelization = basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true);
2096ca0f394SUmesh Unnikrishnan was_grad_found = true;
210681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
2116ca0f394SUmesh Unnikrishnan }
2126ca0f394SUmesh Unnikrishnan }
2136ca0f394SUmesh Unnikrishnan }
2146ca0f394SUmesh Unnikrishnan
2156ca0f394SUmesh Unnikrishnan CeedInt block_sizes[3];
2166ca0f394SUmesh Unnikrishnan CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
2176ca0f394SUmesh Unnikrishnan
2186ca0f394SUmesh Unnikrishnan // Define CEED_Q_VLA
2196ca0f394SUmesh Unnikrishnan code << "\n#undef CEED_Q_VLA\n";
2206ca0f394SUmesh Unnikrishnan if (dim != 3 || use_collograd_parallelization) {
2216ca0f394SUmesh Unnikrishnan code << "#define CEED_Q_VLA 1\n\n";
2226ca0f394SUmesh Unnikrishnan } else {
2236ca0f394SUmesh Unnikrishnan code << "#define CEED_Q_VLA " << Q_1d << "\n\n";
2246ca0f394SUmesh Unnikrishnan }
2256ca0f394SUmesh Unnikrishnan
2266ca0f394SUmesh Unnikrishnan // Determine subgroup size based on supported sizes : Default : 16 (if supported)
2276ca0f394SUmesh Unnikrishnan std::vector allowed_sg_sizes = sycl_data->sycl_device.get_info<sycl::info::device::sub_group_sizes>();
2286ca0f394SUmesh Unnikrishnan CeedInt sub_group_size_op = allowed_sg_sizes[allowed_sg_sizes.size() - 1];
2296ca0f394SUmesh Unnikrishnan for (const auto &s : allowed_sg_sizes) {
2306ca0f394SUmesh Unnikrishnan if (s == 16) {
2316ca0f394SUmesh Unnikrishnan sub_group_size_op = s;
2326ca0f394SUmesh Unnikrishnan break;
2336ca0f394SUmesh Unnikrishnan }
2346ca0f394SUmesh Unnikrishnan }
2356ca0f394SUmesh Unnikrishnan
23609095acaSJeremy L Thompson code << qfunction_source;
2376ca0f394SUmesh Unnikrishnan
2386ca0f394SUmesh Unnikrishnan // Kernel function
2396ca0f394SUmesh Unnikrishnan code << "\n// -----------------------------------------------------------------------------\n";
2406ca0f394SUmesh Unnikrishnan code << "__attribute__((reqd_work_group_size(GROUP_SIZE_X, GROUP_SIZE_Y, GROUP_SIZE_Z), intel_reqd_sub_group_size(" << sub_group_size_op << ")))\n";
2416ca0f394SUmesh Unnikrishnan code << "kernel void " << operator_name << "(";
2426ca0f394SUmesh Unnikrishnan code << "const CeedInt num_elem, ";
2436ca0f394SUmesh Unnikrishnan code << "global void* ctx, ";
2446ca0f394SUmesh Unnikrishnan code << "global const FieldsInt_Sycl* indices, ";
2456ca0f394SUmesh Unnikrishnan code << "global Fields_Sycl* fields, ";
2466ca0f394SUmesh Unnikrishnan code << "global const Fields_Sycl* B, ";
2476ca0f394SUmesh Unnikrishnan code << "global const Fields_Sycl* G, ";
2486ca0f394SUmesh Unnikrishnan code << "global const CeedScalar * restrict W";
2496ca0f394SUmesh Unnikrishnan code << ") {\n";
2506ca0f394SUmesh Unnikrishnan
2516ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
2526ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
2536ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT) { // Skip CEED_EVAL_WEIGHT
2546ca0f394SUmesh Unnikrishnan code << " global const CeedScalar* d_u_" << i << " = fields->inputs[" << i << "];\n";
2556ca0f394SUmesh Unnikrishnan }
2566ca0f394SUmesh Unnikrishnan }
2576ca0f394SUmesh Unnikrishnan
2586ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
2596ca0f394SUmesh Unnikrishnan code << " global CeedScalar* d_v_" << i << " = fields->outputs[" << i << "];\n";
2606ca0f394SUmesh Unnikrishnan }
2616ca0f394SUmesh Unnikrishnan
2626ca0f394SUmesh Unnikrishnan // TODO: Convert these to defined constants to save on GRF
2636ca0f394SUmesh Unnikrishnan code << " const CeedInt DIM = " << dim << ";\n";
2646ca0f394SUmesh Unnikrishnan code << " const CeedInt Q_1D = " << Q_1d << ";\n";
2656ca0f394SUmesh Unnikrishnan
2666ca0f394SUmesh Unnikrishnan const CeedInt scratch_size = block_sizes[0] * block_sizes[1] * block_sizes[2];
2676ca0f394SUmesh Unnikrishnan code << " local CeedScalar scratch[" << scratch_size << "];\n";
2686ca0f394SUmesh Unnikrishnan code << " local CeedScalar * elem_scratch = scratch + get_local_id(2) * T_1D" << (dim > 1 ? "*T_1D" : "") << ";\n";
2696ca0f394SUmesh Unnikrishnan
2706ca0f394SUmesh Unnikrishnan code << "\n // -- Input field constants and basis data --\n";
2716ca0f394SUmesh Unnikrishnan // Initialize constants, and matrices B and G
2726ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
2736ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n";
2746ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp
275dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
276dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
277dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
278681d0ea7SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
2796782e2f8SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
2806ca0f394SUmesh Unnikrishnan
2816ca0f394SUmesh Unnikrishnan // Set field constants
2826ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT) {
2836ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
284356036faSJeremy L Thompson if (basis != CEED_BASIS_NONE) {
2856ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
2866ca0f394SUmesh Unnikrishnan code << " const CeedInt P_in_" << i << " = " << P_1d << ";\n";
2876ca0f394SUmesh Unnikrishnan } else {
2886ca0f394SUmesh Unnikrishnan code << " const CeedInt P_in_" << i << " = " << Q_1d << ";\n";
2896ca0f394SUmesh Unnikrishnan }
2906ca0f394SUmesh Unnikrishnan code << " const CeedInt num_comp_in_" << i << " = " << num_comp << ";\n";
2916ca0f394SUmesh Unnikrishnan }
2926ca0f394SUmesh Unnikrishnan
2936ca0f394SUmesh Unnikrishnan // Load basis data
2946ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
2956ca0f394SUmesh Unnikrishnan switch (eval_mode) {
2966ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
2976ca0f394SUmesh Unnikrishnan break;
2986ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
2996ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
3006ca0f394SUmesh Unnikrishnan h_B.inputs[i] = basis_impl->d_interp_1d;
3016ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
3026ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
3036ca0f394SUmesh Unnikrishnan break;
3046ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
3056ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
3066ca0f394SUmesh Unnikrishnan h_B.inputs[i] = basis_impl->d_interp_1d;
3076ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n";
3086ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n";
3096ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
3106ca0f394SUmesh Unnikrishnan h_G.inputs[i] = basis_impl->d_collo_grad_1d;
3116ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_in_" << i << "[" << Q_1d * Q_1d << "];\n";
3126ca0f394SUmesh Unnikrishnan code << " loadMatrix(Q_1D*Q_1D, G->inputs[" << i << "], s_G_in_" << i << ");\n";
3136ca0f394SUmesh Unnikrishnan } else {
3141c66c397SJeremy L Thompson bool has_collo_grad = basis_impl->d_collo_grad_1d;
3156ca0f394SUmesh Unnikrishnan h_G.inputs[i] = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
3166ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_in_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
3176ca0f394SUmesh Unnikrishnan code << " loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_in_" + std::to_string(i))) << "*Q_1D, G->inputs[" << i << "], s_G_in_" << i
3186ca0f394SUmesh Unnikrishnan << ");\n";
3196ca0f394SUmesh Unnikrishnan }
3206ca0f394SUmesh Unnikrishnan break;
3216ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT:
3226ca0f394SUmesh Unnikrishnan break; // No action
3236ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
3246ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
3256ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL:
3266ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
3276ca0f394SUmesh Unnikrishnan }
328681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
3296ca0f394SUmesh Unnikrishnan }
3306ca0f394SUmesh Unnikrishnan
3316ca0f394SUmesh Unnikrishnan code << "\n // -- Output field constants and basis data --\n";
3326ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
3336ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
3346ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp
335dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
336dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
337dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
338681d0ea7SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
3396782e2f8SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
3406ca0f394SUmesh Unnikrishnan
3416ca0f394SUmesh Unnikrishnan // Set field constants
3426ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
343356036faSJeremy L Thompson if (basis != CEED_BASIS_NONE) {
3446ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
3456ca0f394SUmesh Unnikrishnan code << " const CeedInt P_out_" << i << " = " << P_1d << ";\n";
3466ca0f394SUmesh Unnikrishnan } else {
3476ca0f394SUmesh Unnikrishnan code << " const CeedInt P_out_" << i << " = " << Q_1d << ";\n";
3486ca0f394SUmesh Unnikrishnan }
3496ca0f394SUmesh Unnikrishnan code << " const CeedInt num_comp_out_" << i << " = " << num_comp << ";\n";
3506ca0f394SUmesh Unnikrishnan
3516ca0f394SUmesh Unnikrishnan // Load basis data
3526ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
3536ca0f394SUmesh Unnikrishnan switch (eval_mode) {
3546ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
3556ca0f394SUmesh Unnikrishnan break; // No action
3566ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
3576ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
3586ca0f394SUmesh Unnikrishnan h_B.outputs[i] = basis_impl->d_interp_1d;
3596ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
3606ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
3616ca0f394SUmesh Unnikrishnan break;
3626ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
3636ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
3646ca0f394SUmesh Unnikrishnan h_B.outputs[i] = basis_impl->d_interp_1d;
3656ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n";
3666ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n";
3676ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
3686ca0f394SUmesh Unnikrishnan h_G.outputs[i] = basis_impl->d_collo_grad_1d;
3696ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_out_" << i << "[" << Q_1d * Q_1d << "];\n";
3706ca0f394SUmesh Unnikrishnan code << " loadMatrix(Q_1D*Q_1D, G->outputs[" << i << "], s_G_out_" << i << ");\n";
3716ca0f394SUmesh Unnikrishnan } else {
3721c66c397SJeremy L Thompson bool has_collo_grad = basis_impl->d_collo_grad_1d;
3736ca0f394SUmesh Unnikrishnan h_G.outputs[i] = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d;
3746ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_out_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n";
3756ca0f394SUmesh Unnikrishnan code << " loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_out_" + std::to_string(i))) << "*Q_1D, G->outputs[" << i << "], s_G_out_" << i
3766ca0f394SUmesh Unnikrishnan << ");\n";
3776ca0f394SUmesh Unnikrishnan }
3786ca0f394SUmesh Unnikrishnan break;
3796ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START
3806ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: {
3816e536b99SJeremy L Thompson return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
3826ca0f394SUmesh Unnikrishnan break; // Should not occur
3836ca0f394SUmesh Unnikrishnan }
3846ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
3854e3038a5SJeremy L Thompson case CEED_EVAL_CURL: {
3866e536b99SJeremy L Thompson return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
3874e3038a5SJeremy L Thompson break; // Should not occur
3884e3038a5SJeremy L Thompson }
3896ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP
3906ca0f394SUmesh Unnikrishnan }
391681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
3926ca0f394SUmesh Unnikrishnan }
3936ca0f394SUmesh Unnikrishnan code << "\n // -- Element loop --\n";
3946ca0f394SUmesh Unnikrishnan code << " work_group_barrier(CLK_LOCAL_MEM_FENCE);\n";
3956ca0f394SUmesh Unnikrishnan code << " {\n";
3966ca0f394SUmesh Unnikrishnan // Input basis apply if needed
3976ca0f394SUmesh Unnikrishnan // Generate the correct eval mode code for each input
3986ca0f394SUmesh Unnikrishnan code << " // -- Input field restrictions and basis actions --\n";
3996ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
4006ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n";
4016ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp
402dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
403dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
404dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
4056782e2f8SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
4066ca0f394SUmesh Unnikrishnan
4076ca0f394SUmesh Unnikrishnan // Restriction
4086ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
409dd64fc84SJeremy L Thompson bool is_strided;
410dd64fc84SJeremy L Thompson
4116ca0f394SUmesh Unnikrishnan code << " CeedScalar r_u_" << i << "[num_comp_in_" << i << "*P_in_" << i << "];\n";
4126ca0f394SUmesh Unnikrishnan
413dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
4146ca0f394SUmesh Unnikrishnan if (!is_strided) {
4156ca0f394SUmesh Unnikrishnan CeedInt comp_stride;
416dd64fc84SJeremy L Thompson
417dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
418dd64fc84SJeremy L Thompson code << " const CeedInt l_size_in_" << i << " = " << l_size << ";\n";
419dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
4206ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n";
421edb2538eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
422f59ebe5eSJeremy L Thompson h_indices.inputs[i] = rstr_impl->d_offsets;
4236ca0f394SUmesh Unnikrishnan code << " readDofsOffset" << dim << "d(num_comp_in_" << i << ", " << comp_stride << ", P_in_" << i << ", num_elem, indices->inputs[" << i
4246ca0f394SUmesh Unnikrishnan << "], d_u_" << i << ", r_u_" << i << ");\n";
4256ca0f394SUmesh Unnikrishnan } else {
4266ca0f394SUmesh Unnikrishnan bool has_backend_strides;
4276ca0f394SUmesh Unnikrishnan CeedInt num_elem;
428dd64fc84SJeremy L Thompson
429dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
430dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
4316ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
432dd64fc84SJeremy L Thompson
4336ca0f394SUmesh Unnikrishnan if (!has_backend_strides) {
43456c48462SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
4356ca0f394SUmesh Unnikrishnan }
4366ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
4376ca0f394SUmesh Unnikrishnan code << " readDofsStrided" << dim << "d(num_comp_in_" << i << ",P_in_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
4386ca0f394SUmesh Unnikrishnan << ", num_elem, d_u_" << i << ", r_u_" << i << ");\n";
4396ca0f394SUmesh Unnikrishnan }
4406ca0f394SUmesh Unnikrishnan }
441681d0ea7SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
4426ca0f394SUmesh Unnikrishnan
4436ca0f394SUmesh Unnikrishnan // Basis action
4446ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
4456ca0f394SUmesh Unnikrishnan switch (eval_mode) {
4466ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
4476ca0f394SUmesh Unnikrishnan if (!use_collograd_parallelization) {
4486ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_t_" << i << " = r_u_" << i << ";\n";
4496ca0f394SUmesh Unnikrishnan }
4506ca0f394SUmesh Unnikrishnan break;
4516ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
4526ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
4536ca0f394SUmesh Unnikrishnan code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_" << i
4546ca0f394SUmesh Unnikrishnan << ", r_t_" << i << ", elem_scratch);\n";
4556ca0f394SUmesh Unnikrishnan break;
4566ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
4576ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
4586ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n";
4596ca0f394SUmesh Unnikrishnan code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_"
4606ca0f394SUmesh Unnikrishnan << i << ", r_t_" << i << ", elem_scratch);\n";
4616ca0f394SUmesh Unnikrishnan } else {
4626ca0f394SUmesh Unnikrishnan CeedInt P_1d;
463681d0ea7SJeremy L Thompson
4646ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
4656ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
4666ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*DIM*Q_1D];\n";
4676ca0f394SUmesh Unnikrishnan code << " Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_in_" << i
4686ca0f394SUmesh Unnikrishnan << ", P_in_" << i << ", Q_1D, r_u_" << i << (dim > 1 ? ", s_B_in_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_in_" << i
4696ca0f394SUmesh Unnikrishnan << ", r_t_" << i << ", elem_scratch);\n";
470681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
4716ca0f394SUmesh Unnikrishnan }
4726ca0f394SUmesh Unnikrishnan break;
4736ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT:
4746ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[Q_1D];\n";
4756ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
4766ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl));
4776ca0f394SUmesh Unnikrishnan impl->W = basis_impl->d_q_weight_1d;
4786ca0f394SUmesh Unnikrishnan code << " Weight" << (dim > 1 ? "Tensor" : "") << dim << "d(Q_1D, W, r_t_" << i << ");\n";
479681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
4806ca0f394SUmesh Unnikrishnan break; // No action
4816ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
4826ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
4836ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL:
4846ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
4856ca0f394SUmesh Unnikrishnan }
4866ca0f394SUmesh Unnikrishnan }
4876ca0f394SUmesh Unnikrishnan
4886ca0f394SUmesh Unnikrishnan // Q function
4896ca0f394SUmesh Unnikrishnan code << "\n // -- Output field setup --\n";
4906ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
4916ca0f394SUmesh Unnikrishnan code << "\n // ---- Output field " << i << " ----\n";
4926ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
4936ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) {
4946ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
4956ca0f394SUmesh Unnikrishnan // Accumulator for gradient slices
4966ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
4976ca0f394SUmesh Unnikrishnan code << " for (CeedInt i = 0; i < num_comp_out_" << i << "; i++) {\n";
4986ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < Q_1D; ++j) {\n";
4996ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[j + i*Q_1D] = 0.0;\n";
5006ca0f394SUmesh Unnikrishnan code << " }\n";
5016ca0f394SUmesh Unnikrishnan code << " }\n";
5026ca0f394SUmesh Unnikrishnan } else {
5036ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*DIM*Q_1D];\n";
5046ca0f394SUmesh Unnikrishnan }
5056ca0f394SUmesh Unnikrishnan }
5066ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_NONE || eval_mode == CEED_EVAL_INTERP) {
5076ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n";
5086ca0f394SUmesh Unnikrishnan }
5096ca0f394SUmesh Unnikrishnan }
5106ca0f394SUmesh Unnikrishnan // We treat quadrature points per slice in 3d to save registers
5116ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
5126ca0f394SUmesh Unnikrishnan code << "\n // Note: Using planes of 3D elements\n";
5136ca0f394SUmesh Unnikrishnan code << " for (CeedInt q = 0; q < Q_1D; q++) {\n";
5146ca0f394SUmesh Unnikrishnan code << " // -- Input fields --\n";
5156ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
5166ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n";
5176ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp
5186ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
5196ca0f394SUmesh Unnikrishnan // Basis action
5206ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
5216ca0f394SUmesh Unnikrishnan switch (eval_mode) {
5226ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
523dd64fc84SJeremy L Thompson bool is_strided;
524dd64fc84SJeremy L Thompson
5256ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
5266ca0f394SUmesh Unnikrishnan
527dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
528dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
5296ca0f394SUmesh Unnikrishnan if (!is_strided) {
5306ca0f394SUmesh Unnikrishnan CeedInt comp_stride;
531dd64fc84SJeremy L Thompson
532dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
533dd64fc84SJeremy L Thompson code << " const CeedInt l_size_in_" << i << " = " << l_size << ";\n";
534dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
5356ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n";
536edb2538eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
537f59ebe5eSJeremy L Thompson h_indices.inputs[i] = rstr_impl->d_offsets;
5386ca0f394SUmesh Unnikrishnan code << " readSliceQuadsOffset"
539dd64fc84SJeremy L Thompson << "3d(num_comp_in_" << i << ", " << comp_stride << ", Q_1D, l_size_in_" << i << ", num_elem, q, indices->inputs[" << i << "], d_u_"
5406ca0f394SUmesh Unnikrishnan << i << ", r_q_" << i << ");\n";
5416ca0f394SUmesh Unnikrishnan } else {
5426ca0f394SUmesh Unnikrishnan bool has_backend_strides;
5436ca0f394SUmesh Unnikrishnan CeedInt num_elem;
544dd64fc84SJeremy L Thompson
545dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
546dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
547dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
5486ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
549dd64fc84SJeremy L Thompson
5506ca0f394SUmesh Unnikrishnan if (!has_backend_strides) {
55156c48462SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
5526ca0f394SUmesh Unnikrishnan }
5536ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
5546ca0f394SUmesh Unnikrishnan code << " readSliceQuadsStrided"
5556ca0f394SUmesh Unnikrishnan << "3d(num_comp_in_" << i << ", Q_1D," << strides[0] << ", " << strides[1] << ", " << strides[2] << ", num_elem, q, d_u_" << i
5566ca0f394SUmesh Unnikrishnan << ", r_q_" << i << ");\n";
5576ca0f394SUmesh Unnikrishnan }
558681d0ea7SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
5596ca0f394SUmesh Unnikrishnan break;
5606ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
5616ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n";
5626ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_in_" << i << " ; ++j) {\n";
5636ca0f394SUmesh Unnikrishnan code << " r_q_" << i << "[j] = r_t_" << i << "[q + j*Q_1D];\n";
5646ca0f394SUmesh Unnikrishnan code << " }\n";
5656ca0f394SUmesh Unnikrishnan break;
5666ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
5676ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "*DIM];\n";
5686ca0f394SUmesh Unnikrishnan code << " gradCollo3d(num_comp_in_" << i << ", Q_1D, q, r_t_" << i << ", s_G_in_" << i << ", r_q_" << i << ", elem_scratch);\n";
5696ca0f394SUmesh Unnikrishnan break;
5706ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT:
5716ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[1];\n";
5726ca0f394SUmesh Unnikrishnan code << " r_q_" << i << "[0] = r_t_" << i << "[q];\n";
5736ca0f394SUmesh Unnikrishnan break; // No action
5746ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
5756ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
5766ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL:
5776ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
5786ca0f394SUmesh Unnikrishnan }
5796ca0f394SUmesh Unnikrishnan }
5806ca0f394SUmesh Unnikrishnan code << "\n // -- Output fields --\n";
5816ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
5826ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
5836ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
5846ca0f394SUmesh Unnikrishnan // Basis action
5856ca0f394SUmesh Unnikrishnan switch (eval_mode) {
5866ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
5876ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
5886ca0f394SUmesh Unnikrishnan break; // No action
5896ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
5906ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n";
5916ca0f394SUmesh Unnikrishnan break;
5926ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
5936ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "*DIM];\n";
5946ca0f394SUmesh Unnikrishnan break;
5956ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT:
5966ca0f394SUmesh Unnikrishnan break; // Should not occur
5976ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
5986ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
5996ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL:
6006ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
6016ca0f394SUmesh Unnikrishnan }
6026ca0f394SUmesh Unnikrishnan }
6036ca0f394SUmesh Unnikrishnan } else {
6046ca0f394SUmesh Unnikrishnan code << "\n // Note: Using full elements\n";
6056ca0f394SUmesh Unnikrishnan code << " // -- Input fields --\n";
6066ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
6076ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n";
6086ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_q_" << i << " = r_t_" << i << ";\n";
6096ca0f394SUmesh Unnikrishnan }
6106ca0f394SUmesh Unnikrishnan code << " // -- Output fields --\n";
6116ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
6126ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
6136ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_qq_" << i << " = r_tt_" << i << ";\n";
6146ca0f394SUmesh Unnikrishnan }
6156ca0f394SUmesh Unnikrishnan }
6166ca0f394SUmesh Unnikrishnan //--------------------------------------------------
6176ca0f394SUmesh Unnikrishnan code << "\n // -- QFunction Inputs and outputs --\n";
6186ca0f394SUmesh Unnikrishnan code << " const CeedScalar * in[" << num_input_fields << "];\n";
6196ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
6206ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n";
6216ca0f394SUmesh Unnikrishnan code << " in[" << i << "] = r_q_" << i << ";\n";
6226ca0f394SUmesh Unnikrishnan }
6236ca0f394SUmesh Unnikrishnan code << " CeedScalar * out[" << num_output_fields << "];\n";
6246ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
6256ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
6266ca0f394SUmesh Unnikrishnan code << " out[" << i << "] = r_qq_" << i << ";\n";
6276ca0f394SUmesh Unnikrishnan }
6286ca0f394SUmesh Unnikrishnan
6296ca0f394SUmesh Unnikrishnan code << "\n // -- Apply QFunction --\n";
63009095acaSJeremy L Thompson code << " " << qfunction_name << "(ctx, ";
6316ca0f394SUmesh Unnikrishnan if (dim != 3 || use_collograd_parallelization) {
6326ca0f394SUmesh Unnikrishnan code << "1";
6336ca0f394SUmesh Unnikrishnan } else {
6346ca0f394SUmesh Unnikrishnan code << "Q_1D";
6356ca0f394SUmesh Unnikrishnan }
6366ca0f394SUmesh Unnikrishnan code << ", in, out);\n";
6376ca0f394SUmesh Unnikrishnan //--------------------------------------------------
6386ca0f394SUmesh Unnikrishnan
6396ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
6406ca0f394SUmesh Unnikrishnan code << " // -- Output fields --\n";
6416ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
6426ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
6436ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
6446ca0f394SUmesh Unnikrishnan // Basis action
6456ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
6466ca0f394SUmesh Unnikrishnan switch (eval_mode) {
6476ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
6486ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
6496ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
6506ca0f394SUmesh Unnikrishnan code << " }\n";
6516ca0f394SUmesh Unnikrishnan break; // No action
6526ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
6536ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n";
6546ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n";
6556ca0f394SUmesh Unnikrishnan code << " }\n";
6566ca0f394SUmesh Unnikrishnan break;
6576ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
6586ca0f394SUmesh Unnikrishnan code << " gradColloTranspose3d(num_comp_out_" << i << ",Q_1D, q, r_qq_" << i << ", s_G_out_" << i << ", r_tt_" << i
6596ca0f394SUmesh Unnikrishnan << ", elem_scratch);\n";
6606ca0f394SUmesh Unnikrishnan break;
6616ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT:
6626ca0f394SUmesh Unnikrishnan break; // Should not occur
6636ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
6646ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
6656ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL:
6666ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented
6676ca0f394SUmesh Unnikrishnan }
6686ca0f394SUmesh Unnikrishnan }
6696ca0f394SUmesh Unnikrishnan code << " }\n";
6706ca0f394SUmesh Unnikrishnan }
6716ca0f394SUmesh Unnikrishnan
6726ca0f394SUmesh Unnikrishnan // Output basis apply if needed
6736ca0f394SUmesh Unnikrishnan // Generate the correct eval mode code for each output
6746ca0f394SUmesh Unnikrishnan code << "\n // -- Output field basis action and restrictions --\n";
6756ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
6766ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n";
6776ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp
678dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
679dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
680dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
6816782e2f8SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
6826ca0f394SUmesh Unnikrishnan // Basis action
6836ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
6846ca0f394SUmesh Unnikrishnan switch (eval_mode) {
6856ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE:
6866ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_v_" << i << " = r_tt_" << i << ";\n";
6876ca0f394SUmesh Unnikrishnan break; // No action
6886ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP:
6896ca0f394SUmesh Unnikrishnan code << " CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
6906ca0f394SUmesh Unnikrishnan code << " InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
6916ca0f394SUmesh Unnikrishnan << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
6926ca0f394SUmesh Unnikrishnan break;
6936ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD:
6946ca0f394SUmesh Unnikrishnan code << " CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n";
6956ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) {
6966ca0f394SUmesh Unnikrishnan code << " InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i
6976ca0f394SUmesh Unnikrishnan << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n";
6986ca0f394SUmesh Unnikrishnan } else {
6996ca0f394SUmesh Unnikrishnan CeedInt P_1d;
7006ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
7016ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
7026ca0f394SUmesh Unnikrishnan code << " GradTranspose" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_out_" << i
7036ca0f394SUmesh Unnikrishnan << ", P_out_" << i << ", Q_1D, r_tt_" << i << (dim > 1 ? ", s_B_out_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_out_" << i
7046ca0f394SUmesh Unnikrishnan << ", r_v_" << i << ", elem_scratch);\n";
705681d0ea7SJeremy L Thompson CeedCallBackend(CeedBasisDestroy(&basis));
7066ca0f394SUmesh Unnikrishnan }
7076ca0f394SUmesh Unnikrishnan break;
7086ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START
7096ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: {
7106e536b99SJeremy L Thompson return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
7116ca0f394SUmesh Unnikrishnan break; // Should not occur
7126ca0f394SUmesh Unnikrishnan }
7136ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV:
7144e3038a5SJeremy L Thompson case CEED_EVAL_CURL: {
7156e536b99SJeremy L Thompson return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
7164e3038a5SJeremy L Thompson break; // Should not occur
7174e3038a5SJeremy L Thompson }
7186ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP
7196ca0f394SUmesh Unnikrishnan }
7206ca0f394SUmesh Unnikrishnan // Restriction
7216ca0f394SUmesh Unnikrishnan bool is_strided;
722dd64fc84SJeremy L Thompson
723dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
7246ca0f394SUmesh Unnikrishnan if (!is_strided) {
7256ca0f394SUmesh Unnikrishnan CeedInt comp_stride;
726dd64fc84SJeremy L Thompson
727dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
728dd64fc84SJeremy L Thompson code << " const CeedInt l_size_out_" << i << " = " << l_size << ";\n";
729dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
7306ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n";
731edb2538eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(elem_rstr, &rstr_impl));
732f59ebe5eSJeremy L Thompson h_indices.outputs[i] = rstr_impl->d_offsets;
7336ca0f394SUmesh Unnikrishnan code << " writeDofsOffset" << dim << "d(num_comp_out_" << i << ", " << comp_stride << ", P_out_" << i << ", num_elem, indices->outputs[" << i
7346ca0f394SUmesh Unnikrishnan << "], r_v_" << i << ", d_v_" << i << ");\n";
7356ca0f394SUmesh Unnikrishnan } else {
7366ca0f394SUmesh Unnikrishnan bool has_backend_strides;
7376ca0f394SUmesh Unnikrishnan CeedInt num_elem;
738dd64fc84SJeremy L Thompson
739dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &has_backend_strides));
740dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem));
7416ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size};
742dd64fc84SJeremy L Thompson
7436ca0f394SUmesh Unnikrishnan if (!has_backend_strides) {
74456c48462SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
7456ca0f394SUmesh Unnikrishnan }
7466ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n";
7476ca0f394SUmesh Unnikrishnan code << " writeDofsStrided" << dim << "d(num_comp_out_" << i << ",P_out_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2]
7486ca0f394SUmesh Unnikrishnan << ", num_elem, r_v_" << i << ", d_v_" << i << ");\n";
7496ca0f394SUmesh Unnikrishnan }
750681d0ea7SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
7516ca0f394SUmesh Unnikrishnan }
7526ca0f394SUmesh Unnikrishnan
7536ca0f394SUmesh Unnikrishnan code << " }\n";
7546ca0f394SUmesh Unnikrishnan code << "}\n";
7556ca0f394SUmesh Unnikrishnan code << "// -----------------------------------------------------------------------------\n\n";
7566ca0f394SUmesh Unnikrishnan
7576ca0f394SUmesh Unnikrishnan // Copy the struct (containing device addresses) from the host to the device
7581f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
7591f4b1b45SUmesh Unnikrishnan
7601f4b1b45SUmesh Unnikrishnan if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
7611f4b1b45SUmesh Unnikrishnan
7621f4b1b45SUmesh Unnikrishnan sycl::event copy_B = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_B, impl->B, 1, e);
7631f4b1b45SUmesh Unnikrishnan sycl::event copy_G = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_G, impl->G, 1, e);
7641f4b1b45SUmesh Unnikrishnan sycl::event copy_indices = sycl_data->sycl_queue.copy<FieldsInt_Sycl>(&h_indices, impl->indices, 1, e);
7656ca0f394SUmesh Unnikrishnan // These copies can happen while the JIT is being done
7666ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_B, copy_G, copy_indices}));
7676ca0f394SUmesh Unnikrishnan
7686ca0f394SUmesh Unnikrishnan // View kernel for debugging
7696ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "Generated Operator Kernels:\n");
7706ca0f394SUmesh Unnikrishnan CeedDebug(ceed, code.str().c_str());
7716ca0f394SUmesh Unnikrishnan
7726ca0f394SUmesh Unnikrishnan std::map<std::string, CeedInt> jit_constants;
7736ca0f394SUmesh Unnikrishnan jit_constants["T_1D"] = block_sizes[0];
7746ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_X"] = block_sizes[0];
7756ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_Y"] = block_sizes[1];
7766ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_Z"] = block_sizes[2];
7776ca0f394SUmesh Unnikrishnan
7786ca0f394SUmesh Unnikrishnan // Compile kernel into a kernel bundle
7796ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module, jit_constants));
7806ca0f394SUmesh Unnikrishnan
7816ca0f394SUmesh Unnikrishnan // Load kernel function
7826ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, operator_name, &impl->op));
7836ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorSetSetupDone(op));
7849bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
785c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf));
7866ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
7876ca0f394SUmesh Unnikrishnan }
7886ca0f394SUmesh Unnikrishnan
7896ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
790