1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <ceed/jit-tools.h> 11 12 #include <iostream> 13 #include <sstream> 14 #include <string> 15 #include <string_view> 16 #include <sycl/sycl.hpp> 17 #include <vector> 18 19 #include "../sycl/ceed-sycl-compile.hpp" 20 #include "ceed-sycl-ref.hpp" 21 22 #define SUB_GROUP_SIZE_QF 16 23 24 //------------------------------------------------------------------------------ 25 // Build QFunction kernel 26 // 27 // TODO: Refactor 28 //------------------------------------------------------------------------------ 29 extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) { 30 Ceed ceed; 31 Ceed_Sycl *data; 32 const char *read_write_kernel_path, *read_write_kernel_source; 33 const char *qfunction_name, *qfunction_source; 34 CeedInt num_input_fields, num_output_fields; 35 CeedQFunctionField *input_fields, *output_fields; 36 CeedQFunction_Sycl *impl; 37 38 CeedCallBackend(CeedQFunctionGetData(qf, (void **)&impl)); 39 // QFunction is built 40 if (impl->QFunction) return CEED_ERROR_SUCCESS; 41 42 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 43 CeedCallBackend(CeedGetData(ceed, &data)); 44 45 // QFunction kernel generation 46 CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 47 48 std::vector<CeedInt> input_sizes(num_input_fields); 49 CeedQFunctionField *input_i = input_fields; 50 51 for (auto &size_i : input_sizes) { 52 CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i)); 53 ++input_i; 54 } 55 56 std::vector<CeedInt> output_sizes(num_output_fields); 57 CeedQFunctionField *output_i = output_fields; 58 59 for (auto &size_i : output_sizes) { 60 CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i)); 61 ++output_i; 62 } 63 64 CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name)); 65 66 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n"); 67 CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source)); 68 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n"); 69 70 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path)); 71 72 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n"); 73 { 74 char *source; 75 76 CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &source)); 77 read_write_kernel_source = source; 78 } 79 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n"); 80 81 std::string_view qf_name_view(qfunction_name); 82 std::string_view qf_source_view(qfunction_source); 83 std::string_view rw_source_view(read_write_kernel_source); 84 const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view); 85 86 // Defintions 87 std::ostringstream code; 88 code << rw_source_view; 89 code << qf_source_view; 90 code << "\n"; 91 92 // Kernel function 93 // Here we are fixing a lower sub-group size value to avoid register spills 94 // This needs to be revisited if all qfunctions require this. 95 code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name 96 << "(__global void *ctx, CeedInt Q,\n"; 97 98 // OpenCL doesn't allow for structs with pointers. 99 // We will need to pass all of the arguments individually. 100 // Input parameters 101 for (CeedInt i = 0; i < num_input_fields; ++i) { 102 code << " " 103 << "__global const CeedScalar *in_" << i << ",\n"; 104 } 105 106 // Output parameters 107 code << " " 108 << "__global CeedScalar *out_0"; 109 for (CeedInt i = 1; i < num_output_fields; ++i) { 110 code << "\n, " 111 << "__global CeedScalar *out_" << i; 112 } 113 // Begin kernel function body 114 code << ") {\n\n"; 115 116 // Inputs 117 code << " // Input fields\n"; 118 for (CeedInt i = 0; i < num_input_fields; ++i) { 119 code << " CeedScalar U_" << i << "[" << input_sizes[i] << "];\n"; 120 } 121 code << " const CeedScalar *inputs[" << num_input_fields << "] = {U_0"; 122 for (CeedInt i = 1; i < num_input_fields; i++) { 123 code << ", U_" << i << "\n"; 124 } 125 code << "};\n\n"; 126 127 // Outputs 128 code << " // Output fields\n"; 129 for (CeedInt i = 0; i < num_output_fields; i++) { 130 code << " CeedScalar V_" << i << "[" << output_sizes[i] << "];\n"; 131 } 132 code << " CeedScalar *outputs[" << num_output_fields << "] = {V_0"; 133 for (CeedInt i = 1; i < num_output_fields; i++) { 134 code << ", V_" << i << "\n"; 135 } 136 code << "};\n\n"; 137 138 code << " const CeedInt q = get_global_linear_id();\n\n"; 139 140 code << "if(q < Q){ \n\n"; 141 142 // Load inputs 143 code << " // -- Load inputs\n"; 144 for (CeedInt i = 0; i < num_input_fields; i++) { 145 code << " readQuads(" << input_sizes[i] << ", Q, q, " 146 << "in_" << i << ", U_" << i << ");\n"; 147 } 148 code << "\n"; 149 150 // QFunction 151 code << " // -- Call QFunction\n"; 152 code << " " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n"; 153 154 // Write outputs 155 code << " // -- Write outputs\n"; 156 for (CeedInt i = 0; i < num_output_fields; i++) { 157 code << " writeQuads(" << output_sizes[i] << ", Q, q, " 158 << "V_" << i << ", out_" << i << ");\n"; 159 } 160 code << "\n"; 161 162 // End kernel function body 163 code << "}\n"; 164 code << "}\n"; 165 166 // View kernel for debugging 167 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n"); 168 CeedDebug(ceed, code.str().c_str()); 169 170 // Compile kernel 171 CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module)); 172 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction)); 173 174 // Cleanup 175 CeedCallBackend(CeedFree(&qfunction_source)); 176 CeedCallBackend(CeedFree(&read_write_kernel_path)); 177 CeedCallBackend(CeedFree(&read_write_kernel_source)); 178 return CEED_ERROR_SUCCESS; 179 } 180 181 //------------------------------------------------------------------------------ 182