1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11
12 #include <iostream>
13 #include <sstream>
14 #include <string>
15 #include <string_view>
16 #include <sycl/sycl.hpp>
17 #include <vector>
18
19 #include "../sycl/ceed-sycl-compile.hpp"
20 #include "ceed-sycl-ref.hpp"
21
22 #define SUB_GROUP_SIZE_QF 16
23
24 //------------------------------------------------------------------------------
25 // Build QFunction kernel
26 //
27 // TODO: Refactor
28 //------------------------------------------------------------------------------
CeedQFunctionBuildKernel_Sycl(CeedQFunction qf)29 extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) {
30 Ceed ceed;
31 Ceed_Sycl *data;
32 const char *read_write_kernel_path, *read_write_kernel_source;
33 const char *qfunction_name, *qfunction_source;
34 CeedInt num_input_fields, num_output_fields;
35 CeedQFunctionField *input_fields, *output_fields;
36 CeedQFunction_Sycl *impl;
37
38 // QFunction is built
39 CeedCallBackend(CeedQFunctionGetData(qf, (void **)&impl));
40 if (impl->QFunction) return CEED_ERROR_SUCCESS;
41
42 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
43 CeedCallBackend(CeedGetData(ceed, &data));
44
45 // QFunction kernel generation
46 CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
47
48 std::vector<CeedInt> input_sizes(num_input_fields);
49 CeedQFunctionField *input_i = input_fields;
50
51 for (auto &size_i : input_sizes) {
52 CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i));
53 ++input_i;
54 }
55
56 std::vector<CeedInt> output_sizes(num_output_fields);
57 CeedQFunctionField *output_i = output_fields;
58
59 for (auto &size_i : output_sizes) {
60 CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i));
61 ++output_i;
62 }
63
64 CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name));
65
66 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n");
67 CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source));
68 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n");
69
70 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path));
71
72 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n");
73 {
74 char *source;
75
76 CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &source));
77 read_write_kernel_source = source;
78 }
79 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
80
81 std::string_view qf_name_view(qfunction_name);
82 std::string_view qf_source_view(qfunction_source);
83 std::string_view rw_source_view(read_write_kernel_source);
84 const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view);
85
86 // Defintions
87 std::ostringstream code;
88 code << rw_source_view;
89 code << qf_source_view;
90 code << "\n";
91
92 // Kernel function
93 // Here we are fixing a lower sub-group size value to avoid register spills
94 // This needs to be revisited if all qfunctions require this.
95 code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name
96 << "(__global void *ctx, CeedInt Q,\n";
97
98 // OpenCL doesn't allow for structs with pointers.
99 // We will need to pass all of the arguments individually.
100 // Input parameters
101 for (CeedInt i = 0; i < num_input_fields; ++i) {
102 code << " "
103 << "__global const CeedScalar *in_" << i << ",\n";
104 }
105
106 // Output parameters
107 code << " "
108 << "__global CeedScalar *out_0";
109 for (CeedInt i = 1; i < num_output_fields; ++i) {
110 code << "\n, "
111 << "__global CeedScalar *out_" << i;
112 }
113 // Begin kernel function body
114 code << ") {\n\n";
115
116 // Inputs
117 code << " // Input fields\n";
118 for (CeedInt i = 0; i < num_input_fields; ++i) {
119 code << " CeedScalar U_" << i << "[" << input_sizes[i] << "];\n";
120 }
121 code << " const CeedScalar *inputs[" << CeedIntMax(num_input_fields, 1) << "] = {U_0";
122 for (CeedInt i = 1; i < num_input_fields; i++) {
123 code << ", U_" << i << "\n";
124 }
125 code << "};\n\n";
126
127 // Outputs
128 code << " // Output fields\n";
129 for (CeedInt i = 0; i < num_output_fields; i++) {
130 code << " CeedScalar V_" << i << "[" << output_sizes[i] << "];\n";
131 }
132 code << " CeedScalar *outputs[" << CeedIntMax(num_output_fields, 1) << "] = {V_0";
133 for (CeedInt i = 1; i < num_output_fields; i++) {
134 code << ", V_" << i << "\n";
135 }
136 code << "};\n\n";
137
138 code << " const CeedInt q = get_global_linear_id();\n\n";
139
140 code << "if(q < Q){ \n\n";
141
142 // Load inputs
143 code << " // -- Load inputs\n";
144 for (CeedInt i = 0; i < num_input_fields; i++) {
145 code << " readQuads(" << input_sizes[i] << ", Q, q, "
146 << "in_" << i << ", U_" << i << ");\n";
147 }
148 code << "\n";
149
150 // QFunction
151 code << " // -- Call QFunction\n";
152 code << " " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n";
153
154 // Write outputs
155 code << " // -- Write outputs\n";
156 for (CeedInt i = 0; i < num_output_fields; i++) {
157 code << " writeQuads(" << output_sizes[i] << ", Q, q, "
158 << "V_" << i << ", out_" << i << ");\n";
159 }
160 code << "\n";
161
162 // End kernel function body
163 code << "}\n";
164 code << "}\n";
165
166 // View kernel for debugging
167 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n");
168 CeedDebug(ceed, code.str().c_str());
169
170 // Compile kernel
171 CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module));
172 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction));
173
174 // Cleanup
175 CeedCallBackend(CeedFree(&qfunction_source));
176 CeedCallBackend(CeedFree(&read_write_kernel_path));
177 CeedCallBackend(CeedFree(&read_write_kernel_source));
178 CeedCallBackend(CeedDestroy(&ceed));
179 return CEED_ERROR_SUCCESS;
180 }
181
182 //------------------------------------------------------------------------------
183