xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 1b37a2a7cc4a4fb30c3e967db1c694c0a1013f51)
1 #pragma once
2 
3 #include <petsc/private/deviceimpl.h>
4 #include <petsc/private/cupmsolverinterface.hpp>
5 #include <petsc/private/logimpl.h>
6 
7 #include <petsc/private/cpp/array.hpp>
8 
9 #include "../segmentedmempool.hpp"
10 #include "cupmallocator.hpp"
11 #include "cupmstream.hpp"
12 #include "cupmevent.hpp"
13 
14 namespace Petsc
15 {
16 
17 namespace device
18 {
19 
20 namespace cupm
21 {
22 
23 namespace impl
24 {
25 
26 template <DeviceType T>
27 class DeviceContext : SolverInterface<T> {
28 public:
29   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
30 
31 private:
32   template <typename H, std::size_t>
33   struct HandleTag {
34     using type = H;
35   };
36 
37   using stream_tag = HandleTag<cupmStream_t, 0>;
38   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
39   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40 
41   using stream_type = CUPMStream<T>;
42   using event_type  = CUPMEvent<T>;
43 
44 public:
45   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46   // header, but since we are using the power of templates it must be declared part of
47   // this class to have easy access the same typedefs. Technically one can make a
48   // templated struct outside the class but it's more code for the same result.
49   struct PetscDeviceContext_IMPLS {
50     stream_type stream{};
51     cupmEvent_t event{};
52     cupmEvent_t begin{}; // timer-only
53     cupmEvent_t end{};   // timer-only
54 #if PetscDefined(USE_DEBUG)
55     PetscBool timerInUse{};
56 #endif
57     cupmBlasHandle_t   blas{};
58     cupmSolverHandle_t solver{};
59 
60     constexpr PetscDeviceContext_IMPLS() noexcept = default;
61 
62     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
63 
64     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
65 
66     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
67   };
68 
69 private:
70   static bool initialized_;
71 
72   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
73   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
74 
75   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
76 
77   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
78 
79   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
80 
81   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
82 
83   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
84   // handles
85   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
86 
87   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
88   {
89     const auto dci    = impls_cast_(dctx);
90     auto      &handle = blashandles_[dctx->device->deviceId];
91 
92     PetscFunctionBegin;
93     if (!handle) {
94       PetscCall(PetscLogEventsPause());
95       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
96       for (auto i = 0; i < 3; ++i) {
97         const auto cberr = cupmBlasCreate(handle.ptr_to());
98         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
99         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
100         if (i != 2) {
101           PetscCall(PetscSleep(3));
102           continue;
103         }
104         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
105       }
106       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
107       PetscCall(PetscLogEventsResume());
108     }
109     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
110     dci->blas = handle;
111     PetscFunctionReturn(PETSC_SUCCESS);
112   }
113 
114   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
115   {
116     const auto dci    = impls_cast_(dctx);
117     auto      &handle = solverhandles_[dctx->device->deviceId];
118 
119     PetscFunctionBegin;
120     if (!handle) {
121       PetscCall(PetscLogEventsPause());
122       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
123       for (auto i = 0; i < 3; ++i) {
124         const auto cerr = cupmSolverCreate(&handle);
125         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
126         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
127         if (i < 2) {
128           PetscCall(PetscSleep(3));
129           continue;
130         }
131         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
132       }
133       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134       PetscCall(PetscLogEventsResume());
135     }
136     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
137     dci->solver = handle;
138     PetscFunctionReturn(PETSC_SUCCESS);
139   }
140 
141   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
142   {
143     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
144 
145     PetscFunctionBegin;
146     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
147                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
148     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
149     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
150     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
151     PetscFunctionReturn(PETSC_SUCCESS);
152   }
153 
154   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
155 
156   static PetscErrorCode finalize_() noexcept
157   {
158     PetscFunctionBegin;
159     for (auto &&handle : blashandles_) {
160       if (handle) {
161         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
162         handle = nullptr;
163       }
164     }
165     for (auto &&handle : solverhandles_) {
166       if (handle) {
167         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
168         handle = nullptr;
169       }
170     }
171     initialized_ = false;
172     PetscFunctionReturn(PETSC_SUCCESS);
173   }
174 
175   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
176   PETSC_NODISCARD static PoolType &default_pool_() noexcept
177   {
178     static PoolType pool;
179     return pool;
180   }
181 
182   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
183   {
184     PetscFunctionBegin;
185     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
186     PetscFunctionReturn(PETSC_SUCCESS);
187   }
188 
189 public:
190   // All of these functions MUST be static in order to be callable from C, otherwise they
191   // get the implicit 'this' pointer tacked on
192   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
193   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
194   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
195   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
196   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
197   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
198   template <typename Handle_t>
199   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
200   template <typename Handle_t>
201   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
202   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
211 
212   // not a PetscDeviceContext method, this registers the class
213   static PetscErrorCode initialize(PetscDevice) noexcept;
214 
215   // clang-format off
216   static constexpr _DeviceContextOps ops = {
217     PetscDesignatedInitializer(destroy, destroy),
218     PetscDesignatedInitializer(changestreamtype, changeStreamType),
219     PetscDesignatedInitializer(setup, setUp),
220     PetscDesignatedInitializer(query, query),
221     PetscDesignatedInitializer(waitforcontext, waitForContext),
222     PetscDesignatedInitializer(synchronize, synchronize),
223     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
224     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
225     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
226     PetscDesignatedInitializer(begintimer, beginTimer),
227     PetscDesignatedInitializer(endtimer, endTimer),
228     PetscDesignatedInitializer(memalloc, memAlloc),
229     PetscDesignatedInitializer(memfree, memFree),
230     PetscDesignatedInitializer(memcopy, memCopy),
231     PetscDesignatedInitializer(memset, memSet),
232     PetscDesignatedInitializer(createevent, createEvent),
233     PetscDesignatedInitializer(recordevent, recordEvent),
234     PetscDesignatedInitializer(waitforevent, waitForEvent)
235   };
236   // clang-format on
237 };
238 
239 // not a PetscDeviceContext method, this initializes the CLASS
240 template <DeviceType T>
241 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242 {
243   PetscFunctionBegin;
244   if (PetscUnlikely(!initialized_)) {
245     uint64_t      threshold = UINT64_MAX;
246     cupmMemPool_t mempool;
247 
248     initialized_ = true;
249     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
250     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
251     blashandles_.fill(nullptr);
252     solverhandles_.fill(nullptr);
253     PetscCall(PetscRegisterFinalize(finalize_));
254   }
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257 
258 template <DeviceType T>
259 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260 {
261   PetscFunctionBegin;
262   if (const auto dci = impls_cast_(dctx)) {
263     PetscCall(dci->stream.destroy());
264     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
265     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
266     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
267     delete dci;
268     dctx->data = nullptr;
269   }
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 template <DeviceType T>
274 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275 {
276   const auto dci = impls_cast_(dctx);
277 
278   PetscFunctionBegin;
279   PetscCall(dci->stream.destroy());
280   // set these to null so they aren't usable until setup is called again
281   dci->blas   = nullptr;
282   dci->solver = nullptr;
283   PetscFunctionReturn(PETSC_SUCCESS);
284 }
285 
286 template <DeviceType T>
287 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288 {
289   const auto dci   = impls_cast_(dctx);
290   auto      &event = dci->event;
291 
292   PetscFunctionBegin;
293   PetscCall(check_current_device_(dctx));
294   PetscCall(dci->stream.change_type(dctx->streamType));
295   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296 #if PetscDefined(USE_DEBUG)
297   dci->timerInUse = PETSC_FALSE;
298 #endif
299   PetscFunctionReturn(PETSC_SUCCESS);
300 }
301 
302 template <DeviceType T>
303 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304 {
305   PetscFunctionBegin;
306   PetscCall(check_current_device_(dctx));
307   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308   case cupmSuccess:
309     *idle = PETSC_TRUE;
310     break;
311   case cupmErrorNotReady:
312     *idle = PETSC_FALSE;
313     // reset the error
314     cerr = cupmGetLastError();
315     static_cast<void>(cerr);
316     break;
317   default:
318     PetscCallCUPM(cerr);
319     PetscUnreachable();
320   }
321   PetscFunctionReturn(PETSC_SUCCESS);
322 }
323 
324 template <DeviceType T>
325 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
326 {
327   const auto dcib  = impls_cast_(dctxb);
328   const auto event = dcib->event;
329 
330   PetscFunctionBegin;
331   PetscCall(check_current_device_(dctxa, dctxb));
332   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
333   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
334   PetscFunctionReturn(PETSC_SUCCESS);
335 }
336 
337 template <DeviceType T>
338 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
339 {
340   auto idle = PETSC_TRUE;
341 
342   PetscFunctionBegin;
343   PetscCall(query(dctx, &idle));
344   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
345   PetscFunctionReturn(PETSC_SUCCESS);
346 }
347 
348 template <DeviceType T>
349 template <typename handle_t>
350 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
351 {
352   PetscFunctionBegin;
353   PetscCall(initialize_handle_(handle_t{}, dctx));
354   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 template <DeviceType T>
359 template <typename handle_t>
360 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
361 {
362   using handle_type = typename handle_t::type;
363 
364   PetscFunctionBegin;
365   PetscCall(initialize_handle_(handle_t{}, dctx));
366   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
367   PetscFunctionReturn(PETSC_SUCCESS);
368 }
369 
370 template <DeviceType T>
371 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
372 {
373   const auto dci = impls_cast_(dctx);
374 
375   PetscFunctionBegin;
376   PetscCall(check_current_device_(dctx));
377 #if PetscDefined(USE_DEBUG)
378   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
379   dci->timerInUse = PETSC_TRUE;
380 #endif
381   if (!dci->begin) {
382     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
383     PetscCallCUPM(cupmEventCreate(&dci->begin));
384     PetscCallCUPM(cupmEventCreate(&dci->end));
385   }
386   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
387   PetscFunctionReturn(PETSC_SUCCESS);
388 }
389 
390 template <DeviceType T>
391 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
392 {
393   float      gtime;
394   const auto dci = impls_cast_(dctx);
395   const auto end = dci->end;
396 
397   PetscFunctionBegin;
398   PetscCall(check_current_device_(dctx));
399 #if PetscDefined(USE_DEBUG)
400   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
401   dci->timerInUse = PETSC_FALSE;
402 #endif
403   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
404   PetscCallCUPM(cupmEventSynchronize(end));
405   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
406   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
407   PetscFunctionReturn(PETSC_SUCCESS);
408 }
409 
410 template <DeviceType T>
411 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
412 {
413   const auto &stream = impls_cast_(dctx)->stream;
414 
415   PetscFunctionBegin;
416   PetscCall(check_current_device_(dctx));
417   PetscCall(check_memtype_(mtype, "allocating"));
418   if (PetscMemTypeHost(mtype)) {
419     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
420   } else {
421     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
422   }
423   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
424   PetscFunctionReturn(PETSC_SUCCESS);
425 }
426 
427 template <DeviceType T>
428 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
429 {
430   const auto &stream = impls_cast_(dctx)->stream;
431 
432   PetscFunctionBegin;
433   PetscCall(check_current_device_(dctx));
434   PetscCall(check_memtype_(mtype, "freeing"));
435   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
436   if (PetscMemTypeHost(mtype)) {
437     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
438     // if ptr exists still exists the pool didn't own it
439     if (*ptr) {
440       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
441 
442       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
443       if (registered) {
444         PetscCallCUPM(cupmFreeHost(*ptr));
445       } else if (managed) {
446         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447       }
448     }
449   } else {
450     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
451     // if ptr still exists the pool didn't own it
452     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
453   }
454   PetscFunctionReturn(PETSC_SUCCESS);
455 }
456 
457 template <DeviceType T>
458 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
459 {
460   const auto stream = impls_cast_(dctx)->stream.get_stream();
461 
462   PetscFunctionBegin;
463   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
464   if (mode == PETSC_DEVICE_COPY_HTOH) {
465     const auto cerr = cupmStreamQuery(stream);
466 
467     // yes this is faster
468     if (cerr == cupmSuccess) {
469       PetscCall(PetscMemcpy(dest, src, n));
470       PetscFunctionReturn(PETSC_SUCCESS);
471     } else if (cerr == cupmErrorNotReady) {
472       auto PETSC_UNUSED unused = cupmGetLastError();
473 
474       static_cast<void>(unused);
475     } else {
476       PetscCallCUPM(cerr);
477     }
478   }
479   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
480   PetscFunctionReturn(PETSC_SUCCESS);
481 }
482 
483 template <DeviceType T>
484 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
485 {
486   PetscFunctionBegin;
487   PetscCall(check_current_device_(dctx));
488   PetscCall(check_memtype_(mtype, "zeroing"));
489   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
490   PetscFunctionReturn(PETSC_SUCCESS);
491 }
492 
493 template <DeviceType T>
494 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
495 {
496   PetscFunctionBegin;
497   PetscCallCXX(event->data = new event_type{});
498   event->destroy = [](PetscEvent event) {
499     PetscFunctionBegin;
500     delete event_cast_(event);
501     event->data = nullptr;
502     PetscFunctionReturn(PETSC_SUCCESS);
503   };
504   PetscFunctionReturn(PETSC_SUCCESS);
505 }
506 
507 template <DeviceType T>
508 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
509 {
510   PetscFunctionBegin;
511   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
512   PetscFunctionReturn(PETSC_SUCCESS);
513 }
514 
515 template <DeviceType T>
516 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
517 {
518   PetscFunctionBegin;
519   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
520   PetscFunctionReturn(PETSC_SUCCESS);
521 }
522 
523 // initialize the static member variables
524 template <DeviceType T>
525 bool DeviceContext<T>::initialized_ = false;
526 
527 template <DeviceType T>
528 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
529 
530 template <DeviceType T>
531 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
532 
533 template <DeviceType T>
534 constexpr _DeviceContextOps DeviceContext<T>::ops;
535 
536 } // namespace impl
537 
538 // shorten this one up a bit (and instantiate the templates)
539 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
540 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
541 
542 // shorthand for what is an EXTREMELY long name
543 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
544 
545 } // namespace cupm
546 
547 } // namespace device
548 
549 } // namespace Petsc
550