xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref-qfunction-load.sycl.cpp (revision dd64fc8452c2d35c954858232143719e6bb2e61d)
1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright #include <ceed/backend.h>
9bd882c8aSJames Wright #include <ceed/ceed.h>
10bd882c8aSJames Wright #include <ceed/jit-tools.h>
11bd882c8aSJames Wright 
12bd882c8aSJames Wright #include <iostream>
13bd882c8aSJames Wright #include <sstream>
14bd882c8aSJames Wright #include <string>
15bd882c8aSJames Wright #include <string_view>
16bd882c8aSJames Wright #include <sycl/sycl.hpp>
17bd882c8aSJames Wright #include <vector>
18bd882c8aSJames Wright 
19bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
20bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
21bd882c8aSJames Wright 
22bd882c8aSJames Wright #define SUB_GROUP_SIZE_QF 16
23bd882c8aSJames Wright 
24bd882c8aSJames Wright //------------------------------------------------------------------------------
25bd882c8aSJames Wright // Build QFunction kernel
26bd882c8aSJames Wright //
27bd882c8aSJames Wright // TODO: Refactor
28bd882c8aSJames Wright //------------------------------------------------------------------------------
29eb7e6cafSJeremy L Thompson extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) {
30*dd64fc84SJeremy L Thompson   Ceed                ceed;
31*dd64fc84SJeremy L Thompson   Ceed_Sycl*          data;
32*dd64fc84SJeremy L Thompson   char *              qfunction_name, *qfunction_source, *read_write_kernel_path, *read_write_kernel_source;
33*dd64fc84SJeremy L Thompson   CeedInt             num_input_fields, num_output_fields;
34*dd64fc84SJeremy L Thompson   CeedQFunctionField *input_fields, *output_fields;
35bd882c8aSJames Wright   CeedQFunction_Sycl* impl;
36*dd64fc84SJeremy L Thompson 
37bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetData(qf, (void**)&impl));
38bd882c8aSJames Wright   // QFunction is built
39bd882c8aSJames Wright   if (impl->QFunction) return CEED_ERROR_SUCCESS;
40bd882c8aSJames Wright 
41bd882c8aSJames Wright   CeedQFunctionGetCeed(qf, &ceed);
42bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
43bd882c8aSJames Wright 
44bd882c8aSJames Wright   // QFunction kernel generation
45bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
46bd882c8aSJames Wright 
47bd882c8aSJames Wright   std::vector<CeedInt> input_sizes(num_input_fields);
48bd882c8aSJames Wright   CeedQFunctionField*  input_i = input_fields;
49*dd64fc84SJeremy L Thompson 
50bd882c8aSJames Wright   for (auto& size_i : input_sizes) {
51bd882c8aSJames Wright     CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i));
52bd882c8aSJames Wright     ++input_i;
53bd882c8aSJames Wright   }
54bd882c8aSJames Wright 
55bd882c8aSJames Wright   std::vector<CeedInt> output_sizes(num_output_fields);
56bd882c8aSJames Wright   CeedQFunctionField*  output_i = output_fields;
57*dd64fc84SJeremy L Thompson 
58bd882c8aSJames Wright   for (auto& size_i : output_sizes) {
59bd882c8aSJames Wright     CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i));
60bd882c8aSJames Wright     ++output_i;
61bd882c8aSJames Wright   }
62bd882c8aSJames Wright 
63bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name));
64bd882c8aSJames Wright 
6523d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n");
66bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source));
6723d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n");
68bd882c8aSJames Wright 
69bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path));
70bd882c8aSJames Wright 
7123d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n");
72bd882c8aSJames Wright   CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source));
7323d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
74bd882c8aSJames Wright 
75bd882c8aSJames Wright   std::string_view  qf_name_view(qfunction_name);
76bd882c8aSJames Wright   std::string_view  qf_source_view(qfunction_source);
77bd882c8aSJames Wright   std::string_view  rw_source_view(read_write_kernel_source);
78bd882c8aSJames Wright   const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view);
79bd882c8aSJames Wright 
80bd882c8aSJames Wright   // Defintions
81bd882c8aSJames Wright   std::ostringstream code;
82bd882c8aSJames Wright   code << rw_source_view;
83bd882c8aSJames Wright   code << qf_source_view;
84bd882c8aSJames Wright   code << "\n";
85bd882c8aSJames Wright 
86bd882c8aSJames Wright   // Kernel function
87bd882c8aSJames Wright   // Here we are fixing a lower sub-group size value to avoid register spills
88bd882c8aSJames Wright   // This needs to be revisited if all qfunctions require this.
89bd882c8aSJames Wright   code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name
90bd882c8aSJames Wright        << "(__global void *ctx, CeedInt Q,\n";
91bd882c8aSJames Wright 
92bd882c8aSJames Wright   // OpenCL doesn't allow for structs with pointers.
93bd882c8aSJames Wright   // We will need to pass all of the arguments individually.
94bd882c8aSJames Wright   // Input parameters
95bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; ++i) {
96bd882c8aSJames Wright     code << "  "
97bd882c8aSJames Wright          << "__global const CeedScalar *in_" << i << ",\n";
98bd882c8aSJames Wright   }
99bd882c8aSJames Wright 
100bd882c8aSJames Wright   // Output parameters
101bd882c8aSJames Wright   code << "  "
102bd882c8aSJames Wright        << "__global CeedScalar *out_0";
103bd882c8aSJames Wright   for (CeedInt i = 1; i < num_output_fields; ++i) {
104bd882c8aSJames Wright     code << "\n,  "
105bd882c8aSJames Wright          << "__global CeedScalar *out_" << i;
106bd882c8aSJames Wright   }
107bd882c8aSJames Wright   // Begin kernel function body
108bd882c8aSJames Wright   code << ") {\n\n";
109bd882c8aSJames Wright 
110bd882c8aSJames Wright   // Inputs
111bd882c8aSJames Wright   code << "  // Input fields\n";
112bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; ++i) {
113bd882c8aSJames Wright     code << "  CeedScalar U_" << i << "[" << input_sizes[i] << "];\n";
114bd882c8aSJames Wright   }
115bd882c8aSJames Wright   code << "  const CeedScalar *inputs[" << num_input_fields << "] = {U_0";
116bd882c8aSJames Wright   for (CeedInt i = 1; i < num_input_fields; i++) {
117bd882c8aSJames Wright     code << ", U_" << i << "\n";
118bd882c8aSJames Wright   }
119bd882c8aSJames Wright   code << "};\n\n";
120bd882c8aSJames Wright 
121bd882c8aSJames Wright   // Outputs
122bd882c8aSJames Wright   code << "  // Output fields\n";
123bd882c8aSJames Wright   for (CeedInt i = 0; i < num_output_fields; i++) {
124bd882c8aSJames Wright     code << "  CeedScalar V_" << i << "[" << output_sizes[i] << "];\n";
125bd882c8aSJames Wright   }
126bd882c8aSJames Wright   code << "  CeedScalar *outputs[" << num_output_fields << "] = {V_0";
127bd882c8aSJames Wright   for (CeedInt i = 1; i < num_output_fields; i++) {
128bd882c8aSJames Wright     code << ", V_" << i << "\n";
129bd882c8aSJames Wright   }
130bd882c8aSJames Wright   code << "};\n\n";
131bd882c8aSJames Wright 
132bd882c8aSJames Wright   code << "  const CeedInt q = get_global_linear_id();\n\n";
133bd882c8aSJames Wright 
134bd882c8aSJames Wright   code << "if(q < Q){ \n\n";
135bd882c8aSJames Wright 
136bd882c8aSJames Wright   // Load inputs
137bd882c8aSJames Wright   code << "  // -- Load inputs\n";
138bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; i++) {
139bd882c8aSJames Wright     code << "  readQuads(" << input_sizes[i] << ", Q, q, "
140bd882c8aSJames Wright          << "in_" << i << ", U_" << i << ");\n";
141bd882c8aSJames Wright   }
142bd882c8aSJames Wright   code << "\n";
143bd882c8aSJames Wright 
144bd882c8aSJames Wright   // QFunction
145bd882c8aSJames Wright   code << "  // -- Call QFunction\n";
146bd882c8aSJames Wright   code << "  " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n";
147bd882c8aSJames Wright 
148bd882c8aSJames Wright   // Write outputs
149bd882c8aSJames Wright   code << "  // -- Write outputs\n";
150bd882c8aSJames Wright   for (CeedInt i = 0; i < num_output_fields; i++) {
151bd882c8aSJames Wright     code << "  writeQuads(" << output_sizes[i] << ", Q, q, "
152bd882c8aSJames Wright          << "V_" << i << ", out_" << i << ");\n";
153bd882c8aSJames Wright   }
154bd882c8aSJames Wright   code << "\n";
155bd882c8aSJames Wright 
156bd882c8aSJames Wright   // End kernel function body
157bd882c8aSJames Wright   code << "}\n";
158bd882c8aSJames Wright   code << "}\n";
159bd882c8aSJames Wright 
160bd882c8aSJames Wright   // View kernel for debugging
16123d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n");
162bd882c8aSJames Wright   CeedDebug(ceed, code.str().c_str());
163bd882c8aSJames Wright 
164bd882c8aSJames Wright   // Compile kernel
165eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module));
166eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction));
167bd882c8aSJames Wright 
168bd882c8aSJames Wright   // Cleanup
169bd882c8aSJames Wright   CeedCallBackend(CeedFree(&qfunction_source));
170bd882c8aSJames Wright   CeedCallBackend(CeedFree(&read_write_kernel_path));
171bd882c8aSJames Wright   CeedCallBackend(CeedFree(&read_write_kernel_source));
172bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
173bd882c8aSJames Wright }
174ff1e7120SSebastian Grimberg 
175bd882c8aSJames Wright //------------------------------------------------------------------------------
176