xref: /libCEED/backends/sycl/online_compiler.hpp (revision 9330daecb0fc008043eec1b94c46ef7aecbb00cd)
1 //===------- online_compiler.hpp - Online source compilation service ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/sycl.hpp>
12 
13 #include <string>
14 #include <vector>
15 
16 namespace sycl {
17 namespace ext::libceed {
18 
19 using byte = unsigned char;
20 
21 enum class compiled_code_format {
22   spir_v = 0  // the only format supported for now
23 };
24 
25 class device_arch {
26  public:
27   static constexpr int any = 0;
28 
29   device_arch(int Val) : Val(Val) {}
30 
31   // TODO1: the list must be extended with a bunch of new GPUs available.
32   // TODO2: the list of supported GPUs grows rapidly.
33   // The API must allow user to define the target GPU option even if it is
34   // not listed in this enumerator below.
35   enum gpu {
36     gpu_any    = 1,
37     gpu_gen9   = 2,
38     gpu_skl    = gpu_gen9,
39     gpu_gen9_5 = 3,
40     gpu_kbl    = gpu_gen9_5,
41     gpu_cfl    = gpu_gen9_5,
42     gpu_gen11  = 4,
43     gpu_icl    = gpu_gen11,
44     gpu_gen12  = 5,
45     gpu_tgl    = gpu_gen12,
46     gpu_tgllp  = gpu_gen12
47   };
48 
49   enum cpu {
50     cpu_any = 1,
51   };
52 
53   enum fpga {
54     fpga_any = 1,
55   };
56 
57   operator int() { return Val; }
58 
59  private:
60   int Val;
61 };
62 
63 /// Represents an error happend during online compilation.
64 class online_compile_error : public sycl::exception {
65  public:
66   online_compile_error() = default;
67   online_compile_error(const std::string &Msg) : sycl::exception(Msg) {}
68 };
69 
70 /// Designates a source language for the online compiler.
71 enum class source_language { opencl_c = 0, cm = 1 };
72 
73 /// Represents an online compiler for the language given as template
74 /// parameter.
75 template <source_language Lang>
76 class online_compiler {
77  public:
78   /// Constructs online compiler which can target any device and produces
79   /// given compiled code format. Produces 64-bit device code.
80   /// The created compiler is "optimistic" - it assumes all applicable SYCL
81   /// device capabilities are supported by the target device(s).
82   online_compiler(compiled_code_format fmt = compiled_code_format::spir_v)
83       : OutputFormat(fmt),
84         OutputFormatVersion({0, 0}),
85         DeviceType(sycl::info::device_type::all),
86         DeviceArch(device_arch::any),
87         Is64Bit(true),
88         DeviceStepping("") {}
89 
90   /// Constructs online compiler which targets given architecture and produces
91   /// given compiled code format. Produces 64-bit device code.
92   /// Throws online_compile_error if values of constructor arguments are
93   /// contradictory or not supported - e.g. if the source language is not
94   /// supported for given device type.
95   online_compiler(sycl::info::device_type dev_type, device_arch arch, compiled_code_format fmt = compiled_code_format::spir_v)
96       : OutputFormat(fmt), OutputFormatVersion({0, 0}), DeviceType(dev_type), DeviceArch(arch), Is64Bit(true), DeviceStepping("") {}
97 
98   /// Constructs online compiler for the target specified by given SYCL device.
99   // TODO: the initial version generates the generic code (SKL now), need
100   // to do additional device::info calls to determine the device by it's
101   // features.
102   online_compiler(const sycl::device &)
103       : OutputFormat(compiled_code_format::spir_v),
104         OutputFormatVersion({0, 0}),
105         DeviceType(sycl::info::device_type::all),
106         DeviceArch(device_arch::any),
107         Is64Bit(true),
108         DeviceStepping("") {}
109 
110   /// Compiles given in-memory \c Lang source to a binary blob. Blob format,
111   /// other parameters are set in the constructor by the compilation target
112   /// specification parameters.
113   /// Specialization for each language will provide exact signatures, which
114   /// can be different for different languages.
115   /// Throws online_compile_error if compilation is not successful.
116   template <typename... Tys>
117   std::vector<byte> compile(const std::string &src, const Tys &...args);
118 
119   /// Sets the compiled code format of the compilation target and returns *this.
120   online_compiler<Lang> &setOutputFormat(compiled_code_format fmt) {
121     OutputFormat = fmt;
122     return *this;
123   }
124 
125   /// Sets the compiled code format version of the compilation target and
126   /// returns *this.
127   online_compiler<Lang> &setOutputFormatVersion(int major, int minor) {
128     OutputFormatVersion = {major, minor};
129     return *this;
130   }
131 
132   /// Sets the device type of the compilation target and returns *this.
133   online_compiler<Lang> &setTargetDeviceType(sycl::info::device_type type) {
134     DeviceType = type;
135     return *this;
136   }
137 
138   /// Sets the device architecture of the compilation target and returns *this.
139   online_compiler<Lang> &setTargetDeviceArch(device_arch arch) {
140     DeviceArch = arch;
141     return *this;
142   }
143 
144   /// Makes the compilation target 32-bit and returns *this.
145   online_compiler<Lang> &set32bitTarget() {
146     Is64Bit = false;
147     return *this;
148   };
149 
150   /// Makes the compilation target 64-bit and returns *this.
151   online_compiler<Lang> &set64bitTarget() {
152     Is64Bit = true;
153     return *this;
154   };
155 
156   /// Sets implementation-defined target device stepping of the compilation
157   /// target and returns *this.
158   online_compiler<Lang> &setTargetDeviceStepping(const std::string &id) {
159     DeviceStepping = id;
160     return *this;
161   }
162 
163  private:
164   /// Compiled code format.
165   compiled_code_format OutputFormat;
166 
167   /// Compiled code format version - a pair of "major" and "minor" components
168   std::pair<int, int> OutputFormatVersion;
169 
170   /// Target device type
171   sycl::info::device_type DeviceType;
172 
173   /// Target device architecture
174   device_arch DeviceArch;
175 
176   /// Whether the target device architecture is 64-bit
177   bool Is64Bit;
178 
179   /// Target device stepping (implementation defined)
180   std::string DeviceStepping;
181 
182   /// Handles to helper functions used by the implementation.
183   void *CompileToSPIRVHandle   = nullptr;
184   void *FreeSPIRVOutputsHandle = nullptr;
185 };
186 
187 // Specializations of the online_compiler class and 'compile' function for
188 // particular languages and parameter types.
189 
190 /// Compiles the given OpenCL source. May throw \c online_compile_error.
191 /// @param src - contents of the source.
192 /// @param options - compilation options (implementation defined); standard
193 ///   OpenCL JIT compiler options must be supported.
194 template <>
195 template <>
196 std::vector<byte> online_compiler<source_language::opencl_c>::compile(const std::string &src, const std::vector<std::string> &options);
197 
198 /// Compiles the given OpenCL source. May throw \c online_compile_error.
199 /// @param src - contents of the source.
200 // template <>
201 // template <>
202 // std::vector<byte>
203 // online_compiler<source_language::opencl_c>::compile(const std::string &src) {
204 //   return compile(src, std::vector<std::string>{});
205 // }
206 
207 /// Compiles the given CM source \p src.
208 /// @param src - contents of the source.
209 /// @param options - compilation options (implementation defined).
210 template <>
211 template <>
212 std::vector<byte> online_compiler<source_language::cm>::compile(const std::string &src, const std::vector<std::string> &options);
213 
214 /// Compiles the given CM source \p src.
215 // template <>
216 // template <>
217 // std::vector<byte> online_compiler<source_language::cm>::compile(const std::string &src) {
218 //   return compile(src, std::vector<std::string>{});
219 // }
220 
221 }  // namespace ext::libceed
222 }  // namespace sycl
223