xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunction-load.cpp (revision 650a5d66e4f30da5db797426ea50232309c53955)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #include <iostream>
14 #include <sstream>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19 
20 //------------------------------------------------------------------------------
21 // Build QFunction kernel
22 //------------------------------------------------------------------------------
23 extern "C" int CeedQFunctionBuildKernel_Hip_ref(CeedQFunction qf) {
24   using std::ostringstream;
25   using std::string;
26 
27   Ceed                ceed;
28   char               *read_write_kernel_path, *read_write_kernel_source;
29   Ceed_Hip           *ceed_Hip;
30   CeedInt             num_input_fields, num_output_fields, size;
31   CeedQFunctionField *input_fields, *output_fields;
32   CeedQFunction_Hip  *data;
33 
34   CeedQFunctionGetCeed(qf, &ceed);
35   CeedCallBackend(CeedGetData(ceed, &ceed_Hip));
36   CeedCallBackend(CeedQFunctionGetData(qf, (void **)&data));
37 
38   // QFunction is built
39   if (data->QFunction) return CEED_ERROR_SUCCESS;
40 
41   CeedCheck(data->qfunction_source, ceed, CEED_ERROR_BACKEND, "No QFunction source or hipFunction_t provided.");
42 
43   // QFunction kernel generation
44   CeedCallBackend(CeedQFunctionGetFields(qf, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
45 
46   // Build strings for final kernel
47   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-qfunction.h", &read_write_kernel_path));
48   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source -----\n");
49   CeedCallBackend(CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source));
50   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction Read/Write Kernel Source Complete! -----\n");
51   string        qfunction_source(data->qfunction_source);
52   string        qfunction_name(data->qfunction_name);
53   string        read_write(read_write_kernel_source);
54   string        kernel_name = "CeedKernelHipRefQFunction_" + qfunction_name;
55   ostringstream code;
56 
57   // Defintions
58   code << read_write;
59   code << qfunction_source;
60   code << "\n";
61   code << "extern \"C\" __launch_bounds__(BLOCK_SIZE)\n";
62   code << "__global__ void " << kernel_name << "(void *ctx, CeedInt Q, Fields_Hip fields) {\n";
63 
64   // Inputs
65   code << "  // Input fields\n";
66   for (CeedInt i = 0; i < num_input_fields; i++) {
67     CeedCallBackend(CeedQFunctionFieldGetSize(input_fields[i], &size));
68     code << "  const CeedInt size_input_" << i << " = " << size << ";\n";
69     code << "  CeedScalar input_" << i << "[size_input_" << i << "];\n";
70   }
71   code << "  const CeedScalar* inputs[" << num_input_fields << "];\n";
72   for (CeedInt i = 0; i < num_input_fields; i++) {
73     code << "  inputs[" << i << "] = input_" << i << ";\n";
74   }
75   code << "\n";
76 
77   // Outputs
78   code << "  // Output fields\n";
79   for (CeedInt i = 0; i < num_output_fields; i++) {
80     CeedCallBackend(CeedQFunctionFieldGetSize(output_fields[i], &size));
81     code << "  const CeedInt size_output_" << i << " = " << size << ";\n";
82     code << "  CeedScalar output_" << i << "[size_output_" << i << "];\n";
83   }
84   code << "  CeedScalar* outputs[" << num_output_fields << "];\n";
85   for (CeedInt i = 0; i < num_output_fields; i++) {
86     code << "  outputs[" << i << "] = output_" << i << ";\n";
87   }
88   code << "\n";
89 
90   // Loop over quadrature points
91   code << "  // Loop over quadrature points\n";
92   code << "  for (CeedInt q = blockIdx.x * blockDim.x + threadIdx.x; q < Q; q += blockDim.x * gridDim.x) {\n";
93 
94   // Load inputs
95   code << "    // -- Load inputs\n";
96   for (CeedInt i = 0; i < num_input_fields; i++) {
97     code << "    readQuads<size_input_" << i << ">(q, Q, fields.inputs[" << i << "], input_" << i << ");\n";
98   }
99   code << "\n";
100 
101   // QFunction
102   code << "    // -- Call QFunction\n";
103   code << "    " << qfunction_name << "(ctx, 1, inputs, outputs);\n\n";
104 
105   // Write outputs
106   code << "    // -- Write outputs\n";
107   for (CeedInt i = 0; i < num_output_fields; i++) {
108     code << "    writeQuads<size_output_" << i << ">(q, Q, output_" << i << ", fields.outputs[" << i << "]);\n";
109   }
110   code << "  }\n";
111   code << "}\n";
112 
113   // View kernel for debugging
114   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Generated QFunction Kernels:\n");
115   CeedDebug(ceed, code.str().c_str());
116 
117   // Compile kernel
118   CeedCallBackend(CeedCompile_Hip(ceed, code.str().c_str(), &data->module, 1, "BLOCK_SIZE", ceed_Hip->opt_block_size));
119   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, kernel_name.c_str(), &data->QFunction));
120 
121   // Cleanup
122   CeedCallBackend(CeedFree(&data->qfunction_source));
123   CeedCallBackend(CeedFree(&read_write_kernel_path));
124   CeedCallBackend(CeedFree(&read_write_kernel_source));
125   return CEED_ERROR_SUCCESS;
126 }
127 
128 //------------------------------------------------------------------------------
129