1a4963045SJacob Faibussowitsch #pragma once
2030f984aSJacob Faibussowitsch
3a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
496a4b4d9SJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp>
57a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h>
6030f984aSJacob Faibussowitsch
70e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp>
8a4af0ceeSJacob Faibussowitsch
90e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
100e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp"
110e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp"
120e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
130e6b6b59SJacob Faibussowitsch
14d71ae5a4SJacob Faibussowitsch namespace Petsc
15d71ae5a4SJacob Faibussowitsch {
16a4af0ceeSJacob Faibussowitsch
17d71ae5a4SJacob Faibussowitsch namespace device
18d71ae5a4SJacob Faibussowitsch {
1917f48955SJacob Faibussowitsch
20d71ae5a4SJacob Faibussowitsch namespace cupm
21d71ae5a4SJacob Faibussowitsch {
2217f48955SJacob Faibussowitsch
23d71ae5a4SJacob Faibussowitsch namespace impl
24d71ae5a4SJacob Faibussowitsch {
25030f984aSJacob Faibussowitsch
2617f48955SJacob Faibussowitsch template <DeviceType T>
2785f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
2817f48955SJacob Faibussowitsch public:
2996a4b4d9SJacob Faibussowitsch PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
3017f48955SJacob Faibussowitsch
3117f48955SJacob Faibussowitsch private:
329371c9d4SSatish Balay template <typename H, std::size_t>
339371c9d4SSatish Balay struct HandleTag {
349371c9d4SSatish Balay using type = H;
359371c9d4SSatish Balay };
360e6b6b59SJacob Faibussowitsch
377a101e5eSJacob Faibussowitsch using stream_tag = HandleTag<cupmStream_t, 0>;
387a101e5eSJacob Faibussowitsch using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
397a101e5eSJacob Faibussowitsch using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40a4af0ceeSJacob Faibussowitsch
410e6b6b59SJacob Faibussowitsch using stream_type = CUPMStream<T>;
420e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>;
430e6b6b59SJacob Faibussowitsch
44030f984aSJacob Faibussowitsch public:
45030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of
47030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a
48030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result.
493048253cSJacob Faibussowitsch struct PetscDeviceContext_IMPLS {
500e6b6b59SJacob Faibussowitsch stream_type stream{};
510e6b6b59SJacob Faibussowitsch cupmEvent_t event{};
520e6b6b59SJacob Faibussowitsch cupmEvent_t begin{}; // timer-only
530e6b6b59SJacob Faibussowitsch cupmEvent_t end{}; // timer-only
54a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
550e6b6b59SJacob Faibussowitsch PetscBool timerInUse{};
56*5268dc8aSHong Zhang PetscBool EnergyMeterInUse{};
57a4af0ceeSJacob Faibussowitsch #endif
580e6b6b59SJacob Faibussowitsch cupmBlasHandle_t blas{};
590e6b6b59SJacob Faibussowitsch cupmSolverHandle_t solver{};
60*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
61*5268dc8aSHong Zhang nvmlDevice_t nvmlHandle{};
62*5268dc8aSHong Zhang unsigned long long energymeterbegin{};
63*5268dc8aSHong Zhang unsigned long long energymeterend{};
64*5268dc8aSHong Zhang #endif
65a4af0ceeSJacob Faibussowitsch
660e6b6b59SJacob Faibussowitsch constexpr PetscDeviceContext_IMPLS() noexcept = default;
670e6b6b59SJacob Faibussowitsch
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS6831d47070SJunchao Zhang PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
690e6b6b59SJacob Faibussowitsch
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS7031d47070SJunchao Zhang PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
710e6b6b59SJacob Faibussowitsch
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS7231d47070SJunchao Zhang PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
73030f984aSJacob Faibussowitsch };
74030f984aSJacob Faibussowitsch
75030f984aSJacob Faibussowitsch private:
7617f48955SJacob Faibussowitsch static bool initialized_;
776d54fb17SJacob Faibussowitsch
7817f48955SJacob Faibussowitsch static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
7917f48955SJacob Faibussowitsch static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
80030f984aSJacob Faibussowitsch
impls_cast_(PetscDeviceContext ptr)81d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
82a4af0ceeSJacob Faibussowitsch
event_cast_(PetscEvent event)83d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
840e6b6b59SJacob Faibussowitsch
CUPMBLAS_HANDLE_CREATE()85d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
867a101e5eSJacob Faibussowitsch
CUPMSOLVER_HANDLE_CREATE()87d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
887a101e5eSJacob Faibussowitsch
897a101e5eSJacob Faibussowitsch // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
907a101e5eSJacob Faibussowitsch // handles
initialize_handle_(stream_tag,PetscDeviceContext)91089fb57cSJacob Faibussowitsch static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
927a101e5eSJacob Faibussowitsch
initialize_handle_(blas_tag,PetscDeviceContext dctx)9396a4b4d9SJacob Faibussowitsch static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
94d71ae5a4SJacob Faibussowitsch {
9596a4b4d9SJacob Faibussowitsch const auto dci = impls_cast_(dctx);
9696a4b4d9SJacob Faibussowitsch auto &handle = blashandles_[dctx->device->deviceId];
977a101e5eSJacob Faibussowitsch
98030f984aSJacob Faibussowitsch PetscFunctionBegin;
9996a4b4d9SJacob Faibussowitsch if (!handle) {
100b665b14eSToby Isaac PetscCall(PetscLogEventsPause());
1017a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
10217f48955SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) {
10396a4b4d9SJacob Faibussowitsch const auto cberr = cupmBlasCreate(handle.ptr_to());
10417f48955SJacob Faibussowitsch if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
1059566063dSJacob Faibussowitsch if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
10617f48955SJacob Faibussowitsch if (i != 2) {
1079566063dSJacob Faibussowitsch PetscCall(PetscSleep(3));
10817f48955SJacob Faibussowitsch continue;
109a4af0ceeSJacob Faibussowitsch }
1105f80ce2aSJacob Faibussowitsch PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111a4af0ceeSJacob Faibussowitsch }
1127a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113b665b14eSToby Isaac PetscCall(PetscLogEventsResume());
114030f984aSJacob Faibussowitsch }
1150e6b6b59SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
1167a101e5eSJacob Faibussowitsch dci->blas = handle;
1173ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1187a101e5eSJacob Faibussowitsch }
1197a101e5eSJacob Faibussowitsch
initialize_handle_(solver_tag,PetscDeviceContext dctx)120089fb57cSJacob Faibussowitsch static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121d71ae5a4SJacob Faibussowitsch {
1226d54fb17SJacob Faibussowitsch const auto dci = impls_cast_(dctx);
1236d54fb17SJacob Faibussowitsch auto &handle = solverhandles_[dctx->device->deviceId];
1247a101e5eSJacob Faibussowitsch
1257a101e5eSJacob Faibussowitsch PetscFunctionBegin;
12696a4b4d9SJacob Faibussowitsch if (!handle) {
127b665b14eSToby Isaac PetscCall(PetscLogEventsPause());
1287a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
12996a4b4d9SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) {
13096a4b4d9SJacob Faibussowitsch const auto cerr = cupmSolverCreate(&handle);
13196a4b4d9SJacob Faibussowitsch if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
13296a4b4d9SJacob Faibussowitsch if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
13396a4b4d9SJacob Faibussowitsch if (i < 2) {
13496a4b4d9SJacob Faibussowitsch PetscCall(PetscSleep(3));
13596a4b4d9SJacob Faibussowitsch continue;
13696a4b4d9SJacob Faibussowitsch }
13796a4b4d9SJacob Faibussowitsch PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
13896a4b4d9SJacob Faibussowitsch }
1397a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140b665b14eSToby Isaac PetscCall(PetscLogEventsResume());
14196a4b4d9SJacob Faibussowitsch }
14296a4b4d9SJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
1437a101e5eSJacob Faibussowitsch dci->solver = handle;
1443ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
14517f48955SJacob Faibussowitsch }
14617f48955SJacob Faibussowitsch
check_current_device_(PetscDeviceContext dctxl,PetscDeviceContext dctxr)147089fb57cSJacob Faibussowitsch static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148d71ae5a4SJacob Faibussowitsch {
1490e6b6b59SJacob Faibussowitsch const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
1500e6b6b59SJacob Faibussowitsch
1510e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
1520e6b6b59SJacob Faibussowitsch PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
1530e6b6b59SJacob Faibussowitsch PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
1540e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
1550e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
1560e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
1573ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1580e6b6b59SJacob Faibussowitsch }
1590e6b6b59SJacob Faibussowitsch
check_current_device_(PetscDeviceContext dctx)160089fb57cSJacob Faibussowitsch static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
1610e6b6b59SJacob Faibussowitsch
finalize_()162089fb57cSJacob Faibussowitsch static PetscErrorCode finalize_() noexcept
163d71ae5a4SJacob Faibussowitsch {
16417f48955SJacob Faibussowitsch PetscFunctionBegin;
16517f48955SJacob Faibussowitsch for (auto &&handle : blashandles_) {
16617f48955SJacob Faibussowitsch if (handle) {
1679566063dSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasDestroy(handle));
16817f48955SJacob Faibussowitsch handle = nullptr;
16917f48955SJacob Faibussowitsch }
17017f48955SJacob Faibussowitsch }
17117f48955SJacob Faibussowitsch for (auto &&handle : solverhandles_) {
17217f48955SJacob Faibussowitsch if (handle) {
17396a4b4d9SJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
17417f48955SJacob Faibussowitsch handle = nullptr;
17517f48955SJacob Faibussowitsch }
17617f48955SJacob Faibussowitsch }
17717f48955SJacob Faibussowitsch initialized_ = false;
1783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
17917f48955SJacob Faibussowitsch }
18017f48955SJacob Faibussowitsch
1810e6b6b59SJacob Faibussowitsch template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
default_pool_()182d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PoolType &default_pool_() noexcept
183d71ae5a4SJacob Faibussowitsch {
1840e6b6b59SJacob Faibussowitsch static PoolType pool;
1850e6b6b59SJacob Faibussowitsch return pool;
1860e6b6b59SJacob Faibussowitsch }
187030f984aSJacob Faibussowitsch
check_memtype_(PetscMemType mtype,const char mess[])188089fb57cSJacob Faibussowitsch static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189d71ae5a4SJacob Faibussowitsch {
1900e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
1910e6b6b59SJacob Faibussowitsch PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
1923ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1930e6b6b59SJacob Faibussowitsch }
1940e6b6b59SJacob Faibussowitsch
1950e6b6b59SJacob Faibussowitsch public:
196030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they
197030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on
198089fb57cSJacob Faibussowitsch static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199089fb57cSJacob Faibussowitsch static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200089fb57cSJacob Faibussowitsch static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201089fb57cSJacob Faibussowitsch static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202089fb57cSJacob Faibussowitsch static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203089fb57cSJacob Faibussowitsch static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204a4af0ceeSJacob Faibussowitsch template <typename Handle_t>
205089fb57cSJacob Faibussowitsch static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
20631d47070SJunchao Zhang template <typename Handle_t>
20797cd0981SJacob Faibussowitsch static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208089fb57cSJacob Faibussowitsch static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209089fb57cSJacob Faibussowitsch static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210*5268dc8aSHong Zhang static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211*5268dc8aSHong Zhang static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212*5268dc8aSHong Zhang static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213089fb57cSJacob Faibussowitsch static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214089fb57cSJacob Faibussowitsch static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215089fb57cSJacob Faibussowitsch static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216089fb57cSJacob Faibussowitsch static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217089fb57cSJacob Faibussowitsch static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218089fb57cSJacob Faibussowitsch static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219089fb57cSJacob Faibussowitsch static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
2207a101e5eSJacob Faibussowitsch
2217a101e5eSJacob Faibussowitsch // not a PetscDeviceContext method, this registers the class
222089fb57cSJacob Faibussowitsch static PetscErrorCode initialize(PetscDevice) noexcept;
2230e6b6b59SJacob Faibussowitsch
2240e6b6b59SJacob Faibussowitsch // clang-format off
2256ff55be4SJacob Faibussowitsch static constexpr _DeviceContextOps ops = {
2266ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(destroy, destroy),
2276ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(changestreamtype, changeStreamType),
2286ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(setup, setUp),
2296ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(query, query),
2306ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(waitforcontext, waitForContext),
2316ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(synchronize, synchronize),
2326ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
2336ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
23431d47070SJunchao Zhang PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
2356ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(begintimer, beginTimer),
2366ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(endtimer, endTimer),
237*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238*5268dc8aSHong Zhang PetscDesignatedInitializer(getpower, getPower),
239*5268dc8aSHong Zhang #else
240*5268dc8aSHong Zhang PetscDesignatedInitializer(getpower, nullptr),
241*5268dc8aSHong Zhang #endif
242*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
243*5268dc8aSHong Zhang PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244*5268dc8aSHong Zhang PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245*5268dc8aSHong Zhang #else
246*5268dc8aSHong Zhang PetscDesignatedInitializer(beginenergymeter, nullptr),
247*5268dc8aSHong Zhang PetscDesignatedInitializer(endenergymeter, nullptr),
248*5268dc8aSHong Zhang #endif
2496ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memalloc, memAlloc),
2506ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memfree, memFree),
2516ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memcopy, memCopy),
2526ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memset, memSet),
2536ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(createevent, createEvent),
2546ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(recordevent, recordEvent),
2556ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(waitforevent, waitForEvent)
2560e6b6b59SJacob Faibussowitsch };
2570e6b6b59SJacob Faibussowitsch // clang-format on
258030f984aSJacob Faibussowitsch };
259030f984aSJacob Faibussowitsch
2600e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS
26117f48955SJacob Faibussowitsch template <DeviceType T>
initialize(PetscDevice device)2626d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263d71ae5a4SJacob Faibussowitsch {
2647a101e5eSJacob Faibussowitsch PetscFunctionBegin;
2657a101e5eSJacob Faibussowitsch if (PetscUnlikely(!initialized_)) {
2660e6b6b59SJacob Faibussowitsch uint64_t threshold = UINT64_MAX;
2676d54fb17SJacob Faibussowitsch cupmMemPool_t mempool;
2680e6b6b59SJacob Faibussowitsch
2697a101e5eSJacob Faibussowitsch initialized_ = true;
2706d54fb17SJacob Faibussowitsch PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
2710e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
2720e6b6b59SJacob Faibussowitsch blashandles_.fill(nullptr);
2730e6b6b59SJacob Faibussowitsch solverhandles_.fill(nullptr);
2747a101e5eSJacob Faibussowitsch PetscCall(PetscRegisterFinalize(finalize_));
2757a101e5eSJacob Faibussowitsch }
2763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
2777a101e5eSJacob Faibussowitsch }
2787a101e5eSJacob Faibussowitsch
2797a101e5eSJacob Faibussowitsch template <DeviceType T>
destroy(PetscDeviceContext dctx)2806d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281d71ae5a4SJacob Faibussowitsch {
282030f984aSJacob Faibussowitsch PetscFunctionBegin;
2830e6b6b59SJacob Faibussowitsch if (const auto dci = impls_cast_(dctx)) {
2840e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy());
285146a86ebSJacob Faibussowitsch if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
2869566063dSJacob Faibussowitsch if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
2879566063dSJacob Faibussowitsch if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
2880e6b6b59SJacob Faibussowitsch delete dci;
2890e6b6b59SJacob Faibussowitsch dctx->data = nullptr;
2900e6b6b59SJacob Faibussowitsch }
2913ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
292030f984aSJacob Faibussowitsch }
293030f984aSJacob Faibussowitsch
29417f48955SJacob Faibussowitsch template <DeviceType T>
changeStreamType(PetscDeviceContext dctx,PETSC_UNUSED PetscStreamType stype)2956d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296d71ae5a4SJacob Faibussowitsch {
2977a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx);
298030f984aSJacob Faibussowitsch
299030f984aSJacob Faibussowitsch PetscFunctionBegin;
3000e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy());
301030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again
302030f984aSJacob Faibussowitsch dci->blas = nullptr;
303030f984aSJacob Faibussowitsch dci->solver = nullptr;
3043ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
305030f984aSJacob Faibussowitsch }
306030f984aSJacob Faibussowitsch
30717f48955SJacob Faibussowitsch template <DeviceType T>
setUp(PetscDeviceContext dctx)3086d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309d71ae5a4SJacob Faibussowitsch {
3107a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx);
3110e6b6b59SJacob Faibussowitsch auto &event = dci->event;
312030f984aSJacob Faibussowitsch
313030f984aSJacob Faibussowitsch PetscFunctionBegin;
3140e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
3150e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.change_type(dctx->streamType));
3160e6b6b59SJacob Faibussowitsch if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
318a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE;
319a4af0ceeSJacob Faibussowitsch #endif
3203ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
321030f984aSJacob Faibussowitsch }
322030f984aSJacob Faibussowitsch
32317f48955SJacob Faibussowitsch template <DeviceType T>
query(PetscDeviceContext dctx,PetscBool * idle)3246d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325d71ae5a4SJacob Faibussowitsch {
326030f984aSJacob Faibussowitsch PetscFunctionBegin;
3270e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
3284b955ea4SJacob Faibussowitsch switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329d71ae5a4SJacob Faibussowitsch case cupmSuccess:
330d71ae5a4SJacob Faibussowitsch *idle = PETSC_TRUE;
331d71ae5a4SJacob Faibussowitsch break;
332d71ae5a4SJacob Faibussowitsch case cupmErrorNotReady:
333d71ae5a4SJacob Faibussowitsch *idle = PETSC_FALSE;
3344b955ea4SJacob Faibussowitsch // reset the error
3354b955ea4SJacob Faibussowitsch cerr = cupmGetLastError();
3364b955ea4SJacob Faibussowitsch static_cast<void>(cerr);
337d71ae5a4SJacob Faibussowitsch break;
338d71ae5a4SJacob Faibussowitsch default:
339d71ae5a4SJacob Faibussowitsch PetscCallCUPM(cerr);
340d71ae5a4SJacob Faibussowitsch PetscUnreachable();
341030f984aSJacob Faibussowitsch }
3423ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
343030f984aSJacob Faibussowitsch }
344030f984aSJacob Faibussowitsch
34517f48955SJacob Faibussowitsch template <DeviceType T>
waitForContext(PetscDeviceContext dctxa,PetscDeviceContext dctxb)3466d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347d71ae5a4SJacob Faibussowitsch {
3480e6b6b59SJacob Faibussowitsch const auto dcib = impls_cast_(dctxb);
3490e6b6b59SJacob Faibussowitsch const auto event = dcib->event;
350030f984aSJacob Faibussowitsch
351030f984aSJacob Faibussowitsch PetscFunctionBegin;
3520e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctxa, dctxb));
3530e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
3540e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
3553ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
356030f984aSJacob Faibussowitsch }
357030f984aSJacob Faibussowitsch
35817f48955SJacob Faibussowitsch template <DeviceType T>
synchronize(PetscDeviceContext dctx)3596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360d71ae5a4SJacob Faibussowitsch {
3610e6b6b59SJacob Faibussowitsch auto idle = PETSC_TRUE;
362030f984aSJacob Faibussowitsch
363030f984aSJacob Faibussowitsch PetscFunctionBegin;
3640e6b6b59SJacob Faibussowitsch PetscCall(query(dctx, &idle));
3650e6b6b59SJacob Faibussowitsch if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
3663ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
367030f984aSJacob Faibussowitsch }
368030f984aSJacob Faibussowitsch
36917f48955SJacob Faibussowitsch template <DeviceType T>
37017f48955SJacob Faibussowitsch template <typename handle_t>
getHandle(PetscDeviceContext dctx,void * handle)3716d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372d71ae5a4SJacob Faibussowitsch {
373a4af0ceeSJacob Faibussowitsch PetscFunctionBegin;
3747a101e5eSJacob Faibussowitsch PetscCall(initialize_handle_(handle_t{}, dctx));
3757a101e5eSJacob Faibussowitsch *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
3763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
377a4af0ceeSJacob Faibussowitsch }
378a4af0ceeSJacob Faibussowitsch
37917f48955SJacob Faibussowitsch template <DeviceType T>
38031d47070SJunchao Zhang template <typename handle_t>
getHandlePtr(PetscDeviceContext dctx,void ** handle)38197cd0981SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
38231d47070SJunchao Zhang {
38331d47070SJunchao Zhang using handle_type = typename handle_t::type;
38431d47070SJunchao Zhang
38531d47070SJunchao Zhang PetscFunctionBegin;
38631d47070SJunchao Zhang PetscCall(initialize_handle_(handle_t{}, dctx));
38797cd0981SJacob Faibussowitsch *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
38831d47070SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
38931d47070SJunchao Zhang }
39031d47070SJunchao Zhang
39131d47070SJunchao Zhang template <DeviceType T>
beginTimer(PetscDeviceContext dctx)3926d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393d71ae5a4SJacob Faibussowitsch {
3940e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx);
395a4af0ceeSJacob Faibussowitsch
396a4af0ceeSJacob Faibussowitsch PetscFunctionBegin;
3970e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
398a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
3995f80ce2aSJacob Faibussowitsch PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_TRUE;
401a4af0ceeSJacob Faibussowitsch #endif
40217f48955SJacob Faibussowitsch if (!dci->begin) {
4030e6b6b59SJacob Faibussowitsch PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
4049566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->begin));
4059566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->end));
40617f48955SJacob Faibussowitsch }
4070e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
4083ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
409a4af0ceeSJacob Faibussowitsch }
410a4af0ceeSJacob Faibussowitsch
41117f48955SJacob Faibussowitsch template <DeviceType T>
endTimer(PetscDeviceContext dctx,PetscLogDouble * elapsed)4126d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413d71ae5a4SJacob Faibussowitsch {
414a4af0ceeSJacob Faibussowitsch float gtime;
4150e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx);
4160e6b6b59SJacob Faibussowitsch const auto end = dci->end;
417a4af0ceeSJacob Faibussowitsch
418a4af0ceeSJacob Faibussowitsch PetscFunctionBegin;
4190e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
420a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG)
4215f80ce2aSJacob Faibussowitsch PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE;
423a4af0ceeSJacob Faibussowitsch #endif
4240e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
4250e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventSynchronize(end));
4260e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end));
42717f48955SJacob Faibussowitsch *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
4283ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
429a4af0ceeSJacob Faibussowitsch }
430a4af0ceeSJacob Faibussowitsch
431*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432*5268dc8aSHong Zhang template <DeviceType T>
getPower(PetscDeviceContext dctx,PetscLogDouble * power)433*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434*5268dc8aSHong Zhang {
435*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx);
436*5268dc8aSHong Zhang nvmlFieldValue_t values[1];
437*5268dc8aSHong Zhang
438*5268dc8aSHong Zhang PetscFunctionBegin;
439*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx));
440*5268dc8aSHong Zhang PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441*5268dc8aSHong Zhang values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442*5268dc8aSHong Zhang if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444*5268dc8aSHong Zhang *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS);
446*5268dc8aSHong Zhang }
447*5268dc8aSHong Zhang #endif
448*5268dc8aSHong Zhang
449*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA)
450*5268dc8aSHong Zhang template <DeviceType T>
beginEnergyMeter(PetscDeviceContext dctx)451*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452*5268dc8aSHong Zhang {
453*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx);
454*5268dc8aSHong Zhang
455*5268dc8aSHong Zhang PetscFunctionBegin;
456*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx));
457*5268dc8aSHong Zhang #if PetscDefined(USE_DEBUG)
458*5268dc8aSHong Zhang PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459*5268dc8aSHong Zhang dci->EnergyMeterInUse = PETSC_TRUE;
460*5268dc8aSHong Zhang #endif
461*5268dc8aSHong Zhang if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS);
464*5268dc8aSHong Zhang }
465*5268dc8aSHong Zhang
466*5268dc8aSHong Zhang template <DeviceType T>
endEnergyMeter(PetscDeviceContext dctx,PetscLogDouble * energy)467*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468*5268dc8aSHong Zhang {
469*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx);
470*5268dc8aSHong Zhang
471*5268dc8aSHong Zhang PetscFunctionBegin;
472*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx));
473*5268dc8aSHong Zhang #if PetscDefined(USE_DEBUG)
474*5268dc8aSHong Zhang PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475*5268dc8aSHong Zhang dci->EnergyMeterInUse = PETSC_FALSE;
476*5268dc8aSHong Zhang #endif
477*5268dc8aSHong Zhang PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479*5268dc8aSHong Zhang *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS);
481*5268dc8aSHong Zhang }
482*5268dc8aSHong Zhang #endif
483*5268dc8aSHong Zhang
4840e6b6b59SJacob Faibussowitsch template <DeviceType T>
memAlloc(PetscDeviceContext dctx,PetscBool clear,PetscMemType mtype,std::size_t n,std::size_t alignment,void ** dest)4856d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486d71ae5a4SJacob Faibussowitsch {
4870e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream;
4880e6b6b59SJacob Faibussowitsch
4890e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
4900e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
4910e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "allocating"));
4920e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) {
4936797ed33SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4940e6b6b59SJacob Faibussowitsch } else {
4956797ed33SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
4960e6b6b59SJacob Faibussowitsch }
4976797ed33SJacob Faibussowitsch if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
4983ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
4990e6b6b59SJacob Faibussowitsch }
5000e6b6b59SJacob Faibussowitsch
5010e6b6b59SJacob Faibussowitsch template <DeviceType T>
memFree(PetscDeviceContext dctx,PetscMemType mtype,void ** ptr)5026d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503d71ae5a4SJacob Faibussowitsch {
5040e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream;
5050e6b6b59SJacob Faibussowitsch
5060e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5070e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
5080e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "freeing"));
5093ba16761SJacob Faibussowitsch if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
5100e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) {
5110e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
5120e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it
5130e6b6b59SJacob Faibussowitsch if (*ptr) {
5140e6b6b59SJacob Faibussowitsch auto registered = PETSC_FALSE, managed = PETSC_FALSE;
5150e6b6b59SJacob Faibussowitsch
5160e6b6b59SJacob Faibussowitsch PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed));
5170e6b6b59SJacob Faibussowitsch if (registered) {
5180e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeHost(*ptr));
5190e6b6b59SJacob Faibussowitsch } else if (managed) {
5200e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
5210e6b6b59SJacob Faibussowitsch }
5220e6b6b59SJacob Faibussowitsch }
5230e6b6b59SJacob Faibussowitsch } else {
5240e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
5256d54fb17SJacob Faibussowitsch // if ptr still exists the pool didn't own it
5260e6b6b59SJacob Faibussowitsch if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
5270e6b6b59SJacob Faibussowitsch }
5283ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5290e6b6b59SJacob Faibussowitsch }
5300e6b6b59SJacob Faibussowitsch
5310e6b6b59SJacob Faibussowitsch template <DeviceType T>
memCopy(PetscDeviceContext dctx,void * PETSC_RESTRICT dest,const void * PETSC_RESTRICT src,std::size_t n,PetscDeviceCopyMode mode)5326d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533d71ae5a4SJacob Faibussowitsch {
5340e6b6b59SJacob Faibussowitsch const auto stream = impls_cast_(dctx)->stream.get_stream();
5350e6b6b59SJacob Faibussowitsch
5360e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5370e6b6b59SJacob Faibussowitsch // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
5380e6b6b59SJacob Faibussowitsch if (mode == PETSC_DEVICE_COPY_HTOH) {
5396d54fb17SJacob Faibussowitsch const auto cerr = cupmStreamQuery(stream);
5406d54fb17SJacob Faibussowitsch
5410e6b6b59SJacob Faibussowitsch // yes this is faster
5426d54fb17SJacob Faibussowitsch if (cerr == cupmSuccess) {
5430e6b6b59SJacob Faibussowitsch PetscCall(PetscMemcpy(dest, src, n));
5443ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5456d54fb17SJacob Faibussowitsch } else if (cerr == cupmErrorNotReady) {
5466d54fb17SJacob Faibussowitsch auto PETSC_UNUSED unused = cupmGetLastError();
5476d54fb17SJacob Faibussowitsch
5486d54fb17SJacob Faibussowitsch static_cast<void>(unused);
5496d54fb17SJacob Faibussowitsch } else {
5506d54fb17SJacob Faibussowitsch PetscCallCUPM(cerr);
5510e6b6b59SJacob Faibussowitsch }
5520e6b6b59SJacob Faibussowitsch }
5533ba16761SJacob Faibussowitsch PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
5543ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5550e6b6b59SJacob Faibussowitsch }
5560e6b6b59SJacob Faibussowitsch
5570e6b6b59SJacob Faibussowitsch template <DeviceType T>
memSet(PetscDeviceContext dctx,PetscMemType mtype,void * ptr,PetscInt v,std::size_t n)5586d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559d71ae5a4SJacob Faibussowitsch {
5600e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5610e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx));
5620e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "zeroing"));
5636797ed33SJacob Faibussowitsch PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
5643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5650e6b6b59SJacob Faibussowitsch }
5660e6b6b59SJacob Faibussowitsch
5670e6b6b59SJacob Faibussowitsch template <DeviceType T>
createEvent(PetscDeviceContext,PetscEvent event)5688eb1d50fSPierre Jolivet inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569d71ae5a4SJacob Faibussowitsch {
5700e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5713048253cSJacob Faibussowitsch PetscCallCXX(event->data = new event_type{});
5720e6b6b59SJacob Faibussowitsch event->destroy = [](PetscEvent event) {
5730e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5740e6b6b59SJacob Faibussowitsch delete event_cast_(event);
5750e6b6b59SJacob Faibussowitsch event->data = nullptr;
5763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5770e6b6b59SJacob Faibussowitsch };
5783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5790e6b6b59SJacob Faibussowitsch }
5800e6b6b59SJacob Faibussowitsch
5810e6b6b59SJacob Faibussowitsch template <DeviceType T>
recordEvent(PetscDeviceContext dctx,PetscEvent event)5826d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583d71ae5a4SJacob Faibussowitsch {
5840e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5850e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
5863ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5870e6b6b59SJacob Faibussowitsch }
5880e6b6b59SJacob Faibussowitsch
5890e6b6b59SJacob Faibussowitsch template <DeviceType T>
waitForEvent(PetscDeviceContext dctx,PetscEvent event)5906d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591d71ae5a4SJacob Faibussowitsch {
5920e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
5930e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
5943ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
5950e6b6b59SJacob Faibussowitsch }
5960e6b6b59SJacob Faibussowitsch
597030f984aSJacob Faibussowitsch // initialize the static member variables
5989371c9d4SSatish Balay template <DeviceType T>
5999371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false;
600030f984aSJacob Faibussowitsch
60117f48955SJacob Faibussowitsch template <DeviceType T>
60217f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
603030f984aSJacob Faibussowitsch
60417f48955SJacob Faibussowitsch template <DeviceType T>
60517f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
60617f48955SJacob Faibussowitsch
6076ff55be4SJacob Faibussowitsch template <DeviceType T>
6086ff55be4SJacob Faibussowitsch constexpr _DeviceContextOps DeviceContext<T>::ops;
6096ff55be4SJacob Faibussowitsch
6100e6b6b59SJacob Faibussowitsch } // namespace impl
611030f984aSJacob Faibussowitsch
612a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates)
6130e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
6140e6b6b59SJacob Faibussowitsch using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
615030f984aSJacob Faibussowitsch
616030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name
6170e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
618030f984aSJacob Faibussowitsch
6190e6b6b59SJacob Faibussowitsch } // namespace cupm
62017f48955SJacob Faibussowitsch
6210e6b6b59SJacob Faibussowitsch } // namespace device
62217f48955SJacob Faibussowitsch
62317f48955SJacob Faibussowitsch } // namespace Petsc
624