xref: /libCEED/rust/libceed-sys/c-src/backends/sycl/ceed-sycl-compile.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright #include "ceed-sycl-compile.hpp"
9*bd882c8aSJames Wright 
10*bd882c8aSJames Wright #include <ceed/backend.h>
11*bd882c8aSJames Wright #include <ceed/ceed.h>
12*bd882c8aSJames Wright #include <ceed/jit-tools.h>
13*bd882c8aSJames Wright #include <level_zero/ze_api.h>
14*bd882c8aSJames Wright 
15*bd882c8aSJames Wright #include <map>
16*bd882c8aSJames Wright #include <sstream>
17*bd882c8aSJames Wright #include <sycl/sycl.hpp>
18*bd882c8aSJames Wright 
19*bd882c8aSJames Wright #include "./online_compiler.hpp"
20*bd882c8aSJames Wright #include "ceed-sycl-common.hpp"
21*bd882c8aSJames Wright 
22*bd882c8aSJames Wright using ByteVector_t = std::vector<unsigned char>;
23*bd882c8aSJames Wright 
24*bd882c8aSJames Wright //------------------------------------------------------------------------------
25*bd882c8aSJames Wright //
26*bd882c8aSJames Wright //------------------------------------------------------------------------------
27*bd882c8aSJames Wright static int CeedJitAddDefinitions_Sycl(Ceed ceed, const std::string &kernel_source, std::string &jit_source,
28*bd882c8aSJames Wright                                       const std::map<std::string, CeedInt> &constants = {}) {
29*bd882c8aSJames Wright   std::ostringstream oss;
30*bd882c8aSJames Wright 
31*bd882c8aSJames Wright   // Prepend defined constants
32*bd882c8aSJames Wright   for (const auto &[name, value] : constants) {
33*bd882c8aSJames Wright     oss << "#define " << name << " " << value << "\n";
34*bd882c8aSJames Wright   }
35*bd882c8aSJames Wright 
36*bd882c8aSJames Wright   // libCeed definitions for Sycl Backends
37*bd882c8aSJames Wright   char       *jit_defs_path, *jit_defs_source;
38*bd882c8aSJames Wright   const char *sycl_jith_path = "ceed/jit-source/sycl/sycl-jit.h";
39*bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, sycl_jith_path, &jit_defs_path));
40*bd882c8aSJames Wright   CeedCallBackend(CeedLoadSourceToBuffer(ceed, jit_defs_path, &jit_defs_source));
41*bd882c8aSJames Wright 
42*bd882c8aSJames Wright   oss << jit_defs_source << "\n";
43*bd882c8aSJames Wright 
44*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&jit_defs_path));
45*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&jit_defs_source));
46*bd882c8aSJames Wright 
47*bd882c8aSJames Wright   // Append kernel_source
48*bd882c8aSJames Wright   oss << "\n" << kernel_source;
49*bd882c8aSJames Wright 
50*bd882c8aSJames Wright   jit_source = oss.str();
51*bd882c8aSJames Wright 
52*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
53*bd882c8aSJames Wright }
54*bd882c8aSJames Wright 
55*bd882c8aSJames Wright //------------------------------------------------------------------------------
56*bd882c8aSJames Wright // TODO: Add architecture flags, optimization flags
57*bd882c8aSJames Wright //------------------------------------------------------------------------------
58*bd882c8aSJames Wright static inline int CeedJitGetFlags_Sycl(std::vector<std::string> &flags) {
59*bd882c8aSJames Wright   flags = {std::string("-cl-std=CL3.0"), std::string("-Dint32_t=int")};
60*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
61*bd882c8aSJames Wright }
62*bd882c8aSJames Wright 
63*bd882c8aSJames Wright //------------------------------------------------------------------------------
64*bd882c8aSJames Wright // Compile an OpenCL source to SPIR-V using Intel's online compiler extension
65*bd882c8aSJames Wright //------------------------------------------------------------------------------
66*bd882c8aSJames Wright static inline int CeedJitCompileSource_Sycl(Ceed ceed, const sycl::device &sycl_device, const std::string &opencl_source, ByteVector_t &il_binary,
67*bd882c8aSJames Wright                                             const std::vector<std::string> &flags = {}) {
68*bd882c8aSJames Wright   sycl::ext::libceed::online_compiler<sycl::ext::libceed::source_language::opencl_c> compiler(sycl_device);
69*bd882c8aSJames Wright 
70*bd882c8aSJames Wright   try {
71*bd882c8aSJames Wright     il_binary = compiler.compile(opencl_source, flags);
72*bd882c8aSJames Wright   } catch (sycl::ext::libceed::online_compile_error &e) {
73*bd882c8aSJames Wright     return CeedError((ceed), CEED_ERROR_BACKEND, e.what());
74*bd882c8aSJames Wright   }
75*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
76*bd882c8aSJames Wright }
77*bd882c8aSJames Wright 
78*bd882c8aSJames Wright // ------------------------------------------------------------------------------
79*bd882c8aSJames Wright // Load (compile) SPIR-V source and wrap in sycl kernel_bundle
80*bd882c8aSJames Wright // TODO: determine appropriate flags
81*bd882c8aSJames Wright // TODO: Error handle lz calls
82*bd882c8aSJames Wright // ------------------------------------------------------------------------------
83*bd882c8aSJames Wright static int CeedJitLoadModule_Sycl(const sycl::context &sycl_context, const sycl::device &sycl_device, const ByteVector_t &il_binary,
84*bd882c8aSJames Wright                                   SyclModule_t **sycl_module) {
85*bd882c8aSJames Wright   auto lz_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
86*bd882c8aSJames Wright   auto lz_device  = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
87*bd882c8aSJames Wright 
88*bd882c8aSJames Wright   ze_module_desc_t lz_mod_desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
89*bd882c8aSJames Wright                                   nullptr,
90*bd882c8aSJames Wright                                   ZE_MODULE_FORMAT_IL_SPIRV,
91*bd882c8aSJames Wright                                   il_binary.size(),
92*bd882c8aSJames Wright                                   il_binary.data(),
93*bd882c8aSJames Wright                                   " -ze-opt-large-register-file",  // flags
94*bd882c8aSJames Wright                                   nullptr};                        // build log
95*bd882c8aSJames Wright 
96*bd882c8aSJames Wright   ze_module_handle_t lz_module;
97*bd882c8aSJames Wright   zeModuleCreate(lz_context, lz_device, &lz_mod_desc, &lz_module, nullptr);
98*bd882c8aSJames Wright 
99*bd882c8aSJames Wright   // sycl make_<type> only throws errors for backend mismatch--assume we have vetted this already
100*bd882c8aSJames Wright   *sycl_module = new SyclModule_t(sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, sycl::bundle_state::executable>(
101*bd882c8aSJames Wright       {lz_module, sycl::ext::oneapi::level_zero::ownership::transfer}, sycl_context));
102*bd882c8aSJames Wright 
103*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
104*bd882c8aSJames Wright }
105*bd882c8aSJames Wright 
106*bd882c8aSJames Wright // ------------------------------------------------------------------------------
107*bd882c8aSJames Wright // Compile kernel source to an executable `sycl::kernel_bundle`
108*bd882c8aSJames Wright // ------------------------------------------------------------------------------
109*bd882c8aSJames Wright int CeedJitBuildModule_Sycl(Ceed ceed, const std::string &kernel_source, SyclModule_t **sycl_module,
110*bd882c8aSJames Wright                             const std::map<std::string, CeedInt> &constants) {
111*bd882c8aSJames Wright   Ceed_Sycl *data;
112*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
113*bd882c8aSJames Wright 
114*bd882c8aSJames Wright   std::string jit_source;
115*bd882c8aSJames Wright   CeedCallBackend(CeedJitAddDefinitions_Sycl(ceed, kernel_source, jit_source, constants));
116*bd882c8aSJames Wright 
117*bd882c8aSJames Wright   std::vector<std::string> flags;
118*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetFlags_Sycl(flags));
119*bd882c8aSJames Wright 
120*bd882c8aSJames Wright   ByteVector_t il_binary;
121*bd882c8aSJames Wright   CeedCallBackend(CeedJitCompileSource_Sycl(ceed, data->sycl_device, jit_source, il_binary, flags));
122*bd882c8aSJames Wright 
123*bd882c8aSJames Wright   CeedCallBackend(CeedJitLoadModule_Sycl(data->sycl_context, data->sycl_device, il_binary, sycl_module));
124*bd882c8aSJames Wright 
125*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
126*bd882c8aSJames Wright }
127*bd882c8aSJames Wright 
128*bd882c8aSJames Wright // ------------------------------------------------------------------------------
129*bd882c8aSJames Wright // Get a sycl kernel from an existing kernel_bundle
130*bd882c8aSJames Wright //
131*bd882c8aSJames Wright // TODO: Error handle lz calls
132*bd882c8aSJames Wright // ------------------------------------------------------------------------------
133*bd882c8aSJames Wright int CeedJitGetKernel_Sycl(Ceed ceed, const SyclModule_t *sycl_module, const std::string &kernel_name, sycl::kernel **sycl_kernel) {
134*bd882c8aSJames Wright   Ceed_Sycl *data;
135*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
136*bd882c8aSJames Wright 
137*bd882c8aSJames Wright   // sycl::get_native returns std::vector<ze_module_handle_t> for lz backend
138*bd882c8aSJames Wright   // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_backend_level_zero.md
139*bd882c8aSJames Wright   ze_module_handle_t lz_module = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(*sycl_module).front();
140*bd882c8aSJames Wright 
141*bd882c8aSJames Wright   ze_kernel_desc_t   lz_kernel_desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, 0, kernel_name.c_str()};
142*bd882c8aSJames Wright   ze_kernel_handle_t lz_kernel;
143*bd882c8aSJames Wright   zeKernelCreate(lz_module, &lz_kernel_desc, &lz_kernel);
144*bd882c8aSJames Wright 
145*bd882c8aSJames Wright   *sycl_kernel = new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
146*bd882c8aSJames Wright       {*sycl_module, lz_kernel, sycl::ext::oneapi::level_zero::ownership::transfer}, data->sycl_context));
147*bd882c8aSJames Wright 
148*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
149*bd882c8aSJames Wright }
150