1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright #include <ceed/backend.h> 9*bd882c8aSJames Wright #include <ceed/ceed.h> 10*bd882c8aSJames Wright #include <ceed/jit-tools.h> 11*bd882c8aSJames Wright 12*bd882c8aSJames Wright #include <iostream> 13*bd882c8aSJames Wright #include <sstream> 14*bd882c8aSJames Wright #include <string> 15*bd882c8aSJames Wright #include <string_view> 16*bd882c8aSJames Wright #include <sycl/sycl.hpp> 17*bd882c8aSJames Wright #include <vector> 18*bd882c8aSJames Wright 19*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 20*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 21*bd882c8aSJames Wright 22*bd882c8aSJames Wright #define SUB_GROUP_SIZE_QF 16 23*bd882c8aSJames Wright 24*bd882c8aSJames Wright //------------------------------------------------------------------------------ 25*bd882c8aSJames Wright // Build QFunction kernel 26*bd882c8aSJames Wright // 27*bd882c8aSJames Wright // TODO: Refactor 28*bd882c8aSJames Wright //------------------------------------------------------------------------------ 29*bd882c8aSJames Wright extern "C" int CeedBuildQFunction_Sycl(CeedQFunction qf) { 30*bd882c8aSJames Wright CeedQFunction_Sycl* impl; 31*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetData(qf, (void**)&impl)); 32*bd882c8aSJames Wright // QFunction is built 33*bd882c8aSJames Wright if (impl->QFunction) return CEED_ERROR_SUCCESS; 34*bd882c8aSJames Wright 35*bd882c8aSJames Wright Ceed ceed; 36*bd882c8aSJames Wright CeedQFunctionGetCeed(qf, &ceed); 37*bd882c8aSJames Wright Ceed_Sycl* data; 38*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 39*bd882c8aSJames Wright 40*bd882c8aSJames Wright // QFunction kernel generation 41*bd882c8aSJames Wright CeedInt num_input_fields, num_output_fields; 42*bd882c8aSJames Wright CeedQFunctionField *input_fields, *output_fields; 43*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 44*bd882c8aSJames Wright 45*bd882c8aSJames Wright std::vector<CeedInt> input_sizes(num_input_fields); 46*bd882c8aSJames Wright CeedQFunctionField* input_i = input_fields; 47*bd882c8aSJames Wright for (auto& size_i : input_sizes) { 48*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i)); 49*bd882c8aSJames Wright ++input_i; 50*bd882c8aSJames Wright } 51*bd882c8aSJames Wright 52*bd882c8aSJames Wright std::vector<CeedInt> output_sizes(num_output_fields); 53*bd882c8aSJames Wright CeedQFunctionField* output_i = output_fields; 54*bd882c8aSJames Wright for (auto& size_i : output_sizes) { 55*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i)); 56*bd882c8aSJames Wright ++output_i; 57*bd882c8aSJames Wright } 58*bd882c8aSJames Wright 59*bd882c8aSJames Wright char* qfunction_name; 60*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name)); 61*bd882c8aSJames Wright 62*bd882c8aSJames Wright char* qfunction_source; 63*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n"); 64*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source)); 65*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n"); 66*bd882c8aSJames Wright 67*bd882c8aSJames Wright char* read_write_kernel_path; 68*bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path)); 69*bd882c8aSJames Wright 70*bd882c8aSJames Wright char* read_write_kernel_source; 71*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source -----\n"); 72*bd882c8aSJames Wright CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source)); 73*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n"); 74*bd882c8aSJames Wright 75*bd882c8aSJames Wright std::string_view qf_name_view(qfunction_name); 76*bd882c8aSJames Wright std::string_view qf_source_view(qfunction_source); 77*bd882c8aSJames Wright std::string_view rw_source_view(read_write_kernel_source); 78*bd882c8aSJames Wright const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view); 79*bd882c8aSJames Wright 80*bd882c8aSJames Wright // Defintions 81*bd882c8aSJames Wright std::ostringstream code; 82*bd882c8aSJames Wright code << rw_source_view; 83*bd882c8aSJames Wright code << qf_source_view; 84*bd882c8aSJames Wright code << "\n"; 85*bd882c8aSJames Wright 86*bd882c8aSJames Wright // Kernel function 87*bd882c8aSJames Wright // Here we are fixing a lower sub-group size value to avoid register spills 88*bd882c8aSJames Wright // This needs to be revisited if all qfunctions require this. 89*bd882c8aSJames Wright code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name 90*bd882c8aSJames Wright << "(__global void *ctx, CeedInt Q,\n"; 91*bd882c8aSJames Wright 92*bd882c8aSJames Wright // OpenCL doesn't allow for structs with pointers. 93*bd882c8aSJames Wright // We will need to pass all of the arguments individually. 94*bd882c8aSJames Wright // Input parameters 95*bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; ++i) { 96*bd882c8aSJames Wright code << " " 97*bd882c8aSJames Wright << "__global const CeedScalar *in_" << i << ",\n"; 98*bd882c8aSJames Wright } 99*bd882c8aSJames Wright 100*bd882c8aSJames Wright // Output parameters 101*bd882c8aSJames Wright code << " " 102*bd882c8aSJames Wright << "__global CeedScalar *out_0"; 103*bd882c8aSJames Wright for (CeedInt i = 1; i < num_output_fields; ++i) { 104*bd882c8aSJames Wright code << "\n, " 105*bd882c8aSJames Wright << "__global CeedScalar *out_" << i; 106*bd882c8aSJames Wright } 107*bd882c8aSJames Wright // Begin kernel function body 108*bd882c8aSJames Wright code << ") {\n\n"; 109*bd882c8aSJames Wright 110*bd882c8aSJames Wright // Inputs 111*bd882c8aSJames Wright code << " // Input fields\n"; 112*bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; ++i) { 113*bd882c8aSJames Wright code << " CeedScalar U_" << i << "[" << input_sizes[i] << "];\n"; 114*bd882c8aSJames Wright } 115*bd882c8aSJames Wright code << " const CeedScalar *inputs[" << num_input_fields << "] = {U_0"; 116*bd882c8aSJames Wright for (CeedInt i = 1; i < num_input_fields; i++) { 117*bd882c8aSJames Wright code << ", U_" << i << "\n"; 118*bd882c8aSJames Wright } 119*bd882c8aSJames Wright code << "};\n\n"; 120*bd882c8aSJames Wright 121*bd882c8aSJames Wright // Outputs 122*bd882c8aSJames Wright code << " // Output fields\n"; 123*bd882c8aSJames Wright for (CeedInt i = 0; i < num_output_fields; i++) { 124*bd882c8aSJames Wright code << " CeedScalar V_" << i << "[" << output_sizes[i] << "];\n"; 125*bd882c8aSJames Wright } 126*bd882c8aSJames Wright code << " CeedScalar *outputs[" << num_output_fields << "] = {V_0"; 127*bd882c8aSJames Wright for (CeedInt i = 1; i < num_output_fields; i++) { 128*bd882c8aSJames Wright code << ", V_" << i << "\n"; 129*bd882c8aSJames Wright } 130*bd882c8aSJames Wright code << "};\n\n"; 131*bd882c8aSJames Wright 132*bd882c8aSJames Wright code << " const CeedInt q = get_global_linear_id();\n\n"; 133*bd882c8aSJames Wright 134*bd882c8aSJames Wright code << "if(q < Q){ \n\n"; 135*bd882c8aSJames Wright 136*bd882c8aSJames Wright // Load inputs 137*bd882c8aSJames Wright code << " // -- Load inputs\n"; 138*bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; i++) { 139*bd882c8aSJames Wright code << " readQuads(" << input_sizes[i] << ", Q, q, " 140*bd882c8aSJames Wright << "in_" << i << ", U_" << i << ");\n"; 141*bd882c8aSJames Wright } 142*bd882c8aSJames Wright code << "\n"; 143*bd882c8aSJames Wright 144*bd882c8aSJames Wright // QFunction 145*bd882c8aSJames Wright code << " // -- Call QFunction\n"; 146*bd882c8aSJames Wright code << " " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n"; 147*bd882c8aSJames Wright 148*bd882c8aSJames Wright // Write outputs 149*bd882c8aSJames Wright code << " // -- Write outputs\n"; 150*bd882c8aSJames Wright for (CeedInt i = 0; i < num_output_fields; i++) { 151*bd882c8aSJames Wright code << " writeQuads(" << output_sizes[i] << ", Q, q, " 152*bd882c8aSJames Wright << "V_" << i << ", out_" << i << ");\n"; 153*bd882c8aSJames Wright } 154*bd882c8aSJames Wright code << "\n"; 155*bd882c8aSJames Wright 156*bd882c8aSJames Wright // End kernel function body 157*bd882c8aSJames Wright code << "}\n"; 158*bd882c8aSJames Wright code << "}\n"; 159*bd882c8aSJames Wright 160*bd882c8aSJames Wright // View kernel for debugging 161*bd882c8aSJames Wright CeedDebug256(ceed, 2, "Generated QFunction Kernels:\n"); 162*bd882c8aSJames Wright CeedDebug(ceed, code.str().c_str()); 163*bd882c8aSJames Wright 164*bd882c8aSJames Wright // Compile kernel 165*bd882c8aSJames Wright CeedCallBackend(CeedJitBuildModule_Sycl(ceed, code.str(), &impl->sycl_module)); 166*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction)); 167*bd882c8aSJames Wright 168*bd882c8aSJames Wright // Cleanup 169*bd882c8aSJames Wright CeedCallBackend(CeedFree(&qfunction_source)); 170*bd882c8aSJames Wright CeedCallBackend(CeedFree(&read_write_kernel_path)); 171*bd882c8aSJames Wright CeedCallBackend(CeedFree(&read_write_kernel_source)); 172*bd882c8aSJames Wright 173*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 174*bd882c8aSJames Wright } 175*bd882c8aSJames Wright //------------------------------------------------------------------------------ 176