1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright #include <ceed/backend.h> 9bd882c8aSJames Wright #include <ceed/ceed.h> 10bd882c8aSJames Wright #include <ceed/jit-tools.h> 11bd882c8aSJames Wright 12bd882c8aSJames Wright #include <iostream> 13bd882c8aSJames Wright #include <sstream> 14bd882c8aSJames Wright #include <string> 15bd882c8aSJames Wright #include <string_view> 16bd882c8aSJames Wright #include <sycl/sycl.hpp> 17bd882c8aSJames Wright #include <vector> 18bd882c8aSJames Wright 19bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 20bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 21bd882c8aSJames Wright 22bd882c8aSJames Wright #define SUB_GROUP_SIZE_QF 16 23bd882c8aSJames Wright 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright // Build QFunction kernel 26bd882c8aSJames Wright // 27bd882c8aSJames Wright // TODO: Refactor 28bd882c8aSJames Wright //------------------------------------------------------------------------------ 29*eb7e6cafSJeremy L Thompson extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) { 30bd882c8aSJames Wright CeedQFunction_Sycl* impl; 31bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetData(qf, (void**)&impl)); 32bd882c8aSJames Wright // QFunction is built 33bd882c8aSJames Wright if (impl->QFunction) return CEED_ERROR_SUCCESS; 34bd882c8aSJames Wright 35bd882c8aSJames Wright Ceed ceed; 36bd882c8aSJames Wright CeedQFunctionGetCeed(qf, &ceed); 37bd882c8aSJames Wright Ceed_Sycl* data; 38bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 39bd882c8aSJames Wright 40bd882c8aSJames Wright // QFunction kernel generation 41bd882c8aSJames Wright CeedInt num_input_fields, num_output_fields; 42bd882c8aSJames Wright CeedQFunctionField *input_fields, *output_fields; 43bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 44bd882c8aSJames Wright 45bd882c8aSJames Wright std::vector<CeedInt> input_sizes(num_input_fields); 46bd882c8aSJames Wright CeedQFunctionField* input_i = input_fields; 47bd882c8aSJames Wright for (auto& size_i : input_sizes) { 48bd882c8aSJames Wright CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i)); 49bd882c8aSJames Wright ++input_i; 50bd882c8aSJames Wright } 51bd882c8aSJames Wright 52bd882c8aSJames Wright std::vector<CeedInt> output_sizes(num_output_fields); 53bd882c8aSJames Wright CeedQFunctionField* output_i = output_fields; 54bd882c8aSJames Wright for (auto& size_i : output_sizes) { 55bd882c8aSJames Wright CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i)); 56bd882c8aSJames Wright ++output_i; 57bd882c8aSJames Wright } 58bd882c8aSJames Wright 59bd882c8aSJames Wright char* qfunction_name; 60bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name)); 61bd882c8aSJames Wright 62bd882c8aSJames Wright char* qfunction_source; 63bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n"); 64bd882c8aSJames Wright CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source)); 65bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n"); 66bd882c8aSJames Wright 67bd882c8aSJames Wright char* read_write_kernel_path; 68bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path)); 69bd882c8aSJames Wright 70bd882c8aSJames Wright char* read_write_kernel_source; 71bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source -----\n"); 72bd882c8aSJames Wright CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source)); 73bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n"); 74bd882c8aSJames Wright 75bd882c8aSJames Wright std::string_view qf_name_view(qfunction_name); 76bd882c8aSJames Wright std::string_view qf_source_view(qfunction_source); 77bd882c8aSJames Wright std::string_view rw_source_view(read_write_kernel_source); 78bd882c8aSJames Wright const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view); 79bd882c8aSJames Wright 80bd882c8aSJames Wright // Defintions 81bd882c8aSJames Wright std::ostringstream code; 82bd882c8aSJames Wright code << rw_source_view; 83bd882c8aSJames Wright code << qf_source_view; 84bd882c8aSJames Wright code << "\n"; 85bd882c8aSJames Wright 86bd882c8aSJames Wright // Kernel function 87bd882c8aSJames Wright // Here we are fixing a lower sub-group size value to avoid register spills 88bd882c8aSJames Wright // This needs to be revisited if all qfunctions require this. 89bd882c8aSJames Wright code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name 90bd882c8aSJames Wright << "(__global void *ctx, CeedInt Q,\n"; 91bd882c8aSJames Wright 92bd882c8aSJames Wright // OpenCL doesn't allow for structs with pointers. 93bd882c8aSJames Wright // We will need to pass all of the arguments individually. 94bd882c8aSJames Wright // Input parameters 95bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; ++i) { 96bd882c8aSJames Wright code << " " 97bd882c8aSJames Wright << "__global const CeedScalar *in_" << i << ",\n"; 98bd882c8aSJames Wright } 99bd882c8aSJames Wright 100bd882c8aSJames Wright // Output parameters 101bd882c8aSJames Wright code << " " 102bd882c8aSJames Wright << "__global CeedScalar *out_0"; 103bd882c8aSJames Wright for (CeedInt i = 1; i < num_output_fields; ++i) { 104bd882c8aSJames Wright code << "\n, " 105bd882c8aSJames Wright << "__global CeedScalar *out_" << i; 106bd882c8aSJames Wright } 107bd882c8aSJames Wright // Begin kernel function body 108bd882c8aSJames Wright code << ") {\n\n"; 109bd882c8aSJames Wright 110bd882c8aSJames Wright // Inputs 111bd882c8aSJames Wright code << " // Input fields\n"; 112bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; ++i) { 113bd882c8aSJames Wright code << " CeedScalar U_" << i << "[" << input_sizes[i] << "];\n"; 114bd882c8aSJames Wright } 115bd882c8aSJames Wright code << " const CeedScalar *inputs[" << num_input_fields << "] = {U_0"; 116bd882c8aSJames Wright for (CeedInt i = 1; i < num_input_fields; i++) { 117bd882c8aSJames Wright code << ", U_" << i << "\n"; 118bd882c8aSJames Wright } 119bd882c8aSJames Wright code << "};\n\n"; 120bd882c8aSJames Wright 121bd882c8aSJames Wright // Outputs 122bd882c8aSJames Wright code << " // Output fields\n"; 123bd882c8aSJames Wright for (CeedInt i = 0; i < num_output_fields; i++) { 124bd882c8aSJames Wright code << " CeedScalar V_" << i << "[" << output_sizes[i] << "];\n"; 125bd882c8aSJames Wright } 126bd882c8aSJames Wright code << " CeedScalar *outputs[" << num_output_fields << "] = {V_0"; 127bd882c8aSJames Wright for (CeedInt i = 1; i < num_output_fields; i++) { 128bd882c8aSJames Wright code << ", V_" << i << "\n"; 129bd882c8aSJames Wright } 130bd882c8aSJames Wright code << "};\n\n"; 131bd882c8aSJames Wright 132bd882c8aSJames Wright code << " const CeedInt q = get_global_linear_id();\n\n"; 133bd882c8aSJames Wright 134bd882c8aSJames Wright code << "if(q < Q){ \n\n"; 135bd882c8aSJames Wright 136bd882c8aSJames Wright // Load inputs 137bd882c8aSJames Wright code << " // -- Load inputs\n"; 138bd882c8aSJames Wright for (CeedInt i = 0; i < num_input_fields; i++) { 139bd882c8aSJames Wright code << " readQuads(" << input_sizes[i] << ", Q, q, " 140bd882c8aSJames Wright << "in_" << i << ", U_" << i << ");\n"; 141bd882c8aSJames Wright } 142bd882c8aSJames Wright code << "\n"; 143bd882c8aSJames Wright 144bd882c8aSJames Wright // QFunction 145bd882c8aSJames Wright code << " // -- Call QFunction\n"; 146bd882c8aSJames Wright code << " " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n"; 147bd882c8aSJames Wright 148bd882c8aSJames Wright // Write outputs 149bd882c8aSJames Wright code << " // -- Write outputs\n"; 150bd882c8aSJames Wright for (CeedInt i = 0; i < num_output_fields; i++) { 151bd882c8aSJames Wright code << " writeQuads(" << output_sizes[i] << ", Q, q, " 152bd882c8aSJames Wright << "V_" << i << ", out_" << i << ");\n"; 153bd882c8aSJames Wright } 154bd882c8aSJames Wright code << "\n"; 155bd882c8aSJames Wright 156bd882c8aSJames Wright // End kernel function body 157bd882c8aSJames Wright code << "}\n"; 158bd882c8aSJames Wright code << "}\n"; 159bd882c8aSJames Wright 160bd882c8aSJames Wright // View kernel for debugging 161bd882c8aSJames Wright CeedDebug256(ceed, 2, "Generated QFunction Kernels:\n"); 162bd882c8aSJames Wright CeedDebug(ceed, code.str().c_str()); 163bd882c8aSJames Wright 164bd882c8aSJames Wright // Compile kernel 165*eb7e6cafSJeremy L Thompson CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module)); 166*eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction)); 167bd882c8aSJames Wright 168bd882c8aSJames Wright // Cleanup 169bd882c8aSJames Wright CeedCallBackend(CeedFree(&qfunction_source)); 170bd882c8aSJames Wright CeedCallBackend(CeedFree(&read_write_kernel_path)); 171bd882c8aSJames Wright CeedCallBackend(CeedFree(&read_write_kernel_source)); 172bd882c8aSJames Wright 173bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 174bd882c8aSJames Wright } 175bd882c8aSJames Wright //------------------------------------------------------------------------------ 176