xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 66c9fbdd036b1e887ebf0d2bef6dbdcafd086d45)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 #if defined(__cplusplus)
16 
17 namespace Petsc
18 {
19 
20 namespace device
21 {
22 
23 namespace cupm
24 {
25 
26 namespace impl
27 {
28 
29 template <DeviceType T>
30 class DeviceContext : BlasInterface<T> {
31 public:
32   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);
33 
34 private:
35   template <typename H, std::size_t>
36   struct HandleTag {
37     using type = H;
38   };
39 
40   using stream_tag = HandleTag<cupmStream_t, 0>;
41   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
42   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
43 
44   using stream_type = CUPMStream<T>;
45   using event_type  = CUPMEvent<T>;
46 
47 public:
48   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
49   // header, but since we are using the power of templates it must be declared part of
50   // this class to have easy access the same typedefs. Technically one can make a
51   // templated struct outside the class but it's more code for the same result.
52   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
53     stream_type stream{};
54     cupmEvent_t event{};
55     cupmEvent_t begin{}; // timer-only
56     cupmEvent_t end{};   // timer-only
57   #if PetscDefined(USE_DEBUG)
58     PetscBool timerInUse{};
59   #endif
60     cupmBlasHandle_t   blas{};
61     cupmSolverHandle_t solver{};
62 
63     constexpr PetscDeviceContext_IMPLS() noexcept = default;
64 
65     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }
66 
67     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }
68 
69     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
70   };
71 
72 private:
73   static bool                                                     initialized_;
74   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
75   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
76 
77   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
78 
79   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
80 
81   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
82 
83   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
84 
85   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
86   // handles
87   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { return 0; }
88 
89   PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept
90   {
91     PetscLogEvent event;
92 
93     PetscFunctionBegin;
94     if (PetscLikely(handle)) PetscFunctionReturn(0);
95     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
96     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
97     for (auto i = 0; i < 3; ++i) {
98       auto cberr = cupmBlasCreate(&handle);
99       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
100       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
101       if (i != 2) {
102         PetscCall(PetscSleep(3));
103         continue;
104       }
105       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
106     }
107     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
108     PetscCall(PetscLogEventResume_Internal(event));
109     PetscFunctionReturn(0);
110   }
111 
112   PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept
113   {
114     const auto dci    = impls_cast_(dctx);
115     auto      &handle = blashandles_[dctx->device->deviceId];
116 
117     PetscFunctionBegin;
118     PetscCall(create_handle_(tag, handle));
119     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
120     dci->blas = handle;
121     PetscFunctionReturn(0);
122   }
123 
124   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(solver_tag, cupmSolverHandle_t &handle))
125   {
126     PetscLogEvent event;
127 
128     PetscFunctionBegin;
129     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
130     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
131     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
132     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
133     PetscCall(PetscLogEventResume_Internal(event));
134     PetscFunctionReturn(0);
135   }
136 
137   PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag tag, PetscDeviceContext dctx) noexcept
138   {
139     const auto dci    = impls_cast_(dctx);
140     auto      &handle = solverhandles_[dctx->device->deviceId];
141 
142     PetscFunctionBegin;
143     PetscCall(create_handle_(tag, handle));
144     PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream()));
145     dci->solver = handle;
146     PetscFunctionReturn(0);
147   }
148 
149   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
150   {
151     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
152 
153     PetscFunctionBegin;
154     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
155                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
156     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
157     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
158     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
159     PetscFunctionReturn(0);
160   }
161 
162   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
163 
164   PETSC_NODISCARD static PetscErrorCode finalize_() noexcept
165   {
166     PetscFunctionBegin;
167     for (auto &&handle : blashandles_) {
168       if (handle) {
169         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
170         handle = nullptr;
171       }
172     }
173     for (auto &&handle : solverhandles_) {
174       if (handle) {
175         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
176         handle = nullptr;
177       }
178     }
179     initialized_ = false;
180     PetscFunctionReturn(0);
181   }
182 
183   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
184   PETSC_NODISCARD static PoolType &default_pool_() noexcept
185   {
186     static PoolType pool;
187     return pool;
188   }
189 
190   PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
191   {
192     PetscFunctionBegin;
193     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
194     PetscFunctionReturn(0);
195   }
196 
197 public:
198   // All of these functions MUST be static in order to be callable from C, otherwise they
199   // get the implicit 'this' pointer tacked on
200   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
201   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType));
202   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
203   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *));
204   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext));
205   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
206   template <typename Handle_t>
207   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *));
208   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
209   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *));
210   PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **));
211   PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **));
212   PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode));
213   PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t));
214   PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent));
215   PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent));
216   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent));
217 
218   // not a PetscDeviceContext method, this registers the class
219   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());
220 
221   // clang-format off
222   const _DeviceContextOps ops = {
223     destroy,
224     changeStreamType,
225     setUp,
226     query,
227     waitForContext,
228     synchronize,
229     getHandle<blas_tag>,
230     getHandle<solver_tag>,
231     getHandle<stream_tag>,
232     beginTimer,
233     endTimer,
234     memAlloc,
235     memFree,
236     memCopy,
237     memSet,
238     createEvent,
239     recordEvent,
240     waitForEvent
241   };
242   // clang-format on
243 };
244 
245 // not a PetscDeviceContext method, this initializes the CLASS
246 template <DeviceType T>
247 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize())
248 {
249   PetscFunctionBegin;
250   if (PetscUnlikely(!initialized_)) {
251     cupmMemPool_t mempool;
252     uint64_t      threshold = UINT64_MAX;
253 
254     initialized_ = true;
255     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0));
256     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
257     blashandles_.fill(nullptr);
258     solverhandles_.fill(nullptr);
259     PetscCall(PetscRegisterFinalize(finalize_));
260   }
261   PetscFunctionReturn(0);
262 }
263 
264 template <DeviceType T>
265 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
266 {
267   PetscFunctionBegin;
268   if (const auto dci = impls_cast_(dctx)) {
269     PetscCall(dci->stream.destroy());
270     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event)));
271     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
272     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
273     delete dci;
274     dctx->data = nullptr;
275   }
276   PetscFunctionReturn(0);
277 }
278 
279 template <DeviceType T>
280 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
281 {
282   const auto dci = impls_cast_(dctx);
283 
284   PetscFunctionBegin;
285   PetscCall(dci->stream.destroy());
286   // set these to null so they aren't usable until setup is called again
287   dci->blas   = nullptr;
288   dci->solver = nullptr;
289   PetscFunctionReturn(0);
290 }
291 
292 template <DeviceType T>
293 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
294 {
295   const auto dci   = impls_cast_(dctx);
296   auto      &event = dci->event;
297 
298   PetscFunctionBegin;
299   PetscCall(check_current_device_(dctx));
300   PetscCall(dci->stream.change_type(dctx->streamType));
301   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
302   #if PetscDefined(USE_DEBUG)
303   dci->timerInUse = PETSC_FALSE;
304   #endif
305   PetscFunctionReturn(0);
306 }
307 
308 template <DeviceType T>
309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
310 {
311   PetscFunctionBegin;
312   PetscCall(check_current_device_(dctx));
313   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
314   case cupmSuccess:
315     *idle = PETSC_TRUE;
316     break;
317   case cupmErrorNotReady:
318     *idle = PETSC_FALSE;
319     // reset the error
320     cerr = cupmGetLastError();
321     static_cast<void>(cerr);
322     break;
323   default:
324     PetscCallCUPM(cerr);
325     PetscUnreachable();
326   }
327   PetscFunctionReturn(0);
328 }
329 
330 template <DeviceType T>
331 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
332 {
333   const auto dcib  = impls_cast_(dctxb);
334   const auto event = dcib->event;
335 
336   PetscFunctionBegin;
337   PetscCall(check_current_device_(dctxa, dctxb));
338   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
339   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
340   PetscFunctionReturn(0);
341 }
342 
343 template <DeviceType T>
344 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
345 {
346   auto idle = PETSC_TRUE;
347 
348   PetscFunctionBegin;
349   PetscCall(query(dctx, &idle));
350   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
351   PetscFunctionReturn(0);
352 }
353 
354 template <DeviceType T>
355 template <typename handle_t>
356 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
357 {
358   PetscFunctionBegin;
359   PetscCall(initialize_handle_(handle_t{}, dctx));
360   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
361   PetscFunctionReturn(0);
362 }
363 
364 template <DeviceType T>
365 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
366 {
367   const auto dci = impls_cast_(dctx);
368 
369   PetscFunctionBegin;
370   PetscCall(check_current_device_(dctx));
371   #if PetscDefined(USE_DEBUG)
372   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
373   dci->timerInUse = PETSC_TRUE;
374   #endif
375   if (!dci->begin) {
376     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
377     PetscCallCUPM(cupmEventCreate(&dci->begin));
378     PetscCallCUPM(cupmEventCreate(&dci->end));
379   }
380   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
381   PetscFunctionReturn(0);
382 }
383 
384 template <DeviceType T>
385 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
386 {
387   float      gtime;
388   const auto dci = impls_cast_(dctx);
389   const auto end = dci->end;
390 
391   PetscFunctionBegin;
392   PetscCall(check_current_device_(dctx));
393   #if PetscDefined(USE_DEBUG)
394   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
395   dci->timerInUse = PETSC_FALSE;
396   #endif
397   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
398   PetscCallCUPM(cupmEventSynchronize(end));
399   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
400   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
401   PetscFunctionReturn(0);
402 }
403 
404 template <DeviceType T>
405 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest))
406 {
407   const auto &stream = impls_cast_(dctx)->stream;
408 
409   PetscFunctionBegin;
410   PetscCall(check_current_device_(dctx));
411   PetscCall(check_memtype_(mtype, "allocating"));
412   if (PetscMemTypeHost(mtype)) {
413     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
414   } else {
415     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
416   }
417   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
418   PetscFunctionReturn(0);
419 }
420 
421 template <DeviceType T>
422 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr))
423 {
424   const auto &stream = impls_cast_(dctx)->stream;
425 
426   PetscFunctionBegin;
427   PetscCall(check_current_device_(dctx));
428   PetscCall(check_memtype_(mtype, "freeing"));
429   if (!*ptr) PetscFunctionReturn(0);
430   if (PetscMemTypeHost(mtype)) {
431     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
432     // if ptr exists still exists the pool didn't own it
433     if (*ptr) {
434       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
435 
436       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
437       if (registered) {
438         PetscCallCUPM(cupmFreeHost(*ptr));
439       } else if (managed) {
440         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
441       }
442     }
443   } else {
444     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
445     // if ptr exists still exists the pool didn't own it
446     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447   }
448   PetscFunctionReturn(0);
449 }
450 
451 template <DeviceType T>
452 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode))
453 {
454   const auto stream = impls_cast_(dctx)->stream.get_stream();
455 
456   PetscFunctionBegin;
457   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
458   if (mode == PETSC_DEVICE_COPY_HTOH) {
459     // yes this is faster
460     if (cupmStreamQuery(stream) == cupmSuccess) {
461       PetscCall(PetscMemcpy(dest, src, n));
462       PetscFunctionReturn(0);
463     }
464     // in case cupmStreamQuery() did not return cupmErrorNotReady
465     PetscCallCUPM(cupmGetLastError());
466   }
467   PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
468   PetscFunctionReturn(0);
469 }
470 
471 template <DeviceType T>
472 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n))
473 {
474   PetscFunctionBegin;
475   PetscCall(check_current_device_(dctx));
476   PetscCall(check_memtype_(mtype, "zeroing"));
477   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
478   PetscFunctionReturn(0);
479 }
480 
481 template <DeviceType T>
482 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event))
483 {
484   PetscFunctionBegin;
485   PetscCallCXX(event->data = new event_type());
486   event->destroy = [](PetscEvent event) {
487     PetscFunctionBegin;
488     delete event_cast_(event);
489     event->data = nullptr;
490     PetscFunctionReturn(0);
491   };
492   PetscFunctionReturn(0);
493 }
494 
495 template <DeviceType T>
496 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event))
497 {
498   PetscFunctionBegin;
499   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
500   PetscFunctionReturn(0);
501 }
502 
503 template <DeviceType T>
504 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event))
505 {
506   PetscFunctionBegin;
507   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
508   PetscFunctionReturn(0);
509 }
510 
511 // initialize the static member variables
512 template <DeviceType T>
513 bool DeviceContext<T>::initialized_ = false;
514 
515 template <DeviceType T>
516 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
517 
518 template <DeviceType T>
519 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
520 
521 } // namespace impl
522 
523 // shorten this one up a bit (and instantiate the templates)
524 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
525 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
526 
527   // shorthand for what is an EXTREMELY long name
528   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
529 
530 } // namespace cupm
531 
532 } // namespace device
533 
534 } // namespace Petsc
535 
536 #endif // __cplusplus
537 
538 #endif // PETSCDEVICECONTEXTCUDA_HPP
539