ceed-cuda-compile.cpp (daaf13a462f999a7d367f3df68e0e3c34270722c) ceed-cuda-compile.cpp (b13efd58b277efef1db70d6f06eaaf4d415a7642)
1// Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2// All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6// This file is part of CEED: http://github.com/ceed
7
8#include "ceed-cuda-compile.h"

--- 24 unchanged lines hidden (view full) ---

33
34//------------------------------------------------------------------------------
35// Compile CUDA kernel
36//------------------------------------------------------------------------------
37int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...) {
38 size_t ptx_size;
39 char *ptx;
40 const char *jit_defs_path, *jit_defs_source;
1// Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2// All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6// This file is part of CEED: http://github.com/ceed
7
8#include "ceed-cuda-compile.h"

--- 24 unchanged lines hidden (view full) ---

33
34//------------------------------------------------------------------------------
35// Compile CUDA kernel
36//------------------------------------------------------------------------------
37int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const CeedInt num_defines, ...) {
38 size_t ptx_size;
39 char *ptx;
40 const char *jit_defs_path, *jit_defs_source;
41 const int num_opts = 4;
42 const char *opts[num_opts];
41 const int num_opts = 3;
42 CeedInt num_jit_source_dirs = 0;
43 const char **opts;
43 nvrtcProgram prog;
44 struct cudaDeviceProp prop;
45 Ceed_Cuda *ceed_data;
46
47 cudaFree(0); // Make sure a Context exists for nvrtc
48
49 std::ostringstream code;
50

--- 21 unchanged lines hidden (view full) ---

72 jit_defs_source = source;
73 }
74 code << jit_defs_source;
75 code << "\n\n";
76 CeedCallBackend(CeedFree(&jit_defs_path));
77 CeedCallBackend(CeedFree(&jit_defs_source));
78
79 // Non-macro options
44 nvrtcProgram prog;
45 struct cudaDeviceProp prop;
46 Ceed_Cuda *ceed_data;
47
48 cudaFree(0); // Make sure a Context exists for nvrtc
49
50 std::ostringstream code;
51

--- 21 unchanged lines hidden (view full) ---

73 jit_defs_source = source;
74 }
75 code << jit_defs_source;
76 code << "\n\n";
77 CeedCallBackend(CeedFree(&jit_defs_path));
78 CeedCallBackend(CeedFree(&jit_defs_source));
79
80 // Non-macro options
81 CeedCallBackend(CeedCalloc(num_opts, &opts));
80 opts[0] = "-default-device";
81 CeedCallBackend(CeedGetData(ceed, &ceed_data));
82 CeedCallCuda(ceed, cudaGetDeviceProperties(&prop, ceed_data->device_id));
83 std::string arch_arg =
84#if CUDA_VERSION >= 11010
85 // NVRTC used to support only virtual architectures through the option
86 // -arch, since it was only emitting PTX. It will now support actual
87 // architectures as well to emit SASS.
88 // https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#dynamic-code-generation
89 "-arch=sm_"
90#else
91 "-arch=compute_"
92#endif
93 + std::to_string(prop.major) + std::to_string(prop.minor);
94 opts[1] = arch_arg.c_str();
95 opts[2] = "-Dint32_t=int";
82 opts[0] = "-default-device";
83 CeedCallBackend(CeedGetData(ceed, &ceed_data));
84 CeedCallCuda(ceed, cudaGetDeviceProperties(&prop, ceed_data->device_id));
85 std::string arch_arg =
86#if CUDA_VERSION >= 11010
87 // NVRTC used to support only virtual architectures through the option
88 // -arch, since it was only emitting PTX. It will now support actual
89 // architectures as well to emit SASS.
90 // https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#dynamic-code-generation
91 "-arch=sm_"
92#else
93 "-arch=compute_"
94#endif
95 + std::to_string(prop.major) + std::to_string(prop.minor);
96 opts[1] = arch_arg.c_str();
97 opts[2] = "-Dint32_t=int";
96 opts[3] = "-I/home/jeremy/Dev/libCEED/include/ceed/jit-source/"
98 {
99 const char **jit_source_dirs;
97
100
101 CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
102 CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
103 for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
104 std::ostringstream include_dirs_arg;
105
106 include_dirs_arg << "-I" << jit_source_dirs[i];
107 CeedCallBackend(CeedStringAllocCopy(include_dirs_arg.str().c_str(), (char **)&opts[num_opts + i]));
108 }
109 CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
110 }
111
98 // Add string source argument provided in call
99 code << source;
100
101 // Create Program
102 CeedCallNvrtc(ceed, nvrtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));
103
104 // Compile kernel
112 // Add string source argument provided in call
113 code << source;
114
115 // Create Program
116 CeedCallNvrtc(ceed, nvrtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));
117
118 // Compile kernel
105 nvrtcResult result = nvrtcCompileProgram(prog, num_opts, opts);
119 nvrtcResult result = nvrtcCompileProgram(prog, num_opts + num_jit_source_dirs, opts);
106
120
121 for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
122 CeedCallBackend(CeedFree(&opts[num_opts + i]));
123 }
124 CeedCallBackend(CeedFree(&opts));
107 if (result != NVRTC_SUCCESS) {
108 char *log;
109 size_t log_size;
110
111 CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- CEED JIT SOURCE FAILED TO COMPILE ----------\n");
112 CeedDebug(ceed, "Source:\n%s\n", code.str().c_str());
113 CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- CEED JIT SOURCE FAILED TO COMPILE ----------\n");
114 CeedCallNvrtc(ceed, nvrtcGetProgramLogSize(prog, &log_size));

--- 83 unchanged lines hidden ---
125 if (result != NVRTC_SUCCESS) {
126 char *log;
127 size_t log_size;
128
129 CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- CEED JIT SOURCE FAILED TO COMPILE ----------\n");
130 CeedDebug(ceed, "Source:\n%s\n", code.str().c_str());
131 CeedDebug256(ceed, CEED_DEBUG_COLOR_ERROR, "---------- CEED JIT SOURCE FAILED TO COMPILE ----------\n");
132 CeedCallNvrtc(ceed, nvrtcGetProgramLogSize(prog, &log_size));

--- 83 unchanged lines hidden ---