xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision a496304597bacff3545e802853d69e8765312868)
1*a4963045SJacob Faibussowitsch #pragma once
2030f984aSJacob Faibussowitsch 
3a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
496a4b4d9SJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp>
57a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h>
6030f984aSJacob Faibussowitsch 
70e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp>
8a4af0ceeSJacob Faibussowitsch 
90e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
100e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp"
110e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp"
120e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
130e6b6b59SJacob Faibussowitsch 
14d71ae5a4SJacob Faibussowitsch namespace Petsc
15d71ae5a4SJacob Faibussowitsch {
16a4af0ceeSJacob Faibussowitsch 
17d71ae5a4SJacob Faibussowitsch namespace device
18d71ae5a4SJacob Faibussowitsch {
1917f48955SJacob Faibussowitsch 
20d71ae5a4SJacob Faibussowitsch namespace cupm
21d71ae5a4SJacob Faibussowitsch {
2217f48955SJacob Faibussowitsch 
23d71ae5a4SJacob Faibussowitsch namespace impl
24d71ae5a4SJacob Faibussowitsch {
25030f984aSJacob Faibussowitsch 
2617f48955SJacob Faibussowitsch template <DeviceType T>
2796a4b4d9SJacob Faibussowitsch class DeviceContext : SolverInterface<T> {
2817f48955SJacob Faibussowitsch public:
2996a4b4d9SJacob Faibussowitsch   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
3017f48955SJacob Faibussowitsch 
3117f48955SJacob Faibussowitsch private:
329371c9d4SSatish Balay   template <typename H, std::size_t>
339371c9d4SSatish Balay   struct HandleTag {
349371c9d4SSatish Balay     using type = H;
359371c9d4SSatish Balay   };
360e6b6b59SJacob Faibussowitsch 
377a101e5eSJacob Faibussowitsch   using stream_tag = HandleTag<cupmStream_t, 0>;
387a101e5eSJacob Faibussowitsch   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
397a101e5eSJacob Faibussowitsch   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40a4af0ceeSJacob Faibussowitsch 
410e6b6b59SJacob Faibussowitsch   using stream_type = CUPMStream<T>;
420e6b6b59SJacob Faibussowitsch   using event_type  = CUPMEvent<T>;
430e6b6b59SJacob Faibussowitsch 
44030f984aSJacob Faibussowitsch public:
45030f984aSJacob Faibussowitsch   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46030f984aSJacob Faibussowitsch   // header, but since we are using the power of templates it must be declared part of
47030f984aSJacob Faibussowitsch   // this class to have easy access the same typedefs. Technically one can make a
48030f984aSJacob Faibussowitsch   // templated struct outside the class but it's more code for the same result.
493048253cSJacob Faibussowitsch   struct PetscDeviceContext_IMPLS {
500e6b6b59SJacob Faibussowitsch     stream_type stream{};
510e6b6b59SJacob Faibussowitsch     cupmEvent_t event{};
520e6b6b59SJacob Faibussowitsch     cupmEvent_t begin{}; // timer-only
530e6b6b59SJacob Faibussowitsch     cupmEvent_t end{};   // timer-only
54a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
550e6b6b59SJacob Faibussowitsch     PetscBool timerInUse{};
56a4af0ceeSJacob Faibussowitsch #endif
570e6b6b59SJacob Faibussowitsch     cupmBlasHandle_t   blas{};
580e6b6b59SJacob Faibussowitsch     cupmSolverHandle_t solver{};
59a4af0ceeSJacob Faibussowitsch 
600e6b6b59SJacob Faibussowitsch     constexpr PetscDeviceContext_IMPLS() noexcept = default;
610e6b6b59SJacob Faibussowitsch 
6231d47070SJunchao Zhang     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
630e6b6b59SJacob Faibussowitsch 
6431d47070SJunchao Zhang     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
650e6b6b59SJacob Faibussowitsch 
6631d47070SJunchao Zhang     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
67030f984aSJacob Faibussowitsch   };
68030f984aSJacob Faibussowitsch 
69030f984aSJacob Faibussowitsch private:
7017f48955SJacob Faibussowitsch   static bool initialized_;
716d54fb17SJacob Faibussowitsch 
7217f48955SJacob Faibussowitsch   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
7317f48955SJacob Faibussowitsch   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
74030f984aSJacob Faibussowitsch 
75d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
76a4af0ceeSJacob Faibussowitsch 
77d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
780e6b6b59SJacob Faibussowitsch 
79d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
807a101e5eSJacob Faibussowitsch 
81d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
827a101e5eSJacob Faibussowitsch 
837a101e5eSJacob Faibussowitsch   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
847a101e5eSJacob Faibussowitsch   // handles
85089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
867a101e5eSJacob Faibussowitsch 
8796a4b4d9SJacob Faibussowitsch   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
88d71ae5a4SJacob Faibussowitsch   {
8996a4b4d9SJacob Faibussowitsch     const auto dci    = impls_cast_(dctx);
9096a4b4d9SJacob Faibussowitsch     auto      &handle = blashandles_[dctx->device->deviceId];
917a101e5eSJacob Faibussowitsch 
92030f984aSJacob Faibussowitsch     PetscFunctionBegin;
9396a4b4d9SJacob Faibussowitsch     if (!handle) {
94b665b14eSToby Isaac       PetscCall(PetscLogEventsPause());
957a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
9617f48955SJacob Faibussowitsch       for (auto i = 0; i < 3; ++i) {
9796a4b4d9SJacob Faibussowitsch         const auto cberr = cupmBlasCreate(handle.ptr_to());
9817f48955SJacob Faibussowitsch         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
999566063dSJacob Faibussowitsch         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
10017f48955SJacob Faibussowitsch         if (i != 2) {
1019566063dSJacob Faibussowitsch           PetscCall(PetscSleep(3));
10217f48955SJacob Faibussowitsch           continue;
103a4af0ceeSJacob Faibussowitsch         }
1045f80ce2aSJacob Faibussowitsch         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
105a4af0ceeSJacob Faibussowitsch       }
1067a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
107b665b14eSToby Isaac       PetscCall(PetscLogEventsResume());
108030f984aSJacob Faibussowitsch     }
1090e6b6b59SJacob Faibussowitsch     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
1107a101e5eSJacob Faibussowitsch     dci->blas = handle;
1113ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1127a101e5eSJacob Faibussowitsch   }
1137a101e5eSJacob Faibussowitsch 
114089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
115d71ae5a4SJacob Faibussowitsch   {
1166d54fb17SJacob Faibussowitsch     const auto dci    = impls_cast_(dctx);
1176d54fb17SJacob Faibussowitsch     auto      &handle = solverhandles_[dctx->device->deviceId];
1187a101e5eSJacob Faibussowitsch 
1197a101e5eSJacob Faibussowitsch     PetscFunctionBegin;
12096a4b4d9SJacob Faibussowitsch     if (!handle) {
121b665b14eSToby Isaac       PetscCall(PetscLogEventsPause());
1227a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
12396a4b4d9SJacob Faibussowitsch       for (auto i = 0; i < 3; ++i) {
12496a4b4d9SJacob Faibussowitsch         const auto cerr = cupmSolverCreate(&handle);
12596a4b4d9SJacob Faibussowitsch         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
12696a4b4d9SJacob Faibussowitsch         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
12796a4b4d9SJacob Faibussowitsch         if (i < 2) {
12896a4b4d9SJacob Faibussowitsch           PetscCall(PetscSleep(3));
12996a4b4d9SJacob Faibussowitsch           continue;
13096a4b4d9SJacob Faibussowitsch         }
13196a4b4d9SJacob Faibussowitsch         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
13296a4b4d9SJacob Faibussowitsch       }
1337a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134b665b14eSToby Isaac       PetscCall(PetscLogEventsResume());
13596a4b4d9SJacob Faibussowitsch     }
13696a4b4d9SJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
1377a101e5eSJacob Faibussowitsch     dci->solver = handle;
1383ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
13917f48955SJacob Faibussowitsch   }
14017f48955SJacob Faibussowitsch 
141089fb57cSJacob Faibussowitsch   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
142d71ae5a4SJacob Faibussowitsch   {
1430e6b6b59SJacob Faibussowitsch     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
1440e6b6b59SJacob Faibussowitsch 
1450e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
1460e6b6b59SJacob Faibussowitsch     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
1470e6b6b59SJacob Faibussowitsch                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
1480e6b6b59SJacob Faibussowitsch     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
1490e6b6b59SJacob Faibussowitsch     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
1500e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
1513ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1520e6b6b59SJacob Faibussowitsch   }
1530e6b6b59SJacob Faibussowitsch 
154089fb57cSJacob Faibussowitsch   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
1550e6b6b59SJacob Faibussowitsch 
156089fb57cSJacob Faibussowitsch   static PetscErrorCode finalize_() noexcept
157d71ae5a4SJacob Faibussowitsch   {
15817f48955SJacob Faibussowitsch     PetscFunctionBegin;
15917f48955SJacob Faibussowitsch     for (auto &&handle : blashandles_) {
16017f48955SJacob Faibussowitsch       if (handle) {
1619566063dSJacob Faibussowitsch         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
16217f48955SJacob Faibussowitsch         handle = nullptr;
16317f48955SJacob Faibussowitsch       }
16417f48955SJacob Faibussowitsch     }
16517f48955SJacob Faibussowitsch     for (auto &&handle : solverhandles_) {
16617f48955SJacob Faibussowitsch       if (handle) {
16796a4b4d9SJacob Faibussowitsch         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
16817f48955SJacob Faibussowitsch         handle = nullptr;
16917f48955SJacob Faibussowitsch       }
17017f48955SJacob Faibussowitsch     }
17117f48955SJacob Faibussowitsch     initialized_ = false;
1723ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
17317f48955SJacob Faibussowitsch   }
17417f48955SJacob Faibussowitsch 
1750e6b6b59SJacob Faibussowitsch   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
176d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PoolType &default_pool_() noexcept
177d71ae5a4SJacob Faibussowitsch   {
1780e6b6b59SJacob Faibussowitsch     static PoolType pool;
1790e6b6b59SJacob Faibussowitsch     return pool;
1800e6b6b59SJacob Faibussowitsch   }
181030f984aSJacob Faibussowitsch 
182089fb57cSJacob Faibussowitsch   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
183d71ae5a4SJacob Faibussowitsch   {
1840e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
1850e6b6b59SJacob Faibussowitsch     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
1863ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1870e6b6b59SJacob Faibussowitsch   }
1880e6b6b59SJacob Faibussowitsch 
1890e6b6b59SJacob Faibussowitsch public:
190030f984aSJacob Faibussowitsch   // All of these functions MUST be static in order to be callable from C, otherwise they
191030f984aSJacob Faibussowitsch   // get the implicit 'this' pointer tacked on
192089fb57cSJacob Faibussowitsch   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
193089fb57cSJacob Faibussowitsch   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
194089fb57cSJacob Faibussowitsch   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
195089fb57cSJacob Faibussowitsch   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
196089fb57cSJacob Faibussowitsch   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
197089fb57cSJacob Faibussowitsch   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
198a4af0ceeSJacob Faibussowitsch   template <typename Handle_t>
199089fb57cSJacob Faibussowitsch   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
20031d47070SJunchao Zhang   template <typename Handle_t>
20197cd0981SJacob Faibussowitsch   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
202089fb57cSJacob Faibussowitsch   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203089fb57cSJacob Faibussowitsch   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204089fb57cSJacob Faibussowitsch   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205089fb57cSJacob Faibussowitsch   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206089fb57cSJacob Faibussowitsch   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207089fb57cSJacob Faibussowitsch   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208089fb57cSJacob Faibussowitsch   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209089fb57cSJacob Faibussowitsch   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210089fb57cSJacob Faibussowitsch   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
2117a101e5eSJacob Faibussowitsch 
2127a101e5eSJacob Faibussowitsch   // not a PetscDeviceContext method, this registers the class
213089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize(PetscDevice) noexcept;
2140e6b6b59SJacob Faibussowitsch 
2150e6b6b59SJacob Faibussowitsch   // clang-format off
2166ff55be4SJacob Faibussowitsch   static constexpr _DeviceContextOps ops = {
2176ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(destroy, destroy),
2186ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(changestreamtype, changeStreamType),
2196ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(setup, setUp),
2206ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(query, query),
2216ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(waitforcontext, waitForContext),
2226ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(synchronize, synchronize),
2236ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
2246ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
22531d47070SJunchao Zhang     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
2266ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(begintimer, beginTimer),
2276ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(endtimer, endTimer),
2286ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memalloc, memAlloc),
2296ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memfree, memFree),
2306ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memcopy, memCopy),
2316ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memset, memSet),
2326ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(createevent, createEvent),
2336ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(recordevent, recordEvent),
2346ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(waitforevent, waitForEvent)
2350e6b6b59SJacob Faibussowitsch   };
2360e6b6b59SJacob Faibussowitsch   // clang-format on
237030f984aSJacob Faibussowitsch };
238030f984aSJacob Faibussowitsch 
2390e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS
24017f48955SJacob Faibussowitsch template <DeviceType T>
2416d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242d71ae5a4SJacob Faibussowitsch {
2437a101e5eSJacob Faibussowitsch   PetscFunctionBegin;
2447a101e5eSJacob Faibussowitsch   if (PetscUnlikely(!initialized_)) {
2450e6b6b59SJacob Faibussowitsch     uint64_t      threshold = UINT64_MAX;
2466d54fb17SJacob Faibussowitsch     cupmMemPool_t mempool;
2470e6b6b59SJacob Faibussowitsch 
2487a101e5eSJacob Faibussowitsch     initialized_ = true;
2496d54fb17SJacob Faibussowitsch     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
2500e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
2510e6b6b59SJacob Faibussowitsch     blashandles_.fill(nullptr);
2520e6b6b59SJacob Faibussowitsch     solverhandles_.fill(nullptr);
2537a101e5eSJacob Faibussowitsch     PetscCall(PetscRegisterFinalize(finalize_));
2547a101e5eSJacob Faibussowitsch   }
2553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2567a101e5eSJacob Faibussowitsch }
2577a101e5eSJacob Faibussowitsch 
2587a101e5eSJacob Faibussowitsch template <DeviceType T>
2596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260d71ae5a4SJacob Faibussowitsch {
261030f984aSJacob Faibussowitsch   PetscFunctionBegin;
2620e6b6b59SJacob Faibussowitsch   if (const auto dci = impls_cast_(dctx)) {
2630e6b6b59SJacob Faibussowitsch     PetscCall(dci->stream.destroy());
264146a86ebSJacob Faibussowitsch     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
2659566063dSJacob Faibussowitsch     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
2669566063dSJacob Faibussowitsch     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
2670e6b6b59SJacob Faibussowitsch     delete dci;
2680e6b6b59SJacob Faibussowitsch     dctx->data = nullptr;
2690e6b6b59SJacob Faibussowitsch   }
2703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
271030f984aSJacob Faibussowitsch }
272030f984aSJacob Faibussowitsch 
27317f48955SJacob Faibussowitsch template <DeviceType T>
2746d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275d71ae5a4SJacob Faibussowitsch {
2767a101e5eSJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
277030f984aSJacob Faibussowitsch 
278030f984aSJacob Faibussowitsch   PetscFunctionBegin;
2790e6b6b59SJacob Faibussowitsch   PetscCall(dci->stream.destroy());
280030f984aSJacob Faibussowitsch   // set these to null so they aren't usable until setup is called again
281030f984aSJacob Faibussowitsch   dci->blas   = nullptr;
282030f984aSJacob Faibussowitsch   dci->solver = nullptr;
2833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
284030f984aSJacob Faibussowitsch }
285030f984aSJacob Faibussowitsch 
28617f48955SJacob Faibussowitsch template <DeviceType T>
2876d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288d71ae5a4SJacob Faibussowitsch {
2897a101e5eSJacob Faibussowitsch   const auto dci   = impls_cast_(dctx);
2900e6b6b59SJacob Faibussowitsch   auto      &event = dci->event;
291030f984aSJacob Faibussowitsch 
292030f984aSJacob Faibussowitsch   PetscFunctionBegin;
2930e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
2940e6b6b59SJacob Faibussowitsch   PetscCall(dci->stream.change_type(dctx->streamType));
2950e6b6b59SJacob Faibussowitsch   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
297a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_FALSE;
298a4af0ceeSJacob Faibussowitsch #endif
2993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
300030f984aSJacob Faibussowitsch }
301030f984aSJacob Faibussowitsch 
30217f48955SJacob Faibussowitsch template <DeviceType T>
3036d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304d71ae5a4SJacob Faibussowitsch {
305030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3060e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
3074b955ea4SJacob Faibussowitsch   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308d71ae5a4SJacob Faibussowitsch   case cupmSuccess:
309d71ae5a4SJacob Faibussowitsch     *idle = PETSC_TRUE;
310d71ae5a4SJacob Faibussowitsch     break;
311d71ae5a4SJacob Faibussowitsch   case cupmErrorNotReady:
312d71ae5a4SJacob Faibussowitsch     *idle = PETSC_FALSE;
3134b955ea4SJacob Faibussowitsch     // reset the error
3144b955ea4SJacob Faibussowitsch     cerr = cupmGetLastError();
3154b955ea4SJacob Faibussowitsch     static_cast<void>(cerr);
316d71ae5a4SJacob Faibussowitsch     break;
317d71ae5a4SJacob Faibussowitsch   default:
318d71ae5a4SJacob Faibussowitsch     PetscCallCUPM(cerr);
319d71ae5a4SJacob Faibussowitsch     PetscUnreachable();
320030f984aSJacob Faibussowitsch   }
3213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
322030f984aSJacob Faibussowitsch }
323030f984aSJacob Faibussowitsch 
32417f48955SJacob Faibussowitsch template <DeviceType T>
3256d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
326d71ae5a4SJacob Faibussowitsch {
3270e6b6b59SJacob Faibussowitsch   const auto dcib  = impls_cast_(dctxb);
3280e6b6b59SJacob Faibussowitsch   const auto event = dcib->event;
329030f984aSJacob Faibussowitsch 
330030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3310e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctxa, dctxb));
3320e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
3330e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
3343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
335030f984aSJacob Faibussowitsch }
336030f984aSJacob Faibussowitsch 
33717f48955SJacob Faibussowitsch template <DeviceType T>
3386d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
339d71ae5a4SJacob Faibussowitsch {
3400e6b6b59SJacob Faibussowitsch   auto idle = PETSC_TRUE;
341030f984aSJacob Faibussowitsch 
342030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3430e6b6b59SJacob Faibussowitsch   PetscCall(query(dctx, &idle));
3440e6b6b59SJacob Faibussowitsch   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
3453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
346030f984aSJacob Faibussowitsch }
347030f984aSJacob Faibussowitsch 
34817f48955SJacob Faibussowitsch template <DeviceType T>
34917f48955SJacob Faibussowitsch template <typename handle_t>
3506d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
351d71ae5a4SJacob Faibussowitsch {
352a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
3537a101e5eSJacob Faibussowitsch   PetscCall(initialize_handle_(handle_t{}, dctx));
3547a101e5eSJacob Faibussowitsch   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
3553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
356a4af0ceeSJacob Faibussowitsch }
357a4af0ceeSJacob Faibussowitsch 
35817f48955SJacob Faibussowitsch template <DeviceType T>
35931d47070SJunchao Zhang template <typename handle_t>
36097cd0981SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
36131d47070SJunchao Zhang {
36231d47070SJunchao Zhang   using handle_type = typename handle_t::type;
36331d47070SJunchao Zhang 
36431d47070SJunchao Zhang   PetscFunctionBegin;
36531d47070SJunchao Zhang   PetscCall(initialize_handle_(handle_t{}, dctx));
36697cd0981SJacob Faibussowitsch   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
36731d47070SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
36831d47070SJunchao Zhang }
36931d47070SJunchao Zhang 
37031d47070SJunchao Zhang template <DeviceType T>
3716d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
372d71ae5a4SJacob Faibussowitsch {
3730e6b6b59SJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
374a4af0ceeSJacob Faibussowitsch 
375a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
3760e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
377a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
3785f80ce2aSJacob Faibussowitsch   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
379a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_TRUE;
380a4af0ceeSJacob Faibussowitsch #endif
38117f48955SJacob Faibussowitsch   if (!dci->begin) {
3820e6b6b59SJacob Faibussowitsch     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
3839566063dSJacob Faibussowitsch     PetscCallCUPM(cupmEventCreate(&dci->begin));
3849566063dSJacob Faibussowitsch     PetscCallCUPM(cupmEventCreate(&dci->end));
38517f48955SJacob Faibussowitsch   }
3860e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
3873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
388a4af0ceeSJacob Faibussowitsch }
389a4af0ceeSJacob Faibussowitsch 
39017f48955SJacob Faibussowitsch template <DeviceType T>
3916d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
392d71ae5a4SJacob Faibussowitsch {
393a4af0ceeSJacob Faibussowitsch   float      gtime;
3940e6b6b59SJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
3950e6b6b59SJacob Faibussowitsch   const auto end = dci->end;
396a4af0ceeSJacob Faibussowitsch 
397a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
3980e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
399a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
4005f80ce2aSJacob Faibussowitsch   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
401a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_FALSE;
402a4af0ceeSJacob Faibussowitsch #endif
4030e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
4040e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventSynchronize(end));
4050e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
40617f48955SJacob Faibussowitsch   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
4073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
408a4af0ceeSJacob Faibussowitsch }
409a4af0ceeSJacob Faibussowitsch 
4100e6b6b59SJacob Faibussowitsch template <DeviceType T>
4116d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
412d71ae5a4SJacob Faibussowitsch {
4130e6b6b59SJacob Faibussowitsch   const auto &stream = impls_cast_(dctx)->stream;
4140e6b6b59SJacob Faibussowitsch 
4150e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4160e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
4170e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "allocating"));
4180e6b6b59SJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) {
4196797ed33SJacob Faibussowitsch     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4200e6b6b59SJacob Faibussowitsch   } else {
4216797ed33SJacob Faibussowitsch     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4220e6b6b59SJacob Faibussowitsch   }
4236797ed33SJacob Faibussowitsch   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
4243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4250e6b6b59SJacob Faibussowitsch }
4260e6b6b59SJacob Faibussowitsch 
4270e6b6b59SJacob Faibussowitsch template <DeviceType T>
4286d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
429d71ae5a4SJacob Faibussowitsch {
4300e6b6b59SJacob Faibussowitsch   const auto &stream = impls_cast_(dctx)->stream;
4310e6b6b59SJacob Faibussowitsch 
4320e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4330e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
4340e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "freeing"));
4353ba16761SJacob Faibussowitsch   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
4360e6b6b59SJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) {
4370e6b6b59SJacob Faibussowitsch     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
4380e6b6b59SJacob Faibussowitsch     // if ptr exists still exists the pool didn't own it
4390e6b6b59SJacob Faibussowitsch     if (*ptr) {
4400e6b6b59SJacob Faibussowitsch       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
4410e6b6b59SJacob Faibussowitsch 
4420e6b6b59SJacob Faibussowitsch       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
4430e6b6b59SJacob Faibussowitsch       if (registered) {
4440e6b6b59SJacob Faibussowitsch         PetscCallCUPM(cupmFreeHost(*ptr));
4450e6b6b59SJacob Faibussowitsch       } else if (managed) {
4460e6b6b59SJacob Faibussowitsch         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
4470e6b6b59SJacob Faibussowitsch       }
4480e6b6b59SJacob Faibussowitsch     }
4490e6b6b59SJacob Faibussowitsch   } else {
4500e6b6b59SJacob Faibussowitsch     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
4516d54fb17SJacob Faibussowitsch     // if ptr still exists the pool didn't own it
4520e6b6b59SJacob Faibussowitsch     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
4530e6b6b59SJacob Faibussowitsch   }
4543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4550e6b6b59SJacob Faibussowitsch }
4560e6b6b59SJacob Faibussowitsch 
4570e6b6b59SJacob Faibussowitsch template <DeviceType T>
4586d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
459d71ae5a4SJacob Faibussowitsch {
4600e6b6b59SJacob Faibussowitsch   const auto stream = impls_cast_(dctx)->stream.get_stream();
4610e6b6b59SJacob Faibussowitsch 
4620e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4630e6b6b59SJacob Faibussowitsch   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
4640e6b6b59SJacob Faibussowitsch   if (mode == PETSC_DEVICE_COPY_HTOH) {
4656d54fb17SJacob Faibussowitsch     const auto cerr = cupmStreamQuery(stream);
4666d54fb17SJacob Faibussowitsch 
4670e6b6b59SJacob Faibussowitsch     // yes this is faster
4686d54fb17SJacob Faibussowitsch     if (cerr == cupmSuccess) {
4690e6b6b59SJacob Faibussowitsch       PetscCall(PetscMemcpy(dest, src, n));
4703ba16761SJacob Faibussowitsch       PetscFunctionReturn(PETSC_SUCCESS);
4716d54fb17SJacob Faibussowitsch     } else if (cerr == cupmErrorNotReady) {
4726d54fb17SJacob Faibussowitsch       auto PETSC_UNUSED unused = cupmGetLastError();
4736d54fb17SJacob Faibussowitsch 
4746d54fb17SJacob Faibussowitsch       static_cast<void>(unused);
4756d54fb17SJacob Faibussowitsch     } else {
4766d54fb17SJacob Faibussowitsch       PetscCallCUPM(cerr);
4770e6b6b59SJacob Faibussowitsch     }
4780e6b6b59SJacob Faibussowitsch   }
4793ba16761SJacob Faibussowitsch   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
4803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4810e6b6b59SJacob Faibussowitsch }
4820e6b6b59SJacob Faibussowitsch 
4830e6b6b59SJacob Faibussowitsch template <DeviceType T>
4846d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
485d71ae5a4SJacob Faibussowitsch {
4860e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4870e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
4880e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "zeroing"));
4896797ed33SJacob Faibussowitsch   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
4903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4910e6b6b59SJacob Faibussowitsch }
4920e6b6b59SJacob Faibussowitsch 
4930e6b6b59SJacob Faibussowitsch template <DeviceType T>
4948eb1d50fSPierre Jolivet inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
495d71ae5a4SJacob Faibussowitsch {
4960e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4973048253cSJacob Faibussowitsch   PetscCallCXX(event->data = new event_type{});
4980e6b6b59SJacob Faibussowitsch   event->destroy = [](PetscEvent event) {
4990e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
5000e6b6b59SJacob Faibussowitsch     delete event_cast_(event);
5010e6b6b59SJacob Faibussowitsch     event->data = nullptr;
5023ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
5030e6b6b59SJacob Faibussowitsch   };
5043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5050e6b6b59SJacob Faibussowitsch }
5060e6b6b59SJacob Faibussowitsch 
5070e6b6b59SJacob Faibussowitsch template <DeviceType T>
5086d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
509d71ae5a4SJacob Faibussowitsch {
5100e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5110e6b6b59SJacob Faibussowitsch   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
5123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5130e6b6b59SJacob Faibussowitsch }
5140e6b6b59SJacob Faibussowitsch 
5150e6b6b59SJacob Faibussowitsch template <DeviceType T>
5166d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
517d71ae5a4SJacob Faibussowitsch {
5180e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5190e6b6b59SJacob Faibussowitsch   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
5203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5210e6b6b59SJacob Faibussowitsch }
5220e6b6b59SJacob Faibussowitsch 
523030f984aSJacob Faibussowitsch // initialize the static member variables
5249371c9d4SSatish Balay template <DeviceType T>
5259371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false;
526030f984aSJacob Faibussowitsch 
52717f48955SJacob Faibussowitsch template <DeviceType T>
52817f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
529030f984aSJacob Faibussowitsch 
53017f48955SJacob Faibussowitsch template <DeviceType T>
53117f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
53217f48955SJacob Faibussowitsch 
5336ff55be4SJacob Faibussowitsch template <DeviceType T>
5346ff55be4SJacob Faibussowitsch constexpr _DeviceContextOps DeviceContext<T>::ops;
5356ff55be4SJacob Faibussowitsch 
5360e6b6b59SJacob Faibussowitsch } // namespace impl
537030f984aSJacob Faibussowitsch 
538a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates)
5390e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
5400e6b6b59SJacob Faibussowitsch using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
541030f984aSJacob Faibussowitsch 
542030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name
5430e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
544030f984aSJacob Faibussowitsch 
5450e6b6b59SJacob Faibussowitsch } // namespace cupm
54617f48955SJacob Faibussowitsch 
5470e6b6b59SJacob Faibussowitsch } // namespace device
54817f48955SJacob Faibussowitsch 
54917f48955SJacob Faibussowitsch } // namespace Petsc
550