xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 5fa70555f2cfa5f8527759fb2fd8b5523acdf153)
1a4963045SJacob Faibussowitsch #pragma once
2030f984aSJacob Faibussowitsch 
3a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
496a4b4d9SJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp>
57a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h>
6030f984aSJacob Faibussowitsch 
70e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp>
8a4af0ceeSJacob Faibussowitsch 
90e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
100e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp"
110e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp"
120e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
130e6b6b59SJacob Faibussowitsch 
14d71ae5a4SJacob Faibussowitsch namespace Petsc
15d71ae5a4SJacob Faibussowitsch {
16a4af0ceeSJacob Faibussowitsch 
17d71ae5a4SJacob Faibussowitsch namespace device
18d71ae5a4SJacob Faibussowitsch {
1917f48955SJacob Faibussowitsch 
20d71ae5a4SJacob Faibussowitsch namespace cupm
21d71ae5a4SJacob Faibussowitsch {
2217f48955SJacob Faibussowitsch 
23d71ae5a4SJacob Faibussowitsch namespace impl
24d71ae5a4SJacob Faibussowitsch {
25030f984aSJacob Faibussowitsch 
2617f48955SJacob Faibussowitsch template <DeviceType T>
2785f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
2817f48955SJacob Faibussowitsch public:
2996a4b4d9SJacob Faibussowitsch   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
3017f48955SJacob Faibussowitsch 
3117f48955SJacob Faibussowitsch private:
329371c9d4SSatish Balay   template <typename H, std::size_t>
339371c9d4SSatish Balay   struct HandleTag {
349371c9d4SSatish Balay     using type = H;
359371c9d4SSatish Balay   };
360e6b6b59SJacob Faibussowitsch 
377a101e5eSJacob Faibussowitsch   using stream_tag = HandleTag<cupmStream_t, 0>;
387a101e5eSJacob Faibussowitsch   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
397a101e5eSJacob Faibussowitsch   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40a4af0ceeSJacob Faibussowitsch 
410e6b6b59SJacob Faibussowitsch   using stream_type = CUPMStream<T>;
420e6b6b59SJacob Faibussowitsch   using event_type  = CUPMEvent<T>;
430e6b6b59SJacob Faibussowitsch 
44030f984aSJacob Faibussowitsch public:
45030f984aSJacob Faibussowitsch   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46030f984aSJacob Faibussowitsch   // header, but since we are using the power of templates it must be declared part of
47030f984aSJacob Faibussowitsch   // this class to have easy access the same typedefs. Technically one can make a
48030f984aSJacob Faibussowitsch   // templated struct outside the class but it's more code for the same result.
493048253cSJacob Faibussowitsch   struct PetscDeviceContext_IMPLS {
500e6b6b59SJacob Faibussowitsch     stream_type stream{};
510e6b6b59SJacob Faibussowitsch     cupmEvent_t event{};
520e6b6b59SJacob Faibussowitsch     cupmEvent_t begin{}; // timer-only
530e6b6b59SJacob Faibussowitsch     cupmEvent_t end{};   // timer-only
54a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
550e6b6b59SJacob Faibussowitsch     PetscBool timerInUse{};
56*5268dc8aSHong Zhang     PetscBool EnergyMeterInUse{};
57a4af0ceeSJacob Faibussowitsch #endif
580e6b6b59SJacob Faibussowitsch     cupmBlasHandle_t   blas{};
590e6b6b59SJacob Faibussowitsch     cupmSolverHandle_t solver{};
60*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
61*5268dc8aSHong Zhang     nvmlDevice_t       nvmlHandle{};
62*5268dc8aSHong Zhang     unsigned long long energymeterbegin{};
63*5268dc8aSHong Zhang     unsigned long long energymeterend{};
64*5268dc8aSHong Zhang #endif
65a4af0ceeSJacob Faibussowitsch 
660e6b6b59SJacob Faibussowitsch     constexpr PetscDeviceContext_IMPLS() noexcept = default;
670e6b6b59SJacob Faibussowitsch 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS6831d47070SJunchao Zhang     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
690e6b6b59SJacob Faibussowitsch 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS7031d47070SJunchao Zhang     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
710e6b6b59SJacob Faibussowitsch 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS7231d47070SJunchao Zhang     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
73030f984aSJacob Faibussowitsch   };
74030f984aSJacob Faibussowitsch 
75030f984aSJacob Faibussowitsch private:
7617f48955SJacob Faibussowitsch   static bool initialized_;
776d54fb17SJacob Faibussowitsch 
7817f48955SJacob Faibussowitsch   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
7917f48955SJacob Faibussowitsch   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
80030f984aSJacob Faibussowitsch 
impls_cast_(PetscDeviceContext ptr)81d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
82a4af0ceeSJacob Faibussowitsch 
event_cast_(PetscEvent event)83d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
840e6b6b59SJacob Faibussowitsch 
CUPMBLAS_HANDLE_CREATE()85d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
867a101e5eSJacob Faibussowitsch 
CUPMSOLVER_HANDLE_CREATE()87d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
887a101e5eSJacob Faibussowitsch 
897a101e5eSJacob Faibussowitsch   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
907a101e5eSJacob Faibussowitsch   // handles
initialize_handle_(stream_tag,PetscDeviceContext)91089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
927a101e5eSJacob Faibussowitsch 
initialize_handle_(blas_tag,PetscDeviceContext dctx)9396a4b4d9SJacob Faibussowitsch   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
94d71ae5a4SJacob Faibussowitsch   {
9596a4b4d9SJacob Faibussowitsch     const auto dci    = impls_cast_(dctx);
9696a4b4d9SJacob Faibussowitsch     auto      &handle = blashandles_[dctx->device->deviceId];
977a101e5eSJacob Faibussowitsch 
98030f984aSJacob Faibussowitsch     PetscFunctionBegin;
9996a4b4d9SJacob Faibussowitsch     if (!handle) {
100b665b14eSToby Isaac       PetscCall(PetscLogEventsPause());
1017a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
10217f48955SJacob Faibussowitsch       for (auto i = 0; i < 3; ++i) {
10396a4b4d9SJacob Faibussowitsch         const auto cberr = cupmBlasCreate(handle.ptr_to());
10417f48955SJacob Faibussowitsch         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
1059566063dSJacob Faibussowitsch         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
10617f48955SJacob Faibussowitsch         if (i != 2) {
1079566063dSJacob Faibussowitsch           PetscCall(PetscSleep(3));
10817f48955SJacob Faibussowitsch           continue;
109a4af0ceeSJacob Faibussowitsch         }
1105f80ce2aSJacob Faibussowitsch         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111a4af0ceeSJacob Faibussowitsch       }
1127a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113b665b14eSToby Isaac       PetscCall(PetscLogEventsResume());
114030f984aSJacob Faibussowitsch     }
1150e6b6b59SJacob Faibussowitsch     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
1167a101e5eSJacob Faibussowitsch     dci->blas = handle;
1173ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1187a101e5eSJacob Faibussowitsch   }
1197a101e5eSJacob Faibussowitsch 
initialize_handle_(solver_tag,PetscDeviceContext dctx)120089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121d71ae5a4SJacob Faibussowitsch   {
1226d54fb17SJacob Faibussowitsch     const auto dci    = impls_cast_(dctx);
1236d54fb17SJacob Faibussowitsch     auto      &handle = solverhandles_[dctx->device->deviceId];
1247a101e5eSJacob Faibussowitsch 
1257a101e5eSJacob Faibussowitsch     PetscFunctionBegin;
12696a4b4d9SJacob Faibussowitsch     if (!handle) {
127b665b14eSToby Isaac       PetscCall(PetscLogEventsPause());
1287a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
12996a4b4d9SJacob Faibussowitsch       for (auto i = 0; i < 3; ++i) {
13096a4b4d9SJacob Faibussowitsch         const auto cerr = cupmSolverCreate(&handle);
13196a4b4d9SJacob Faibussowitsch         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
13296a4b4d9SJacob Faibussowitsch         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
13396a4b4d9SJacob Faibussowitsch         if (i < 2) {
13496a4b4d9SJacob Faibussowitsch           PetscCall(PetscSleep(3));
13596a4b4d9SJacob Faibussowitsch           continue;
13696a4b4d9SJacob Faibussowitsch         }
13796a4b4d9SJacob Faibussowitsch         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
13896a4b4d9SJacob Faibussowitsch       }
1397a101e5eSJacob Faibussowitsch       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140b665b14eSToby Isaac       PetscCall(PetscLogEventsResume());
14196a4b4d9SJacob Faibussowitsch     }
14296a4b4d9SJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
1437a101e5eSJacob Faibussowitsch     dci->solver = handle;
1443ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
14517f48955SJacob Faibussowitsch   }
14617f48955SJacob Faibussowitsch 
check_current_device_(PetscDeviceContext dctxl,PetscDeviceContext dctxr)147089fb57cSJacob Faibussowitsch   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148d71ae5a4SJacob Faibussowitsch   {
1490e6b6b59SJacob Faibussowitsch     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
1500e6b6b59SJacob Faibussowitsch 
1510e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
1520e6b6b59SJacob Faibussowitsch     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
1530e6b6b59SJacob Faibussowitsch                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
1540e6b6b59SJacob Faibussowitsch     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
1550e6b6b59SJacob Faibussowitsch     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
1560e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
1573ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1580e6b6b59SJacob Faibussowitsch   }
1590e6b6b59SJacob Faibussowitsch 
check_current_device_(PetscDeviceContext dctx)160089fb57cSJacob Faibussowitsch   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
1610e6b6b59SJacob Faibussowitsch 
finalize_()162089fb57cSJacob Faibussowitsch   static PetscErrorCode finalize_() noexcept
163d71ae5a4SJacob Faibussowitsch   {
16417f48955SJacob Faibussowitsch     PetscFunctionBegin;
16517f48955SJacob Faibussowitsch     for (auto &&handle : blashandles_) {
16617f48955SJacob Faibussowitsch       if (handle) {
1679566063dSJacob Faibussowitsch         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
16817f48955SJacob Faibussowitsch         handle = nullptr;
16917f48955SJacob Faibussowitsch       }
17017f48955SJacob Faibussowitsch     }
17117f48955SJacob Faibussowitsch     for (auto &&handle : solverhandles_) {
17217f48955SJacob Faibussowitsch       if (handle) {
17396a4b4d9SJacob Faibussowitsch         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
17417f48955SJacob Faibussowitsch         handle = nullptr;
17517f48955SJacob Faibussowitsch       }
17617f48955SJacob Faibussowitsch     }
17717f48955SJacob Faibussowitsch     initialized_ = false;
1783ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
17917f48955SJacob Faibussowitsch   }
18017f48955SJacob Faibussowitsch 
1810e6b6b59SJacob Faibussowitsch   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
default_pool_()182d71ae5a4SJacob Faibussowitsch   PETSC_NODISCARD static PoolType &default_pool_() noexcept
183d71ae5a4SJacob Faibussowitsch   {
1840e6b6b59SJacob Faibussowitsch     static PoolType pool;
1850e6b6b59SJacob Faibussowitsch     return pool;
1860e6b6b59SJacob Faibussowitsch   }
187030f984aSJacob Faibussowitsch 
check_memtype_(PetscMemType mtype,const char mess[])188089fb57cSJacob Faibussowitsch   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189d71ae5a4SJacob Faibussowitsch   {
1900e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
1910e6b6b59SJacob Faibussowitsch     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
1923ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1930e6b6b59SJacob Faibussowitsch   }
1940e6b6b59SJacob Faibussowitsch 
1950e6b6b59SJacob Faibussowitsch public:
196030f984aSJacob Faibussowitsch   // All of these functions MUST be static in order to be callable from C, otherwise they
197030f984aSJacob Faibussowitsch   // get the implicit 'this' pointer tacked on
198089fb57cSJacob Faibussowitsch   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199089fb57cSJacob Faibussowitsch   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200089fb57cSJacob Faibussowitsch   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201089fb57cSJacob Faibussowitsch   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202089fb57cSJacob Faibussowitsch   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203089fb57cSJacob Faibussowitsch   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204a4af0ceeSJacob Faibussowitsch   template <typename Handle_t>
205089fb57cSJacob Faibussowitsch   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
20631d47070SJunchao Zhang   template <typename Handle_t>
20797cd0981SJacob Faibussowitsch   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208089fb57cSJacob Faibussowitsch   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209089fb57cSJacob Faibussowitsch   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210*5268dc8aSHong Zhang   static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211*5268dc8aSHong Zhang   static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212*5268dc8aSHong Zhang   static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213089fb57cSJacob Faibussowitsch   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214089fb57cSJacob Faibussowitsch   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215089fb57cSJacob Faibussowitsch   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216089fb57cSJacob Faibussowitsch   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217089fb57cSJacob Faibussowitsch   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218089fb57cSJacob Faibussowitsch   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219089fb57cSJacob Faibussowitsch   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
2207a101e5eSJacob Faibussowitsch 
2217a101e5eSJacob Faibussowitsch   // not a PetscDeviceContext method, this registers the class
222089fb57cSJacob Faibussowitsch   static PetscErrorCode initialize(PetscDevice) noexcept;
2230e6b6b59SJacob Faibussowitsch 
2240e6b6b59SJacob Faibussowitsch   // clang-format off
2256ff55be4SJacob Faibussowitsch   static constexpr _DeviceContextOps ops = {
2266ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(destroy, destroy),
2276ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(changestreamtype, changeStreamType),
2286ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(setup, setUp),
2296ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(query, query),
2306ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(waitforcontext, waitForContext),
2316ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(synchronize, synchronize),
2326ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
2336ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
23431d47070SJunchao Zhang     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
2356ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(begintimer, beginTimer),
2366ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(endtimer, endTimer),
237*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238*5268dc8aSHong Zhang     PetscDesignatedInitializer(getpower, getPower),
239*5268dc8aSHong Zhang #else
240*5268dc8aSHong Zhang     PetscDesignatedInitializer(getpower, nullptr),
241*5268dc8aSHong Zhang #endif
242*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
243*5268dc8aSHong Zhang     PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244*5268dc8aSHong Zhang     PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245*5268dc8aSHong Zhang #else
246*5268dc8aSHong Zhang     PetscDesignatedInitializer(beginenergymeter, nullptr),
247*5268dc8aSHong Zhang     PetscDesignatedInitializer(endenergymeter, nullptr),
248*5268dc8aSHong Zhang #endif
2496ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memalloc, memAlloc),
2506ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memfree, memFree),
2516ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memcopy, memCopy),
2526ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(memset, memSet),
2536ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(createevent, createEvent),
2546ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(recordevent, recordEvent),
2556ff55be4SJacob Faibussowitsch     PetscDesignatedInitializer(waitforevent, waitForEvent)
2560e6b6b59SJacob Faibussowitsch   };
2570e6b6b59SJacob Faibussowitsch   // clang-format on
258030f984aSJacob Faibussowitsch };
259030f984aSJacob Faibussowitsch 
2600e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS
26117f48955SJacob Faibussowitsch template <DeviceType T>
initialize(PetscDevice device)2626d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263d71ae5a4SJacob Faibussowitsch {
2647a101e5eSJacob Faibussowitsch   PetscFunctionBegin;
2657a101e5eSJacob Faibussowitsch   if (PetscUnlikely(!initialized_)) {
2660e6b6b59SJacob Faibussowitsch     uint64_t      threshold = UINT64_MAX;
2676d54fb17SJacob Faibussowitsch     cupmMemPool_t mempool;
2680e6b6b59SJacob Faibussowitsch 
2697a101e5eSJacob Faibussowitsch     initialized_ = true;
2706d54fb17SJacob Faibussowitsch     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
2710e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
2720e6b6b59SJacob Faibussowitsch     blashandles_.fill(nullptr);
2730e6b6b59SJacob Faibussowitsch     solverhandles_.fill(nullptr);
2747a101e5eSJacob Faibussowitsch     PetscCall(PetscRegisterFinalize(finalize_));
2757a101e5eSJacob Faibussowitsch   }
2763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2777a101e5eSJacob Faibussowitsch }
2787a101e5eSJacob Faibussowitsch 
2797a101e5eSJacob Faibussowitsch template <DeviceType T>
destroy(PetscDeviceContext dctx)2806d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281d71ae5a4SJacob Faibussowitsch {
282030f984aSJacob Faibussowitsch   PetscFunctionBegin;
2830e6b6b59SJacob Faibussowitsch   if (const auto dci = impls_cast_(dctx)) {
2840e6b6b59SJacob Faibussowitsch     PetscCall(dci->stream.destroy());
285146a86ebSJacob Faibussowitsch     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
2869566063dSJacob Faibussowitsch     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
2879566063dSJacob Faibussowitsch     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
2880e6b6b59SJacob Faibussowitsch     delete dci;
2890e6b6b59SJacob Faibussowitsch     dctx->data = nullptr;
2900e6b6b59SJacob Faibussowitsch   }
2913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
292030f984aSJacob Faibussowitsch }
293030f984aSJacob Faibussowitsch 
29417f48955SJacob Faibussowitsch template <DeviceType T>
changeStreamType(PetscDeviceContext dctx,PETSC_UNUSED PetscStreamType stype)2956d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296d71ae5a4SJacob Faibussowitsch {
2977a101e5eSJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
298030f984aSJacob Faibussowitsch 
299030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3000e6b6b59SJacob Faibussowitsch   PetscCall(dci->stream.destroy());
301030f984aSJacob Faibussowitsch   // set these to null so they aren't usable until setup is called again
302030f984aSJacob Faibussowitsch   dci->blas   = nullptr;
303030f984aSJacob Faibussowitsch   dci->solver = nullptr;
3043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
305030f984aSJacob Faibussowitsch }
306030f984aSJacob Faibussowitsch 
30717f48955SJacob Faibussowitsch template <DeviceType T>
setUp(PetscDeviceContext dctx)3086d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309d71ae5a4SJacob Faibussowitsch {
3107a101e5eSJacob Faibussowitsch   const auto dci   = impls_cast_(dctx);
3110e6b6b59SJacob Faibussowitsch   auto      &event = dci->event;
312030f984aSJacob Faibussowitsch 
313030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3140e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
3150e6b6b59SJacob Faibussowitsch   PetscCall(dci->stream.change_type(dctx->streamType));
3160e6b6b59SJacob Faibussowitsch   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
318a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_FALSE;
319a4af0ceeSJacob Faibussowitsch #endif
3203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
321030f984aSJacob Faibussowitsch }
322030f984aSJacob Faibussowitsch 
32317f48955SJacob Faibussowitsch template <DeviceType T>
query(PetscDeviceContext dctx,PetscBool * idle)3246d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325d71ae5a4SJacob Faibussowitsch {
326030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3270e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
3284b955ea4SJacob Faibussowitsch   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329d71ae5a4SJacob Faibussowitsch   case cupmSuccess:
330d71ae5a4SJacob Faibussowitsch     *idle = PETSC_TRUE;
331d71ae5a4SJacob Faibussowitsch     break;
332d71ae5a4SJacob Faibussowitsch   case cupmErrorNotReady:
333d71ae5a4SJacob Faibussowitsch     *idle = PETSC_FALSE;
3344b955ea4SJacob Faibussowitsch     // reset the error
3354b955ea4SJacob Faibussowitsch     cerr = cupmGetLastError();
3364b955ea4SJacob Faibussowitsch     static_cast<void>(cerr);
337d71ae5a4SJacob Faibussowitsch     break;
338d71ae5a4SJacob Faibussowitsch   default:
339d71ae5a4SJacob Faibussowitsch     PetscCallCUPM(cerr);
340d71ae5a4SJacob Faibussowitsch     PetscUnreachable();
341030f984aSJacob Faibussowitsch   }
3423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
343030f984aSJacob Faibussowitsch }
344030f984aSJacob Faibussowitsch 
34517f48955SJacob Faibussowitsch template <DeviceType T>
waitForContext(PetscDeviceContext dctxa,PetscDeviceContext dctxb)3466d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347d71ae5a4SJacob Faibussowitsch {
3480e6b6b59SJacob Faibussowitsch   const auto dcib  = impls_cast_(dctxb);
3490e6b6b59SJacob Faibussowitsch   const auto event = dcib->event;
350030f984aSJacob Faibussowitsch 
351030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3520e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctxa, dctxb));
3530e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
3540e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
3553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
356030f984aSJacob Faibussowitsch }
357030f984aSJacob Faibussowitsch 
35817f48955SJacob Faibussowitsch template <DeviceType T>
synchronize(PetscDeviceContext dctx)3596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360d71ae5a4SJacob Faibussowitsch {
3610e6b6b59SJacob Faibussowitsch   auto idle = PETSC_TRUE;
362030f984aSJacob Faibussowitsch 
363030f984aSJacob Faibussowitsch   PetscFunctionBegin;
3640e6b6b59SJacob Faibussowitsch   PetscCall(query(dctx, &idle));
3650e6b6b59SJacob Faibussowitsch   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
3663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
367030f984aSJacob Faibussowitsch }
368030f984aSJacob Faibussowitsch 
36917f48955SJacob Faibussowitsch template <DeviceType T>
37017f48955SJacob Faibussowitsch template <typename handle_t>
getHandle(PetscDeviceContext dctx,void * handle)3716d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372d71ae5a4SJacob Faibussowitsch {
373a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
3747a101e5eSJacob Faibussowitsch   PetscCall(initialize_handle_(handle_t{}, dctx));
3757a101e5eSJacob Faibussowitsch   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
3763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
377a4af0ceeSJacob Faibussowitsch }
378a4af0ceeSJacob Faibussowitsch 
37917f48955SJacob Faibussowitsch template <DeviceType T>
38031d47070SJunchao Zhang template <typename handle_t>
getHandlePtr(PetscDeviceContext dctx,void ** handle)38197cd0981SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
38231d47070SJunchao Zhang {
38331d47070SJunchao Zhang   using handle_type = typename handle_t::type;
38431d47070SJunchao Zhang 
38531d47070SJunchao Zhang   PetscFunctionBegin;
38631d47070SJunchao Zhang   PetscCall(initialize_handle_(handle_t{}, dctx));
38797cd0981SJacob Faibussowitsch   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
38831d47070SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
38931d47070SJunchao Zhang }
39031d47070SJunchao Zhang 
39131d47070SJunchao Zhang template <DeviceType T>
beginTimer(PetscDeviceContext dctx)3926d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393d71ae5a4SJacob Faibussowitsch {
3940e6b6b59SJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
395a4af0ceeSJacob Faibussowitsch 
396a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
3970e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
398a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
3995f80ce2aSJacob Faibussowitsch   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_TRUE;
401a4af0ceeSJacob Faibussowitsch #endif
40217f48955SJacob Faibussowitsch   if (!dci->begin) {
4030e6b6b59SJacob Faibussowitsch     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
4049566063dSJacob Faibussowitsch     PetscCallCUPM(cupmEventCreate(&dci->begin));
4059566063dSJacob Faibussowitsch     PetscCallCUPM(cupmEventCreate(&dci->end));
40617f48955SJacob Faibussowitsch   }
4070e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
4083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
409a4af0ceeSJacob Faibussowitsch }
410a4af0ceeSJacob Faibussowitsch 
41117f48955SJacob Faibussowitsch template <DeviceType T>
endTimer(PetscDeviceContext dctx,PetscLogDouble * elapsed)4126d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413d71ae5a4SJacob Faibussowitsch {
414a4af0ceeSJacob Faibussowitsch   float      gtime;
4150e6b6b59SJacob Faibussowitsch   const auto dci = impls_cast_(dctx);
4160e6b6b59SJacob Faibussowitsch   const auto end = dci->end;
417a4af0ceeSJacob Faibussowitsch 
418a4af0ceeSJacob Faibussowitsch   PetscFunctionBegin;
4190e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
420a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
4215f80ce2aSJacob Faibussowitsch   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422a4af0ceeSJacob Faibussowitsch   dci->timerInUse = PETSC_FALSE;
423a4af0ceeSJacob Faibussowitsch #endif
4240e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
4250e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventSynchronize(end));
4260e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
42717f48955SJacob Faibussowitsch   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
4283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
429a4af0ceeSJacob Faibussowitsch }
430a4af0ceeSJacob Faibussowitsch 
431*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432*5268dc8aSHong Zhang template <DeviceType T>
getPower(PetscDeviceContext dctx,PetscLogDouble * power)433*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434*5268dc8aSHong Zhang {
435*5268dc8aSHong Zhang   const auto       dci = impls_cast_(dctx);
436*5268dc8aSHong Zhang   nvmlFieldValue_t values[1];
437*5268dc8aSHong Zhang 
438*5268dc8aSHong Zhang   PetscFunctionBegin;
439*5268dc8aSHong Zhang   PetscCall(check_current_device_(dctx));
440*5268dc8aSHong Zhang   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441*5268dc8aSHong Zhang   values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442*5268dc8aSHong Zhang   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443*5268dc8aSHong Zhang   PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444*5268dc8aSHong Zhang   *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445*5268dc8aSHong Zhang   PetscFunctionReturn(PETSC_SUCCESS);
446*5268dc8aSHong Zhang }
447*5268dc8aSHong Zhang #endif
448*5268dc8aSHong Zhang 
449*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
450*5268dc8aSHong Zhang template <DeviceType T>
beginEnergyMeter(PetscDeviceContext dctx)451*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452*5268dc8aSHong Zhang {
453*5268dc8aSHong Zhang   const auto dci = impls_cast_(dctx);
454*5268dc8aSHong Zhang 
455*5268dc8aSHong Zhang   PetscFunctionBegin;
456*5268dc8aSHong Zhang   PetscCall(check_current_device_(dctx));
457*5268dc8aSHong Zhang   #if PetscDefined(USE_DEBUG)
458*5268dc8aSHong Zhang   PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459*5268dc8aSHong Zhang   dci->EnergyMeterInUse = PETSC_TRUE;
460*5268dc8aSHong Zhang   #endif
461*5268dc8aSHong Zhang   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462*5268dc8aSHong Zhang   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463*5268dc8aSHong Zhang   PetscFunctionReturn(PETSC_SUCCESS);
464*5268dc8aSHong Zhang }
465*5268dc8aSHong Zhang 
466*5268dc8aSHong Zhang template <DeviceType T>
endEnergyMeter(PetscDeviceContext dctx,PetscLogDouble * energy)467*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468*5268dc8aSHong Zhang {
469*5268dc8aSHong Zhang   const auto dci = impls_cast_(dctx);
470*5268dc8aSHong Zhang 
471*5268dc8aSHong Zhang   PetscFunctionBegin;
472*5268dc8aSHong Zhang   PetscCall(check_current_device_(dctx));
473*5268dc8aSHong Zhang   #if PetscDefined(USE_DEBUG)
474*5268dc8aSHong Zhang   PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475*5268dc8aSHong Zhang   dci->EnergyMeterInUse = PETSC_FALSE;
476*5268dc8aSHong Zhang   #endif
477*5268dc8aSHong Zhang   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478*5268dc8aSHong Zhang   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479*5268dc8aSHong Zhang   *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480*5268dc8aSHong Zhang   PetscFunctionReturn(PETSC_SUCCESS);
481*5268dc8aSHong Zhang }
482*5268dc8aSHong Zhang #endif
483*5268dc8aSHong Zhang 
4840e6b6b59SJacob Faibussowitsch template <DeviceType T>
memAlloc(PetscDeviceContext dctx,PetscBool clear,PetscMemType mtype,std::size_t n,std::size_t alignment,void ** dest)4856d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486d71ae5a4SJacob Faibussowitsch {
4870e6b6b59SJacob Faibussowitsch   const auto &stream = impls_cast_(dctx)->stream;
4880e6b6b59SJacob Faibussowitsch 
4890e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
4900e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
4910e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "allocating"));
4920e6b6b59SJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) {
4936797ed33SJacob Faibussowitsch     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4940e6b6b59SJacob Faibussowitsch   } else {
4956797ed33SJacob Faibussowitsch     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4960e6b6b59SJacob Faibussowitsch   }
4976797ed33SJacob Faibussowitsch   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
4983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4990e6b6b59SJacob Faibussowitsch }
5000e6b6b59SJacob Faibussowitsch 
5010e6b6b59SJacob Faibussowitsch template <DeviceType T>
memFree(PetscDeviceContext dctx,PetscMemType mtype,void ** ptr)5026d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503d71ae5a4SJacob Faibussowitsch {
5040e6b6b59SJacob Faibussowitsch   const auto &stream = impls_cast_(dctx)->stream;
5050e6b6b59SJacob Faibussowitsch 
5060e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5070e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
5080e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "freeing"));
5093ba16761SJacob Faibussowitsch   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
5100e6b6b59SJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) {
5110e6b6b59SJacob Faibussowitsch     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
5120e6b6b59SJacob Faibussowitsch     // if ptr exists still exists the pool didn't own it
5130e6b6b59SJacob Faibussowitsch     if (*ptr) {
5140e6b6b59SJacob Faibussowitsch       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
5150e6b6b59SJacob Faibussowitsch 
5160e6b6b59SJacob Faibussowitsch       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
5170e6b6b59SJacob Faibussowitsch       if (registered) {
5180e6b6b59SJacob Faibussowitsch         PetscCallCUPM(cupmFreeHost(*ptr));
5190e6b6b59SJacob Faibussowitsch       } else if (managed) {
5200e6b6b59SJacob Faibussowitsch         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
5210e6b6b59SJacob Faibussowitsch       }
5220e6b6b59SJacob Faibussowitsch     }
5230e6b6b59SJacob Faibussowitsch   } else {
5240e6b6b59SJacob Faibussowitsch     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
5256d54fb17SJacob Faibussowitsch     // if ptr still exists the pool didn't own it
5260e6b6b59SJacob Faibussowitsch     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
5270e6b6b59SJacob Faibussowitsch   }
5283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5290e6b6b59SJacob Faibussowitsch }
5300e6b6b59SJacob Faibussowitsch 
5310e6b6b59SJacob Faibussowitsch template <DeviceType T>
memCopy(PetscDeviceContext dctx,void * PETSC_RESTRICT dest,const void * PETSC_RESTRICT src,std::size_t n,PetscDeviceCopyMode mode)5326d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533d71ae5a4SJacob Faibussowitsch {
5340e6b6b59SJacob Faibussowitsch   const auto stream = impls_cast_(dctx)->stream.get_stream();
5350e6b6b59SJacob Faibussowitsch 
5360e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5370e6b6b59SJacob Faibussowitsch   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
5380e6b6b59SJacob Faibussowitsch   if (mode == PETSC_DEVICE_COPY_HTOH) {
5396d54fb17SJacob Faibussowitsch     const auto cerr = cupmStreamQuery(stream);
5406d54fb17SJacob Faibussowitsch 
5410e6b6b59SJacob Faibussowitsch     // yes this is faster
5426d54fb17SJacob Faibussowitsch     if (cerr == cupmSuccess) {
5430e6b6b59SJacob Faibussowitsch       PetscCall(PetscMemcpy(dest, src, n));
5443ba16761SJacob Faibussowitsch       PetscFunctionReturn(PETSC_SUCCESS);
5456d54fb17SJacob Faibussowitsch     } else if (cerr == cupmErrorNotReady) {
5466d54fb17SJacob Faibussowitsch       auto PETSC_UNUSED unused = cupmGetLastError();
5476d54fb17SJacob Faibussowitsch 
5486d54fb17SJacob Faibussowitsch       static_cast<void>(unused);
5496d54fb17SJacob Faibussowitsch     } else {
5506d54fb17SJacob Faibussowitsch       PetscCallCUPM(cerr);
5510e6b6b59SJacob Faibussowitsch     }
5520e6b6b59SJacob Faibussowitsch   }
5533ba16761SJacob Faibussowitsch   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
5543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5550e6b6b59SJacob Faibussowitsch }
5560e6b6b59SJacob Faibussowitsch 
5570e6b6b59SJacob Faibussowitsch template <DeviceType T>
memSet(PetscDeviceContext dctx,PetscMemType mtype,void * ptr,PetscInt v,std::size_t n)5586d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559d71ae5a4SJacob Faibussowitsch {
5600e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5610e6b6b59SJacob Faibussowitsch   PetscCall(check_current_device_(dctx));
5620e6b6b59SJacob Faibussowitsch   PetscCall(check_memtype_(mtype, "zeroing"));
5636797ed33SJacob Faibussowitsch   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
5643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5650e6b6b59SJacob Faibussowitsch }
5660e6b6b59SJacob Faibussowitsch 
5670e6b6b59SJacob Faibussowitsch template <DeviceType T>
createEvent(PetscDeviceContext,PetscEvent event)5688eb1d50fSPierre Jolivet inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569d71ae5a4SJacob Faibussowitsch {
5700e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5713048253cSJacob Faibussowitsch   PetscCallCXX(event->data = new event_type{});
5720e6b6b59SJacob Faibussowitsch   event->destroy = [](PetscEvent event) {
5730e6b6b59SJacob Faibussowitsch     PetscFunctionBegin;
5740e6b6b59SJacob Faibussowitsch     delete event_cast_(event);
5750e6b6b59SJacob Faibussowitsch     event->data = nullptr;
5763ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
5770e6b6b59SJacob Faibussowitsch   };
5783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5790e6b6b59SJacob Faibussowitsch }
5800e6b6b59SJacob Faibussowitsch 
5810e6b6b59SJacob Faibussowitsch template <DeviceType T>
recordEvent(PetscDeviceContext dctx,PetscEvent event)5826d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583d71ae5a4SJacob Faibussowitsch {
5840e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5850e6b6b59SJacob Faibussowitsch   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
5863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5870e6b6b59SJacob Faibussowitsch }
5880e6b6b59SJacob Faibussowitsch 
5890e6b6b59SJacob Faibussowitsch template <DeviceType T>
waitForEvent(PetscDeviceContext dctx,PetscEvent event)5906d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591d71ae5a4SJacob Faibussowitsch {
5920e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
5930e6b6b59SJacob Faibussowitsch   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
5943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5950e6b6b59SJacob Faibussowitsch }
5960e6b6b59SJacob Faibussowitsch 
597030f984aSJacob Faibussowitsch // initialize the static member variables
5989371c9d4SSatish Balay template <DeviceType T>
5999371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false;
600030f984aSJacob Faibussowitsch 
60117f48955SJacob Faibussowitsch template <DeviceType T>
60217f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
603030f984aSJacob Faibussowitsch 
60417f48955SJacob Faibussowitsch template <DeviceType T>
60517f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
60617f48955SJacob Faibussowitsch 
6076ff55be4SJacob Faibussowitsch template <DeviceType T>
6086ff55be4SJacob Faibussowitsch constexpr _DeviceContextOps DeviceContext<T>::ops;
6096ff55be4SJacob Faibussowitsch 
6100e6b6b59SJacob Faibussowitsch } // namespace impl
611030f984aSJacob Faibussowitsch 
612a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates)
6130e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
6140e6b6b59SJacob Faibussowitsch using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
615030f984aSJacob Faibussowitsch 
616030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name
6170e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
618030f984aSJacob Faibussowitsch 
6190e6b6b59SJacob Faibussowitsch } // namespace cupm
62017f48955SJacob Faibussowitsch 
6210e6b6b59SJacob Faibussowitsch } // namespace device
62217f48955SJacob Faibussowitsch 
62317f48955SJacob Faibussowitsch } // namespace Petsc
624