1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed
7bd882c8aSJames Wright
8bd882c8aSJames Wright #include "ceed-sycl-compile.hpp"
9bd882c8aSJames Wright
10bd882c8aSJames Wright #include <ceed/backend.h>
11bd882c8aSJames Wright #include <ceed/ceed.h>
12bd882c8aSJames Wright #include <ceed/jit-tools.h>
13bd882c8aSJames Wright #include <level_zero/ze_api.h>
14bd882c8aSJames Wright
15bd882c8aSJames Wright #include <map>
16bd882c8aSJames Wright #include <sstream>
17bd882c8aSJames Wright #include <sycl/sycl.hpp>
18bd882c8aSJames Wright
19bd882c8aSJames Wright #include "./online_compiler.hpp"
20bd882c8aSJames Wright #include "ceed-sycl-common.hpp"
21bd882c8aSJames Wright
22bd882c8aSJames Wright using ByteVector_t = std::vector<unsigned char>;
23bd882c8aSJames Wright
24bd882c8aSJames Wright //------------------------------------------------------------------------------
256ca0f394SUmesh Unnikrishnan // Add defined constants at the beginning of kernel source
26bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedJitAddDefinitions_Sycl(Ceed ceed,const std::string & kernel_source,std::string & jit_source,const std::map<std::string,CeedInt> & constants={})27bd882c8aSJames Wright static int CeedJitAddDefinitions_Sycl(Ceed ceed, const std::string &kernel_source, std::string &jit_source,
28bd882c8aSJames Wright const std::map<std::string, CeedInt> &constants = {}) {
29bd882c8aSJames Wright std::ostringstream oss;
30bd882c8aSJames Wright
31f8608ea8SJed Brown const char *jit_defs_path, *jit_defs_source;
32dd64fc84SJeremy L Thompson const char *sycl_jith_path = "ceed/jit-source/sycl/sycl-jit.h";
33dd64fc84SJeremy L Thompson
34bd882c8aSJames Wright // Prepend defined constants
35bd882c8aSJames Wright for (const auto &[name, value] : constants) {
36bd882c8aSJames Wright oss << "#define " << name << " " << value << "\n";
37bd882c8aSJames Wright }
38bd882c8aSJames Wright
39bd882c8aSJames Wright // libCeed definitions for Sycl Backends
40bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, sycl_jith_path, &jit_defs_path));
41f8608ea8SJed Brown {
42f8608ea8SJed Brown char *source;
4322070f95SJeremy L Thompson
44f8608ea8SJed Brown CeedCallBackend(CeedLoadSourceToBuffer(ceed, jit_defs_path, &source));
45f8608ea8SJed Brown jit_defs_source = source;
46f8608ea8SJed Brown }
47bd882c8aSJames Wright
48bd882c8aSJames Wright oss << jit_defs_source << "\n";
49bd882c8aSJames Wright
50bd882c8aSJames Wright CeedCallBackend(CeedFree(&jit_defs_path));
51bd882c8aSJames Wright CeedCallBackend(CeedFree(&jit_defs_source));
52bd882c8aSJames Wright
53bd882c8aSJames Wright // Append kernel_source
54bd882c8aSJames Wright oss << "\n" << kernel_source;
55bd882c8aSJames Wright
56bd882c8aSJames Wright jit_source = oss.str();
57bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
58bd882c8aSJames Wright }
59bd882c8aSJames Wright
60bd882c8aSJames Wright //------------------------------------------------------------------------------
61bd882c8aSJames Wright // TODO: Add architecture flags, optimization flags
62bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedJitGetFlags_Sycl(std::vector<std::string> & flags)63bd882c8aSJames Wright static inline int CeedJitGetFlags_Sycl(std::vector<std::string> &flags) {
64bce4db6fSJames Wright flags = {std::string("-cl-std=CL3.0"), std::string("-Dint32_t=int"), std::string("-DCEED_RUNNING_JIT_PASS=1")};
65bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
66bd882c8aSJames Wright }
67bd882c8aSJames Wright
68bd882c8aSJames Wright //------------------------------------------------------------------------------
69bd882c8aSJames Wright // Compile an OpenCL source to SPIR-V using Intel's online compiler extension
70bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedJitCompileSource_Sycl(Ceed ceed,const sycl::device & sycl_device,const std::string & opencl_source,ByteVector_t & il_binary,const std::vector<std::string> & flags={})71bd882c8aSJames Wright static inline int CeedJitCompileSource_Sycl(Ceed ceed, const sycl::device &sycl_device, const std::string &opencl_source, ByteVector_t &il_binary,
72bd882c8aSJames Wright const std::vector<std::string> &flags = {}) {
73bd882c8aSJames Wright sycl::ext::libceed::online_compiler<sycl::ext::libceed::source_language::opencl_c> compiler(sycl_device);
74bd882c8aSJames Wright
75bd882c8aSJames Wright try {
76bd882c8aSJames Wright il_binary = compiler.compile(opencl_source, flags);
77bd882c8aSJames Wright } catch (sycl::ext::libceed::online_compile_error &e) {
78bd882c8aSJames Wright return CeedError((ceed), CEED_ERROR_BACKEND, e.what());
79bd882c8aSJames Wright }
80bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
81bd882c8aSJames Wright }
82bd882c8aSJames Wright
83bd882c8aSJames Wright // ------------------------------------------------------------------------------
84bd882c8aSJames Wright // Load (compile) SPIR-V source and wrap in sycl kernel_bundle
85bd882c8aSJames Wright // ------------------------------------------------------------------------------
CeedLoadModule_Sycl(Ceed ceed,const sycl::context & sycl_context,const sycl::device & sycl_device,const ByteVector_t & il_binary,SyclModule_t ** sycl_module)866ca0f394SUmesh Unnikrishnan static int CeedLoadModule_Sycl(Ceed ceed, const sycl::context &sycl_context, const sycl::device &sycl_device, const ByteVector_t &il_binary,
87bd882c8aSJames Wright SyclModule_t **sycl_module) {
88bd882c8aSJames Wright auto lz_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
89bd882c8aSJames Wright auto lz_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
90bd882c8aSJames Wright
91bd882c8aSJames Wright ze_module_desc_t lz_mod_desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
926ca0f394SUmesh Unnikrishnan nullptr, // extension specific structs
93bd882c8aSJames Wright ZE_MODULE_FORMAT_IL_SPIRV,
94bd882c8aSJames Wright il_binary.size(),
95bd882c8aSJames Wright il_binary.data(),
96bd882c8aSJames Wright " -ze-opt-large-register-file", // flags
976ca0f394SUmesh Unnikrishnan nullptr}; // specialization constants
98bd882c8aSJames Wright
99bd882c8aSJames Wright ze_module_handle_t lz_module;
1006ca0f394SUmesh Unnikrishnan ze_module_build_log_handle_t lz_log;
1016ca0f394SUmesh Unnikrishnan ze_result_t lz_err = zeModuleCreate(lz_context, lz_device, &lz_mod_desc, &lz_module, &lz_log);
1026ca0f394SUmesh Unnikrishnan
1036ca0f394SUmesh Unnikrishnan if (ZE_RESULT_SUCCESS != lz_err) {
1046ca0f394SUmesh Unnikrishnan size_t log_size = 0;
105dd64fc84SJeremy L Thompson char *log_message;
106dd64fc84SJeremy L Thompson
1076ca0f394SUmesh Unnikrishnan zeModuleBuildLogGetString(lz_log, &log_size, nullptr);
1086ca0f394SUmesh Unnikrishnan
1095a5594ffSJeremy L Thompson CeedCallBackend(CeedCalloc(log_size, &log_message));
1106ca0f394SUmesh Unnikrishnan zeModuleBuildLogGetString(lz_log, &log_size, log_message);
1116ca0f394SUmesh Unnikrishnan
1126ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Failed to compile Level Zero module:\n%s", log_message);
1136ca0f394SUmesh Unnikrishnan }
114bd882c8aSJames Wright
115bd882c8aSJames Wright // sycl make_<type> only throws errors for backend mismatch--assume we have vetted this already
116bd882c8aSJames Wright *sycl_module = new SyclModule_t(sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, sycl::bundle_state::executable>(
117bd882c8aSJames Wright {lz_module, sycl::ext::oneapi::level_zero::ownership::transfer}, sycl_context));
118bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
119bd882c8aSJames Wright }
120bd882c8aSJames Wright
121bd882c8aSJames Wright // ------------------------------------------------------------------------------
122bd882c8aSJames Wright // Compile kernel source to an executable `sycl::kernel_bundle`
123bd882c8aSJames Wright // ------------------------------------------------------------------------------
CeedBuildModule_Sycl(Ceed ceed,const std::string & kernel_source,SyclModule_t ** sycl_module,const std::map<std::string,CeedInt> & constants)124eb7e6cafSJeremy L Thompson int CeedBuildModule_Sycl(Ceed ceed, const std::string &kernel_source, SyclModule_t **sycl_module, const std::map<std::string, CeedInt> &constants) {
125bd882c8aSJames Wright Ceed_Sycl *data;
126bd882c8aSJames Wright std::string jit_source;
127bd882c8aSJames Wright std::vector<std::string> flags;
128bd882c8aSJames Wright ByteVector_t il_binary;
129dd64fc84SJeremy L Thompson
130dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data));
131dd64fc84SJeremy L Thompson CeedCallBackend(CeedJitAddDefinitions_Sycl(ceed, kernel_source, jit_source, constants));
132dd64fc84SJeremy L Thompson CeedCallBackend(CeedJitGetFlags_Sycl(flags));
133bd882c8aSJames Wright CeedCallBackend(CeedJitCompileSource_Sycl(ceed, data->sycl_device, jit_source, il_binary, flags));
1346ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedLoadModule_Sycl(ceed, data->sycl_context, data->sycl_device, il_binary, sycl_module));
135bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
136bd882c8aSJames Wright }
137bd882c8aSJames Wright
138bd882c8aSJames Wright // ------------------------------------------------------------------------------
139bd882c8aSJames Wright // Get a sycl kernel from an existing kernel_bundle
140bd882c8aSJames Wright //
141bd882c8aSJames Wright // TODO: Error handle lz calls
142bd882c8aSJames Wright // ------------------------------------------------------------------------------
CeedGetKernel_Sycl(Ceed ceed,const SyclModule_t * sycl_module,const std::string & kernel_name,sycl::kernel ** sycl_kernel)143eb7e6cafSJeremy L Thompson int CeedGetKernel_Sycl(Ceed ceed, const SyclModule_t *sycl_module, const std::string &kernel_name, sycl::kernel **sycl_kernel) {
144bd882c8aSJames Wright Ceed_Sycl *data;
145dd64fc84SJeremy L Thompson
146bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data));
147bd882c8aSJames Wright
148bd882c8aSJames Wright // sycl::get_native returns std::vector<ze_module_handle_t> for lz backend
149bd882c8aSJames Wright // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_backend_level_zero.md
150bd882c8aSJames Wright ze_module_handle_t lz_module = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(*sycl_module).front();
151bd882c8aSJames Wright
152bd882c8aSJames Wright ze_kernel_desc_t lz_kernel_desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, 0, kernel_name.c_str()};
153bd882c8aSJames Wright ze_kernel_handle_t lz_kernel;
1546ca0f394SUmesh Unnikrishnan ze_result_t lz_err = zeKernelCreate(lz_module, &lz_kernel_desc, &lz_kernel);
1556ca0f394SUmesh Unnikrishnan
1566ca0f394SUmesh Unnikrishnan if (ZE_RESULT_SUCCESS != lz_err) {
1576ca0f394SUmesh Unnikrishnan return CeedError(ceed, CEED_ERROR_BACKEND, "Failed to retrieve kernel from Level Zero module");
1586ca0f394SUmesh Unnikrishnan }
159bd882c8aSJames Wright
1601a8516d0SJames Wright *sycl_kernel = new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>({*sycl_module, lz_kernel,
1611a8516d0SJames Wright sycl::ext::oneapi::level_zero::ownership::transfer},
1621a8516d0SJames Wright data->sycl_context));
163bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
164bd882c8aSJames Wright }
1656ca0f394SUmesh Unnikrishnan
1666ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1676ca0f394SUmesh Unnikrishnan // Run SYCL kernel for spatial dimension with shared memory
1686ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
CeedRunKernelDimSharedSycl(Ceed ceed,sycl::kernel * kernel,const int grid_size,const int block_size_x,const int block_size_y,const int block_size_z,const int shared_mem_size,void ** kernel_args)1696ca0f394SUmesh Unnikrishnan int CeedRunKernelDimSharedSycl(Ceed ceed, sycl::kernel *kernel, const int grid_size, const int block_size_x, const int block_size_y,
1706ca0f394SUmesh Unnikrishnan const int block_size_z, const int shared_mem_size, void **kernel_args) {
1716ca0f394SUmesh Unnikrishnan sycl::range<3> local_range(block_size_z, block_size_y, block_size_x);
1726ca0f394SUmesh Unnikrishnan sycl::range<3> global_range(grid_size * block_size_z, block_size_y, block_size_x);
1736ca0f394SUmesh Unnikrishnan sycl::nd_range<3> kernel_range(global_range, local_range);
1746ca0f394SUmesh Unnikrishnan
1756ca0f394SUmesh Unnikrishnan //-----------
1766ca0f394SUmesh Unnikrishnan // Order queue
1776ca0f394SUmesh Unnikrishnan Ceed_Sycl *ceed_Sycl;
178dd64fc84SJeremy L Thompson
1796ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
1806ca0f394SUmesh Unnikrishnan sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
1816ca0f394SUmesh Unnikrishnan
1826ca0f394SUmesh Unnikrishnan ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
1836ca0f394SUmesh Unnikrishnan cgh.depends_on(e);
1846ca0f394SUmesh Unnikrishnan cgh.set_args(*kernel_args);
1856ca0f394SUmesh Unnikrishnan cgh.parallel_for(kernel_range, *kernel);
1866ca0f394SUmesh Unnikrishnan });
1876ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
1886ca0f394SUmesh Unnikrishnan }
189