xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 1cc06b555e92f8ec64db10330b8bbd830e5bc876) !
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmsolverinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 namespace Petsc
16 {
17 
18 namespace device
19 {
20 
21 namespace cupm
22 {
23 
24 namespace impl
25 {
26 
27 template <DeviceType T>
28 class DeviceContext : SolverInterface<T> {
29 public:
30   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31 
32 private:
33   template <typename H, std::size_t>
34   struct HandleTag {
35     using type = H;
36   };
37 
38   using stream_tag = HandleTag<cupmStream_t, 0>;
39   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
40   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
41 
42   using stream_type = CUPMStream<T>;
43   using event_type  = CUPMEvent<T>;
44 
45 public:
46   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
47   // header, but since we are using the power of templates it must be declared part of
48   // this class to have easy access the same typedefs. Technically one can make a
49   // templated struct outside the class but it's more code for the same result.
50   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
51     stream_type stream{};
52     cupmEvent_t event{};
53     cupmEvent_t begin{}; // timer-only
54     cupmEvent_t end{};   // timer-only
55 #if PetscDefined(USE_DEBUG)
56     PetscBool timerInUse{};
57 #endif
58     cupmBlasHandle_t   blas{};
59     cupmSolverHandle_t solver{};
60 
61     constexpr PetscDeviceContext_IMPLS() noexcept = default;
62 
63     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }
64 
65     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }
66 
67     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
68   };
69 
70 private:
71   static bool initialized_;
72 
73   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
74   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
75 
76   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
77 
78   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
79 
80   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
81 
82   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
83 
84   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
85   // handles
86   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
87 
88   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
89   {
90     const auto dci    = impls_cast_(dctx);
91     auto      &handle = blashandles_[dctx->device->deviceId];
92 
93     PetscFunctionBegin;
94     if (!handle) {
95       PetscLogEvent event;
96 
97       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
98       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
99       for (auto i = 0; i < 3; ++i) {
100         const auto cberr = cupmBlasCreate(handle.ptr_to());
101         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
102         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
103         if (i != 2) {
104           PetscCall(PetscSleep(3));
105           continue;
106         }
107         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
108       }
109       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
110       PetscCall(PetscLogEventResume_Internal(event));
111     }
112     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
113     dci->blas = handle;
114     PetscFunctionReturn(PETSC_SUCCESS);
115   }
116 
117   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
118   {
119     const auto dci    = impls_cast_(dctx);
120     auto      &handle = solverhandles_[dctx->device->deviceId];
121 
122     PetscFunctionBegin;
123     if (!handle) {
124       PetscLogEvent event;
125 
126       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
127       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
128       for (auto i = 0; i < 3; ++i) {
129         const auto cerr = cupmSolverCreate(&handle);
130         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
131         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
132         if (i < 2) {
133           PetscCall(PetscSleep(3));
134           continue;
135         }
136         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
137       }
138       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
139       PetscCall(PetscLogEventResume_Internal(event));
140     }
141     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
142     dci->solver = handle;
143     PetscFunctionReturn(PETSC_SUCCESS);
144   }
145 
146   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
147   {
148     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
149 
150     PetscFunctionBegin;
151     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
152                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
153     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
154     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
155     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
156     PetscFunctionReturn(PETSC_SUCCESS);
157   }
158 
159   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
160 
161   static PetscErrorCode finalize_() noexcept
162   {
163     PetscFunctionBegin;
164     for (auto &&handle : blashandles_) {
165       if (handle) {
166         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
167         handle = nullptr;
168       }
169     }
170 
171     for (auto &&handle : solverhandles_) {
172       if (handle) {
173         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
174         handle = nullptr;
175       }
176     }
177     initialized_ = false;
178     PetscFunctionReturn(PETSC_SUCCESS);
179   }
180 
181   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
182   PETSC_NODISCARD static PoolType &default_pool_() noexcept
183   {
184     static PoolType pool;
185     return pool;
186   }
187 
188   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189   {
190     PetscFunctionBegin;
191     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
192     PetscFunctionReturn(PETSC_SUCCESS);
193   }
194 
195 public:
196   // All of these functions MUST be static in order to be callable from C, otherwise they
197   // get the implicit 'this' pointer tacked on
198   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204   template <typename Handle_t>
205   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
206   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
207   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
208   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
209   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
210   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
211   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
212   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
213   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
214   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
215 
216   // not a PetscDeviceContext method, this registers the class
217   static PetscErrorCode initialize(PetscDevice) noexcept;
218 
219   // clang-format off
220   static constexpr _DeviceContextOps ops = {
221     PetscDesignatedInitializer(destroy, destroy),
222     PetscDesignatedInitializer(changestreamtype, changeStreamType),
223     PetscDesignatedInitializer(setup, setUp),
224     PetscDesignatedInitializer(query, query),
225     PetscDesignatedInitializer(waitforcontext, waitForContext),
226     PetscDesignatedInitializer(synchronize, synchronize),
227     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
228     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
229     PetscDesignatedInitializer(getstreamhandle, getHandle<stream_tag>),
230     PetscDesignatedInitializer(begintimer, beginTimer),
231     PetscDesignatedInitializer(endtimer, endTimer),
232     PetscDesignatedInitializer(memalloc, memAlloc),
233     PetscDesignatedInitializer(memfree, memFree),
234     PetscDesignatedInitializer(memcopy, memCopy),
235     PetscDesignatedInitializer(memset, memSet),
236     PetscDesignatedInitializer(createevent, createEvent),
237     PetscDesignatedInitializer(recordevent, recordEvent),
238     PetscDesignatedInitializer(waitforevent, waitForEvent)
239   };
240   // clang-format on
241 };
242 
243 // not a PetscDeviceContext method, this initializes the CLASS
244 template <DeviceType T>
245 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
246 {
247   PetscFunctionBegin;
248   if (PetscUnlikely(!initialized_)) {
249     uint64_t      threshold = UINT64_MAX;
250     cupmMemPool_t mempool;
251 
252     initialized_ = true;
253     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
254     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
255     blashandles_.fill(nullptr);
256     solverhandles_.fill(nullptr);
257     PetscCall(PetscRegisterFinalize(finalize_));
258   }
259   PetscFunctionReturn(PETSC_SUCCESS);
260 }
261 
262 template <DeviceType T>
263 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
264 {
265   PetscFunctionBegin;
266   if (const auto dci = impls_cast_(dctx)) {
267     PetscCall(dci->stream.destroy());
268     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
269     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
270     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
271     delete dci;
272     dctx->data = nullptr;
273   }
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 template <DeviceType T>
278 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
279 {
280   const auto dci = impls_cast_(dctx);
281 
282   PetscFunctionBegin;
283   PetscCall(dci->stream.destroy());
284   // set these to null so they aren't usable until setup is called again
285   dci->blas   = nullptr;
286   dci->solver = nullptr;
287   PetscFunctionReturn(PETSC_SUCCESS);
288 }
289 
290 template <DeviceType T>
291 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
292 {
293   const auto dci   = impls_cast_(dctx);
294   auto      &event = dci->event;
295 
296   PetscFunctionBegin;
297   PetscCall(check_current_device_(dctx));
298   PetscCall(dci->stream.change_type(dctx->streamType));
299   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
300 #if PetscDefined(USE_DEBUG)
301   dci->timerInUse = PETSC_FALSE;
302 #endif
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 template <DeviceType T>
307 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
308 {
309   PetscFunctionBegin;
310   PetscCall(check_current_device_(dctx));
311   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
312   case cupmSuccess:
313     *idle = PETSC_TRUE;
314     break;
315   case cupmErrorNotReady:
316     *idle = PETSC_FALSE;
317     // reset the error
318     cerr = cupmGetLastError();
319     static_cast<void>(cerr);
320     break;
321   default:
322     PetscCallCUPM(cerr);
323     PetscUnreachable();
324   }
325   PetscFunctionReturn(PETSC_SUCCESS);
326 }
327 
328 template <DeviceType T>
329 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
330 {
331   const auto dcib  = impls_cast_(dctxb);
332   const auto event = dcib->event;
333 
334   PetscFunctionBegin;
335   PetscCall(check_current_device_(dctxa, dctxb));
336   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
337   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
338   PetscFunctionReturn(PETSC_SUCCESS);
339 }
340 
341 template <DeviceType T>
342 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
343 {
344   auto idle = PETSC_TRUE;
345 
346   PetscFunctionBegin;
347   PetscCall(query(dctx, &idle));
348   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
349   PetscFunctionReturn(PETSC_SUCCESS);
350 }
351 
352 template <DeviceType T>
353 template <typename handle_t>
354 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
355 {
356   PetscFunctionBegin;
357   PetscCall(initialize_handle_(handle_t{}, dctx));
358   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
359   PetscFunctionReturn(PETSC_SUCCESS);
360 }
361 
362 template <DeviceType T>
363 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
364 {
365   const auto dci = impls_cast_(dctx);
366 
367   PetscFunctionBegin;
368   PetscCall(check_current_device_(dctx));
369 #if PetscDefined(USE_DEBUG)
370   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
371   dci->timerInUse = PETSC_TRUE;
372 #endif
373   if (!dci->begin) {
374     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
375     PetscCallCUPM(cupmEventCreate(&dci->begin));
376     PetscCallCUPM(cupmEventCreate(&dci->end));
377   }
378   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
379   PetscFunctionReturn(PETSC_SUCCESS);
380 }
381 
382 template <DeviceType T>
383 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
384 {
385   float      gtime;
386   const auto dci = impls_cast_(dctx);
387   const auto end = dci->end;
388 
389   PetscFunctionBegin;
390   PetscCall(check_current_device_(dctx));
391 #if PetscDefined(USE_DEBUG)
392   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
393   dci->timerInUse = PETSC_FALSE;
394 #endif
395   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
396   PetscCallCUPM(cupmEventSynchronize(end));
397   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
398   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
399   PetscFunctionReturn(PETSC_SUCCESS);
400 }
401 
402 template <DeviceType T>
403 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
404 {
405   const auto &stream = impls_cast_(dctx)->stream;
406 
407   PetscFunctionBegin;
408   PetscCall(check_current_device_(dctx));
409   PetscCall(check_memtype_(mtype, "allocating"));
410   if (PetscMemTypeHost(mtype)) {
411     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
412   } else {
413     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
414   }
415   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
416   PetscFunctionReturn(PETSC_SUCCESS);
417 }
418 
419 template <DeviceType T>
420 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
421 {
422   const auto &stream = impls_cast_(dctx)->stream;
423 
424   PetscFunctionBegin;
425   PetscCall(check_current_device_(dctx));
426   PetscCall(check_memtype_(mtype, "freeing"));
427   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
428   if (PetscMemTypeHost(mtype)) {
429     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
430     // if ptr exists still exists the pool didn't own it
431     if (*ptr) {
432       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
433 
434       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
435       if (registered) {
436         PetscCallCUPM(cupmFreeHost(*ptr));
437       } else if (managed) {
438         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
439       }
440     }
441   } else {
442     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
443     // if ptr still exists the pool didn't own it
444     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
445   }
446   PetscFunctionReturn(PETSC_SUCCESS);
447 }
448 
449 template <DeviceType T>
450 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
451 {
452   const auto stream = impls_cast_(dctx)->stream.get_stream();
453 
454   PetscFunctionBegin;
455   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
456   if (mode == PETSC_DEVICE_COPY_HTOH) {
457     const auto cerr = cupmStreamQuery(stream);
458 
459     // yes this is faster
460     if (cerr == cupmSuccess) {
461       PetscCall(PetscMemcpy(dest, src, n));
462       PetscFunctionReturn(PETSC_SUCCESS);
463     } else if (cerr == cupmErrorNotReady) {
464       auto PETSC_UNUSED unused = cupmGetLastError();
465 
466       static_cast<void>(unused);
467     } else {
468       PetscCallCUPM(cerr);
469     }
470   }
471   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
472   PetscFunctionReturn(PETSC_SUCCESS);
473 }
474 
475 template <DeviceType T>
476 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
477 {
478   PetscFunctionBegin;
479   PetscCall(check_current_device_(dctx));
480   PetscCall(check_memtype_(mtype, "zeroing"));
481   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
482   PetscFunctionReturn(PETSC_SUCCESS);
483 }
484 
485 template <DeviceType T>
486 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
487 {
488   PetscFunctionBegin;
489   PetscCallCXX(event->data = new event_type());
490   event->destroy = [](PetscEvent event) {
491     PetscFunctionBegin;
492     delete event_cast_(event);
493     event->data = nullptr;
494     PetscFunctionReturn(PETSC_SUCCESS);
495   };
496   PetscFunctionReturn(PETSC_SUCCESS);
497 }
498 
499 template <DeviceType T>
500 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
501 {
502   PetscFunctionBegin;
503   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
504   PetscFunctionReturn(PETSC_SUCCESS);
505 }
506 
507 template <DeviceType T>
508 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
509 {
510   PetscFunctionBegin;
511   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
512   PetscFunctionReturn(PETSC_SUCCESS);
513 }
514 
515 // initialize the static member variables
516 template <DeviceType T>
517 bool DeviceContext<T>::initialized_ = false;
518 
519 template <DeviceType T>
520 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
521 
522 template <DeviceType T>
523 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
524 
525 template <DeviceType T>
526 constexpr _DeviceContextOps DeviceContext<T>::ops;
527 
528 } // namespace impl
529 
530 // shorten this one up a bit (and instantiate the templates)
531 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
532 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
533 
534 // shorthand for what is an EXTREMELY long name
535 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
536 
537 } // namespace cupm
538 
539 } // namespace device
540 
541 } // namespace Petsc
542 
543 #endif // PETSCDEVICECONTEXTCUDA_HPP
544