xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <array>
9 
10 namespace Petsc {
11 
12 namespace Device {
13 
14 namespace CUPM {
15 
16 namespace Impl {
17 
18 // Forward declare
19 template <DeviceType T>
20 class PETSC_VISIBILITY_INTERNAL DeviceContext;
21 
22 template <DeviceType T>
23 class DeviceContext : Impl::BlasInterface<T> {
24 public:
25   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);
26 
27 private:
28   // for tag-based dispatch of handle retrieval
29   template <typename H, std::size_t>
30   struct HandleTag {
31     using type = H;
32   };
33   using stream_tag = HandleTag<cupmStream_t, 0>;
34   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
35   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
36 
37 public:
38   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
39   // header, but since we are using the power of templates it must be declared part of
40   // this class to have easy access the same typedefs. Technically one can make a
41   // templated struct outside the class but it's more code for the same result.
42   struct PetscDeviceContext_IMPLS {
43     cupmStream_t stream;
44     cupmEvent_t  event;
45     cupmEvent_t  begin; // timer-only
46     cupmEvent_t  end;   // timer-only
47 #if PetscDefined(USE_DEBUG)
48     PetscBool timerInUse;
49 #endif
50     cupmBlasHandle_t   blas;
51     cupmSolverHandle_t solver;
52 
53     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) {
54       return this->stream;
55     }
56     PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) {
57       return this->blas;
58     }
59     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) {
60       return this->solver;
61     }
62   };
63 
64 private:
65   static bool                                                     initialized_;
66   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
67   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
68 
69   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr)) {
70     return static_cast<PetscDeviceContext_IMPLS *>(ptr->data);
71   }
72 
73   PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMBLAS_HANDLE_CREATE()) {
74     return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE;
75   }
76 
77   PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMSOLVER_HANDLE_CREATE()) {
78     return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE;
79   }
80 
81   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
82   // handles
83   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) {
84     return 0;
85   }
86 
87   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmBlasHandle_t &handle)) {
88     PetscLogEvent event;
89 
90     PetscFunctionBegin;
91     if (PetscLikely(handle)) PetscFunctionReturn(0);
92     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
93     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
94     for (auto i = 0; i < 3; ++i) {
95       auto cberr = cupmBlasCreate(&handle);
96       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
97       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
98       if (i != 2) {
99         PetscCall(PetscSleep(3));
100         continue;
101       }
102       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
103     }
104     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
105     PetscCall(PetscLogEventResume_Internal(event));
106     PetscFunctionReturn(0);
107   }
108 
109   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx)) {
110     const auto dci    = impls_cast_(dctx);
111     auto      &handle = blashandles_[dctx->device->deviceId];
112 
113     PetscFunctionBegin;
114     PetscCall(create_handle_(handle));
115     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream));
116     dci->blas = handle;
117     PetscFunctionReturn(0);
118   }
119 
120   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) {
121     PetscLogEvent event;
122 
123     PetscFunctionBegin;
124     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
125     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
126     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
127     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
128     PetscCall(PetscLogEventResume_Internal(event));
129     PetscFunctionReturn(0);
130   }
131 
132   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx)) {
133     const auto dci    = impls_cast_(dctx);
134     auto      &handle = solverhandles_[dctx->device->deviceId];
135 
136     PetscFunctionBegin;
137     PetscCall(create_handle_(handle));
138     PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream));
139     dci->solver = handle;
140     PetscFunctionReturn(0);
141   }
142 
143   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) {
144     PetscFunctionBegin;
145     for (auto &&handle : blashandles_) {
146       if (handle) {
147         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
148         handle = nullptr;
149       }
150     }
151     for (auto &&handle : solverhandles_) {
152       if (handle) {
153         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
154         handle = nullptr;
155       }
156     }
157     initialized_ = false;
158     PetscFunctionReturn(0);
159   }
160 
161 public:
162   const struct _DeviceContextOps ops = {
163     destroy, changeStreamType, setUp, query, waitForContext, synchronize, getHandle<blas_tag>, getHandle<solver_tag>, getHandle<stream_tag>, beginTimer, endTimer,
164   };
165 
166   // All of these functions MUST be static in order to be callable from C, otherwise they
167   // get the implicit 'this' pointer tacked on
168   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
169   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType));
170   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
171   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *));
172   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext));
173   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
174   template <typename Handle_t>
175   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *));
176   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
177   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *));
178 
179   // not a PetscDeviceContext method, this registers the class
180   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());
181 };
182 
183 template <DeviceType T>
184 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) {
185   PetscFunctionBegin;
186   if (PetscUnlikely(!initialized_)) {
187     initialized_ = true;
188     PetscCall(PetscRegisterFinalize(finalize_));
189   }
190   PetscFunctionReturn(0);
191 }
192 
193 template <DeviceType T>
194 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) {
195   const auto dci = impls_cast_(dctx);
196 
197   PetscFunctionBegin;
198   if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream));
199   if (dci->event) PetscCallCUPM(cupmEventDestroy(dci->event));
200   if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
201   if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
202   PetscCall(PetscFree(dctx->data));
203   PetscFunctionReturn(0);
204 }
205 
206 template <DeviceType T>
207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) {
208   const auto dci = impls_cast_(dctx);
209 
210   PetscFunctionBegin;
211   if (auto &stream = dci->stream) {
212     PetscCallCUPM(cupmStreamDestroy(stream));
213     stream = nullptr;
214   }
215   // set these to null so they aren't usable until setup is called again
216   dci->blas   = nullptr;
217   dci->solver = nullptr;
218   PetscFunctionReturn(0);
219 }
220 
221 template <DeviceType T>
222 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) {
223   const auto dci    = impls_cast_(dctx);
224   auto      &stream = dci->stream;
225 
226   PetscFunctionBegin;
227   if (stream) {
228     PetscCallCUPM(cupmStreamDestroy(stream));
229     stream = nullptr;
230   }
231   switch (const auto stype = dctx->streamType) {
232   case PETSC_STREAM_GLOBAL_BLOCKING:
233     // don't create a stream for global blocking
234     break;
235   case PETSC_STREAM_DEFAULT_BLOCKING: PetscCallCUPM(cupmStreamCreate(&stream)); break;
236   case PETSC_STREAM_GLOBAL_NONBLOCKING: PetscCallCUPM(cupmStreamCreateWithFlags(&stream, cupmStreamNonBlocking)); break;
237   default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_CORRUPT, "Invalid PetscStreamType %s", PetscStreamTypes[util::integral_value(stype)]); break;
238   }
239   if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event));
240 #if PetscDefined(USE_DEBUG)
241   dci->timerInUse = PETSC_FALSE;
242 #endif
243   PetscFunctionReturn(0);
244 }
245 
246 template <DeviceType T>
247 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) {
248   cupmError_t cerr;
249 
250   PetscFunctionBegin;
251   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
252   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
253   else {
254     // somethings gone wrong
255     if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr);
256     *idle = PETSC_FALSE;
257   }
258   PetscFunctionReturn(0);
259 }
260 
261 template <DeviceType T>
262 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) {
263   auto dcib = impls_cast_(dctxb);
264 
265   PetscFunctionBegin;
266   PetscCallCUPM(cupmEventRecord(dcib->event, dcib->stream));
267   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream, dcib->event, 0));
268   PetscFunctionReturn(0);
269 }
270 
271 template <DeviceType T>
272 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) {
273   auto dci = impls_cast_(dctx);
274 
275   PetscFunctionBegin;
276   // in case anything was queued on the event
277   PetscCallCUPM(cupmStreamWaitEvent(dci->stream, dci->event, 0));
278   PetscCallCUPM(cupmStreamSynchronize(dci->stream));
279   PetscFunctionReturn(0);
280 }
281 
282 template <DeviceType T>
283 template <typename handle_t>
284 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) {
285   PetscFunctionBegin;
286   PetscCall(initialize_handle_(handle_t{}, dctx));
287   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
288   PetscFunctionReturn(0);
289 }
290 
291 template <DeviceType T>
292 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) {
293   auto dci = impls_cast_(dctx);
294 
295   PetscFunctionBegin;
296 #if PetscDefined(USE_DEBUG)
297   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
298   dci->timerInUse = PETSC_TRUE;
299 #endif
300   if (!dci->begin) {
301     PetscCallCUPM(cupmEventCreate(&dci->begin));
302     PetscCallCUPM(cupmEventCreate(&dci->end));
303   }
304   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream));
305   PetscFunctionReturn(0);
306 }
307 
308 template <DeviceType T>
309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) {
310   float gtime;
311   auto  dci = impls_cast_(dctx);
312 
313   PetscFunctionBegin;
314 #if PetscDefined(USE_DEBUG)
315   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
316   dci->timerInUse = PETSC_FALSE;
317 #endif
318   PetscCallCUPM(cupmEventRecord(dci->end, dci->stream));
319   PetscCallCUPM(cupmEventSynchronize(dci->end));
320   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, dci->end));
321   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
322   PetscFunctionReturn(0);
323 }
324 
325 // initialize the static member variables
326 template <DeviceType T>
327 bool DeviceContext<T>::initialized_ = false;
328 
329 template <DeviceType T>
330 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
331 
332 template <DeviceType T>
333 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
334 
335 } // namespace Impl
336 
337 // shorten this one up a bit (and instantiate the templates)
338 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
339 using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;
340 
341 // shorthand for what is an EXTREMELY long name
342 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
343 
344 } // namespace CUPM
345 
346 } // namespace Device
347 
348 } // namespace Petsc
349 
350 #endif // PETSCDEVICECONTEXTCUDA_HPP
351