xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 0619917b5a674bb687c64e7daba2ab22be99af31)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmsolverinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <petsc/private/cpp/array.hpp>
9 
10 #include "../segmentedmempool.hpp"
11 #include "cupmallocator.hpp"
12 #include "cupmstream.hpp"
13 #include "cupmevent.hpp"
14 
15 namespace Petsc
16 {
17 
18 namespace device
19 {
20 
21 namespace cupm
22 {
23 
24 namespace impl
25 {
26 
27 template <DeviceType T>
28 class DeviceContext : SolverInterface<T> {
29 public:
30   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31 
32 private:
33   template <typename H, std::size_t>
34   struct HandleTag {
35     using type = H;
36   };
37 
38   using stream_tag = HandleTag<cupmStream_t, 0>;
39   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
40   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
41 
42   using stream_type = CUPMStream<T>;
43   using event_type  = CUPMEvent<T>;
44 
45 public:
46   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
47   // header, but since we are using the power of templates it must be declared part of
48   // this class to have easy access the same typedefs. Technically one can make a
49   // templated struct outside the class but it's more code for the same result.
50   struct PetscDeviceContext_IMPLS {
51     stream_type stream{};
52     cupmEvent_t event{};
53     cupmEvent_t begin{}; // timer-only
54     cupmEvent_t end{};   // timer-only
55 #if PetscDefined(USE_DEBUG)
56     PetscBool timerInUse{};
57 #endif
58     cupmBlasHandle_t   blas{};
59     cupmSolverHandle_t solver{};
60 
61     constexpr PetscDeviceContext_IMPLS() noexcept = default;
62 
63     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
64 
65     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
66 
67     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
68   };
69 
70 private:
71   static bool initialized_;
72 
73   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
74   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
75 
76   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
77 
78   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
79 
80   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
81 
82   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
83 
84   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
85   // handles
86   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
87 
88   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
89   {
90     const auto dci    = impls_cast_(dctx);
91     auto      &handle = blashandles_[dctx->device->deviceId];
92 
93     PetscFunctionBegin;
94     if (!handle) {
95       PetscCall(PetscLogEventsPause());
96       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
97       for (auto i = 0; i < 3; ++i) {
98         const auto cberr = cupmBlasCreate(handle.ptr_to());
99         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
100         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
101         if (i != 2) {
102           PetscCall(PetscSleep(3));
103           continue;
104         }
105         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
106       }
107       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
108       PetscCall(PetscLogEventsResume());
109     }
110     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
111     dci->blas = handle;
112     PetscFunctionReturn(PETSC_SUCCESS);
113   }
114 
115   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
116   {
117     const auto dci    = impls_cast_(dctx);
118     auto      &handle = solverhandles_[dctx->device->deviceId];
119 
120     PetscFunctionBegin;
121     if (!handle) {
122       PetscCall(PetscLogEventsPause());
123       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
124       for (auto i = 0; i < 3; ++i) {
125         const auto cerr = cupmSolverCreate(&handle);
126         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
127         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
128         if (i < 2) {
129           PetscCall(PetscSleep(3));
130           continue;
131         }
132         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
133       }
134       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
135       PetscCall(PetscLogEventsResume());
136     }
137     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
138     dci->solver = handle;
139     PetscFunctionReturn(PETSC_SUCCESS);
140   }
141 
142   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
143   {
144     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
145 
146     PetscFunctionBegin;
147     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
148                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
149     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
150     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
151     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
152     PetscFunctionReturn(PETSC_SUCCESS);
153   }
154 
155   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
156 
157   static PetscErrorCode finalize_() noexcept
158   {
159     PetscFunctionBegin;
160     for (auto &&handle : blashandles_) {
161       if (handle) {
162         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
163         handle = nullptr;
164       }
165     }
166     for (auto &&handle : solverhandles_) {
167       if (handle) {
168         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
169         handle = nullptr;
170       }
171     }
172     initialized_ = false;
173     PetscFunctionReturn(PETSC_SUCCESS);
174   }
175 
176   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
177   PETSC_NODISCARD static PoolType &default_pool_() noexcept
178   {
179     static PoolType pool;
180     return pool;
181   }
182 
183   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
184   {
185     PetscFunctionBegin;
186     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
187     PetscFunctionReturn(PETSC_SUCCESS);
188   }
189 
190 public:
191   // All of these functions MUST be static in order to be callable from C, otherwise they
192   // get the implicit 'this' pointer tacked on
193   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
194   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
195   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
196   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
197   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
198   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
199   template <typename Handle_t>
200   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
201   template <typename Handle_t>
202   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
203   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
204   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
205   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
206   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
207   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
208   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
209   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
210   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
211   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
212 
213   // not a PetscDeviceContext method, this registers the class
214   static PetscErrorCode initialize(PetscDevice) noexcept;
215 
216   // clang-format off
217   static constexpr _DeviceContextOps ops = {
218     PetscDesignatedInitializer(destroy, destroy),
219     PetscDesignatedInitializer(changestreamtype, changeStreamType),
220     PetscDesignatedInitializer(setup, setUp),
221     PetscDesignatedInitializer(query, query),
222     PetscDesignatedInitializer(waitforcontext, waitForContext),
223     PetscDesignatedInitializer(synchronize, synchronize),
224     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
225     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
226     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
227     PetscDesignatedInitializer(begintimer, beginTimer),
228     PetscDesignatedInitializer(endtimer, endTimer),
229     PetscDesignatedInitializer(memalloc, memAlloc),
230     PetscDesignatedInitializer(memfree, memFree),
231     PetscDesignatedInitializer(memcopy, memCopy),
232     PetscDesignatedInitializer(memset, memSet),
233     PetscDesignatedInitializer(createevent, createEvent),
234     PetscDesignatedInitializer(recordevent, recordEvent),
235     PetscDesignatedInitializer(waitforevent, waitForEvent)
236   };
237   // clang-format on
238 };
239 
240 // not a PetscDeviceContext method, this initializes the CLASS
241 template <DeviceType T>
242 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
243 {
244   PetscFunctionBegin;
245   if (PetscUnlikely(!initialized_)) {
246     uint64_t      threshold = UINT64_MAX;
247     cupmMemPool_t mempool;
248 
249     initialized_ = true;
250     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
251     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
252     blashandles_.fill(nullptr);
253     solverhandles_.fill(nullptr);
254     PetscCall(PetscRegisterFinalize(finalize_));
255   }
256   PetscFunctionReturn(PETSC_SUCCESS);
257 }
258 
259 template <DeviceType T>
260 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
261 {
262   PetscFunctionBegin;
263   if (const auto dci = impls_cast_(dctx)) {
264     PetscCall(dci->stream.destroy());
265     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
266     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
267     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
268     delete dci;
269     dctx->data = nullptr;
270   }
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 template <DeviceType T>
275 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
276 {
277   const auto dci = impls_cast_(dctx);
278 
279   PetscFunctionBegin;
280   PetscCall(dci->stream.destroy());
281   // set these to null so they aren't usable until setup is called again
282   dci->blas   = nullptr;
283   dci->solver = nullptr;
284   PetscFunctionReturn(PETSC_SUCCESS);
285 }
286 
287 template <DeviceType T>
288 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
289 {
290   const auto dci   = impls_cast_(dctx);
291   auto      &event = dci->event;
292 
293   PetscFunctionBegin;
294   PetscCall(check_current_device_(dctx));
295   PetscCall(dci->stream.change_type(dctx->streamType));
296   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
297 #if PetscDefined(USE_DEBUG)
298   dci->timerInUse = PETSC_FALSE;
299 #endif
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 template <DeviceType T>
304 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
305 {
306   PetscFunctionBegin;
307   PetscCall(check_current_device_(dctx));
308   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
309   case cupmSuccess:
310     *idle = PETSC_TRUE;
311     break;
312   case cupmErrorNotReady:
313     *idle = PETSC_FALSE;
314     // reset the error
315     cerr = cupmGetLastError();
316     static_cast<void>(cerr);
317     break;
318   default:
319     PetscCallCUPM(cerr);
320     PetscUnreachable();
321   }
322   PetscFunctionReturn(PETSC_SUCCESS);
323 }
324 
325 template <DeviceType T>
326 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
327 {
328   const auto dcib  = impls_cast_(dctxb);
329   const auto event = dcib->event;
330 
331   PetscFunctionBegin;
332   PetscCall(check_current_device_(dctxa, dctxb));
333   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
334   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
335   PetscFunctionReturn(PETSC_SUCCESS);
336 }
337 
338 template <DeviceType T>
339 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
340 {
341   auto idle = PETSC_TRUE;
342 
343   PetscFunctionBegin;
344   PetscCall(query(dctx, &idle));
345   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
346   PetscFunctionReturn(PETSC_SUCCESS);
347 }
348 
349 template <DeviceType T>
350 template <typename handle_t>
351 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
352 {
353   PetscFunctionBegin;
354   PetscCall(initialize_handle_(handle_t{}, dctx));
355   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
356   PetscFunctionReturn(PETSC_SUCCESS);
357 }
358 
359 template <DeviceType T>
360 template <typename handle_t>
361 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
362 {
363   using handle_type = typename handle_t::type;
364 
365   PetscFunctionBegin;
366   PetscCall(initialize_handle_(handle_t{}, dctx));
367   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
368   PetscFunctionReturn(PETSC_SUCCESS);
369 }
370 
371 template <DeviceType T>
372 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
373 {
374   const auto dci = impls_cast_(dctx);
375 
376   PetscFunctionBegin;
377   PetscCall(check_current_device_(dctx));
378 #if PetscDefined(USE_DEBUG)
379   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
380   dci->timerInUse = PETSC_TRUE;
381 #endif
382   if (!dci->begin) {
383     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
384     PetscCallCUPM(cupmEventCreate(&dci->begin));
385     PetscCallCUPM(cupmEventCreate(&dci->end));
386   }
387   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
388   PetscFunctionReturn(PETSC_SUCCESS);
389 }
390 
391 template <DeviceType T>
392 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
393 {
394   float      gtime;
395   const auto dci = impls_cast_(dctx);
396   const auto end = dci->end;
397 
398   PetscFunctionBegin;
399   PetscCall(check_current_device_(dctx));
400 #if PetscDefined(USE_DEBUG)
401   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
402   dci->timerInUse = PETSC_FALSE;
403 #endif
404   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
405   PetscCallCUPM(cupmEventSynchronize(end));
406   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
407   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
408   PetscFunctionReturn(PETSC_SUCCESS);
409 }
410 
411 template <DeviceType T>
412 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
413 {
414   const auto &stream = impls_cast_(dctx)->stream;
415 
416   PetscFunctionBegin;
417   PetscCall(check_current_device_(dctx));
418   PetscCall(check_memtype_(mtype, "allocating"));
419   if (PetscMemTypeHost(mtype)) {
420     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
421   } else {
422     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
423   }
424   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
425   PetscFunctionReturn(PETSC_SUCCESS);
426 }
427 
428 template <DeviceType T>
429 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
430 {
431   const auto &stream = impls_cast_(dctx)->stream;
432 
433   PetscFunctionBegin;
434   PetscCall(check_current_device_(dctx));
435   PetscCall(check_memtype_(mtype, "freeing"));
436   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
437   if (PetscMemTypeHost(mtype)) {
438     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
439     // if ptr exists still exists the pool didn't own it
440     if (*ptr) {
441       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
442 
443       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
444       if (registered) {
445         PetscCallCUPM(cupmFreeHost(*ptr));
446       } else if (managed) {
447         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
448       }
449     }
450   } else {
451     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
452     // if ptr still exists the pool didn't own it
453     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
454   }
455   PetscFunctionReturn(PETSC_SUCCESS);
456 }
457 
458 template <DeviceType T>
459 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
460 {
461   const auto stream = impls_cast_(dctx)->stream.get_stream();
462 
463   PetscFunctionBegin;
464   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
465   if (mode == PETSC_DEVICE_COPY_HTOH) {
466     const auto cerr = cupmStreamQuery(stream);
467 
468     // yes this is faster
469     if (cerr == cupmSuccess) {
470       PetscCall(PetscMemcpy(dest, src, n));
471       PetscFunctionReturn(PETSC_SUCCESS);
472     } else if (cerr == cupmErrorNotReady) {
473       auto PETSC_UNUSED unused = cupmGetLastError();
474 
475       static_cast<void>(unused);
476     } else {
477       PetscCallCUPM(cerr);
478     }
479   }
480   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
481   PetscFunctionReturn(PETSC_SUCCESS);
482 }
483 
484 template <DeviceType T>
485 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
486 {
487   PetscFunctionBegin;
488   PetscCall(check_current_device_(dctx));
489   PetscCall(check_memtype_(mtype, "zeroing"));
490   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
491   PetscFunctionReturn(PETSC_SUCCESS);
492 }
493 
494 template <DeviceType T>
495 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
496 {
497   PetscFunctionBegin;
498   PetscCallCXX(event->data = new event_type{});
499   event->destroy = [](PetscEvent event) {
500     PetscFunctionBegin;
501     delete event_cast_(event);
502     event->data = nullptr;
503     PetscFunctionReturn(PETSC_SUCCESS);
504   };
505   PetscFunctionReturn(PETSC_SUCCESS);
506 }
507 
508 template <DeviceType T>
509 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
510 {
511   PetscFunctionBegin;
512   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
513   PetscFunctionReturn(PETSC_SUCCESS);
514 }
515 
516 template <DeviceType T>
517 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
518 {
519   PetscFunctionBegin;
520   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
521   PetscFunctionReturn(PETSC_SUCCESS);
522 }
523 
524 // initialize the static member variables
525 template <DeviceType T>
526 bool DeviceContext<T>::initialized_ = false;
527 
528 template <DeviceType T>
529 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
530 
531 template <DeviceType T>
532 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
533 
534 template <DeviceType T>
535 constexpr _DeviceContextOps DeviceContext<T>::ops;
536 
537 } // namespace impl
538 
539 // shorten this one up a bit (and instantiate the templates)
540 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
541 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
542 
543 // shorthand for what is an EXTREMELY long name
544 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
545 
546 } // namespace cupm
547 
548 } // namespace device
549 
550 } // namespace Petsc
551 
552 #endif // PETSCDEVICECONTEXTCUDA_HPP
553