xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 6016107252cfa03568230349a14202a384fbf0c0)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 #if defined(__cplusplus)
16 namespace Petsc {
17 
18 namespace device {
19 
20 namespace cupm {
21 
22 namespace impl {
23 
24 template <DeviceType T>
25 class DeviceContext : BlasInterface<T> {
26 public:
27   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);
28 
29 private:
30   template <typename H, std::size_t>
31   struct HandleTag {
32     using type = H;
33   };
34 
35   using stream_tag = HandleTag<cupmStream_t, 0>;
36   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
37   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
38 
39   using stream_type = CUPMStream<T>;
40   using event_type  = CUPMEvent<T>;
41 
42 public:
43   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
44   // header, but since we are using the power of templates it must be declared part of
45   // this class to have easy access the same typedefs. Technically one can make a
46   // templated struct outside the class but it's more code for the same result.
47   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
48     stream_type stream{};
49     cupmEvent_t event{};
50     cupmEvent_t begin{}; // timer-only
51     cupmEvent_t end{};   // timer-only
52 #if PetscDefined(USE_DEBUG)
53     PetscBool timerInUse{};
54 #endif
55     cupmBlasHandle_t   blas{};
56     cupmSolverHandle_t solver{};
57 
58     constexpr PetscDeviceContext_IMPLS() noexcept = default;
59 
60     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept {
61       return this->stream.get_stream();
62     }
63 
64     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept {
65       return this->blas;
66     }
67 
68     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept {
69       return this->solver;
70     }
71   };
72 
73 private:
74   static bool                                                     initialized_;
75   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
76   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
77 
78   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept {
79     return static_cast<PetscDeviceContext_IMPLS *>(ptr->data);
80   }
81 
82   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept {
83     return static_cast<CUPMEvent<T> *>(event->data);
84   }
85 
86   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept {
87     return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE;
88   }
89 
90   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept {
91     return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE;
92   }
93 
94   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
95   // handles
96   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) {
97     return 0;
98   }
99 
100   PETSC_NODISCARD static PetscErrorCode create_handle_(cupmBlasHandle_t &handle) noexcept {
101     PetscLogEvent event;
102 
103     PetscFunctionBegin;
104     if (PetscLikely(handle)) PetscFunctionReturn(0);
105     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
106     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
107     for (auto i = 0; i < 3; ++i) {
108       auto cberr = cupmBlasCreate(&handle);
109       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
110       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
111       if (i != 2) {
112         PetscCall(PetscSleep(3));
113         continue;
114       }
115       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
116     }
117     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
118     PetscCall(PetscLogEventResume_Internal(event));
119     PetscFunctionReturn(0);
120   }
121 
122   PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept {
123     const auto dci    = impls_cast_(dctx);
124     auto      &handle = blashandles_[dctx->device->deviceId];
125 
126     PetscFunctionBegin;
127     PetscCall(create_handle_(handle));
128     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
129     dci->blas = handle;
130     PetscFunctionReturn(0);
131   }
132 
133   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) {
134     PetscLogEvent event;
135 
136     PetscFunctionBegin;
137     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
138     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
139     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
140     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
141     PetscCall(PetscLogEventResume_Internal(event));
142     PetscFunctionReturn(0);
143   }
144 
145   PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept {
146     const auto dci    = impls_cast_(dctx);
147     auto      &handle = solverhandles_[dctx->device->deviceId];
148 
149     PetscFunctionBegin;
150     PetscCall(create_handle_(handle));
151     PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream()));
152     dci->solver = handle;
153     PetscFunctionReturn(0);
154   }
155 
156   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept {
157     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
158 
159     PetscFunctionBegin;
160     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
161                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
162     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
163     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
164     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
165     PetscFunctionReturn(0);
166   }
167 
168   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept {
169     return check_current_device_(dctx, dctx);
170   }
171 
172   PETSC_NODISCARD static PetscErrorCode finalize_() noexcept {
173     PetscFunctionBegin;
174     for (auto &&handle : blashandles_) {
175       if (handle) {
176         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
177         handle = nullptr;
178       }
179     }
180     for (auto &&handle : solverhandles_) {
181       if (handle) {
182         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
183         handle = nullptr;
184       }
185     }
186     initialized_ = false;
187     PetscFunctionReturn(0);
188   }
189 
190   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
191   PETSC_NODISCARD static PoolType &default_pool_() noexcept {
192     static PoolType pool;
193     return pool;
194   }
195 
196   PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept {
197     PetscFunctionBegin;
198     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
199     PetscFunctionReturn(0);
200   }
201 
202 public:
203   // All of these functions MUST be static in order to be callable from C, otherwise they
204   // get the implicit 'this' pointer tacked on
205   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
206   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType));
207   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
208   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *));
209   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext));
210   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
211   template <typename Handle_t>
212   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *));
213   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
214   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *));
215   PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, void **));
216   PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **));
217   PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode));
218   PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t));
219   PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent));
220   PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent));
221   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent));
222 
223   // not a PetscDeviceContext method, this registers the class
224   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());
225 
226   // clang-format off
227   const _DeviceContextOps ops = {
228     destroy,
229     changeStreamType,
230     setUp,
231     query,
232     waitForContext,
233     synchronize,
234     getHandle<blas_tag>,
235     getHandle<solver_tag>,
236     getHandle<stream_tag>,
237     beginTimer,
238     endTimer,
239     memAlloc,
240     memFree,
241     memCopy,
242     memSet,
243     createEvent,
244     recordEvent,
245     waitForEvent
246   };
247   // clang-format on
248 };
249 
250 // not a PetscDeviceContext method, this initializes the CLASS
251 template <DeviceType T>
252 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) {
253   PetscFunctionBegin;
254   if (PetscUnlikely(!initialized_)) {
255     cupmMemPool_t mempool;
256     uint64_t      threshold = UINT64_MAX;
257 
258     initialized_ = true;
259     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0));
260     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
261     blashandles_.fill(nullptr);
262     solverhandles_.fill(nullptr);
263     PetscCall(PetscRegisterFinalize(finalize_));
264   }
265   PetscFunctionReturn(0);
266 }
267 
268 template <DeviceType T>
269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) {
270   PetscFunctionBegin;
271   if (const auto dci = impls_cast_(dctx)) {
272     PetscCall(dci->stream.destroy());
273     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event)));
274     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
275     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
276     delete dci;
277     dctx->data = nullptr;
278   }
279   PetscFunctionReturn(0);
280 }
281 
282 template <DeviceType T>
283 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) {
284   const auto dci = impls_cast_(dctx);
285 
286   PetscFunctionBegin;
287   PetscCall(dci->stream.destroy());
288   // set these to null so they aren't usable until setup is called again
289   dci->blas   = nullptr;
290   dci->solver = nullptr;
291   PetscFunctionReturn(0);
292 }
293 
294 template <DeviceType T>
295 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) {
296   const auto dci   = impls_cast_(dctx);
297   auto      &event = dci->event;
298 
299   PetscFunctionBegin;
300   PetscCall(check_current_device_(dctx));
301   PetscCall(dci->stream.change_type(dctx->streamType));
302   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
303 #if PetscDefined(USE_DEBUG)
304   dci->timerInUse = PETSC_FALSE;
305 #endif
306   PetscFunctionReturn(0);
307 }
308 
309 template <DeviceType T>
310 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) {
311   PetscFunctionBegin;
312   PetscCall(check_current_device_(dctx));
313   switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
314   case cupmSuccess: *idle = PETSC_TRUE; break;
315   case cupmErrorNotReady: *idle = PETSC_FALSE; break;
316   default: PetscCallCUPM(cerr); PetscUnreachable();
317   }
318   PetscFunctionReturn(0);
319 }
320 
321 template <DeviceType T>
322 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) {
323   const auto dcib  = impls_cast_(dctxb);
324   const auto event = dcib->event;
325 
326   PetscFunctionBegin;
327   PetscCall(check_current_device_(dctxa, dctxb));
328   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
329   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
330   PetscFunctionReturn(0);
331 }
332 
333 template <DeviceType T>
334 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) {
335   auto idle = PETSC_TRUE;
336 
337   PetscFunctionBegin;
338   PetscCall(query(dctx, &idle));
339   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
340   PetscFunctionReturn(0);
341 }
342 
343 template <DeviceType T>
344 template <typename handle_t>
345 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) {
346   PetscFunctionBegin;
347   PetscCall(initialize_handle_(handle_t{}, dctx));
348   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
349   PetscFunctionReturn(0);
350 }
351 
352 template <DeviceType T>
353 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) {
354   const auto dci = impls_cast_(dctx);
355 
356   PetscFunctionBegin;
357   PetscCall(check_current_device_(dctx));
358 #if PetscDefined(USE_DEBUG)
359   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
360   dci->timerInUse = PETSC_TRUE;
361 #endif
362   if (!dci->begin) {
363     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
364     PetscCallCUPM(cupmEventCreate(&dci->begin));
365     PetscCallCUPM(cupmEventCreate(&dci->end));
366   }
367   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
368   PetscFunctionReturn(0);
369 }
370 
371 template <DeviceType T>
372 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) {
373   float      gtime;
374   const auto dci = impls_cast_(dctx);
375   const auto end = dci->end;
376 
377   PetscFunctionBegin;
378   PetscCall(check_current_device_(dctx));
379 #if PetscDefined(USE_DEBUG)
380   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
381   dci->timerInUse = PETSC_FALSE;
382 #endif
383   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
384   PetscCallCUPM(cupmEventSynchronize(end));
385   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
386   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
387   PetscFunctionReturn(0);
388 }
389 
390 template <DeviceType T>
391 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, void **dest)) {
392   const auto &stream = impls_cast_(dctx)->stream;
393 
394   PetscFunctionBegin;
395   PetscCall(check_current_device_(dctx));
396   PetscCall(check_memtype_(mtype, "allocating"));
397   if (PetscMemTypeHost(mtype)) {
398     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream));
399     if (clear) std::memset(*dest, 0, n);
400   } else {
401     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream));
402     if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
403   }
404   PetscFunctionReturn(0);
405 }
406 
407 template <DeviceType T>
408 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) {
409   const auto &stream = impls_cast_(dctx)->stream;
410 
411   PetscFunctionBegin;
412   PetscCall(check_current_device_(dctx));
413   PetscCall(check_memtype_(mtype, "freeing"));
414   if (!*ptr) PetscFunctionReturn(0);
415   if (PetscMemTypeHost(mtype)) {
416     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
417     // if ptr exists still exists the pool didn't own it
418     if (*ptr) {
419       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
420 
421       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
422       if (registered) {
423         PetscCallCUPM(cupmFreeHost(*ptr));
424       } else if (managed) {
425         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
426       }
427     }
428   } else {
429     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
430     // if ptr exists still exists the pool didn't own it
431     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
432   }
433   PetscFunctionReturn(0);
434 }
435 
436 template <DeviceType T>
437 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) {
438   const auto stream = impls_cast_(dctx)->stream.get_stream();
439 
440   PetscFunctionBegin;
441   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
442   if (mode == PETSC_DEVICE_COPY_HTOH) {
443     // yes this is faster
444     if (cupmStreamQuery(stream) == cupmSuccess) {
445       PetscCall(PetscMemcpy(dest, src, n));
446       PetscFunctionReturn(0);
447     }
448     // in case cupmStreamQuery() did not return cupmErrorNotReady
449     PetscCallCUPM(cupmGetLastError());
450   }
451   PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
452   PetscFunctionReturn(0);
453 }
454 
455 template <DeviceType T>
456 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) {
457   auto vint = static_cast<int>(v);
458 
459   PetscFunctionBegin;
460   PetscCall(check_current_device_(dctx));
461   PetscCall(check_memtype_(mtype, "zeroing"));
462   if (PetscMemTypeHost(mtype)) {
463     // must call public sync to prune the dependency graph
464     PetscCall(PetscDeviceContextSynchronize(dctx));
465     std::memset(ptr, vint, n);
466   } else {
467     PetscCallCUPM(cupmMemsetAsync(ptr, vint, n, impls_cast_(dctx)->stream.get_stream()));
468   }
469   PetscFunctionReturn(0);
470 }
471 
472 template <DeviceType T>
473 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) {
474   PetscFunctionBegin;
475   PetscCallCXX(event->data = new event_type());
476   event->destroy = [](PetscEvent event) {
477     PetscFunctionBegin;
478     delete event_cast_(event);
479     event->data = nullptr;
480     PetscFunctionReturn(0);
481   };
482   PetscFunctionReturn(0);
483 }
484 
485 template <DeviceType T>
486 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) {
487   PetscFunctionBegin;
488   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
489   PetscFunctionReturn(0);
490 }
491 
492 template <DeviceType T>
493 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) {
494   PetscFunctionBegin;
495   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
496   PetscFunctionReturn(0);
497 }
498 
499 // initialize the static member variables
500 template <DeviceType T>
501 bool DeviceContext<T>::initialized_ = false;
502 
503 template <DeviceType T>
504 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
505 
506 template <DeviceType T>
507 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
508 
509 } // namespace impl
510 
511 // shorten this one up a bit (and instantiate the templates)
512 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
513 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
514 
515 // shorthand for what is an EXTREMELY long name
516 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
517 
518 } // namespace cupm
519 
520 } // namespace device
521 
522 } // namespace Petsc
523 
524 #endif // __cplusplus
525 
526 #endif // PETSCDEVICECONTEXTCUDA_HPP
527