xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 2d1571503b53ddc44d5587ea811639ee03673e35)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmsolverinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 #if defined(__cplusplus)
16 
17 namespace Petsc
18 {
19 
20 namespace device
21 {
22 
23 namespace cupm
24 {
25 
26 namespace impl
27 {
28 
29 template <DeviceType T>
30 class DeviceContext : SolverInterface<T> {
31 public:
32   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
33 
34 private:
35   template <typename H, std::size_t>
36   struct HandleTag {
37     using type = H;
38   };
39 
40   using stream_tag = HandleTag<cupmStream_t, 0>;
41   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
42   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
43 
44   using stream_type = CUPMStream<T>;
45   using event_type  = CUPMEvent<T>;
46 
47 public:
48   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
49   // header, but since we are using the power of templates it must be declared part of
50   // this class to have easy access the same typedefs. Technically one can make a
51   // templated struct outside the class but it's more code for the same result.
52   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
53     stream_type stream{};
54     cupmEvent_t event{};
55     cupmEvent_t begin{}; // timer-only
56     cupmEvent_t end{};   // timer-only
57   #if PetscDefined(USE_DEBUG)
58     PetscBool timerInUse{};
59   #endif
60     cupmBlasHandle_t   blas{};
61     cupmSolverHandle_t solver{};
62 
63     constexpr PetscDeviceContext_IMPLS() noexcept = default;
64 
65     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
66 
67     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
68 
69     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
70   };
71 
72 private:
73   static bool initialized_;
74 
75   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
76   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
77 
78   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
79 
80   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
81 
82   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
83 
84   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
85 
86   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
87   // handles
88   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
89 
90   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
91   {
92     const auto dci    = impls_cast_(dctx);
93     auto      &handle = blashandles_[dctx->device->deviceId];
94 
95     PetscFunctionBegin;
96     if (!handle) {
97       PetscLogEvent event;
98 
99       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
100       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
101       for (auto i = 0; i < 3; ++i) {
102         const auto cberr = cupmBlasCreate(handle.ptr_to());
103         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
104         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
105         if (i != 2) {
106           PetscCall(PetscSleep(3));
107           continue;
108         }
109         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
110       }
111       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
112       PetscCall(PetscLogEventResume_Internal(event));
113     }
114     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
115     dci->blas = handle;
116     PetscFunctionReturn(PETSC_SUCCESS);
117   }
118 
119   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
120   {
121     const auto dci    = impls_cast_(dctx);
122     auto      &handle = solverhandles_[dctx->device->deviceId];
123 
124     PetscFunctionBegin;
125     if (!handle) {
126       PetscLogEvent event;
127 
128       PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
129       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
130       for (auto i = 0; i < 3; ++i) {
131         const auto cerr = cupmSolverCreate(&handle);
132         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
133         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
134         if (i < 2) {
135           PetscCall(PetscSleep(3));
136           continue;
137         }
138         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
139       }
140       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
141       PetscCall(PetscLogEventResume_Internal(event));
142     }
143     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
144     dci->solver = handle;
145     PetscFunctionReturn(PETSC_SUCCESS);
146   }
147 
148   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
149   {
150     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
151 
152     PetscFunctionBegin;
153     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
154                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
155     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
156     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
157     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
158     PetscFunctionReturn(PETSC_SUCCESS);
159   }
160 
161   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
162 
163   static PetscErrorCode finalize_() noexcept
164   {
165     PetscFunctionBegin;
166     for (auto &&handle : blashandles_) {
167       if (handle) {
168         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
169         handle = nullptr;
170       }
171     }
172 
173     for (auto &&handle : solverhandles_) {
174       if (handle) {
175         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
176         handle = nullptr;
177       }
178     }
179     initialized_ = false;
180     PetscFunctionReturn(PETSC_SUCCESS);
181   }
182 
183   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
184   PETSC_NODISCARD static PoolType &default_pool_() noexcept
185   {
186     static PoolType pool;
187     return pool;
188   }
189 
190   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
191   {
192     PetscFunctionBegin;
193     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
194     PetscFunctionReturn(PETSC_SUCCESS);
195   }
196 
197 public:
198   // All of these functions MUST be static in order to be callable from C, otherwise they
199   // get the implicit 'this' pointer tacked on
200   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
201   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
202   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
203   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
204   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
205   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
206   template <typename Handle_t>
207   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
208   template <typename Handle_t>
209   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
210   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
211   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
212   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
213   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
214   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
215   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
216   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
217   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
218   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
219 
220   // not a PetscDeviceContext method, this registers the class
221   static PetscErrorCode initialize(PetscDevice) noexcept;
222 
223   // clang-format off
224   static constexpr _DeviceContextOps ops = {
225     PetscDesignatedInitializer(destroy, destroy),
226     PetscDesignatedInitializer(changestreamtype, changeStreamType),
227     PetscDesignatedInitializer(setup, setUp),
228     PetscDesignatedInitializer(query, query),
229     PetscDesignatedInitializer(waitforcontext, waitForContext),
230     PetscDesignatedInitializer(synchronize, synchronize),
231     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
232     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
233     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
234     PetscDesignatedInitializer(begintimer, beginTimer),
235     PetscDesignatedInitializer(endtimer, endTimer),
236     PetscDesignatedInitializer(memalloc, memAlloc),
237     PetscDesignatedInitializer(memfree, memFree),
238     PetscDesignatedInitializer(memcopy, memCopy),
239     PetscDesignatedInitializer(memset, memSet),
240     PetscDesignatedInitializer(createevent, createEvent),
241     PetscDesignatedInitializer(recordevent, recordEvent),
242     PetscDesignatedInitializer(waitforevent, waitForEvent)
243   };
244   // clang-format on
245 };
246 
247 // not a PetscDeviceContext method, this initializes the CLASS
248 template <DeviceType T>
249 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
250 {
251   PetscFunctionBegin;
252   if (PetscUnlikely(!initialized_)) {
253     uint64_t      threshold = UINT64_MAX;
254     cupmMemPool_t mempool;
255 
256     initialized_ = true;
257     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
258     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
259     blashandles_.fill(nullptr);
260     solverhandles_.fill(nullptr);
261     PetscCall(PetscRegisterFinalize(finalize_));
262   }
263   PetscFunctionReturn(PETSC_SUCCESS);
264 }
265 
266 template <DeviceType T>
267 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
268 {
269   PetscFunctionBegin;
270   if (const auto dci = impls_cast_(dctx)) {
271     PetscCall(dci->stream.destroy());
272     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
273     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
274     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
275     delete dci;
276     dctx->data = nullptr;
277   }
278   PetscFunctionReturn(PETSC_SUCCESS);
279 }
280 
281 template <DeviceType T>
282 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
283 {
284   const auto dci = impls_cast_(dctx);
285 
286   PetscFunctionBegin;
287   PetscCall(dci->stream.destroy());
288   // set these to null so they aren't usable until setup is called again
289   dci->blas   = nullptr;
290   dci->solver = nullptr;
291   PetscFunctionReturn(PETSC_SUCCESS);
292 }
293 
294 template <DeviceType T>
295 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
296 {
297   const auto dci   = impls_cast_(dctx);
298   auto      &event = dci->event;
299 
300   PetscFunctionBegin;
301   PetscCall(check_current_device_(dctx));
302   PetscCall(dci->stream.change_type(dctx->streamType));
303   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
304   #if PetscDefined(USE_DEBUG)
305   dci->timerInUse = PETSC_FALSE;
306   #endif
307   PetscFunctionReturn(PETSC_SUCCESS);
308 }
309 
310 template <DeviceType T>
311 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
312 {
313   PetscFunctionBegin;
314   PetscCall(check_current_device_(dctx));
315   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
316   case cupmSuccess:
317     *idle = PETSC_TRUE;
318     break;
319   case cupmErrorNotReady:
320     *idle = PETSC_FALSE;
321     // reset the error
322     cerr = cupmGetLastError();
323     static_cast<void>(cerr);
324     break;
325   default:
326     PetscCallCUPM(cerr);
327     PetscUnreachable();
328   }
329   PetscFunctionReturn(PETSC_SUCCESS);
330 }
331 
332 template <DeviceType T>
333 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
334 {
335   const auto dcib  = impls_cast_(dctxb);
336   const auto event = dcib->event;
337 
338   PetscFunctionBegin;
339   PetscCall(check_current_device_(dctxa, dctxb));
340   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
341   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
342   PetscFunctionReturn(PETSC_SUCCESS);
343 }
344 
345 template <DeviceType T>
346 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
347 {
348   auto idle = PETSC_TRUE;
349 
350   PetscFunctionBegin;
351   PetscCall(query(dctx, &idle));
352   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
353   PetscFunctionReturn(PETSC_SUCCESS);
354 }
355 
356 template <DeviceType T>
357 template <typename handle_t>
358 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
359 {
360   PetscFunctionBegin;
361   PetscCall(initialize_handle_(handle_t{}, dctx));
362   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
363   PetscFunctionReturn(PETSC_SUCCESS);
364 }
365 
366 template <DeviceType T>
367 template <typename handle_t>
368 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
369 {
370   using handle_type = typename handle_t::type;
371 
372   PetscFunctionBegin;
373   PetscCall(initialize_handle_(handle_t{}, dctx));
374   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
375   PetscFunctionReturn(PETSC_SUCCESS);
376 }
377 
378 template <DeviceType T>
379 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
380 {
381   const auto dci = impls_cast_(dctx);
382 
383   PetscFunctionBegin;
384   PetscCall(check_current_device_(dctx));
385   #if PetscDefined(USE_DEBUG)
386   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
387   dci->timerInUse = PETSC_TRUE;
388   #endif
389   if (!dci->begin) {
390     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
391     PetscCallCUPM(cupmEventCreate(&dci->begin));
392     PetscCallCUPM(cupmEventCreate(&dci->end));
393   }
394   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
395   PetscFunctionReturn(PETSC_SUCCESS);
396 }
397 
398 template <DeviceType T>
399 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
400 {
401   float      gtime;
402   const auto dci = impls_cast_(dctx);
403   const auto end = dci->end;
404 
405   PetscFunctionBegin;
406   PetscCall(check_current_device_(dctx));
407   #if PetscDefined(USE_DEBUG)
408   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
409   dci->timerInUse = PETSC_FALSE;
410   #endif
411   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
412   PetscCallCUPM(cupmEventSynchronize(end));
413   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
414   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
415   PetscFunctionReturn(PETSC_SUCCESS);
416 }
417 
418 template <DeviceType T>
419 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
420 {
421   const auto &stream = impls_cast_(dctx)->stream;
422 
423   PetscFunctionBegin;
424   PetscCall(check_current_device_(dctx));
425   PetscCall(check_memtype_(mtype, "allocating"));
426   if (PetscMemTypeHost(mtype)) {
427     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
428   } else {
429     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
430   }
431   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
432   PetscFunctionReturn(PETSC_SUCCESS);
433 }
434 
435 template <DeviceType T>
436 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
437 {
438   const auto &stream = impls_cast_(dctx)->stream;
439 
440   PetscFunctionBegin;
441   PetscCall(check_current_device_(dctx));
442   PetscCall(check_memtype_(mtype, "freeing"));
443   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
444   if (PetscMemTypeHost(mtype)) {
445     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
446     // if ptr exists still exists the pool didn't own it
447     if (*ptr) {
448       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
449 
450       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
451       if (registered) {
452         PetscCallCUPM(cupmFreeHost(*ptr));
453       } else if (managed) {
454         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
455       }
456     }
457   } else {
458     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
459     // if ptr still exists the pool didn't own it
460     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
461   }
462   PetscFunctionReturn(PETSC_SUCCESS);
463 }
464 
465 template <DeviceType T>
466 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
467 {
468   const auto stream = impls_cast_(dctx)->stream.get_stream();
469 
470   PetscFunctionBegin;
471   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
472   if (mode == PETSC_DEVICE_COPY_HTOH) {
473     const auto cerr = cupmStreamQuery(stream);
474 
475     // yes this is faster
476     if (cerr == cupmSuccess) {
477       PetscCall(PetscMemcpy(dest, src, n));
478       PetscFunctionReturn(PETSC_SUCCESS);
479     } else if (cerr == cupmErrorNotReady) {
480       auto PETSC_UNUSED unused = cupmGetLastError();
481 
482       static_cast<void>(unused);
483     } else {
484       PetscCallCUPM(cerr);
485     }
486   }
487   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
488   PetscFunctionReturn(PETSC_SUCCESS);
489 }
490 
491 template <DeviceType T>
492 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
493 {
494   PetscFunctionBegin;
495   PetscCall(check_current_device_(dctx));
496   PetscCall(check_memtype_(mtype, "zeroing"));
497   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
498   PetscFunctionReturn(PETSC_SUCCESS);
499 }
500 
501 template <DeviceType T>
502 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
503 {
504   PetscFunctionBegin;
505   PetscCallCXX(event->data = new event_type());
506   event->destroy = [](PetscEvent event) {
507     PetscFunctionBegin;
508     delete event_cast_(event);
509     event->data = nullptr;
510     PetscFunctionReturn(PETSC_SUCCESS);
511   };
512   PetscFunctionReturn(PETSC_SUCCESS);
513 }
514 
515 template <DeviceType T>
516 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
517 {
518   PetscFunctionBegin;
519   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
520   PetscFunctionReturn(PETSC_SUCCESS);
521 }
522 
523 template <DeviceType T>
524 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
525 {
526   PetscFunctionBegin;
527   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
528   PetscFunctionReturn(PETSC_SUCCESS);
529 }
530 
531 // initialize the static member variables
532 template <DeviceType T>
533 bool DeviceContext<T>::initialized_ = false;
534 
535 template <DeviceType T>
536 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
537 
538 template <DeviceType T>
539 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
540 
541 template <DeviceType T>
542 constexpr _DeviceContextOps DeviceContext<T>::ops;
543 
544 } // namespace impl
545 
546 // shorten this one up a bit (and instantiate the templates)
547 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
548 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
549 
550   // shorthand for what is an EXTREMELY long name
551   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
552 
553 } // namespace cupm
554 
555 } // namespace device
556 
557 } // namespace Petsc
558 
559 #endif // __cplusplus
560 
561 #endif // PETSCDEVICECONTEXTCUDA_HPP
562