xref: /libCEED/backends/sycl/ceed-sycl-compile.sycl.cpp (revision 6ca0f394dabdca92269b68ec74be8bebae3befa4)
1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright #include "ceed-sycl-compile.hpp"
9bd882c8aSJames Wright 
10bd882c8aSJames Wright #include <ceed/backend.h>
11bd882c8aSJames Wright #include <ceed/ceed.h>
12bd882c8aSJames Wright #include <ceed/jit-tools.h>
13bd882c8aSJames Wright #include <level_zero/ze_api.h>
14bd882c8aSJames Wright 
15bd882c8aSJames Wright #include <map>
16bd882c8aSJames Wright #include <sstream>
17bd882c8aSJames Wright #include <sycl/sycl.hpp>
18bd882c8aSJames Wright 
19bd882c8aSJames Wright #include "./online_compiler.hpp"
20bd882c8aSJames Wright #include "ceed-sycl-common.hpp"
21bd882c8aSJames Wright 
22bd882c8aSJames Wright using ByteVector_t = std::vector<unsigned char>;
23bd882c8aSJames Wright 
24bd882c8aSJames Wright //------------------------------------------------------------------------------
25*6ca0f394SUmesh Unnikrishnan // Add defined constants at the beginning of kernel source
26bd882c8aSJames Wright //------------------------------------------------------------------------------
27bd882c8aSJames Wright static int CeedJitAddDefinitions_Sycl(Ceed ceed, const std::string &kernel_source, std::string &jit_source,
28bd882c8aSJames Wright                                       const std::map<std::string, CeedInt> &constants = {}) {
29bd882c8aSJames Wright   std::ostringstream oss;
30bd882c8aSJames Wright 
31bd882c8aSJames Wright   // Prepend defined constants
32bd882c8aSJames Wright   for (const auto &[name, value] : constants) {
33bd882c8aSJames Wright     oss << "#define " << name << " " << value << "\n";
34bd882c8aSJames Wright   }
35bd882c8aSJames Wright 
36bd882c8aSJames Wright   // libCeed definitions for Sycl Backends
37bd882c8aSJames Wright   char       *jit_defs_path, *jit_defs_source;
38bd882c8aSJames Wright   const char *sycl_jith_path = "ceed/jit-source/sycl/sycl-jit.h";
39bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, sycl_jith_path, &jit_defs_path));
40bd882c8aSJames Wright   CeedCallBackend(CeedLoadSourceToBuffer(ceed, jit_defs_path, &jit_defs_source));
41bd882c8aSJames Wright 
42bd882c8aSJames Wright   oss << jit_defs_source << "\n";
43bd882c8aSJames Wright 
44bd882c8aSJames Wright   CeedCallBackend(CeedFree(&jit_defs_path));
45bd882c8aSJames Wright   CeedCallBackend(CeedFree(&jit_defs_source));
46bd882c8aSJames Wright 
47bd882c8aSJames Wright   // Append kernel_source
48bd882c8aSJames Wright   oss << "\n" << kernel_source;
49bd882c8aSJames Wright 
50bd882c8aSJames Wright   jit_source = oss.str();
51bd882c8aSJames Wright 
52bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
53bd882c8aSJames Wright }
54bd882c8aSJames Wright 
55bd882c8aSJames Wright //------------------------------------------------------------------------------
56bd882c8aSJames Wright // TODO: Add architecture flags, optimization flags
57bd882c8aSJames Wright //------------------------------------------------------------------------------
58bd882c8aSJames Wright static inline int CeedJitGetFlags_Sycl(std::vector<std::string> &flags) {
59bd882c8aSJames Wright   flags = {std::string("-cl-std=CL3.0"), std::string("-Dint32_t=int")};
60bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
61bd882c8aSJames Wright }
62bd882c8aSJames Wright 
63bd882c8aSJames Wright //------------------------------------------------------------------------------
64bd882c8aSJames Wright // Compile an OpenCL source to SPIR-V using Intel's online compiler extension
65bd882c8aSJames Wright //------------------------------------------------------------------------------
66bd882c8aSJames Wright static inline int CeedJitCompileSource_Sycl(Ceed ceed, const sycl::device &sycl_device, const std::string &opencl_source, ByteVector_t &il_binary,
67bd882c8aSJames Wright                                             const std::vector<std::string> &flags = {}) {
68bd882c8aSJames Wright   sycl::ext::libceed::online_compiler<sycl::ext::libceed::source_language::opencl_c> compiler(sycl_device);
69bd882c8aSJames Wright 
70bd882c8aSJames Wright   try {
71bd882c8aSJames Wright     il_binary = compiler.compile(opencl_source, flags);
72bd882c8aSJames Wright   } catch (sycl::ext::libceed::online_compile_error &e) {
73bd882c8aSJames Wright     return CeedError((ceed), CEED_ERROR_BACKEND, e.what());
74bd882c8aSJames Wright   }
75bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
76bd882c8aSJames Wright }
77bd882c8aSJames Wright 
78bd882c8aSJames Wright // ------------------------------------------------------------------------------
79bd882c8aSJames Wright // Load (compile) SPIR-V source and wrap in sycl kernel_bundle
80bd882c8aSJames Wright // ------------------------------------------------------------------------------
81*6ca0f394SUmesh Unnikrishnan static int CeedLoadModule_Sycl(Ceed ceed, const sycl::context &sycl_context, const sycl::device &sycl_device, const ByteVector_t &il_binary,
82bd882c8aSJames Wright                                SyclModule_t **sycl_module) {
83bd882c8aSJames Wright   auto lz_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
84bd882c8aSJames Wright   auto lz_device  = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
85bd882c8aSJames Wright 
86bd882c8aSJames Wright   ze_module_desc_t lz_mod_desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
87*6ca0f394SUmesh Unnikrishnan                                   nullptr,  // extension specific structs
88bd882c8aSJames Wright                                   ZE_MODULE_FORMAT_IL_SPIRV,
89bd882c8aSJames Wright                                   il_binary.size(),
90bd882c8aSJames Wright                                   il_binary.data(),
91bd882c8aSJames Wright                                   " -ze-opt-large-register-file",  // flags
92*6ca0f394SUmesh Unnikrishnan                                   nullptr};                        // specialization constants
93bd882c8aSJames Wright 
94bd882c8aSJames Wright   ze_module_handle_t           lz_module;
95*6ca0f394SUmesh Unnikrishnan   ze_module_build_log_handle_t lz_log;
96*6ca0f394SUmesh Unnikrishnan   ze_result_t                  lz_err = zeModuleCreate(lz_context, lz_device, &lz_mod_desc, &lz_module, &lz_log);
97*6ca0f394SUmesh Unnikrishnan 
98*6ca0f394SUmesh Unnikrishnan   if (ZE_RESULT_SUCCESS != lz_err) {
99*6ca0f394SUmesh Unnikrishnan     size_t log_size = 0;
100*6ca0f394SUmesh Unnikrishnan     zeModuleBuildLogGetString(lz_log, &log_size, nullptr);
101*6ca0f394SUmesh Unnikrishnan 
102*6ca0f394SUmesh Unnikrishnan     char *log_message;
103*6ca0f394SUmesh Unnikrishnan     CeedCall(CeedCalloc(log_size, &log_message));
104*6ca0f394SUmesh Unnikrishnan     zeModuleBuildLogGetString(lz_log, &log_size, log_message);
105*6ca0f394SUmesh Unnikrishnan 
106*6ca0f394SUmesh Unnikrishnan     return CeedError(ceed, CEED_ERROR_BACKEND, "Failed to compile Level Zero module:\n%s", log_message);
107*6ca0f394SUmesh Unnikrishnan   }
108bd882c8aSJames Wright 
109bd882c8aSJames Wright   // sycl make_<type> only throws errors for backend mismatch--assume we have vetted this already
110bd882c8aSJames Wright   *sycl_module = new SyclModule_t(sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, sycl::bundle_state::executable>(
111bd882c8aSJames Wright       {lz_module, sycl::ext::oneapi::level_zero::ownership::transfer}, sycl_context));
112bd882c8aSJames Wright 
113bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
114bd882c8aSJames Wright }
115bd882c8aSJames Wright 
116bd882c8aSJames Wright // ------------------------------------------------------------------------------
117bd882c8aSJames Wright // Compile kernel source to an executable `sycl::kernel_bundle`
118bd882c8aSJames Wright // ------------------------------------------------------------------------------
119eb7e6cafSJeremy L Thompson int CeedBuildModule_Sycl(Ceed ceed, const std::string &kernel_source, SyclModule_t **sycl_module, const std::map<std::string, CeedInt> &constants) {
120bd882c8aSJames Wright   Ceed_Sycl *data;
121bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
122bd882c8aSJames Wright 
123bd882c8aSJames Wright   std::string jit_source;
124bd882c8aSJames Wright   CeedCallBackend(CeedJitAddDefinitions_Sycl(ceed, kernel_source, jit_source, constants));
125bd882c8aSJames Wright 
126bd882c8aSJames Wright   std::vector<std::string> flags;
127bd882c8aSJames Wright   CeedCallBackend(CeedJitGetFlags_Sycl(flags));
128bd882c8aSJames Wright 
129bd882c8aSJames Wright   ByteVector_t il_binary;
130bd882c8aSJames Wright   CeedCallBackend(CeedJitCompileSource_Sycl(ceed, data->sycl_device, jit_source, il_binary, flags));
131bd882c8aSJames Wright 
132*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedLoadModule_Sycl(ceed, data->sycl_context, data->sycl_device, il_binary, sycl_module));
133bd882c8aSJames Wright 
134bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
135bd882c8aSJames Wright }
136bd882c8aSJames Wright 
137bd882c8aSJames Wright // ------------------------------------------------------------------------------
138bd882c8aSJames Wright // Get a sycl kernel from an existing kernel_bundle
139bd882c8aSJames Wright //
140bd882c8aSJames Wright // TODO: Error handle lz calls
141bd882c8aSJames Wright // ------------------------------------------------------------------------------
142eb7e6cafSJeremy L Thompson int CeedGetKernel_Sycl(Ceed ceed, const SyclModule_t *sycl_module, const std::string &kernel_name, sycl::kernel **sycl_kernel) {
143bd882c8aSJames Wright   Ceed_Sycl *data;
144bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
145bd882c8aSJames Wright 
146bd882c8aSJames Wright   // sycl::get_native returns std::vector<ze_module_handle_t> for lz backend
147bd882c8aSJames Wright   // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_backend_level_zero.md
148bd882c8aSJames Wright   ze_module_handle_t lz_module = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(*sycl_module).front();
149bd882c8aSJames Wright 
150bd882c8aSJames Wright   ze_kernel_desc_t   lz_kernel_desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, 0, kernel_name.c_str()};
151bd882c8aSJames Wright   ze_kernel_handle_t lz_kernel;
152*6ca0f394SUmesh Unnikrishnan   ze_result_t        lz_err = zeKernelCreate(lz_module, &lz_kernel_desc, &lz_kernel);
153*6ca0f394SUmesh Unnikrishnan 
154*6ca0f394SUmesh Unnikrishnan   if (ZE_RESULT_SUCCESS != lz_err) {
155*6ca0f394SUmesh Unnikrishnan     return CeedError(ceed, CEED_ERROR_BACKEND, "Failed to retrieve kernel from Level Zero module");
156*6ca0f394SUmesh Unnikrishnan   }
157bd882c8aSJames Wright 
158bd882c8aSJames Wright   *sycl_kernel = new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
159bd882c8aSJames Wright       {*sycl_module, lz_kernel, sycl::ext::oneapi::level_zero::ownership::transfer}, data->sycl_context));
160bd882c8aSJames Wright 
161bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
162bd882c8aSJames Wright }
163*6ca0f394SUmesh Unnikrishnan 
164*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
165*6ca0f394SUmesh Unnikrishnan // Run SYCL kernel for spatial dimension with shared memory
166*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
167*6ca0f394SUmesh Unnikrishnan int CeedRunKernelDimSharedSycl(Ceed ceed, sycl::kernel *kernel, const int grid_size, const int block_size_x, const int block_size_y,
168*6ca0f394SUmesh Unnikrishnan                                const int block_size_z, const int shared_mem_size, void **kernel_args) {
169*6ca0f394SUmesh Unnikrishnan   sycl::range<3>    local_range(block_size_z, block_size_y, block_size_x);
170*6ca0f394SUmesh Unnikrishnan   sycl::range<3>    global_range(grid_size * block_size_z, block_size_y, block_size_x);
171*6ca0f394SUmesh Unnikrishnan   sycl::nd_range<3> kernel_range(global_range, local_range);
172*6ca0f394SUmesh Unnikrishnan 
173*6ca0f394SUmesh Unnikrishnan   //-----------
174*6ca0f394SUmesh Unnikrishnan   // Order queue
175*6ca0f394SUmesh Unnikrishnan   Ceed_Sycl *ceed_Sycl;
176*6ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
177*6ca0f394SUmesh Unnikrishnan   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
178*6ca0f394SUmesh Unnikrishnan 
179*6ca0f394SUmesh Unnikrishnan   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
180*6ca0f394SUmesh Unnikrishnan     cgh.depends_on(e);
181*6ca0f394SUmesh Unnikrishnan     cgh.set_args(*kernel_args);
182*6ca0f394SUmesh Unnikrishnan     cgh.parallel_for(kernel_range, *kernel);
183*6ca0f394SUmesh Unnikrishnan   });
184*6ca0f394SUmesh Unnikrishnan 
185*6ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
186*6ca0f394SUmesh Unnikrishnan }
187