xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref-qfunction-load.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright #include <ceed/backend.h>
9*bd882c8aSJames Wright #include <ceed/ceed.h>
10*bd882c8aSJames Wright #include <ceed/jit-tools.h>
11*bd882c8aSJames Wright 
12*bd882c8aSJames Wright #include <iostream>
13*bd882c8aSJames Wright #include <sstream>
14*bd882c8aSJames Wright #include <string>
15*bd882c8aSJames Wright #include <string_view>
16*bd882c8aSJames Wright #include <sycl/sycl.hpp>
17*bd882c8aSJames Wright #include <vector>
18*bd882c8aSJames Wright 
19*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
20*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
21*bd882c8aSJames Wright 
22*bd882c8aSJames Wright #define SUB_GROUP_SIZE_QF 16
23*bd882c8aSJames Wright 
24*bd882c8aSJames Wright //------------------------------------------------------------------------------
25*bd882c8aSJames Wright // Build QFunction kernel
26*bd882c8aSJames Wright //
27*bd882c8aSJames Wright // TODO: Refactor
28*bd882c8aSJames Wright //------------------------------------------------------------------------------
29*bd882c8aSJames Wright extern "C" int CeedBuildQFunction_Sycl(CeedQFunction qf) {
30*bd882c8aSJames Wright   CeedQFunction_Sycl* impl;
31*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetData(qf, (void**)&impl));
32*bd882c8aSJames Wright   // QFunction is built
33*bd882c8aSJames Wright   if (impl->QFunction) return CEED_ERROR_SUCCESS;
34*bd882c8aSJames Wright 
35*bd882c8aSJames Wright   Ceed ceed;
36*bd882c8aSJames Wright   CeedQFunctionGetCeed(qf, &ceed);
37*bd882c8aSJames Wright   Ceed_Sycl* data;
38*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
39*bd882c8aSJames Wright 
40*bd882c8aSJames Wright   // QFunction kernel generation
41*bd882c8aSJames Wright   CeedInt             num_input_fields, num_output_fields;
42*bd882c8aSJames Wright   CeedQFunctionField *input_fields, *output_fields;
43*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
44*bd882c8aSJames Wright 
45*bd882c8aSJames Wright   std::vector<CeedInt> input_sizes(num_input_fields);
46*bd882c8aSJames Wright   CeedQFunctionField*  input_i = input_fields;
47*bd882c8aSJames Wright   for (auto& size_i : input_sizes) {
48*bd882c8aSJames Wright     CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i));
49*bd882c8aSJames Wright     ++input_i;
50*bd882c8aSJames Wright   }
51*bd882c8aSJames Wright 
52*bd882c8aSJames Wright   std::vector<CeedInt> output_sizes(num_output_fields);
53*bd882c8aSJames Wright   CeedQFunctionField*  output_i = output_fields;
54*bd882c8aSJames Wright   for (auto& size_i : output_sizes) {
55*bd882c8aSJames Wright     CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i));
56*bd882c8aSJames Wright     ++output_i;
57*bd882c8aSJames Wright   }
58*bd882c8aSJames Wright 
59*bd882c8aSJames Wright   char* qfunction_name;
60*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name));
61*bd882c8aSJames Wright 
62*bd882c8aSJames Wright   char* qfunction_source;
63*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n");
64*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source));
65*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n");
66*bd882c8aSJames Wright 
67*bd882c8aSJames Wright   char* read_write_kernel_path;
68*bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path));
69*bd882c8aSJames Wright 
70*bd882c8aSJames Wright   char* read_write_kernel_source;
71*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source -----\n");
72*bd882c8aSJames Wright   CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source));
73*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
74*bd882c8aSJames Wright 
75*bd882c8aSJames Wright   std::string_view  qf_name_view(qfunction_name);
76*bd882c8aSJames Wright   std::string_view  qf_source_view(qfunction_source);
77*bd882c8aSJames Wright   std::string_view  rw_source_view(read_write_kernel_source);
78*bd882c8aSJames Wright   const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view);
79*bd882c8aSJames Wright 
80*bd882c8aSJames Wright   // Defintions
81*bd882c8aSJames Wright   std::ostringstream code;
82*bd882c8aSJames Wright   code << rw_source_view;
83*bd882c8aSJames Wright   code << qf_source_view;
84*bd882c8aSJames Wright   code << "\n";
85*bd882c8aSJames Wright 
86*bd882c8aSJames Wright   // Kernel function
87*bd882c8aSJames Wright   // Here we are fixing a lower sub-group size value to avoid register spills
88*bd882c8aSJames Wright   // This needs to be revisited if all qfunctions require this.
89*bd882c8aSJames Wright   code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name
90*bd882c8aSJames Wright        << "(__global void *ctx, CeedInt Q,\n";
91*bd882c8aSJames Wright 
92*bd882c8aSJames Wright   // OpenCL doesn't allow for structs with pointers.
93*bd882c8aSJames Wright   // We will need to pass all of the arguments individually.
94*bd882c8aSJames Wright   // Input parameters
95*bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; ++i) {
96*bd882c8aSJames Wright     code << "  "
97*bd882c8aSJames Wright          << "__global const CeedScalar *in_" << i << ",\n";
98*bd882c8aSJames Wright   }
99*bd882c8aSJames Wright 
100*bd882c8aSJames Wright   // Output parameters
101*bd882c8aSJames Wright   code << "  "
102*bd882c8aSJames Wright        << "__global CeedScalar *out_0";
103*bd882c8aSJames Wright   for (CeedInt i = 1; i < num_output_fields; ++i) {
104*bd882c8aSJames Wright     code << "\n,  "
105*bd882c8aSJames Wright          << "__global CeedScalar *out_" << i;
106*bd882c8aSJames Wright   }
107*bd882c8aSJames Wright   // Begin kernel function body
108*bd882c8aSJames Wright   code << ") {\n\n";
109*bd882c8aSJames Wright 
110*bd882c8aSJames Wright   // Inputs
111*bd882c8aSJames Wright   code << "  // Input fields\n";
112*bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; ++i) {
113*bd882c8aSJames Wright     code << "  CeedScalar U_" << i << "[" << input_sizes[i] << "];\n";
114*bd882c8aSJames Wright   }
115*bd882c8aSJames Wright   code << "  const CeedScalar *inputs[" << num_input_fields << "] = {U_0";
116*bd882c8aSJames Wright   for (CeedInt i = 1; i < num_input_fields; i++) {
117*bd882c8aSJames Wright     code << ", U_" << i << "\n";
118*bd882c8aSJames Wright   }
119*bd882c8aSJames Wright   code << "};\n\n";
120*bd882c8aSJames Wright 
121*bd882c8aSJames Wright   // Outputs
122*bd882c8aSJames Wright   code << "  // Output fields\n";
123*bd882c8aSJames Wright   for (CeedInt i = 0; i < num_output_fields; i++) {
124*bd882c8aSJames Wright     code << "  CeedScalar V_" << i << "[" << output_sizes[i] << "];\n";
125*bd882c8aSJames Wright   }
126*bd882c8aSJames Wright   code << "  CeedScalar *outputs[" << num_output_fields << "] = {V_0";
127*bd882c8aSJames Wright   for (CeedInt i = 1; i < num_output_fields; i++) {
128*bd882c8aSJames Wright     code << ", V_" << i << "\n";
129*bd882c8aSJames Wright   }
130*bd882c8aSJames Wright   code << "};\n\n";
131*bd882c8aSJames Wright 
132*bd882c8aSJames Wright   code << "  const CeedInt q = get_global_linear_id();\n\n";
133*bd882c8aSJames Wright 
134*bd882c8aSJames Wright   code << "if(q < Q){ \n\n";
135*bd882c8aSJames Wright 
136*bd882c8aSJames Wright   // Load inputs
137*bd882c8aSJames Wright   code << "  // -- Load inputs\n";
138*bd882c8aSJames Wright   for (CeedInt i = 0; i < num_input_fields; i++) {
139*bd882c8aSJames Wright     code << "  readQuads(" << input_sizes[i] << ", Q, q, "
140*bd882c8aSJames Wright          << "in_" << i << ", U_" << i << ");\n";
141*bd882c8aSJames Wright   }
142*bd882c8aSJames Wright   code << "\n";
143*bd882c8aSJames Wright 
144*bd882c8aSJames Wright   // QFunction
145*bd882c8aSJames Wright   code << "  // -- Call QFunction\n";
146*bd882c8aSJames Wright   code << "  " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n";
147*bd882c8aSJames Wright 
148*bd882c8aSJames Wright   // Write outputs
149*bd882c8aSJames Wright   code << "  // -- Write outputs\n";
150*bd882c8aSJames Wright   for (CeedInt i = 0; i < num_output_fields; i++) {
151*bd882c8aSJames Wright     code << "  writeQuads(" << output_sizes[i] << ", Q, q, "
152*bd882c8aSJames Wright          << "V_" << i << ", out_" << i << ");\n";
153*bd882c8aSJames Wright   }
154*bd882c8aSJames Wright   code << "\n";
155*bd882c8aSJames Wright 
156*bd882c8aSJames Wright   // End kernel function body
157*bd882c8aSJames Wright   code << "}\n";
158*bd882c8aSJames Wright   code << "}\n";
159*bd882c8aSJames Wright 
160*bd882c8aSJames Wright   // View kernel for debugging
161*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "Generated QFunction Kernels:\n");
162*bd882c8aSJames Wright   CeedDebug(ceed, code.str().c_str());
163*bd882c8aSJames Wright 
164*bd882c8aSJames Wright   // Compile kernel
165*bd882c8aSJames Wright   CeedCallBackend(CeedJitBuildModule_Sycl(ceed, code.str(), &impl->sycl_module));
166*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction));
167*bd882c8aSJames Wright 
168*bd882c8aSJames Wright   // Cleanup
169*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&qfunction_source));
170*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&read_write_kernel_path));
171*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&read_write_kernel_source));
172*bd882c8aSJames Wright 
173*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
174*bd882c8aSJames Wright }
175*bd882c8aSJames Wright //------------------------------------------------------------------------------
176