1*6ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*6ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*6ca0f394SUmesh Unnikrishnan // 4*6ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause 5*6ca0f394SUmesh Unnikrishnan // 6*6ca0f394SUmesh Unnikrishnan // This file is part of CEED: http://github.com/ceed 7*6ca0f394SUmesh Unnikrishnan 8*6ca0f394SUmesh Unnikrishnan #define CEED_DEBUG_COLOR 12 9*6ca0f394SUmesh Unnikrishnan 10*6ca0f394SUmesh Unnikrishnan #include <ceed/backend.h> 11*6ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h> 12*6ca0f394SUmesh Unnikrishnan #include <ceed/jit-source/sycl/sycl-types.h> 13*6ca0f394SUmesh Unnikrishnan #include <ceed/jit-tools.h> 14*6ca0f394SUmesh Unnikrishnan 15*6ca0f394SUmesh Unnikrishnan #include <iostream> 16*6ca0f394SUmesh Unnikrishnan #include <sstream> 17*6ca0f394SUmesh Unnikrishnan #include <string> 18*6ca0f394SUmesh Unnikrishnan #include <string_view> 19*6ca0f394SUmesh Unnikrishnan #include <vector> 20*6ca0f394SUmesh Unnikrishnan 21*6ca0f394SUmesh Unnikrishnan #include "../sycl-ref/ceed-sycl-ref.hpp" 22*6ca0f394SUmesh Unnikrishnan #include "../sycl-shared/ceed-sycl-shared.hpp" 23*6ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp" 24*6ca0f394SUmesh Unnikrishnan 25*6ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp" 26*6ca0f394SUmesh Unnikrishnan 27*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 28*6ca0f394SUmesh Unnikrishnan // Calculate the block size used for launching the operator kernel 29*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 30*6ca0f394SUmesh Unnikrishnan extern "C" int BlockGridCalculate_Sycl_gen(const CeedInt dim, const CeedInt P_1d, const CeedInt Q_1d, CeedInt *block_sizes) { 31*6ca0f394SUmesh Unnikrishnan const CeedInt thread1d = CeedIntMax(Q_1d, P_1d); 32*6ca0f394SUmesh Unnikrishnan if (dim == 1) { 33*6ca0f394SUmesh Unnikrishnan CeedInt elems_per_block = 64 * thread1d > 256 ? 256 / thread1d : 64; 34*6ca0f394SUmesh Unnikrishnan elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 35*6ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d; 36*6ca0f394SUmesh Unnikrishnan block_sizes[1] = 1; 37*6ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block; 38*6ca0f394SUmesh Unnikrishnan } else if (dim == 2) { 39*6ca0f394SUmesh Unnikrishnan const CeedInt elems_per_block = thread1d < 4 ? 16 : 2; 40*6ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d; 41*6ca0f394SUmesh Unnikrishnan block_sizes[1] = thread1d; 42*6ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block; 43*6ca0f394SUmesh Unnikrishnan } else if (dim == 3) { 44*6ca0f394SUmesh Unnikrishnan const CeedInt elems_per_block = thread1d < 6 ? 4 : (thread1d < 8 ? 2 : 1); 45*6ca0f394SUmesh Unnikrishnan block_sizes[0] = thread1d; 46*6ca0f394SUmesh Unnikrishnan block_sizes[1] = thread1d; 47*6ca0f394SUmesh Unnikrishnan block_sizes[2] = elems_per_block; 48*6ca0f394SUmesh Unnikrishnan } 49*6ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS; 50*6ca0f394SUmesh Unnikrishnan } 51*6ca0f394SUmesh Unnikrishnan 52*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 53*6ca0f394SUmesh Unnikrishnan // Build single operator kernel 54*6ca0f394SUmesh Unnikrishnan // - [ ] Check arguments to device functions reudsed from sycl-shared-basis are correct 55*6ca0f394SUmesh Unnikrishnan // - [ ] Do kernel jitting! 56*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 57*6ca0f394SUmesh Unnikrishnan extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) { 58*6ca0f394SUmesh Unnikrishnan bool is_setup_done; 59*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 60*6ca0f394SUmesh Unnikrishnan if (is_setup_done) return CEED_ERROR_SUCCESS; 61*6ca0f394SUmesh Unnikrishnan 62*6ca0f394SUmesh Unnikrishnan Ceed ceed; 63*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 64*6ca0f394SUmesh Unnikrishnan Ceed_Sycl *sycl_data; 65*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &sycl_data)); 66*6ca0f394SUmesh Unnikrishnan 67*6ca0f394SUmesh Unnikrishnan CeedOperator_Sycl_gen *impl; 68*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetData(op, &impl)); 69*6ca0f394SUmesh Unnikrishnan Fields_Sycl h_B, h_G; 70*6ca0f394SUmesh Unnikrishnan FieldsInt_Sycl h_indices; 71*6ca0f394SUmesh Unnikrishnan CeedQFunction qf; 72*6ca0f394SUmesh Unnikrishnan CeedQFunction_Sycl_gen *qf_impl; 73*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 74*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl)); 75*6ca0f394SUmesh Unnikrishnan CeedSize lsize; 76*6ca0f394SUmesh Unnikrishnan CeedInt Q, P_1d = 0, Q_1d = 0, elem_size, num_input_fields, num_output_fields, num_comp, dim = 1; 77*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 78*6ca0f394SUmesh Unnikrishnan Q_1d = Q; 79*6ca0f394SUmesh Unnikrishnan 80*6ca0f394SUmesh Unnikrishnan CeedOperatorField *op_input_fields, *op_output_fields; 81*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 82*6ca0f394SUmesh Unnikrishnan CeedQFunctionField *qf_input_fields, *qf_output_fields; 83*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 84*6ca0f394SUmesh Unnikrishnan 85*6ca0f394SUmesh Unnikrishnan CeedEvalMode eval_mode; 86*6ca0f394SUmesh Unnikrishnan CeedBasis basis; 87*6ca0f394SUmesh Unnikrishnan CeedBasis_Sycl_shared *basis_impl; 88*6ca0f394SUmesh Unnikrishnan CeedElemRestriction Erestrict; 89*6ca0f394SUmesh Unnikrishnan CeedElemRestriction_Sycl *restr_impl; 90*6ca0f394SUmesh Unnikrishnan 91*6ca0f394SUmesh Unnikrishnan // Check for restriction only identity operator 92*6ca0f394SUmesh Unnikrishnan bool is_identity_qf; 93*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionIsIdentity(qf, &is_identity_qf)); 94*6ca0f394SUmesh Unnikrishnan if (is_identity_qf) { 95*6ca0f394SUmesh Unnikrishnan CeedEvalMode eval_mode_in, eval_mode_out; 96*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[0], &eval_mode_in)); 97*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[0], &eval_mode_out)); 98*6ca0f394SUmesh Unnikrishnan if (eval_mode_in == CEED_EVAL_NONE && eval_mode_out == CEED_EVAL_NONE) { 99*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START 100*6ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement restriction only identity operators"); 101*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP 102*6ca0f394SUmesh Unnikrishnan } 103*6ca0f394SUmesh Unnikrishnan } 104*6ca0f394SUmesh Unnikrishnan 105*6ca0f394SUmesh Unnikrishnan std::ostringstream code; 106*6ca0f394SUmesh Unnikrishnan // TODO: generalize to accept different device functions? 107*6ca0f394SUmesh Unnikrishnan { 108*6ca0f394SUmesh Unnikrishnan char *tensor_basis_kernel_path, *tensor_basis_code; 109*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h", &tensor_basis_kernel_path)); 110*6ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "----- Loading Tensor Basis Kernel Source -----\n"); 111*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedLoadSourceToBuffer(ceed, tensor_basis_kernel_path, &tensor_basis_code)); 112*6ca0f394SUmesh Unnikrishnan code << tensor_basis_code; 113*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&tensor_basis_kernel_path)); 114*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&tensor_basis_code)); 115*6ca0f394SUmesh Unnikrishnan } 116*6ca0f394SUmesh Unnikrishnan { 117*6ca0f394SUmesh Unnikrishnan char *sycl_gen_template_path, *sycl_gen_template_source; 118*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-gen-templates.h", &sycl_gen_template_path)); 119*6ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "----- Loading Sycl-Gen Template Source -----\n"); 120*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedLoadSourceToBuffer(ceed, sycl_gen_template_path, &sycl_gen_template_source)); 121*6ca0f394SUmesh Unnikrishnan code << sycl_gen_template_source; 122*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&sycl_gen_template_path)); 123*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&sycl_gen_template_source)); 124*6ca0f394SUmesh Unnikrishnan } 125*6ca0f394SUmesh Unnikrishnan 126*6ca0f394SUmesh Unnikrishnan std::string_view q_function_source(qf_impl->q_function_source); 127*6ca0f394SUmesh Unnikrishnan std::string_view q_function_name(qf_impl->q_function_name); 128*6ca0f394SUmesh Unnikrishnan const std::string operator_name = "CeedKernelSyclGenOperator_" + std::string(q_function_name); 129*6ca0f394SUmesh Unnikrishnan 130*6ca0f394SUmesh Unnikrishnan // Find dim, P_1d, Q_1d 131*6ca0f394SUmesh Unnikrishnan impl->max_P_1d = 0; 132*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 133*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 134*6ca0f394SUmesh Unnikrishnan if (basis != CEED_BASIS_COLLOCATED) { 135*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 136*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 137*6ca0f394SUmesh Unnikrishnan 138*6ca0f394SUmesh Unnikrishnan // Collect dim, P_1d, and Q_1d 139*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 140*6ca0f394SUmesh Unnikrishnan bool isTensor; 141*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisIsTensor(basis, &isTensor)); 142*6ca0f394SUmesh Unnikrishnan if (isTensor) { 143*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); 144*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 145*6ca0f394SUmesh Unnikrishnan if (P_1d > impl->max_P_1d) impl->max_P_1d = P_1d; 146*6ca0f394SUmesh Unnikrishnan } else { 147*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START 148*6ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis"); 149*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP 150*6ca0f394SUmesh Unnikrishnan } 151*6ca0f394SUmesh Unnikrishnan } 152*6ca0f394SUmesh Unnikrishnan } 153*6ca0f394SUmesh Unnikrishnan // Check output bases for Q_1d, dim as well 154*6ca0f394SUmesh Unnikrishnan // The only input basis might be CEED_BASIS_COLLOCATED 155*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 156*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 157*6ca0f394SUmesh Unnikrishnan 158*6ca0f394SUmesh Unnikrishnan if (basis != CEED_BASIS_COLLOCATED) { 159*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 160*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 161*6ca0f394SUmesh Unnikrishnan 162*6ca0f394SUmesh Unnikrishnan // Collect Q_1d 163*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 164*6ca0f394SUmesh Unnikrishnan bool isTensor; 165*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisIsTensor(basis, &isTensor)); 166*6ca0f394SUmesh Unnikrishnan if (isTensor) { 167*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); 168*6ca0f394SUmesh Unnikrishnan } else { 169*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START 170*6ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement operators with non-tensor basis"); 171*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP 172*6ca0f394SUmesh Unnikrishnan } 173*6ca0f394SUmesh Unnikrishnan } 174*6ca0f394SUmesh Unnikrishnan } 175*6ca0f394SUmesh Unnikrishnan impl->dim = dim; 176*6ca0f394SUmesh Unnikrishnan impl->Q_1d = Q_1d; 177*6ca0f394SUmesh Unnikrishnan 178*6ca0f394SUmesh Unnikrishnan // Only use 3D collocated gradient parallelization strategy when gradient is computed 179*6ca0f394SUmesh Unnikrishnan // TODO: put in a function? 180*6ca0f394SUmesh Unnikrishnan bool use_collograd_parallelization = false; 181*6ca0f394SUmesh Unnikrishnan if (dim == 3) { 182*6ca0f394SUmesh Unnikrishnan bool was_grad_found = false; 183*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 184*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 185*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) { 186*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 187*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 188*6ca0f394SUmesh Unnikrishnan use_collograd_parallelization = !!basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true); 189*6ca0f394SUmesh Unnikrishnan was_grad_found = true; 190*6ca0f394SUmesh Unnikrishnan } 191*6ca0f394SUmesh Unnikrishnan } 192*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 193*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 194*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) { 195*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 196*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 197*6ca0f394SUmesh Unnikrishnan use_collograd_parallelization = !!basis_impl->d_collo_grad_1d && (was_grad_found ? use_collograd_parallelization : true); 198*6ca0f394SUmesh Unnikrishnan was_grad_found = true; 199*6ca0f394SUmesh Unnikrishnan } 200*6ca0f394SUmesh Unnikrishnan } 201*6ca0f394SUmesh Unnikrishnan } 202*6ca0f394SUmesh Unnikrishnan 203*6ca0f394SUmesh Unnikrishnan CeedInt block_sizes[3]; 204*6ca0f394SUmesh Unnikrishnan CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes)); 205*6ca0f394SUmesh Unnikrishnan 206*6ca0f394SUmesh Unnikrishnan // Define CEED_Q_VLA 207*6ca0f394SUmesh Unnikrishnan code << "\n#undef CEED_Q_VLA\n"; 208*6ca0f394SUmesh Unnikrishnan if (dim != 3 || use_collograd_parallelization) { 209*6ca0f394SUmesh Unnikrishnan code << "#define CEED_Q_VLA 1\n\n"; 210*6ca0f394SUmesh Unnikrishnan } else { 211*6ca0f394SUmesh Unnikrishnan code << "#define CEED_Q_VLA " << Q_1d << "\n\n"; 212*6ca0f394SUmesh Unnikrishnan } 213*6ca0f394SUmesh Unnikrishnan 214*6ca0f394SUmesh Unnikrishnan // Determine subgroup size based on supported sizes : Default : 16 (if supported) 215*6ca0f394SUmesh Unnikrishnan std::vector allowed_sg_sizes = sycl_data->sycl_device.get_info<sycl::info::device::sub_group_sizes>(); 216*6ca0f394SUmesh Unnikrishnan CeedInt sub_group_size_op = allowed_sg_sizes[allowed_sg_sizes.size() - 1]; 217*6ca0f394SUmesh Unnikrishnan for (const auto &s : allowed_sg_sizes) { 218*6ca0f394SUmesh Unnikrishnan if (s == 16) { 219*6ca0f394SUmesh Unnikrishnan sub_group_size_op = s; 220*6ca0f394SUmesh Unnikrishnan break; 221*6ca0f394SUmesh Unnikrishnan } 222*6ca0f394SUmesh Unnikrishnan } 223*6ca0f394SUmesh Unnikrishnan 224*6ca0f394SUmesh Unnikrishnan code << q_function_source; 225*6ca0f394SUmesh Unnikrishnan 226*6ca0f394SUmesh Unnikrishnan // Kernel function 227*6ca0f394SUmesh Unnikrishnan code << "\n// -----------------------------------------------------------------------------\n"; 228*6ca0f394SUmesh Unnikrishnan code << "__attribute__((reqd_work_group_size(GROUP_SIZE_X, GROUP_SIZE_Y, GROUP_SIZE_Z), intel_reqd_sub_group_size(" << sub_group_size_op << ")))\n"; 229*6ca0f394SUmesh Unnikrishnan code << "kernel void " << operator_name << "("; 230*6ca0f394SUmesh Unnikrishnan code << "const CeedInt num_elem, "; 231*6ca0f394SUmesh Unnikrishnan code << "global void* ctx, "; 232*6ca0f394SUmesh Unnikrishnan code << "global const FieldsInt_Sycl* indices, "; 233*6ca0f394SUmesh Unnikrishnan code << "global Fields_Sycl* fields, "; 234*6ca0f394SUmesh Unnikrishnan code << "global const Fields_Sycl* B, "; 235*6ca0f394SUmesh Unnikrishnan code << "global const Fields_Sycl* G, "; 236*6ca0f394SUmesh Unnikrishnan code << "global const CeedScalar * restrict W"; 237*6ca0f394SUmesh Unnikrishnan code << ") {\n"; 238*6ca0f394SUmesh Unnikrishnan 239*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 240*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 241*6ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT) { // Skip CEED_EVAL_WEIGHT 242*6ca0f394SUmesh Unnikrishnan code << " global const CeedScalar* d_u_" << i << " = fields->inputs[" << i << "];\n"; 243*6ca0f394SUmesh Unnikrishnan } 244*6ca0f394SUmesh Unnikrishnan } 245*6ca0f394SUmesh Unnikrishnan 246*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 247*6ca0f394SUmesh Unnikrishnan code << " global CeedScalar* d_v_" << i << " = fields->outputs[" << i << "];\n"; 248*6ca0f394SUmesh Unnikrishnan } 249*6ca0f394SUmesh Unnikrishnan 250*6ca0f394SUmesh Unnikrishnan // TODO: Convert these to defined constants to save on GRF 251*6ca0f394SUmesh Unnikrishnan code << " const CeedInt DIM = " << dim << ";\n"; 252*6ca0f394SUmesh Unnikrishnan code << " const CeedInt Q_1D = " << Q_1d << ";\n"; 253*6ca0f394SUmesh Unnikrishnan 254*6ca0f394SUmesh Unnikrishnan const CeedInt scratch_size = block_sizes[0] * block_sizes[1] * block_sizes[2]; 255*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar scratch[" << scratch_size << "];\n"; 256*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar * elem_scratch = scratch + get_local_id(2) * T_1D" << (dim > 1 ? "*T_1D" : "") << ";\n"; 257*6ca0f394SUmesh Unnikrishnan 258*6ca0f394SUmesh Unnikrishnan code << "\n // -- Input field constants and basis data --\n"; 259*6ca0f394SUmesh Unnikrishnan // Initialize constants, and matrices B and G 260*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 261*6ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n"; 262*6ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp 263*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict)); 264*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size)); 265*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 266*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp)); 267*6ca0f394SUmesh Unnikrishnan 268*6ca0f394SUmesh Unnikrishnan // Set field constants 269*6ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT) { 270*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 271*6ca0f394SUmesh Unnikrishnan if (basis != CEED_BASIS_COLLOCATED) { 272*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 273*6ca0f394SUmesh Unnikrishnan code << " const CeedInt P_in_" << i << " = " << P_1d << ";\n"; 274*6ca0f394SUmesh Unnikrishnan } else { 275*6ca0f394SUmesh Unnikrishnan code << " const CeedInt P_in_" << i << " = " << Q_1d << ";\n"; 276*6ca0f394SUmesh Unnikrishnan } 277*6ca0f394SUmesh Unnikrishnan code << " const CeedInt num_comp_in_" << i << " = " << num_comp << ";\n"; 278*6ca0f394SUmesh Unnikrishnan } 279*6ca0f394SUmesh Unnikrishnan 280*6ca0f394SUmesh Unnikrishnan // Load basis data 281*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 282*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 283*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 284*6ca0f394SUmesh Unnikrishnan break; 285*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 286*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 287*6ca0f394SUmesh Unnikrishnan h_B.inputs[i] = basis_impl->d_interp_1d; 288*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n"; 289*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n"; 290*6ca0f394SUmesh Unnikrishnan break; 291*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 292*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 293*6ca0f394SUmesh Unnikrishnan h_B.inputs[i] = basis_impl->d_interp_1d; 294*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_in_" << i << "[" << P_1d * Q_1d << "];\n"; 295*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_in_" << i << "*Q_1D, B->inputs[" << i << "], s_B_in_" << i << ");\n"; 296*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 297*6ca0f394SUmesh Unnikrishnan h_G.inputs[i] = basis_impl->d_collo_grad_1d; 298*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_in_" << i << "[" << Q_1d * Q_1d << "];\n"; 299*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(Q_1D*Q_1D, G->inputs[" << i << "], s_G_in_" << i << ");\n"; 300*6ca0f394SUmesh Unnikrishnan } else { 301*6ca0f394SUmesh Unnikrishnan bool has_collo_grad = !!basis_impl->d_collo_grad_1d; 302*6ca0f394SUmesh Unnikrishnan h_G.inputs[i] = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d; 303*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_in_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n"; 304*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_in_" + std::to_string(i))) << "*Q_1D, G->inputs[" << i << "], s_G_in_" << i 305*6ca0f394SUmesh Unnikrishnan << ");\n"; 306*6ca0f394SUmesh Unnikrishnan } 307*6ca0f394SUmesh Unnikrishnan break; 308*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: 309*6ca0f394SUmesh Unnikrishnan break; // No action 310*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 311*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 312*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 313*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 314*6ca0f394SUmesh Unnikrishnan } 315*6ca0f394SUmesh Unnikrishnan } 316*6ca0f394SUmesh Unnikrishnan 317*6ca0f394SUmesh Unnikrishnan code << "\n // -- Output field constants and basis data --\n"; 318*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 319*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 320*6ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp 321*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &Erestrict)); 322*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size)); 323*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 324*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp)); 325*6ca0f394SUmesh Unnikrishnan 326*6ca0f394SUmesh Unnikrishnan // Set field constants 327*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 328*6ca0f394SUmesh Unnikrishnan if (basis != CEED_BASIS_COLLOCATED) { 329*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 330*6ca0f394SUmesh Unnikrishnan code << " const CeedInt P_out_" << i << " = " << P_1d << ";\n"; 331*6ca0f394SUmesh Unnikrishnan } else { 332*6ca0f394SUmesh Unnikrishnan code << " const CeedInt P_out_" << i << " = " << Q_1d << ";\n"; 333*6ca0f394SUmesh Unnikrishnan } 334*6ca0f394SUmesh Unnikrishnan code << " const CeedInt num_comp_out_" << i << " = " << num_comp << ";\n"; 335*6ca0f394SUmesh Unnikrishnan 336*6ca0f394SUmesh Unnikrishnan // Load basis data 337*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 338*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 339*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 340*6ca0f394SUmesh Unnikrishnan break; // No action 341*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 342*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 343*6ca0f394SUmesh Unnikrishnan h_B.outputs[i] = basis_impl->d_interp_1d; 344*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n"; 345*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n"; 346*6ca0f394SUmesh Unnikrishnan break; 347*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 348*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 349*6ca0f394SUmesh Unnikrishnan h_B.outputs[i] = basis_impl->d_interp_1d; 350*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_B_out_" << i << "[" << P_1d * Q_1d << "];\n"; 351*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(P_out_" << i << "*Q_1D, B->outputs[" << i << "], s_B_out_" << i << ");\n"; 352*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 353*6ca0f394SUmesh Unnikrishnan h_G.outputs[i] = basis_impl->d_collo_grad_1d; 354*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_out_" << i << "[" << Q_1d * Q_1d << "];\n"; 355*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(Q_1D*Q_1D, G->outputs[" << i << "], s_G_out_" << i << ");\n"; 356*6ca0f394SUmesh Unnikrishnan } else { 357*6ca0f394SUmesh Unnikrishnan bool has_collo_grad = !!basis_impl->d_collo_grad_1d; 358*6ca0f394SUmesh Unnikrishnan h_G.outputs[i] = has_collo_grad ? basis_impl->d_collo_grad_1d : basis_impl->d_grad_1d; 359*6ca0f394SUmesh Unnikrishnan code << " local CeedScalar s_G_out_" << i << "[" << Q_1d * (has_collo_grad ? Q_1d : P_1d) << "];\n"; 360*6ca0f394SUmesh Unnikrishnan code << " loadMatrix(" << (has_collo_grad ? "Q_1D" : ("P_out_" + std::to_string(i))) << "*Q_1D, G->outputs[" << i << "], s_G_out_" << i 361*6ca0f394SUmesh Unnikrishnan << ");\n"; 362*6ca0f394SUmesh Unnikrishnan } 363*6ca0f394SUmesh Unnikrishnan break; 364*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START 365*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: { 366*6ca0f394SUmesh Unnikrishnan Ceed ceed; 367*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 368*6ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 369*6ca0f394SUmesh Unnikrishnan break; // Should not occur 370*6ca0f394SUmesh Unnikrishnan } 371*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 372*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 373*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 374*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 375*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP 376*6ca0f394SUmesh Unnikrishnan } 377*6ca0f394SUmesh Unnikrishnan } 378*6ca0f394SUmesh Unnikrishnan code << "\n // -- Element loop --\n"; 379*6ca0f394SUmesh Unnikrishnan code << " work_group_barrier(CLK_LOCAL_MEM_FENCE);\n"; 380*6ca0f394SUmesh Unnikrishnan code << " {\n"; 381*6ca0f394SUmesh Unnikrishnan // Input basis apply if needed 382*6ca0f394SUmesh Unnikrishnan // Generate the correct eval mode code for each input 383*6ca0f394SUmesh Unnikrishnan code << " // -- Input field restrictions and basis actions --\n"; 384*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 385*6ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n"; 386*6ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp 387*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict)); 388*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size)); 389*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 390*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp)); 391*6ca0f394SUmesh Unnikrishnan 392*6ca0f394SUmesh Unnikrishnan // Restriction 393*6ca0f394SUmesh Unnikrishnan if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) { 394*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_u_" << i << "[num_comp_in_" << i << "*P_in_" << i << "];\n"; 395*6ca0f394SUmesh Unnikrishnan 396*6ca0f394SUmesh Unnikrishnan bool is_strided; 397*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided)); 398*6ca0f394SUmesh Unnikrishnan if (!is_strided) { 399*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize)); 400*6ca0f394SUmesh Unnikrishnan code << " const CeedInt lsize_in_" << i << " = " << lsize << ";\n"; 401*6ca0f394SUmesh Unnikrishnan CeedInt comp_stride; 402*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride)); 403*6ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n"; 404*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl)); 405*6ca0f394SUmesh Unnikrishnan h_indices.inputs[i] = restr_impl->d_ind; 406*6ca0f394SUmesh Unnikrishnan code << " readDofsOffset" << dim << "d(num_comp_in_" << i << ", " << comp_stride << ", P_in_" << i << ", num_elem, indices->inputs[" << i 407*6ca0f394SUmesh Unnikrishnan << "], d_u_" << i << ", r_u_" << i << ");\n"; 408*6ca0f394SUmesh Unnikrishnan } else { 409*6ca0f394SUmesh Unnikrishnan bool has_backend_strides; 410*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides)); 411*6ca0f394SUmesh Unnikrishnan CeedInt num_elem; 412*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem)); 413*6ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size}; 414*6ca0f394SUmesh Unnikrishnan if (!has_backend_strides) { 415*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides)); 416*6ca0f394SUmesh Unnikrishnan } 417*6ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n"; 418*6ca0f394SUmesh Unnikrishnan code << " readDofsStrided" << dim << "d(num_comp_in_" << i << ",P_in_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2] 419*6ca0f394SUmesh Unnikrishnan << ", num_elem, d_u_" << i << ", r_u_" << i << ");\n"; 420*6ca0f394SUmesh Unnikrishnan } 421*6ca0f394SUmesh Unnikrishnan } 422*6ca0f394SUmesh Unnikrishnan 423*6ca0f394SUmesh Unnikrishnan // Basis action 424*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 425*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 426*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 427*6ca0f394SUmesh Unnikrishnan if (!use_collograd_parallelization) { 428*6ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_t_" << i << " = r_u_" << i << ";\n"; 429*6ca0f394SUmesh Unnikrishnan } 430*6ca0f394SUmesh Unnikrishnan break; 431*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 432*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n"; 433*6ca0f394SUmesh Unnikrishnan code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_" << i 434*6ca0f394SUmesh Unnikrishnan << ", r_t_" << i << ", elem_scratch);\n"; 435*6ca0f394SUmesh Unnikrishnan break; 436*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 437*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 438*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*Q_1D];\n"; 439*6ca0f394SUmesh Unnikrishnan code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_in_" << i << ", P_in_" << i << ", Q_1D, r_u_" << i << ", s_B_in_" 440*6ca0f394SUmesh Unnikrishnan << i << ", r_t_" << i << ", elem_scratch);\n"; 441*6ca0f394SUmesh Unnikrishnan } else { 442*6ca0f394SUmesh Unnikrishnan CeedInt P_1d; 443*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 444*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 445*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[num_comp_in_" << i << "*DIM*Q_1D];\n"; 446*6ca0f394SUmesh Unnikrishnan code << " Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_in_" << i 447*6ca0f394SUmesh Unnikrishnan << ", P_in_" << i << ", Q_1D, r_u_" << i << (dim > 1 ? ", s_B_in_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_in_" << i 448*6ca0f394SUmesh Unnikrishnan << ", r_t_" << i << ", elem_scratch);\n"; 449*6ca0f394SUmesh Unnikrishnan } 450*6ca0f394SUmesh Unnikrishnan break; 451*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: 452*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_t_" << i << "[Q_1D];\n"; 453*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 454*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetData(basis, &basis_impl)); 455*6ca0f394SUmesh Unnikrishnan impl->W = basis_impl->d_q_weight_1d; 456*6ca0f394SUmesh Unnikrishnan code << " Weight" << (dim > 1 ? "Tensor" : "") << dim << "d(Q_1D, W, r_t_" << i << ");\n"; 457*6ca0f394SUmesh Unnikrishnan break; // No action 458*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 459*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 460*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 461*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 462*6ca0f394SUmesh Unnikrishnan } 463*6ca0f394SUmesh Unnikrishnan } 464*6ca0f394SUmesh Unnikrishnan 465*6ca0f394SUmesh Unnikrishnan // Q function 466*6ca0f394SUmesh Unnikrishnan code << "\n // -- Output field setup --\n"; 467*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 468*6ca0f394SUmesh Unnikrishnan code << "\n // ---- Output field " << i << " ----\n"; 469*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 470*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_GRAD) { 471*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 472*6ca0f394SUmesh Unnikrishnan // Accumulator for gradient slices 473*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n"; 474*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt i = 0; i < num_comp_out_" << i << "; i++) {\n"; 475*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < Q_1D; ++j) {\n"; 476*6ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[j + i*Q_1D] = 0.0;\n"; 477*6ca0f394SUmesh Unnikrishnan code << " }\n"; 478*6ca0f394SUmesh Unnikrishnan code << " }\n"; 479*6ca0f394SUmesh Unnikrishnan } else { 480*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*DIM*Q_1D];\n"; 481*6ca0f394SUmesh Unnikrishnan } 482*6ca0f394SUmesh Unnikrishnan } 483*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_NONE || eval_mode == CEED_EVAL_INTERP) { 484*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_tt_" << i << "[num_comp_out_" << i << "*Q_1D];\n"; 485*6ca0f394SUmesh Unnikrishnan } 486*6ca0f394SUmesh Unnikrishnan } 487*6ca0f394SUmesh Unnikrishnan // We treat quadrature points per slice in 3d to save registers 488*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 489*6ca0f394SUmesh Unnikrishnan code << "\n // Note: Using planes of 3D elements\n"; 490*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt q = 0; q < Q_1D; q++) {\n"; 491*6ca0f394SUmesh Unnikrishnan code << " // -- Input fields --\n"; 492*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 493*6ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n"; 494*6ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp 495*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 496*6ca0f394SUmesh Unnikrishnan // Basis action 497*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 498*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 499*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 500*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n"; 501*6ca0f394SUmesh Unnikrishnan 502*6ca0f394SUmesh Unnikrishnan bool is_strided; 503*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &Erestrict)); 504*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided)); 505*6ca0f394SUmesh Unnikrishnan if (!is_strided) { 506*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize)); 507*6ca0f394SUmesh Unnikrishnan code << " const CeedInt lsize_in_" << i << " = " << lsize << ";\n"; 508*6ca0f394SUmesh Unnikrishnan CeedInt comp_stride; 509*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride)); 510*6ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n"; 511*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl)); 512*6ca0f394SUmesh Unnikrishnan h_indices.inputs[i] = restr_impl->d_ind; 513*6ca0f394SUmesh Unnikrishnan code << " readSliceQuadsOffset" 514*6ca0f394SUmesh Unnikrishnan << "3d(num_comp_in_" << i << ", " << comp_stride << ", Q_1D, lsize_in_" << i << ", num_elem, q, indices->inputs[" << i << "], d_u_" 515*6ca0f394SUmesh Unnikrishnan << i << ", r_q_" << i << ");\n"; 516*6ca0f394SUmesh Unnikrishnan } else { 517*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size)); 518*6ca0f394SUmesh Unnikrishnan bool has_backend_strides; 519*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides)); 520*6ca0f394SUmesh Unnikrishnan CeedInt num_elem; 521*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem)); 522*6ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size}; 523*6ca0f394SUmesh Unnikrishnan if (!has_backend_strides) { 524*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides)); 525*6ca0f394SUmesh Unnikrishnan } 526*6ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n"; 527*6ca0f394SUmesh Unnikrishnan code << " readSliceQuadsStrided" 528*6ca0f394SUmesh Unnikrishnan << "3d(num_comp_in_" << i << ", Q_1D," << strides[0] << ", " << strides[1] << ", " << strides[2] << ", num_elem, q, d_u_" << i 529*6ca0f394SUmesh Unnikrishnan << ", r_q_" << i << ");\n"; 530*6ca0f394SUmesh Unnikrishnan } 531*6ca0f394SUmesh Unnikrishnan break; 532*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 533*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "];\n"; 534*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_in_" << i << " ; ++j) {\n"; 535*6ca0f394SUmesh Unnikrishnan code << " r_q_" << i << "[j] = r_t_" << i << "[q + j*Q_1D];\n"; 536*6ca0f394SUmesh Unnikrishnan code << " }\n"; 537*6ca0f394SUmesh Unnikrishnan break; 538*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 539*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[num_comp_in_" << i << "*DIM];\n"; 540*6ca0f394SUmesh Unnikrishnan code << " gradCollo3d(num_comp_in_" << i << ", Q_1D, q, r_t_" << i << ", s_G_in_" << i << ", r_q_" << i << ", elem_scratch);\n"; 541*6ca0f394SUmesh Unnikrishnan break; 542*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: 543*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_q_" << i << "[1];\n"; 544*6ca0f394SUmesh Unnikrishnan code << " r_q_" << i << "[0] = r_t_" << i << "[q];\n"; 545*6ca0f394SUmesh Unnikrishnan break; // No action 546*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 547*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 548*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 549*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 550*6ca0f394SUmesh Unnikrishnan } 551*6ca0f394SUmesh Unnikrishnan } 552*6ca0f394SUmesh Unnikrishnan code << "\n // -- Output fields --\n"; 553*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 554*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 555*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 556*6ca0f394SUmesh Unnikrishnan // Basis action 557*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 558*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 559*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n"; 560*6ca0f394SUmesh Unnikrishnan break; // No action 561*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 562*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "];\n"; 563*6ca0f394SUmesh Unnikrishnan break; 564*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 565*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_qq_" << i << "[num_comp_out_" << i << "*DIM];\n"; 566*6ca0f394SUmesh Unnikrishnan break; 567*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: 568*6ca0f394SUmesh Unnikrishnan break; // Should not occur 569*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 570*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 571*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 572*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 573*6ca0f394SUmesh Unnikrishnan } 574*6ca0f394SUmesh Unnikrishnan } 575*6ca0f394SUmesh Unnikrishnan } else { 576*6ca0f394SUmesh Unnikrishnan code << "\n // Note: Using full elements\n"; 577*6ca0f394SUmesh Unnikrishnan code << " // -- Input fields --\n"; 578*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 579*6ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n"; 580*6ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_q_" << i << " = r_t_" << i << ";\n"; 581*6ca0f394SUmesh Unnikrishnan } 582*6ca0f394SUmesh Unnikrishnan code << " // -- Output fields --\n"; 583*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 584*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 585*6ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_qq_" << i << " = r_tt_" << i << ";\n"; 586*6ca0f394SUmesh Unnikrishnan } 587*6ca0f394SUmesh Unnikrishnan } 588*6ca0f394SUmesh Unnikrishnan //-------------------------------------------------- 589*6ca0f394SUmesh Unnikrishnan code << "\n // -- QFunction Inputs and outputs --\n"; 590*6ca0f394SUmesh Unnikrishnan code << " const CeedScalar * in[" << num_input_fields << "];\n"; 591*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 592*6ca0f394SUmesh Unnikrishnan code << " // ---- Input field " << i << " ----\n"; 593*6ca0f394SUmesh Unnikrishnan code << " in[" << i << "] = r_q_" << i << ";\n"; 594*6ca0f394SUmesh Unnikrishnan } 595*6ca0f394SUmesh Unnikrishnan code << " CeedScalar * out[" << num_output_fields << "];\n"; 596*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 597*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 598*6ca0f394SUmesh Unnikrishnan code << " out[" << i << "] = r_qq_" << i << ";\n"; 599*6ca0f394SUmesh Unnikrishnan } 600*6ca0f394SUmesh Unnikrishnan 601*6ca0f394SUmesh Unnikrishnan code << "\n // -- Apply QFunction --\n"; 602*6ca0f394SUmesh Unnikrishnan code << " " << q_function_name << "(ctx, "; 603*6ca0f394SUmesh Unnikrishnan if (dim != 3 || use_collograd_parallelization) { 604*6ca0f394SUmesh Unnikrishnan code << "1"; 605*6ca0f394SUmesh Unnikrishnan } else { 606*6ca0f394SUmesh Unnikrishnan code << "Q_1D"; 607*6ca0f394SUmesh Unnikrishnan } 608*6ca0f394SUmesh Unnikrishnan code << ", in, out);\n"; 609*6ca0f394SUmesh Unnikrishnan //-------------------------------------------------- 610*6ca0f394SUmesh Unnikrishnan 611*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 612*6ca0f394SUmesh Unnikrishnan code << " // -- Output fields --\n"; 613*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 614*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 615*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 616*6ca0f394SUmesh Unnikrishnan // Basis action 617*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 618*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 619*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 620*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n"; 621*6ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n"; 622*6ca0f394SUmesh Unnikrishnan code << " }\n"; 623*6ca0f394SUmesh Unnikrishnan break; // No action 624*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 625*6ca0f394SUmesh Unnikrishnan code << " for (CeedInt j = 0; j < num_comp_out_" << i << " ; ++j) {\n"; 626*6ca0f394SUmesh Unnikrishnan code << " r_tt_" << i << "[q + j*Q_1D] = r_qq_" << i << "[j];\n"; 627*6ca0f394SUmesh Unnikrishnan code << " }\n"; 628*6ca0f394SUmesh Unnikrishnan break; 629*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 630*6ca0f394SUmesh Unnikrishnan code << " gradColloTranspose3d(num_comp_out_" << i << ",Q_1D, q, r_qq_" << i << ", s_G_out_" << i << ", r_tt_" << i 631*6ca0f394SUmesh Unnikrishnan << ", elem_scratch);\n"; 632*6ca0f394SUmesh Unnikrishnan break; 633*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: 634*6ca0f394SUmesh Unnikrishnan break; // Should not occur 635*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 636*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 637*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 638*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 639*6ca0f394SUmesh Unnikrishnan } 640*6ca0f394SUmesh Unnikrishnan } 641*6ca0f394SUmesh Unnikrishnan code << " }\n"; 642*6ca0f394SUmesh Unnikrishnan } 643*6ca0f394SUmesh Unnikrishnan 644*6ca0f394SUmesh Unnikrishnan // Output basis apply if needed 645*6ca0f394SUmesh Unnikrishnan // Generate the correct eval mode code for each output 646*6ca0f394SUmesh Unnikrishnan code << "\n // -- Output field basis action and restrictions --\n"; 647*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 648*6ca0f394SUmesh Unnikrishnan code << " // ---- Output field " << i << " ----\n"; 649*6ca0f394SUmesh Unnikrishnan // Get elem_size, eval_mode, num_comp 650*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &Erestrict)); 651*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elem_size)); 652*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 653*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumComponents(Erestrict, &num_comp)); 654*6ca0f394SUmesh Unnikrishnan // Basis action 655*6ca0f394SUmesh Unnikrishnan code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; 656*6ca0f394SUmesh Unnikrishnan switch (eval_mode) { 657*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_NONE: 658*6ca0f394SUmesh Unnikrishnan code << " private CeedScalar* r_v_" << i << " = r_tt_" << i << ";\n"; 659*6ca0f394SUmesh Unnikrishnan break; // No action 660*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_INTERP: 661*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n"; 662*6ca0f394SUmesh Unnikrishnan code << " InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i 663*6ca0f394SUmesh Unnikrishnan << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n"; 664*6ca0f394SUmesh Unnikrishnan break; 665*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_GRAD: 666*6ca0f394SUmesh Unnikrishnan code << " CeedScalar r_v_" << i << "[num_comp_out_" << i << "*P_out_" << i << "];\n"; 667*6ca0f394SUmesh Unnikrishnan if (use_collograd_parallelization) { 668*6ca0f394SUmesh Unnikrishnan code << " InterpTranspose" << (dim > 1 ? "Tensor" : "") << dim << "d(num_comp_out_" << i << ",P_out_" << i << ", Q_1D, r_tt_" << i 669*6ca0f394SUmesh Unnikrishnan << ", s_B_out_" << i << ", r_v_" << i << ", elem_scratch);\n"; 670*6ca0f394SUmesh Unnikrishnan } else { 671*6ca0f394SUmesh Unnikrishnan CeedInt P_1d; 672*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 673*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 674*6ca0f394SUmesh Unnikrishnan code << " GradTranspose" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d(num_comp_out_" << i 675*6ca0f394SUmesh Unnikrishnan << ", P_out_" << i << ", Q_1D, r_tt_" << i << (dim > 1 ? ", s_B_out_" : "") << (dim > 1 ? std::to_string(i) : "") << ", s_G_out_" << i 676*6ca0f394SUmesh Unnikrishnan << ", r_v_" << i << ", elem_scratch);\n"; 677*6ca0f394SUmesh Unnikrishnan } 678*6ca0f394SUmesh Unnikrishnan break; 679*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_START 680*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_WEIGHT: { 681*6ca0f394SUmesh Unnikrishnan Ceed ceed; 682*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 683*6ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 684*6ca0f394SUmesh Unnikrishnan break; // Should not occur 685*6ca0f394SUmesh Unnikrishnan } 686*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_DIV: 687*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 688*6ca0f394SUmesh Unnikrishnan case CEED_EVAL_CURL: 689*6ca0f394SUmesh Unnikrishnan break; // TODO: Not implemented 690*6ca0f394SUmesh Unnikrishnan // LCOV_EXCL_STOP 691*6ca0f394SUmesh Unnikrishnan } 692*6ca0f394SUmesh Unnikrishnan // Restriction 693*6ca0f394SUmesh Unnikrishnan bool is_strided; 694*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &is_strided)); 695*6ca0f394SUmesh Unnikrishnan if (!is_strided) { 696*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetLVectorSize(Erestrict, &lsize)); 697*6ca0f394SUmesh Unnikrishnan code << " const CeedInt lsize_out_" << i << " = " << lsize << ";\n"; 698*6ca0f394SUmesh Unnikrishnan CeedInt comp_stride; 699*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetCompStride(Erestrict, &comp_stride)); 700*6ca0f394SUmesh Unnikrishnan code << " // CompStride: " << comp_stride << "\n"; 701*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetData(Erestrict, &restr_impl)); 702*6ca0f394SUmesh Unnikrishnan h_indices.outputs[i] = restr_impl->d_ind; 703*6ca0f394SUmesh Unnikrishnan code << " writeDofsOffset" << dim << "d(num_comp_out_" << i << ", " << comp_stride << ", P_out_" << i << ", num_elem, indices->outputs[" << i 704*6ca0f394SUmesh Unnikrishnan << "], r_v_" << i << ", d_v_" << i << ");\n"; 705*6ca0f394SUmesh Unnikrishnan } else { 706*6ca0f394SUmesh Unnikrishnan bool has_backend_strides; 707*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &has_backend_strides)); 708*6ca0f394SUmesh Unnikrishnan CeedInt num_elem; 709*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetNumElements(Erestrict, &num_elem)); 710*6ca0f394SUmesh Unnikrishnan CeedInt strides[3] = {1, elem_size * num_elem, elem_size}; 711*6ca0f394SUmesh Unnikrishnan if (!has_backend_strides) { 712*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedElemRestrictionGetStrides(Erestrict, &strides)); 713*6ca0f394SUmesh Unnikrishnan } 714*6ca0f394SUmesh Unnikrishnan code << " // Strides: {" << strides[0] << ", " << strides[1] << ", " << strides[2] << "}\n"; 715*6ca0f394SUmesh Unnikrishnan code << " writeDofsStrided" << dim << "d(num_comp_out_" << i << ",P_out_" << i << "," << strides[0] << "," << strides[1] << "," << strides[2] 716*6ca0f394SUmesh Unnikrishnan << ", num_elem, r_v_" << i << ", d_v_" << i << ");\n"; 717*6ca0f394SUmesh Unnikrishnan } 718*6ca0f394SUmesh Unnikrishnan } 719*6ca0f394SUmesh Unnikrishnan 720*6ca0f394SUmesh Unnikrishnan code << " }\n"; 721*6ca0f394SUmesh Unnikrishnan code << "}\n"; 722*6ca0f394SUmesh Unnikrishnan code << "// -----------------------------------------------------------------------------\n\n"; 723*6ca0f394SUmesh Unnikrishnan 724*6ca0f394SUmesh Unnikrishnan // Copy the struct (containing device addresses) from the host to the device 725*6ca0f394SUmesh Unnikrishnan sycl::event copy_B = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_B, impl->B, 1); 726*6ca0f394SUmesh Unnikrishnan sycl::event copy_G = sycl_data->sycl_queue.copy<Fields_Sycl>(&h_G, impl->G, 1); 727*6ca0f394SUmesh Unnikrishnan sycl::event copy_indices = sycl_data->sycl_queue.copy<FieldsInt_Sycl>(&h_indices, impl->indices, 1); 728*6ca0f394SUmesh Unnikrishnan // These copies can happen while the JIT is being done 729*6ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_B, copy_G, copy_indices})); 730*6ca0f394SUmesh Unnikrishnan 731*6ca0f394SUmesh Unnikrishnan // View kernel for debugging 732*6ca0f394SUmesh Unnikrishnan CeedDebug256(ceed, 2, "Generated Operator Kernels:\n"); 733*6ca0f394SUmesh Unnikrishnan CeedDebug(ceed, code.str().c_str()); 734*6ca0f394SUmesh Unnikrishnan 735*6ca0f394SUmesh Unnikrishnan std::map<std::string, CeedInt> jit_constants; 736*6ca0f394SUmesh Unnikrishnan jit_constants["T_1D"] = block_sizes[0]; 737*6ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_X"] = block_sizes[0]; 738*6ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_Y"] = block_sizes[1]; 739*6ca0f394SUmesh Unnikrishnan jit_constants["GROUP_SIZE_Z"] = block_sizes[2]; 740*6ca0f394SUmesh Unnikrishnan 741*6ca0f394SUmesh Unnikrishnan // Compile kernel into a kernel bundle 742*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module, jit_constants)); 743*6ca0f394SUmesh Unnikrishnan 744*6ca0f394SUmesh Unnikrishnan // Load kernel function 745*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, operator_name, &impl->op)); 746*6ca0f394SUmesh Unnikrishnan 747*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorSetSetupDone(op)); 748*6ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS; 749*6ca0f394SUmesh Unnikrishnan } 750*6ca0f394SUmesh Unnikrishnan 751*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 752