xref: /libCEED/backends/hip/ceed-hip-compile.cpp (revision 8d12f40e0e187f71c4a1a78742076f931e72da09)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
330f4f45fSnbeams //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
530f4f45fSnbeams //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
730f4f45fSnbeams 
82b730f8bSJeremy L Thompson #include "ceed-hip-compile.h"
92b730f8bSJeremy L Thompson 
1049aac155SJeremy L Thompson #include <ceed.h>
11ec3da8bcSJed Brown #include <ceed/backend.h>
12c9c2c079SJeremy L Thompson #include <ceed/jit-tools.h>
1330f4f45fSnbeams #include <stdarg.h>
143d576824SJeremy L Thompson #include <string.h>
15c85e8640SSebastian Grimberg #include <hip/hiprtc.h>
162b730f8bSJeremy L Thompson 
172b730f8bSJeremy L Thompson #include <sstream>
182b730f8bSJeremy L Thompson 
197fcac036SJeremy L Thompson #include "ceed-hip-common.h"
2030f4f45fSnbeams 
2130f4f45fSnbeams #define CeedChk_hiprtc(ceed, x)                                                                               \
2230f4f45fSnbeams   do {                                                                                                        \
2330f4f45fSnbeams     hiprtcResult result = static_cast<hiprtcResult>(x);                                                       \
242b730f8bSJeremy L Thompson     if (result != HIPRTC_SUCCESS) return CeedError((ceed), CEED_ERROR_BACKEND, hiprtcGetErrorString(result)); \
2530f4f45fSnbeams   } while (0)
2630f4f45fSnbeams 
272b730f8bSJeremy L Thompson #define CeedCallHiprtc(ceed, ...)  \
282b730f8bSJeremy L Thompson   do {                             \
292b730f8bSJeremy L Thompson     int ierr_q_ = __VA_ARGS__;     \
302b730f8bSJeremy L Thompson     CeedChk_hiprtc(ceed, ierr_q_); \
316574a04fSJeremy L Thompson   } while (0)
322b730f8bSJeremy L Thompson 
3330f4f45fSnbeams //------------------------------------------------------------------------------
3430f4f45fSnbeams // Compile HIP kernel
3530f4f45fSnbeams //------------------------------------------------------------------------------
36*8d12f40eSJeremy L Thompson static int CeedCompileCore_Hip(Ceed ceed, const char *source, const bool throw_error, bool *is_compile_good, hipModule_t *module,
37*8d12f40eSJeremy L Thompson                                const CeedInt num_defines, va_list args) {
38b7453713SJeremy L Thompson   size_t                 ptx_size;
3991adc9c8SJeremy L Thompson   char                  *ptx;
40a491a57eSJeremy L Thompson   const int              num_opts            = 4;
414753b775SJeremy L Thompson   CeedInt                num_jit_source_dirs = 0, num_jit_defines = 0;
42b13efd58SJeremy L Thompson   const char           **opts;
43b7453713SJeremy L Thompson   int                    runtime_version;
4430f4f45fSnbeams   hiprtcProgram          prog;
45b7453713SJeremy L Thompson   struct hipDeviceProp_t prop;
46b7453713SJeremy L Thompson   Ceed_Hip              *ceed_data;
47b7453713SJeremy L Thompson 
48b7453713SJeremy L Thompson   hipFree(0);  // Make sure a Context exists for hiprtc
4930f4f45fSnbeams 
5030f4f45fSnbeams   std::ostringstream code;
51c9c2c079SJeremy L Thompson 
52ea61e9acSJeremy L Thompson   // Add hip runtime include statement for generation if runtime < 40400000 (implies ROCm < 4.5)
532b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipRuntimeGetVersion(&runtime_version));
549faa5937SNatalie Beams   if (runtime_version < 40400000) {
5530f4f45fSnbeams     code << "\n#include <hip/hip_runtime.h>\n";
569faa5937SNatalie Beams   }
57ea61e9acSJeremy L Thompson   // With ROCm 4.5, need to include these definitions specifically for hiprtc (but cannot include the runtime header)
589faa5937SNatalie Beams   else {
599faa5937SNatalie Beams     code << "#include <stddef.h>\n";
609faa5937SNatalie Beams     code << "#define __forceinline__ inline __attribute__((always_inline))\n";
619faa5937SNatalie Beams     code << "#define HIP_DYNAMIC_SHARED(type, var) extern __shared__ type var[];\n";
629faa5937SNatalie Beams   }
6330f4f45fSnbeams 
64c9c2c079SJeremy L Thompson   // Kernel specific options, such as kernel constants
65c9c2c079SJeremy L Thompson   if (num_defines > 0) {
6630f4f45fSnbeams     char *name;
6730f4f45fSnbeams     int   val;
68b7453713SJeremy L Thompson 
69c9c2c079SJeremy L Thompson     for (int i = 0; i < num_defines; i++) {
7030f4f45fSnbeams       name = va_arg(args, char *);
7130f4f45fSnbeams       val  = va_arg(args, int);
7230f4f45fSnbeams       code << "#define " << name << " " << val << "\n";
7330f4f45fSnbeams     }
7430f4f45fSnbeams   }
7530f4f45fSnbeams 
76c9c2c079SJeremy L Thompson   // Standard libCEED definitions for HIP backends
7791adc9c8SJeremy L Thompson   code << "#include <ceed/jit-source/hip/hip-jit.h>\n\n";
7830f4f45fSnbeams 
7930f4f45fSnbeams   // Non-macro options
80b13efd58SJeremy L Thompson   CeedCallBackend(CeedCalloc(num_opts, &opts));
8130f4f45fSnbeams   opts[0] = "-default-device";
822b730f8bSJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, (void **)&ceed_data));
832b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipGetDeviceProperties(&prop, ceed_data->device_id));
840d0321e0SJeremy L Thompson   std::string arch_arg = "--gpu-architecture=" + std::string(prop.gcnArchName);
850d0321e0SJeremy L Thompson   opts[1]              = arch_arg.c_str();
86b3c5430cSnbeams   opts[2]              = "-munsafe-fp-atomics";
87a491a57eSJeremy L Thompson   opts[3]              = "-DCEED_RUNNING_JIT_PASS=1";
884753b775SJeremy L Thompson   // Additional include dirs
89b13efd58SJeremy L Thompson   {
90b13efd58SJeremy L Thompson     const char **jit_source_dirs;
91b13efd58SJeremy L Thompson 
92b13efd58SJeremy L Thompson     CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
93b13efd58SJeremy L Thompson     CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
94b13efd58SJeremy L Thompson     for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
954753b775SJeremy L Thompson       std::ostringstream include_dir_arg;
96b13efd58SJeremy L Thompson 
974753b775SJeremy L Thompson       include_dir_arg << "-I" << jit_source_dirs[i];
984753b775SJeremy L Thompson       CeedCallBackend(CeedStringAllocCopy(include_dir_arg.str().c_str(), (char **)&opts[num_opts + i]));
99b13efd58SJeremy L Thompson     }
100b13efd58SJeremy L Thompson     CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
101b13efd58SJeremy L Thompson   }
1024753b775SJeremy L Thompson   // User defines
1034753b775SJeremy L Thompson   {
1044753b775SJeremy L Thompson     const char **jit_defines;
1054753b775SJeremy L Thompson 
1064753b775SJeremy L Thompson     CeedCallBackend(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
1074753b775SJeremy L Thompson     CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs + num_jit_defines, &opts));
1084753b775SJeremy L Thompson     for (CeedInt i = 0; i < num_jit_defines; i++) {
1094753b775SJeremy L Thompson       std::ostringstream define_arg;
1104753b775SJeremy L Thompson 
1114753b775SJeremy L Thompson       define_arg << "-D" << jit_defines[i];
1124753b775SJeremy L Thompson       CeedCallBackend(CeedStringAllocCopy(define_arg.str().c_str(), (char **)&opts[num_opts + num_jit_source_dirs + i]));
1134753b775SJeremy L Thompson     }
1144753b775SJeremy L Thompson     CeedCallBackend(CeedRestoreJitDefines(ceed, &jit_defines));
1154753b775SJeremy L Thompson   }
11630f4f45fSnbeams 
11730f4f45fSnbeams   // Add string source argument provided in call
11830f4f45fSnbeams   code << source;
11930f4f45fSnbeams 
12030f4f45fSnbeams   // Create Program
1212b730f8bSJeremy L Thompson   CeedCallHiprtc(ceed, hiprtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));
12230f4f45fSnbeams 
12330f4f45fSnbeams   // Compile kernel
12426ef7cdaSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- ATTEMPTING TO COMPILE JIT SOURCE ----------\n");
12526ef7cdaSJeremy L Thompson   CeedDebug(ceed, "Source:\n%s\n", code.str().c_str());
12626ef7cdaSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- END OF JIT SOURCE ----------\n");
1274753b775SJeremy L Thompson   hiprtcResult result = hiprtcCompileProgram(prog, num_opts + num_jit_source_dirs + num_jit_defines, opts);
128b7453713SJeremy L Thompson 
129b13efd58SJeremy L Thompson   for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
130b13efd58SJeremy L Thompson     CeedCallBackend(CeedFree(&opts[num_opts + i]));
131b13efd58SJeremy L Thompson   }
1324753b775SJeremy L Thompson   for (CeedInt i = 0; i < num_jit_defines; i++) {
1334753b775SJeremy L Thompson     CeedCallBackend(CeedFree(&opts[num_opts + num_jit_source_dirs + i]));
1344753b775SJeremy L Thompson   }
135b13efd58SJeremy L Thompson   CeedCallBackend(CeedFree(&opts));
136*8d12f40eSJeremy L Thompson   *is_compile_good = result == HIPRTC_SUCCESS;
137*8d12f40eSJeremy L Thompson   if (!*is_compile_good && throw_error) {
1380d0321e0SJeremy L Thompson     size_t log_size;
13930f4f45fSnbeams     char  *log;
140b7453713SJeremy L Thompson 
141b7453713SJeremy L Thompson     CeedChk_hiprtc(ceed, hiprtcGetProgramLogSize(prog, &log_size));
1422b730f8bSJeremy L Thompson     CeedCallBackend(CeedMalloc(log_size, &log));
1432b730f8bSJeremy L Thompson     CeedCallHiprtc(ceed, hiprtcGetProgramLog(prog, log));
1442b730f8bSJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "%s\n%s", hiprtcGetErrorString(result), log);
14530f4f45fSnbeams   }
14630f4f45fSnbeams 
1472b730f8bSJeremy L Thompson   CeedCallHiprtc(ceed, hiprtcGetCodeSize(prog, &ptx_size));
1482b730f8bSJeremy L Thompson   CeedCallBackend(CeedMalloc(ptx_size, &ptx));
1492b730f8bSJeremy L Thompson   CeedCallHiprtc(ceed, hiprtcGetCode(prog, ptx));
1502b730f8bSJeremy L Thompson   CeedCallHiprtc(ceed, hiprtcDestroyProgram(&prog));
15130f4f45fSnbeams 
1522b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipModuleLoadData(module, ptx));
1532b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&ptx));
154e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
15530f4f45fSnbeams }
15630f4f45fSnbeams 
157*8d12f40eSJeremy L Thompson int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const CeedInt num_defines, ...) {
158*8d12f40eSJeremy L Thompson   bool    is_compile_good = true;
159*8d12f40eSJeremy L Thompson   va_list args;
160*8d12f40eSJeremy L Thompson 
161*8d12f40eSJeremy L Thompson   va_start(args, num_defines);
162*8d12f40eSJeremy L Thompson   CeedCallBackend(CeedCompileCore_Hip(ceed, source, true, &is_compile_good, module, num_defines, args));
163*8d12f40eSJeremy L Thompson   va_end(args);
164*8d12f40eSJeremy L Thompson   return CEED_ERROR_SUCCESS;
165*8d12f40eSJeremy L Thompson }
166*8d12f40eSJeremy L Thompson 
167*8d12f40eSJeremy L Thompson int CeedTryCompile_Hip(Ceed ceed, const char *source, bool *is_compile_good, hipModule_t *module, const CeedInt num_defines, ...) {
168*8d12f40eSJeremy L Thompson   va_list args;
169*8d12f40eSJeremy L Thompson 
170*8d12f40eSJeremy L Thompson   va_start(args, num_defines);
171*8d12f40eSJeremy L Thompson   CeedCallBackend(CeedCompileCore_Hip(ceed, source, false, is_compile_good, module, num_defines, args));
172*8d12f40eSJeremy L Thompson   va_end(args);
173*8d12f40eSJeremy L Thompson   return CEED_ERROR_SUCCESS;
174*8d12f40eSJeremy L Thompson }
175*8d12f40eSJeremy L Thompson 
17630f4f45fSnbeams //------------------------------------------------------------------------------
17730f4f45fSnbeams // Get HIP kernel
17830f4f45fSnbeams //------------------------------------------------------------------------------
179eb7e6cafSJeremy L Thompson int CeedGetKernel_Hip(Ceed ceed, hipModule_t module, const char *name, hipFunction_t *kernel) {
1802b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipModuleGetFunction(kernel, module, name));
181e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
18230f4f45fSnbeams }
18330f4f45fSnbeams 
18430f4f45fSnbeams //------------------------------------------------------------------------------
18530f4f45fSnbeams // Run HIP kernel
18630f4f45fSnbeams //------------------------------------------------------------------------------
187eb7e6cafSJeremy L Thompson int CeedRunKernel_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, const int block_size, void **args) {
1882b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size, 1, 1, 0, NULL, args, NULL));
189e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
19030f4f45fSnbeams }
19130f4f45fSnbeams 
19230f4f45fSnbeams //------------------------------------------------------------------------------
19330f4f45fSnbeams // Run HIP kernel for spatial dimension
19430f4f45fSnbeams //------------------------------------------------------------------------------
195eb7e6cafSJeremy L Thompson int CeedRunKernelDim_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, const int block_size_x, const int block_size_y, const int block_size_z,
1962b730f8bSJeremy L Thompson                          void **args) {
1972b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, 0, NULL, args, NULL));
198e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
19930f4f45fSnbeams }
20030f4f45fSnbeams 
20130f4f45fSnbeams //------------------------------------------------------------------------------
202e15f9bd0SJeremy L Thompson // Run HIP kernel for spatial dimension with shared memory
20330f4f45fSnbeams //------------------------------------------------------------------------------
204*8d12f40eSJeremy L Thompson static int CeedRunKernelDimSharedCore_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, const int block_size_x, const int block_size_y,
205*8d12f40eSJeremy L Thompson                                           const int block_size_z, const int shared_mem_size, const bool throw_error, bool *is_good_run, void **args) {
206*8d12f40eSJeremy L Thompson   hipError_t result = hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, shared_mem_size, NULL, args, NULL);
207*8d12f40eSJeremy L Thompson 
208*8d12f40eSJeremy L Thompson   *is_good_run = result == hipSuccess;
209*8d12f40eSJeremy L Thompson   if (throw_error) CeedCallHip(ceed, result);
210*8d12f40eSJeremy L Thompson   return CEED_ERROR_SUCCESS;
211*8d12f40eSJeremy L Thompson }
212*8d12f40eSJeremy L Thompson 
213eb7e6cafSJeremy L Thompson int CeedRunKernelDimShared_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, const int block_size_x, const int block_size_y,
2142b730f8bSJeremy L Thompson                                const int block_size_z, const int shared_mem_size, void **args) {
215*8d12f40eSJeremy L Thompson   bool is_good_run = true;
216*8d12f40eSJeremy L Thompson 
217*8d12f40eSJeremy L Thompson   CeedCallBackend(
218*8d12f40eSJeremy L Thompson       CeedRunKernelDimSharedCore_Hip(ceed, kernel, grid_size, block_size_x, block_size_y, block_size_z, shared_mem_size, true, &is_good_run, args));
219*8d12f40eSJeremy L Thompson   return CEED_ERROR_SUCCESS;
220*8d12f40eSJeremy L Thompson }
221*8d12f40eSJeremy L Thompson 
222*8d12f40eSJeremy L Thompson int CeedTryRunKernelDimShared_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, const int block_size_x, const int block_size_y,
223*8d12f40eSJeremy L Thompson                                   const int block_size_z, const int shared_mem_size, bool *is_good_run, void **args) {
224*8d12f40eSJeremy L Thompson   CeedCallBackend(
225*8d12f40eSJeremy L Thompson       CeedRunKernelDimSharedCore_Hip(ceed, kernel, grid_size, block_size_x, block_size_y, block_size_z, shared_mem_size, false, is_good_run, args));
226e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
22730f4f45fSnbeams }
2282a86cc9dSSebastian Grimberg 
2292a86cc9dSSebastian Grimberg //------------------------------------------------------------------------------
230