xref: /libCEED/backends/sycl/ceed-sycl-compile.sycl.cpp (revision 8e6aa226c2c84e58dd7feb551fd506c4f25986db)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include "ceed-sycl-compile.hpp"
9 
10 #include <ceed/backend.h>
11 #include <ceed/ceed.h>
12 #include <ceed/jit-tools.h>
13 #include <level_zero/ze_api.h>
14 
15 #include <map>
16 #include <sstream>
17 #include <sycl/sycl.hpp>
18 
19 #include "./online_compiler.hpp"
20 #include "ceed-sycl-common.hpp"
21 
22 using ByteVector_t = std::vector<unsigned char>;
23 
24 //------------------------------------------------------------------------------
25 //
26 //------------------------------------------------------------------------------
27 static int CeedJitAddDefinitions_Sycl(Ceed ceed, const std::string &kernel_source, std::string &jit_source,
28                                       const std::map<std::string, CeedInt> &constants = {}) {
29   std::ostringstream oss;
30 
31   // Prepend defined constants
32   for (const auto &[name, value] : constants) {
33     oss << "#define " << name << " " << value << "\n";
34   }
35 
36   // libCeed definitions for Sycl Backends
37   char       *jit_defs_path, *jit_defs_source;
38   const char *sycl_jith_path = "ceed/jit-source/sycl/sycl-jit.h";
39   CeedCallBackend(CeedGetJitAbsolutePath(ceed, sycl_jith_path, &jit_defs_path));
40   CeedCallBackend(CeedLoadSourceToBuffer(ceed, jit_defs_path, &jit_defs_source));
41 
42   oss << jit_defs_source << "\n";
43 
44   CeedCallBackend(CeedFree(&jit_defs_path));
45   CeedCallBackend(CeedFree(&jit_defs_source));
46 
47   // Append kernel_source
48   oss << "\n" << kernel_source;
49 
50   jit_source = oss.str();
51 
52   return CEED_ERROR_SUCCESS;
53 }
54 
55 //------------------------------------------------------------------------------
56 // TODO: Add architecture flags, optimization flags
57 //------------------------------------------------------------------------------
58 static inline int CeedJitGetFlags_Sycl(std::vector<std::string> &flags) {
59   flags = {std::string("-cl-std=CL3.0"), std::string("-Dint32_t=int")};
60   return CEED_ERROR_SUCCESS;
61 }
62 
63 //------------------------------------------------------------------------------
64 // Compile an OpenCL source to SPIR-V using Intel's online compiler extension
65 //------------------------------------------------------------------------------
66 static inline int CeedJitCompileSource_Sycl(Ceed ceed, const sycl::device &sycl_device, const std::string &opencl_source, ByteVector_t &il_binary,
67                                             const std::vector<std::string> &flags = {}) {
68   sycl::ext::libceed::online_compiler<sycl::ext::libceed::source_language::opencl_c> compiler(sycl_device);
69 
70   try {
71     il_binary = compiler.compile(opencl_source, flags);
72   } catch (sycl::ext::libceed::online_compile_error &e) {
73     return CeedError((ceed), CEED_ERROR_BACKEND, e.what());
74   }
75   return CEED_ERROR_SUCCESS;
76 }
77 
78 // ------------------------------------------------------------------------------
79 // Load (compile) SPIR-V source and wrap in sycl kernel_bundle
80 // TODO: determine appropriate flags
81 // TODO: Error handle lz calls
82 // ------------------------------------------------------------------------------
83 static int CeedJitLoadModule_Sycl(const sycl::context &sycl_context, const sycl::device &sycl_device, const ByteVector_t &il_binary,
84                                   SyclModule_t **sycl_module) {
85   auto lz_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
86   auto lz_device  = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
87 
88   ze_module_desc_t lz_mod_desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
89                                   nullptr,
90                                   ZE_MODULE_FORMAT_IL_SPIRV,
91                                   il_binary.size(),
92                                   il_binary.data(),
93                                   " -ze-opt-large-register-file",  // flags
94                                   nullptr};                        // build log
95 
96   ze_module_handle_t lz_module;
97   zeModuleCreate(lz_context, lz_device, &lz_mod_desc, &lz_module, nullptr);
98 
99   // sycl make_<type> only throws errors for backend mismatch--assume we have vetted this already
100   *sycl_module = new SyclModule_t(sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, sycl::bundle_state::executable>(
101       {lz_module, sycl::ext::oneapi::level_zero::ownership::transfer}, sycl_context));
102 
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 // ------------------------------------------------------------------------------
107 // Compile kernel source to an executable `sycl::kernel_bundle`
108 // ------------------------------------------------------------------------------
109 int CeedBuildModule_Sycl(Ceed ceed, const std::string &kernel_source, SyclModule_t **sycl_module, const std::map<std::string, CeedInt> &constants) {
110   Ceed_Sycl *data;
111   CeedCallBackend(CeedGetData(ceed, &data));
112 
113   std::string jit_source;
114   CeedCallBackend(CeedJitAddDefinitions_Sycl(ceed, kernel_source, jit_source, constants));
115 
116   std::vector<std::string> flags;
117   CeedCallBackend(CeedJitGetFlags_Sycl(flags));
118 
119   ByteVector_t il_binary;
120   CeedCallBackend(CeedJitCompileSource_Sycl(ceed, data->sycl_device, jit_source, il_binary, flags));
121 
122   CeedCallBackend(CeedJitLoadModule_Sycl(data->sycl_context, data->sycl_device, il_binary, sycl_module));
123 
124   return CEED_ERROR_SUCCESS;
125 }
126 
127 // ------------------------------------------------------------------------------
128 // Get a sycl kernel from an existing kernel_bundle
129 //
130 // TODO: Error handle lz calls
131 // ------------------------------------------------------------------------------
132 int CeedGetKernel_Sycl(Ceed ceed, const SyclModule_t *sycl_module, const std::string &kernel_name, sycl::kernel **sycl_kernel) {
133   Ceed_Sycl *data;
134   CeedCallBackend(CeedGetData(ceed, &data));
135 
136   // sycl::get_native returns std::vector<ze_module_handle_t> for lz backend
137   // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_backend_level_zero.md
138   ze_module_handle_t lz_module = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(*sycl_module).front();
139 
140   ze_kernel_desc_t   lz_kernel_desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr, 0, kernel_name.c_str()};
141   ze_kernel_handle_t lz_kernel;
142   zeKernelCreate(lz_module, &lz_kernel_desc, &lz_kernel);
143 
144   *sycl_kernel = new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
145       {*sycl_module, lz_kernel, sycl::ext::oneapi::level_zero::ownership::transfer}, data->sycl_context));
146 
147   return CEED_ERROR_SUCCESS;
148 }
149