xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction-load.sycl.cpp (revision 94b7b29b41ad8a17add4c577886859ef16f89dec)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11 
12 #include <iostream>
13 #include <sstream>
14 #include <string>
15 #include <string_view>
16 #include <sycl/sycl.hpp>
17 #include <vector>
18 
19 #include "../sycl/ceed-sycl-compile.hpp"
20 #include "ceed-sycl-ref.hpp"
21 
22 #define SUB_GROUP_SIZE_QF 16
23 
24 //------------------------------------------------------------------------------
25 // Build QFunction kernel
26 //
27 // TODO: Refactor
28 //------------------------------------------------------------------------------
29 extern "C" int CeedQFunctionBuildKernel_Sycl(CeedQFunction qf) {
30   CeedQFunction_Sycl* impl;
31   CeedCallBackend(CeedQFunctionGetData(qf, (void**)&impl));
32   // QFunction is built
33   if (impl->QFunction) return CEED_ERROR_SUCCESS;
34 
35   Ceed ceed;
36   CeedQFunctionGetCeed(qf, &ceed);
37   Ceed_Sycl* data;
38   CeedCallBackend(CeedGetData(ceed, &data));
39 
40   // QFunction kernel generation
41   CeedInt             num_input_fields, num_output_fields;
42   CeedQFunctionField *input_fields, *output_fields;
43   CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
44 
45   std::vector<CeedInt> input_sizes(num_input_fields);
46   CeedQFunctionField*  input_i = input_fields;
47   for (auto& size_i : input_sizes) {
48     CeedCallBackend(CeedQFunctionFieldGetSize(*input_i, &size_i));
49     ++input_i;
50   }
51 
52   std::vector<CeedInt> output_sizes(num_output_fields);
53   CeedQFunctionField*  output_i = output_fields;
54   for (auto& size_i : output_sizes) {
55     CeedCallBackend(CeedQFunctionFieldGetSize(*output_i, &size_i));
56     ++output_i;
57   }
58 
59   char* qfunction_name;
60   CeedCallBackend(CeedQFunctionGetKernelName(qf, &qfunction_name));
61 
62   char* qfunction_source;
63   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n");
64   CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &qfunction_source));
65   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n");
66 
67   char* read_write_kernel_path;
68   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-ref-qfunction.h", &read_write_kernel_path));
69 
70   char* read_write_kernel_source;
71   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n");
72   CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source));
73   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
74 
75   std::string_view  qf_name_view(qfunction_name);
76   std::string_view  qf_source_view(qfunction_source);
77   std::string_view  rw_source_view(read_write_kernel_source);
78   const std::string kernel_name = "CeedKernelSyclRefQFunction_" + std::string(qf_name_view);
79 
80   // Defintions
81   std::ostringstream code;
82   code << rw_source_view;
83   code << qf_source_view;
84   code << "\n";
85 
86   // Kernel function
87   // Here we are fixing a lower sub-group size value to avoid register spills
88   // This needs to be revisited if all qfunctions require this.
89   code << "__attribute__((intel_reqd_sub_group_size(" << SUB_GROUP_SIZE_QF << "))) __kernel void " << kernel_name
90        << "(__global void *ctx, CeedInt Q,\n";
91 
92   // OpenCL doesn't allow for structs with pointers.
93   // We will need to pass all of the arguments individually.
94   // Input parameters
95   for (CeedInt i = 0; i < num_input_fields; ++i) {
96     code << "  "
97          << "__global const CeedScalar *in_" << i << ",\n";
98   }
99 
100   // Output parameters
101   code << "  "
102        << "__global CeedScalar *out_0";
103   for (CeedInt i = 1; i < num_output_fields; ++i) {
104     code << "\n,  "
105          << "__global CeedScalar *out_" << i;
106   }
107   // Begin kernel function body
108   code << ") {\n\n";
109 
110   // Inputs
111   code << "  // Input fields\n";
112   for (CeedInt i = 0; i < num_input_fields; ++i) {
113     code << "  CeedScalar U_" << i << "[" << input_sizes[i] << "];\n";
114   }
115   code << "  const CeedScalar *inputs[" << num_input_fields << "] = {U_0";
116   for (CeedInt i = 1; i < num_input_fields; i++) {
117     code << ", U_" << i << "\n";
118   }
119   code << "};\n\n";
120 
121   // Outputs
122   code << "  // Output fields\n";
123   for (CeedInt i = 0; i < num_output_fields; i++) {
124     code << "  CeedScalar V_" << i << "[" << output_sizes[i] << "];\n";
125   }
126   code << "  CeedScalar *outputs[" << num_output_fields << "] = {V_0";
127   for (CeedInt i = 1; i < num_output_fields; i++) {
128     code << ", V_" << i << "\n";
129   }
130   code << "};\n\n";
131 
132   code << "  const CeedInt q = get_global_linear_id();\n\n";
133 
134   code << "if(q < Q){ \n\n";
135 
136   // Load inputs
137   code << "  // -- Load inputs\n";
138   for (CeedInt i = 0; i < num_input_fields; i++) {
139     code << "  readQuads(" << input_sizes[i] << ", Q, q, "
140          << "in_" << i << ", U_" << i << ");\n";
141   }
142   code << "\n";
143 
144   // QFunction
145   code << "  // -- Call QFunction\n";
146   code << "  " << qf_name_view << "(ctx, 1, inputs, outputs);\n\n";
147 
148   // Write outputs
149   code << "  // -- Write outputs\n";
150   for (CeedInt i = 0; i < num_output_fields; i++) {
151     code << "  writeQuads(" << output_sizes[i] << ", Q, q, "
152          << "V_" << i << ", out_" << i << ");\n";
153   }
154   code << "\n";
155 
156   // End kernel function body
157   code << "}\n";
158   code << "}\n";
159 
160   // View kernel for debugging
161   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n");
162   CeedDebug(ceed, code.str().c_str());
163 
164   // Compile kernel
165   CeedCallBackend(CeedBuildModule_Sycl(ceed, code.str(), &impl->sycl_module));
166   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, kernel_name, &impl->QFunction));
167 
168   // Cleanup
169   CeedCallBackend(CeedFree(&qfunction_source));
170   CeedCallBackend(CeedFree(&read_write_kernel_path));
171   CeedCallBackend(CeedFree(&read_write_kernel_source));
172 
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177