xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction-load.sycl.cpp (revision 70dc8078dbe287928145230e5a499c81da54eabb)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11 
12 #include <iostream>
13 #include <sstream>
14 #include <string>
15 #include <string_view>
16 #include <sycl/sycl.hpp>
17 #include <vector>
18 
19 #include "../sycl/ceed-sycl-compile.hpp"
20 #include "ceed-sycl-ref.hpp"
21 
22 #define SUB_GROUP_SIZE_QF 16
23 
24 //------------------------------------------------------------------------------
25 // Build QFunction kernel
26 //
27 // TODO: Refactor
28 //------------------------------------------------------------------------------
29 extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) {
30   Ceed                ceed;
31   Ceed_Sycl          *data;
32   const char         *read_write_kernel_path, *read_write_kernel_source;
33   const char         *qfunction_name, *qfunction_source;
34   CeedInt             num_input_fields, num_output_fields;
35   CeedQFunctionField *input_fields, *output_fields;
36   CeedQFunction_Sycl *impl;
37 
38   CeedCallBackend(CeedQFunctionGetData(qf, (void **)&impl));
39   // QFunction is built
40   if (impl->QFunction) return CEED_ERROR_SUCCESS;
41 
42   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
43   CeedCallBackend(CeedGetData(ceed, &data));
44 
45   // QFunction kernel generation
46   CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
47 
48   std::vector<CeedInt> input_sizes(num_input_fields);
49   CeedQFunctionField  *input_i = input_fields;
50 
51   for (auto &size_i : input_sizes) {
52     CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i));
53     ++input_i;
54   }
55 
56   std::vector<CeedInt> output_sizes(num_output_fields);
57   CeedQFunctionField  *output_i = output_fields;
58 
59   for (auto &size_i : output_sizes) {
60     CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i));
61     ++output_i;
62   }
63 
64   CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name));
65 
66   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n");
67   CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source));
68   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n");
69 
70   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path));
71 
72   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n");
73   {
74     char *source;
75 
76     CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &source));
77     read_write_kernel_source = source;
78   }
79   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
80 
81   std::string_view  qf_name_view(qfunction_name);
82   std::string_view  qf_source_view(qfunction_source);
83   std::string_view  rw_source_view(read_write_kernel_source);
84   const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view);
85 
86   // Defintions
87   std::ostringstream code;
88   code << rw_source_view;
89   code << qf_source_view;
90   code << "\n";
91 
92   // Kernel function
93   // Here we are fixing a lower sub-group size value to avoid register spills
94   // This needs to be revisited if all qfunctions require this.
95   code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name
96        << "(__global void *ctx, CeedInt Q,\n";
97 
98   // OpenCL doesn't allow for structs with pointers.
99   // We will need to pass all of the arguments individually.
100   // Input parameters
101   for (CeedInt i = 0; i < num_input_fields; ++i) {
102     code << "  "
103          << "__global const CeedScalar *in_" << i << ",\n";
104   }
105 
106   // Output parameters
107   code << "  "
108        << "__global CeedScalar *out_0";
109   for (CeedInt i = 1; i < num_output_fields; ++i) {
110     code << "\n,  "
111          << "__global CeedScalar *out_" << i;
112   }
113   // Begin kernel function body
114   code << ") {\n\n";
115 
116   // Inputs
117   code << "  // Input fields\n";
118   for (CeedInt i = 0; i < num_input_fields; ++i) {
119     code << "  CeedScalar U_" << i << "[" << input_sizes[i] << "];\n";
120   }
121   code << "  const CeedScalar *inputs[" << CeedIntMax(num_input_fields, 1) << "] = {U_0";
122   for (CeedInt i = 1; i < num_input_fields; i++) {
123     code << ", U_" << i << "\n";
124   }
125   code << "};\n\n";
126 
127   // Outputs
128   code << "  // Output fields\n";
129   for (CeedInt i = 0; i < num_output_fields; i++) {
130     code << "  CeedScalar V_" << i << "[" << output_sizes[i] << "];\n";
131   }
132   code << "  CeedScalar *outputs[" << CeedIntMax(num_output_fields, 1) << "] = {V_0";
133   for (CeedInt i = 1; i < num_output_fields; i++) {
134     code << ", V_" << i << "\n";
135   }
136   code << "};\n\n";
137 
138   code << "  const CeedInt q = get_global_linear_id();\n\n";
139 
140   code << "if(q < Q){ \n\n";
141 
142   // Load inputs
143   code << "  // -- Load inputs\n";
144   for (CeedInt i = 0; i < num_input_fields; i++) {
145     code << "  readQuads(" << input_sizes[i] << ", Q, q, "
146          << "in_" << i << ", U_" << i << ");\n";
147   }
148   code << "\n";
149 
150   // QFunction
151   code << "  // -- Call QFunction\n";
152   code << "  " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n";
153 
154   // Write outputs
155   code << "  // -- Write outputs\n";
156   for (CeedInt i = 0; i < num_output_fields; i++) {
157     code << "  writeQuads(" << output_sizes[i] << ", Q, q, "
158          << "V_" << i << ", out_" << i << ");\n";
159   }
160   code << "\n";
161 
162   // End kernel function body
163   code << "}\n";
164   code << "}\n";
165 
166   // View kernel for debugging
167   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n");
168   CeedDebug(ceed, code.str().c_str());
169 
170   // Compile kernel
171   CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module));
172   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction));
173 
174   // Cleanup
175   CeedCallBackend(CeedFree(&qfunction_source));
176   CeedCallBackend(CeedFree(&read_write_kernel_path));
177   CeedCallBackend(CeedFree(&read_write_kernel_source));
178   return CEED_ERROR_SUCCESS;
179 }
180 
181 //------------------------------------------------------------------------------
182