xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunction-load.cpp (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <iostream>
12 #include <sstream>
13 #include <string.h>
14 #include "ceed-hip-ref.h"
15 #include "../hip/ceed-hip-compile.h"
16 
17 //------------------------------------------------------------------------------
18 // Build QFunction kernel
19 //------------------------------------------------------------------------------
20 extern "C" int CeedHipBuildQFunction(CeedQFunction qf) {
21   CeedInt ierr;
22   using std::ostringstream;
23   using std::string;
24   Ceed ceed;
25   CeedQFunctionGetCeed(qf, &ceed);
26   Ceed_Hip *ceed_Hip;
27   ierr = CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
28   CeedQFunction_Hip *data;
29   ierr = CeedQFunctionGetData(qf, (void **)&data); CeedChkBackend(ierr);
30 
31   // QFunction is built
32   if (data->QFunction)
33     return CEED_ERROR_SUCCESS;
34 
35   if (!data->qfunction_source)
36     // LCOV_EXCL_START
37     return CeedError(ceed, CEED_ERROR_BACKEND,
38                      "No QFunction source or hipFunction_t provided.");
39   // LCOV_EXCL_STOP
40 
41   // QFunction kernel generation
42   CeedInt num_input_fields, num_output_fields, size;
43   CeedQFunctionField *input_fields, *output_fields;
44   ierr = CeedQFunctionGetFields(qf, &num_input_fields, &input_fields,
45                                 &num_output_fields, &output_fields);
46   CeedChkBackend(ierr);
47 
48   // Build strings for final kernel
49   char *read_write_kernel_path, *read_write_kernel_source;
50   ierr = CeedPathConcatenate(ceed, __FILE__, "kernels/hip-ref-qfunction.h",
51                              &read_write_kernel_path); CeedChkBackend(ierr);
52   CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source -----\n");
53   ierr = CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source);
54   CeedChkBackend(ierr);
55   CeedDebug256(ceed, 2, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
56   string qfunction_source(data->qfunction_source);
57   string qfunction_name(data->qfunction_name);
58   string read_write(read_write_kernel_source);
59   string kernel_name = "CeedKernel_Hip_ref_" + qfunction_name;
60   ostringstream code;
61 
62   // Defintions
63   code << "\n#define CEED_QFUNCTION(name) inline __device__ int name\n";
64   code << "#define CEED_QFUNCTION_HELPER inline __device__ __forceinline__\n";
65   code << "#define CeedPragmaSIMD\n";
66   code << "#define CEED_ERROR_SUCCESS 0\n";
67   code << "#define CEED_Q_VLA 1\n\n";
68   code << "typedef struct { const CeedScalar* inputs[16]; CeedScalar* outputs[16]; } Fields_Hip;\n";
69   code << read_write;
70   code << qfunction_source;
71   code << "\n";
72   code << "extern \"C\" __launch_bounds__(BLOCK_SIZE)\n";
73   code << "__global__ void " << kernel_name << "(void *ctx, CeedInt Q, Fields_Hip fields) {\n";
74 
75   // Inputs
76   code << "  // Input fields\n";
77   for (CeedInt i = 0; i < num_input_fields; i++) {
78     ierr = CeedQFunctionFieldGetSize(input_fields[i], &size); CeedChkBackend(ierr);
79     code << "  const CeedInt size_input_" << i << " = " << size << ";\n";
80     code << "  CeedScalar input_" << i << "[size_input_" << i << "];\n";
81   }
82   code << "  const CeedScalar* inputs[" << num_input_fields << "];\n";
83   for (CeedInt i = 0; i < num_input_fields; i++) {
84     code << "  inputs[" << i << "] = input_" << i << ";\n";
85   }
86   code << "\n";
87 
88   // Outputs
89   code << "  // Output fields\n";
90   for (CeedInt i = 0; i < num_output_fields; i++) {
91     ierr = CeedQFunctionFieldGetSize(output_fields[i], &size); CeedChkBackend(ierr);
92     code << "  const CeedInt size_output_" << i << " = " << size << ";\n";
93     code << "  CeedScalar output_" << i << "[size_output_" << i << "];\n";
94   }
95   code << "  CeedScalar* outputs[" << num_output_fields << "];\n";
96   for (CeedInt i = 0; i < num_output_fields; i++) {
97     code << "  outputs[" << i << "] = output_" << i << ";\n";
98   }
99   code << "\n";
100 
101   // Loop over quadrature points
102   code << "  // Loop over quadrature points\n";
103   code << "  for (CeedInt q = blockIdx.x * blockDim.x + threadIdx.x; q < Q; q += blockDim.x * gridDim.x) {\n";
104 
105   // Load inputs
106   code << "    // -- Load inputs\n";
107   for (CeedInt i = 0; i < num_input_fields; i++) {
108     code << "    readQuads<size_input_" << i << ">(q, Q, fields.inputs[" << i << "], input_" << i << ");\n";
109   }
110   code << "\n";
111 
112   // QFunction
113   code << "    // -- Call QFunction\n";
114   code << "    " << qfunction_name << "(ctx, 1, inputs, outputs);\n\n";
115 
116   // Write outputs
117   code << "    // -- Write outputs\n";
118   for (CeedInt i = 0; i < num_output_fields; i++) {
119     code << "    writeQuads<size_output_" << i << ">(q, Q, output_" << i << ", fields.outputs[" << i << "]);\n";
120   }
121   code << "  }\n";
122   code << "}\n";
123 
124   // View kernel for debugging
125   CeedDebug256(ceed, 2, "Generated QFunction Kernels:\n");
126   CeedDebug(ceed, code.str().c_str());
127 
128   // Compile kernel
129   ierr = CeedCompileHip(ceed, code.str().c_str(), &data->module,
130 		        1, "BLOCK_SIZE", ceed_Hip->opt_block_size);
131   CeedChkBackend(ierr);
132   ierr = CeedGetKernelHip(ceed, data->module, kernel_name.c_str(), &data->QFunction);
133   CeedChkBackend(ierr);
134 
135   // Cleanup
136   ierr = CeedFree(&data->qfunction_source); CeedChkBackend(ierr);
137   ierr = CeedFree(&read_write_kernel_path); CeedChkBackend(ierr);
138   ierr = CeedFree(&read_write_kernel_source); CeedChkBackend(ierr);
139 
140   return CEED_ERROR_SUCCESS;
141 }
142 //------------------------------------------------------------------------------
143