1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright #include "ceed-sycl-compile.hpp" 9*bd882c8aSJames Wright 10*bd882c8aSJames Wright #include <ceed/backend.h> 11*bd882c8aSJames Wright #include <ceed/ceed.h> 12*bd882c8aSJames Wright #include <ceed/jit-tools.h> 13*bd882c8aSJames Wright #include <level_zero/ze_api.h> 14*bd882c8aSJames Wright 15*bd882c8aSJames Wright #include <map> 16*bd882c8aSJames Wright #include <sstream> 17*bd882c8aSJames Wright #include <sycl/sycl.hpp> 18*bd882c8aSJames Wright 19*bd882c8aSJames Wright #include "./online_compiler.hpp" 20*bd882c8aSJames Wright #include "ceed-sycl-common.hpp" 21*bd882c8aSJames Wright 22*bd882c8aSJames Wright using ByteVector_t = std::vector<unsigned char>; 23*bd882c8aSJames Wright 24*bd882c8aSJames Wright //------------------------------------------------------------------------------ 25*bd882c8aSJames Wright // 26*bd882c8aSJames Wright //------------------------------------------------------------------------------ 27*bd882c8aSJames Wright static int CeedJitAddDefinitions_Sycl(Ceed ceed, const std::string &kernel_source, std::string &jit_source, 28*bd882c8aSJames Wright const std::map<std::string, CeedInt> &constants = {}) { 29*bd882c8aSJames Wright std::ostringstream oss; 30*bd882c8aSJames Wright 31*bd882c8aSJames Wright // Prepend defined constants 32*bd882c8aSJames Wright for (const auto &[name, value] : constants) { 33*bd882c8aSJames Wright oss << "#define " << name << " " << value << "\n"; 34*bd882c8aSJames Wright } 35*bd882c8aSJames Wright 36*bd882c8aSJames Wright // libCeed definitions for Sycl Backends 37*bd882c8aSJames Wright char *jit_defs_path, *jit_defs_source; 38*bd882c8aSJames Wright const char *sycl_jith_path = "ceed/jit-source/sycl/sycl-jit.h"; 39*bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, sycl_jith_path, &jit_defs_path)); 40*bd882c8aSJames Wright CeedCallBackend(CeedLoadSourceToBuffer(ceed, jit_defs_path, &jit_defs_source)); 41*bd882c8aSJames Wright 42*bd882c8aSJames Wright oss << jit_defs_source << "\n"; 43*bd882c8aSJames Wright 44*bd882c8aSJames Wright CeedCallBackend(CeedFree(&jit_defs_path)); 45*bd882c8aSJames Wright CeedCallBackend(CeedFree(&jit_defs_source)); 46*bd882c8aSJames Wright 47*bd882c8aSJames Wright // Append kernel_source 48*bd882c8aSJames Wright oss << "\n" << kernel_source; 49*bd882c8aSJames Wright 50*bd882c8aSJames Wright jit_source = oss.str(); 51*bd882c8aSJames Wright 52*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 53*bd882c8aSJames Wright } 54*bd882c8aSJames Wright 55*bd882c8aSJames Wright //------------------------------------------------------------------------------ 56*bd882c8aSJames Wright // TODO: Add architecture flags, optimization flags 57*bd882c8aSJames Wright //------------------------------------------------------------------------------ 58*bd882c8aSJames Wright static inline int CeedJitGetFlags_Sycl(std::vector<std::string> &flags) { 59*bd882c8aSJames Wright flags = {std::string("-cl-std=CL3.0"), std::string("-Dint32_t=int")}; 60*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 61*bd882c8aSJames Wright } 62*bd882c8aSJames Wright 63*bd882c8aSJames Wright //------------------------------------------------------------------------------ 64*bd882c8aSJames Wright // Compile an OpenCL source to SPIR-V using Intel's online compiler extension 65*bd882c8aSJames Wright //------------------------------------------------------------------------------ 66*bd882c8aSJames Wright static inline int CeedJitCompileSource_Sycl(Ceed ceed, const sycl::device &sycl_device, const std::string &opencl_source, ByteVector_t &il_binary, 67*bd882c8aSJames Wright const std::vector<std::string> &flags = {}) { 68*bd882c8aSJames Wright sycl::ext::libceed::online_compiler<sycl::ext::libceed::source_language::opencl_c> compiler(sycl_device); 69*bd882c8aSJames Wright 70*bd882c8aSJames Wright try { 71*bd882c8aSJames Wright il_binary = compiler.compile(opencl_source, flags); 72*bd882c8aSJames Wright } catch (sycl::ext::libceed::online_compile_error &e) { 73*bd882c8aSJames Wright return CeedError((ceed), CEED_ERROR_BACKEND, e.what()); 74*bd882c8aSJames Wright } 75*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 76*bd882c8aSJames Wright } 77*bd882c8aSJames Wright 78*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 79*bd882c8aSJames Wright // Load (compile) SPIR-V source and wrap in sycl kernel_bundle 80*bd882c8aSJames Wright // TODO: determine appropriate flags 81*bd882c8aSJames Wright // TODO: Error handle lz calls 82*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 83*bd882c8aSJames Wright static int CeedJitLoadModule_Sycl(const sycl::context &sycl_context, const sycl::device &sycl_device, const ByteVector_t &il_binary, 84*bd882c8aSJames Wright SyclModule_t **sycl_module) { 85*bd882c8aSJames Wright auto lz_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context); 86*bd882c8aSJames Wright auto lz_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device); 87*bd882c8aSJames Wright 88*bd882c8aSJames Wright ze_module_desc_t lz_mod_desc = {ZE_STRUCTURE_TYPE_MODULE_DESC, 89*bd882c8aSJames Wright nullptr, 90*bd882c8aSJames Wright ZE_MODULE_FORMAT_IL_SPIRV, 91*bd882c8aSJames Wright il_binary.size(), 92*bd882c8aSJames Wright il_binary.data(), 93*bd882c8aSJames Wright " -ze-opt-large-register-file", // flags 94*bd882c8aSJames Wright nullptr}; // build log 95*bd882c8aSJames Wright 96*bd882c8aSJames Wright ze_module_handle_t lz_module; 97*bd882c8aSJames Wright zeModuleCreate(lz_context, lz_device, &lz_mod_desc, &lz_module, nullptr); 98*bd882c8aSJames Wright 99*bd882c8aSJames Wright // sycl make_<type> only throws errors for backend mismatch--assume we have vetted this already 100*bd882c8aSJames Wright *sycl_module = new SyclModule_t(sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, sycl::bundle_state::executable>( 101*bd882c8aSJames Wright {lz_module, sycl::ext::oneapi::level_zero::ownership::transfer}, sycl_context)); 102*bd882c8aSJames Wright 103*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 104*bd882c8aSJames Wright } 105*bd882c8aSJames Wright 106*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 107*bd882c8aSJames Wright // Compile kernel source to an executable `sycl::kernel_bundle` 108*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 109*bd882c8aSJames Wright int CeedJitBuildModule_Sycl(Ceed ceed, const std::string &kernel_source, SyclModule_t **sycl_module, 110*bd882c8aSJames Wright const std::map<std::string, CeedInt> &constants) { 111*bd882c8aSJames Wright Ceed_Sycl *data; 112*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 113*bd882c8aSJames Wright 114*bd882c8aSJames Wright std::string jit_source; 115*bd882c8aSJames Wright CeedCallBackend(CeedJitAddDefinitions_Sycl(ceed, kernel_source, jit_source, constants)); 116*bd882c8aSJames Wright 117*bd882c8aSJames Wright std::vector<std::string> flags; 118*bd882c8aSJames Wright CeedCallBackend(CeedJitGetFlags_Sycl(flags)); 119*bd882c8aSJames Wright 120*bd882c8aSJames Wright ByteVector_t il_binary; 121*bd882c8aSJames Wright CeedCallBackend(CeedJitCompileSource_Sycl(ceed, data->sycl_device, jit_source, il_binary, flags)); 122*bd882c8aSJames Wright 123*bd882c8aSJames Wright CeedCallBackend(CeedJitLoadModule_Sycl(data->sycl_context, data->sycl_device, il_binary, sycl_module)); 124*bd882c8aSJames Wright 125*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 126*bd882c8aSJames Wright } 127*bd882c8aSJames Wright 128*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 129*bd882c8aSJames Wright // Get a sycl kernel from an existing kernel_bundle 130*bd882c8aSJames Wright // 131*bd882c8aSJames Wright // TODO: Error handle lz calls 132*bd882c8aSJames Wright // ------------------------------------------------------------------------------ 133*bd882c8aSJames Wright int CeedJitGetKernel_Sycl(Ceed ceed, const SyclModule_t *sycl_module, const std::string &kernel_name, sycl::kernel **sycl_kernel) { 134*bd882c8aSJames Wright Ceed_Sycl *data; 135*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 136*bd882c8aSJames Wright 137*bd882c8aSJames Wright // sycl::get_native returns std::vector<ze_module_handle_t> for lz backend 138*bd882c8aSJames Wright // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_backend_level_zero.md 139*bd882c8aSJames Wright ze_module_handle_t lz_module = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(*sycl_module).front(); 140*bd882c8aSJames Wright 141*bd882c8aSJames Wright ze_kernel_desc_t lz_kernel_desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, 0, kernel_name.c_str()}; 142*bd882c8aSJames Wright ze_kernel_handle_t lz_kernel; 143*bd882c8aSJames Wright zeKernelCreate(lz_module, &lz_kernel_desc, &lz_kernel); 144*bd882c8aSJames Wright 145*bd882c8aSJames Wright *sycl_kernel = new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>( 146*bd882c8aSJames Wright {*sycl_module, lz_kernel, sycl::ext::oneapi::level_zero::ownership::transfer}, data->sycl_context)); 147*bd882c8aSJames Wright 148*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 149*bd882c8aSJames Wright } 150