xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision c1730cc1cd3b6fc6205d70c430a9ca69e7268eec)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 #if defined(__cplusplus)
16 
17 namespace Petsc
18 {
19 
20 namespace device
21 {
22 
23 namespace cupm
24 {
25 
26 namespace impl
27 {
28 
29 template <DeviceType T>
30 class DeviceContext : BlasInterface<T> {
31 public:
32   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);
33 
34 private:
35   template <typename H, std::size_t>
36   struct HandleTag {
37     using type = H;
38   };
39 
40   using stream_tag = HandleTag<cupmStream_t, 0>;
41   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
42   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
43 
44   using stream_type = CUPMStream<T>;
45   using event_type  = CUPMEvent<T>;
46 
47 public:
48   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
49   // header, but since we are using the power of templates it must be declared part of
50   // this class to have easy access the same typedefs. Technically one can make a
51   // templated struct outside the class but it's more code for the same result.
52   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
53     stream_type stream{};
54     cupmEvent_t event{};
55     cupmEvent_t begin{}; // timer-only
56     cupmEvent_t end{};   // timer-only
57   #if PetscDefined(USE_DEBUG)
58     PetscBool timerInUse{};
59   #endif
60     cupmBlasHandle_t   blas{};
61     cupmSolverHandle_t solver{};
62 
63     constexpr PetscDeviceContext_IMPLS() noexcept = default;
64 
65     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }
66 
67     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }
68 
69     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
70   };
71 
72 private:
73   static bool initialized_;
74 
75   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
76   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
77 
78   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
79 
80   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
81 
82   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
83 
84   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
85 
86   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
87   // handles
88   PETSC_NODISCARD static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return 0; }
89 
90   PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept
91   {
92     PetscLogEvent event;
93 
94     PetscFunctionBegin;
95     if (PetscLikely(handle)) PetscFunctionReturn(0);
96     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
97     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
98     for (auto i = 0; i < 3; ++i) {
99       auto cberr = cupmBlasCreate(&handle);
100       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
101       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
102       if (i != 2) {
103         PetscCall(PetscSleep(3));
104         continue;
105       }
106       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
107     }
108     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
109     PetscCall(PetscLogEventResume_Internal(event));
110     PetscFunctionReturn(0);
111   }
112 
113   PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept
114   {
115     const auto dci    = impls_cast_(dctx);
116     auto      &handle = blashandles_[dctx->device->deviceId];
117 
118     PetscFunctionBegin;
119     PetscCall(create_handle_(tag, handle));
120     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
121     dci->blas = handle;
122     PetscFunctionReturn(0);
123   }
124 
125   PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
126   {
127     const auto    dci    = impls_cast_(dctx);
128     auto         &handle = solverhandles_[dctx->device->deviceId];
129     PetscLogEvent event;
130 
131     PetscFunctionBegin;
132     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
133     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
135     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
136     PetscCall(PetscLogEventResume_Internal(event));
137     PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream()));
138     dci->solver = handle;
139     PetscFunctionReturn(0);
140   }
141 
142   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
143   {
144     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
145 
146     PetscFunctionBegin;
147     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
148                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
149     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
150     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
151     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
152     PetscFunctionReturn(0);
153   }
154 
155   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
156 
157   PETSC_NODISCARD static PetscErrorCode finalize_() noexcept
158   {
159     PetscFunctionBegin;
160     for (auto &&handle : blashandles_) {
161       if (handle) {
162         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
163         handle = nullptr;
164       }
165     }
166 
167     for (auto &&handle : solverhandles_) {
168       if (handle) {
169         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
170         handle = nullptr;
171       }
172     }
173     initialized_ = false;
174     PetscFunctionReturn(0);
175   }
176 
177   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
178   PETSC_NODISCARD static PoolType &default_pool_() noexcept
179   {
180     static PoolType pool;
181     return pool;
182   }
183 
184   PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
185   {
186     PetscFunctionBegin;
187     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
188     PetscFunctionReturn(0);
189   }
190 
191 public:
192   // All of these functions MUST be static in order to be callable from C, otherwise they
193   // get the implicit 'this' pointer tacked on
194   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
195   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
196   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
197   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
198   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
199   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
200   template <typename Handle_t>
201   PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
202   PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203   PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204   PETSC_NODISCARD static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205   PETSC_NODISCARD static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206   PETSC_NODISCARD static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207   PETSC_NODISCARD static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208   PETSC_NODISCARD static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209   PETSC_NODISCARD static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210   PETSC_NODISCARD static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
211 
212   // not a PetscDeviceContext method, this registers the class
213   PETSC_NODISCARD static PetscErrorCode initialize(PetscDevice) noexcept;
214 
215   // clang-format off
216   const _DeviceContextOps ops = {
217     destroy,
218     changeStreamType,
219     setUp,
220     query,
221     waitForContext,
222     synchronize,
223     getHandle<blas_tag>,
224     getHandle<solver_tag>,
225     getHandle<stream_tag>,
226     beginTimer,
227     endTimer,
228     memAlloc,
229     memFree,
230     memCopy,
231     memSet,
232     createEvent,
233     recordEvent,
234     waitForEvent
235   };
236   // clang-format on
237 };
238 
239 // not a PetscDeviceContext method, this initializes the CLASS
240 template <DeviceType T>
241 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242 {
243   PetscFunctionBegin;
244   if (PetscUnlikely(!initialized_)) {
245     uint64_t      threshold = UINT64_MAX;
246     cupmMemPool_t mempool;
247 
248     initialized_ = true;
249     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
250     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
251     blashandles_.fill(nullptr);
252     solverhandles_.fill(nullptr);
253     PetscCall(PetscRegisterFinalize(finalize_));
254   }
255   PetscFunctionReturn(0);
256 }
257 
258 template <DeviceType T>
259 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260 {
261   PetscFunctionBegin;
262   if (const auto dci = impls_cast_(dctx)) {
263     PetscCall(dci->stream.destroy());
264     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event)));
265     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
266     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
267     delete dci;
268     dctx->data = nullptr;
269   }
270   PetscFunctionReturn(0);
271 }
272 
273 template <DeviceType T>
274 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275 {
276   const auto dci = impls_cast_(dctx);
277 
278   PetscFunctionBegin;
279   PetscCall(dci->stream.destroy());
280   // set these to null so they aren't usable until setup is called again
281   dci->blas   = nullptr;
282   dci->solver = nullptr;
283   PetscFunctionReturn(0);
284 }
285 
286 template <DeviceType T>
287 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288 {
289   const auto dci   = impls_cast_(dctx);
290   auto      &event = dci->event;
291 
292   PetscFunctionBegin;
293   PetscCall(check_current_device_(dctx));
294   PetscCall(dci->stream.change_type(dctx->streamType));
295   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296   #if PetscDefined(USE_DEBUG)
297   dci->timerInUse = PETSC_FALSE;
298   #endif
299   PetscFunctionReturn(0);
300 }
301 
302 template <DeviceType T>
303 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304 {
305   PetscFunctionBegin;
306   PetscCall(check_current_device_(dctx));
307   switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308   case cupmSuccess:
309     *idle = PETSC_TRUE;
310     break;
311   case cupmErrorNotReady:
312     *idle = PETSC_FALSE;
313     break;
314   default:
315     PetscCallCUPM(cerr);
316     PetscUnreachable();
317   }
318   PetscFunctionReturn(0);
319 }
320 
321 template <DeviceType T>
322 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
323 {
324   const auto dcib  = impls_cast_(dctxb);
325   const auto event = dcib->event;
326 
327   PetscFunctionBegin;
328   PetscCall(check_current_device_(dctxa, dctxb));
329   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
330   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
331   PetscFunctionReturn(0);
332 }
333 
334 template <DeviceType T>
335 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
336 {
337   auto idle = PETSC_TRUE;
338 
339   PetscFunctionBegin;
340   PetscCall(query(dctx, &idle));
341   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
342   PetscFunctionReturn(0);
343 }
344 
345 template <DeviceType T>
346 template <typename handle_t>
347 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
348 {
349   PetscFunctionBegin;
350   PetscCall(initialize_handle_(handle_t{}, dctx));
351   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
352   PetscFunctionReturn(0);
353 }
354 
355 template <DeviceType T>
356 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
357 {
358   const auto dci = impls_cast_(dctx);
359 
360   PetscFunctionBegin;
361   PetscCall(check_current_device_(dctx));
362   #if PetscDefined(USE_DEBUG)
363   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
364   dci->timerInUse = PETSC_TRUE;
365   #endif
366   if (!dci->begin) {
367     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
368     PetscCallCUPM(cupmEventCreate(&dci->begin));
369     PetscCallCUPM(cupmEventCreate(&dci->end));
370   }
371   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
372   PetscFunctionReturn(0);
373 }
374 
375 template <DeviceType T>
376 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
377 {
378   float      gtime;
379   const auto dci = impls_cast_(dctx);
380   const auto end = dci->end;
381 
382   PetscFunctionBegin;
383   PetscCall(check_current_device_(dctx));
384   #if PetscDefined(USE_DEBUG)
385   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
386   dci->timerInUse = PETSC_FALSE;
387   #endif
388   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
389   PetscCallCUPM(cupmEventSynchronize(end));
390   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
391   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
392   PetscFunctionReturn(0);
393 }
394 
395 template <DeviceType T>
396 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
397 {
398   const auto &stream = impls_cast_(dctx)->stream;
399 
400   PetscFunctionBegin;
401   PetscCall(check_current_device_(dctx));
402   PetscCall(check_memtype_(mtype, "allocating"));
403   if (PetscMemTypeHost(mtype)) {
404     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
405   } else {
406     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
407   }
408   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
409   PetscFunctionReturn(0);
410 }
411 
412 template <DeviceType T>
413 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
414 {
415   const auto &stream = impls_cast_(dctx)->stream;
416 
417   PetscFunctionBegin;
418   PetscCall(check_current_device_(dctx));
419   PetscCall(check_memtype_(mtype, "freeing"));
420   if (!*ptr) PetscFunctionReturn(0);
421   if (PetscMemTypeHost(mtype)) {
422     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
423     // if ptr exists still exists the pool didn't own it
424     if (*ptr) {
425       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
426 
427       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
428       if (registered) {
429         PetscCallCUPM(cupmFreeHost(*ptr));
430       } else if (managed) {
431         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
432       }
433     }
434   } else {
435     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
436     // if ptr still exists the pool didn't own it
437     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
438   }
439   PetscFunctionReturn(0);
440 }
441 
442 template <DeviceType T>
443 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
444 {
445   const auto stream = impls_cast_(dctx)->stream.get_stream();
446 
447   PetscFunctionBegin;
448   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
449   if (mode == PETSC_DEVICE_COPY_HTOH) {
450     const auto cerr = cupmStreamQuery(stream);
451 
452     // yes this is faster
453     if (cerr == cupmSuccess) {
454       PetscCall(PetscMemcpy(dest, src, n));
455       PetscFunctionReturn(0);
456     } else if (cerr == cupmErrorNotReady) {
457       auto PETSC_UNUSED unused = cupmGetLastError();
458 
459       static_cast<void>(unused);
460     } else {
461       PetscCallCUPM(cerr);
462     }
463   }
464   PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
465   PetscFunctionReturn(0);
466 }
467 
468 template <DeviceType T>
469 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
470 {
471   PetscFunctionBegin;
472   PetscCall(check_current_device_(dctx));
473   PetscCall(check_memtype_(mtype, "zeroing"));
474   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
475   PetscFunctionReturn(0);
476 }
477 
478 template <DeviceType T>
479 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
480 {
481   PetscFunctionBegin;
482   PetscCallCXX(event->data = new event_type());
483   event->destroy = [](PetscEvent event) {
484     PetscFunctionBegin;
485     delete event_cast_(event);
486     event->data = nullptr;
487     PetscFunctionReturn(0);
488   };
489   PetscFunctionReturn(0);
490 }
491 
492 template <DeviceType T>
493 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
494 {
495   PetscFunctionBegin;
496   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
497   PetscFunctionReturn(0);
498 }
499 
500 template <DeviceType T>
501 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
502 {
503   PetscFunctionBegin;
504   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
505   PetscFunctionReturn(0);
506 }
507 
508 // initialize the static member variables
509 template <DeviceType T>
510 bool DeviceContext<T>::initialized_ = false;
511 
512 template <DeviceType T>
513 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
514 
515 template <DeviceType T>
516 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
517 
518 } // namespace impl
519 
520 // shorten this one up a bit (and instantiate the templates)
521 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
522 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
523 
524   // shorthand for what is an EXTREMELY long name
525   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
526 
527 } // namespace cupm
528 
529 } // namespace device
530 
531 } // namespace Petsc
532 
533 #endif // __cplusplus
534 
535 #endif // PETSCDEVICECONTEXTCUDA_HPP
536