1a4963045SJacob Faibussowitsch #pragma once 2030f984aSJacob Faibussowitsch 3a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 496a4b4d9SJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp> 57a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h> 6030f984aSJacob Faibussowitsch 70e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp> 8a4af0ceeSJacob Faibussowitsch 90e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp" 100e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp" 110e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp" 120e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp" 130e6b6b59SJacob Faibussowitsch 14d71ae5a4SJacob Faibussowitsch namespace Petsc 15d71ae5a4SJacob Faibussowitsch { 16a4af0ceeSJacob Faibussowitsch 17d71ae5a4SJacob Faibussowitsch namespace device 18d71ae5a4SJacob Faibussowitsch { 1917f48955SJacob Faibussowitsch 20d71ae5a4SJacob Faibussowitsch namespace cupm 21d71ae5a4SJacob Faibussowitsch { 2217f48955SJacob Faibussowitsch 23d71ae5a4SJacob Faibussowitsch namespace impl 24d71ae5a4SJacob Faibussowitsch { 25030f984aSJacob Faibussowitsch 2617f48955SJacob Faibussowitsch template <DeviceType T> 2785f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> { 2817f48955SJacob Faibussowitsch public: 2996a4b4d9SJacob Faibussowitsch PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 3017f48955SJacob Faibussowitsch 3117f48955SJacob Faibussowitsch private: 329371c9d4SSatish Balay template <typename H, std::size_t> 339371c9d4SSatish Balay struct HandleTag { 349371c9d4SSatish Balay using type = H; 359371c9d4SSatish Balay }; 360e6b6b59SJacob Faibussowitsch 377a101e5eSJacob Faibussowitsch using stream_tag = HandleTag<cupmStream_t, 0>; 387a101e5eSJacob Faibussowitsch using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 397a101e5eSJacob Faibussowitsch using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 40a4af0ceeSJacob Faibussowitsch 410e6b6b59SJacob Faibussowitsch using stream_type = CUPMStream<T>; 420e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>; 430e6b6b59SJacob Faibussowitsch 44030f984aSJacob Faibussowitsch public: 45030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of 47030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a 48030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result. 493048253cSJacob Faibussowitsch struct PetscDeviceContext_IMPLS { 500e6b6b59SJacob Faibussowitsch stream_type stream{}; 510e6b6b59SJacob Faibussowitsch cupmEvent_t event{}; 520e6b6b59SJacob Faibussowitsch cupmEvent_t begin{}; // timer-only 530e6b6b59SJacob Faibussowitsch cupmEvent_t end{}; // timer-only 54a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 550e6b6b59SJacob Faibussowitsch PetscBool timerInUse{}; 56*5268dc8aSHong Zhang PetscBool EnergyMeterInUse{}; 57a4af0ceeSJacob Faibussowitsch #endif 580e6b6b59SJacob Faibussowitsch cupmBlasHandle_t blas{}; 590e6b6b59SJacob Faibussowitsch cupmSolverHandle_t solver{}; 60*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA) 61*5268dc8aSHong Zhang nvmlDevice_t nvmlHandle{}; 62*5268dc8aSHong Zhang unsigned long long energymeterbegin{}; 63*5268dc8aSHong Zhang unsigned long long energymeterend{}; 64*5268dc8aSHong Zhang #endif 65a4af0ceeSJacob Faibussowitsch 660e6b6b59SJacob Faibussowitsch constexpr PetscDeviceContext_IMPLS() noexcept = default; 670e6b6b59SJacob Faibussowitsch 6831d47070SJunchao Zhang PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); } 690e6b6b59SJacob Faibussowitsch 7031d47070SJunchao Zhang PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; } 710e6b6b59SJacob Faibussowitsch 7231d47070SJunchao Zhang PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; } 73030f984aSJacob Faibussowitsch }; 74030f984aSJacob Faibussowitsch 75030f984aSJacob Faibussowitsch private: 7617f48955SJacob Faibussowitsch static bool initialized_; 776d54fb17SJacob Faibussowitsch 7817f48955SJacob Faibussowitsch static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 7917f48955SJacob Faibussowitsch static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 80030f984aSJacob Faibussowitsch 81d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 82a4af0ceeSJacob Faibussowitsch 83d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 840e6b6b59SJacob Faibussowitsch 85d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 867a101e5eSJacob Faibussowitsch 87d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 887a101e5eSJacob Faibussowitsch 897a101e5eSJacob Faibussowitsch // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 907a101e5eSJacob Faibussowitsch // handles 91089fb57cSJacob Faibussowitsch static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 927a101e5eSJacob Faibussowitsch 9396a4b4d9SJacob Faibussowitsch static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 94d71ae5a4SJacob Faibussowitsch { 9596a4b4d9SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 9696a4b4d9SJacob Faibussowitsch auto &handle = blashandles_[dctx->device->deviceId]; 977a101e5eSJacob Faibussowitsch 98030f984aSJacob Faibussowitsch PetscFunctionBegin; 9996a4b4d9SJacob Faibussowitsch if (!handle) { 100b665b14eSToby Isaac PetscCall(PetscLogEventsPause()); 1017a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 10217f48955SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) { 10396a4b4d9SJacob Faibussowitsch const auto cberr = cupmBlasCreate(handle.ptr_to()); 10417f48955SJacob Faibussowitsch if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 1059566063dSJacob Faibussowitsch if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 10617f48955SJacob Faibussowitsch if (i != 2) { 1079566063dSJacob Faibussowitsch PetscCall(PetscSleep(3)); 10817f48955SJacob Faibussowitsch continue; 109a4af0ceeSJacob Faibussowitsch } 1105f80ce2aSJacob Faibussowitsch PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 111a4af0ceeSJacob Faibussowitsch } 1127a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 113b665b14eSToby Isaac PetscCall(PetscLogEventsResume()); 114030f984aSJacob Faibussowitsch } 1150e6b6b59SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 1167a101e5eSJacob Faibussowitsch dci->blas = handle; 1173ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1187a101e5eSJacob Faibussowitsch } 1197a101e5eSJacob Faibussowitsch 120089fb57cSJacob Faibussowitsch static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 121d71ae5a4SJacob Faibussowitsch { 1226d54fb17SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1236d54fb17SJacob Faibussowitsch auto &handle = solverhandles_[dctx->device->deviceId]; 1247a101e5eSJacob Faibussowitsch 1257a101e5eSJacob Faibussowitsch PetscFunctionBegin; 12696a4b4d9SJacob Faibussowitsch if (!handle) { 127b665b14eSToby Isaac PetscCall(PetscLogEventsPause()); 1287a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 12996a4b4d9SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) { 13096a4b4d9SJacob Faibussowitsch const auto cerr = cupmSolverCreate(&handle); 13196a4b4d9SJacob Faibussowitsch if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 13296a4b4d9SJacob Faibussowitsch if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 13396a4b4d9SJacob Faibussowitsch if (i < 2) { 13496a4b4d9SJacob Faibussowitsch PetscCall(PetscSleep(3)); 13596a4b4d9SJacob Faibussowitsch continue; 13696a4b4d9SJacob Faibussowitsch } 13796a4b4d9SJacob Faibussowitsch PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 13896a4b4d9SJacob Faibussowitsch } 1397a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 140b665b14eSToby Isaac PetscCall(PetscLogEventsResume()); 14196a4b4d9SJacob Faibussowitsch } 14296a4b4d9SJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 1437a101e5eSJacob Faibussowitsch dci->solver = handle; 1443ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 14517f48955SJacob Faibussowitsch } 14617f48955SJacob Faibussowitsch 147089fb57cSJacob Faibussowitsch static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 148d71ae5a4SJacob Faibussowitsch { 1490e6b6b59SJacob Faibussowitsch const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 1500e6b6b59SJacob Faibussowitsch 1510e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1520e6b6b59SJacob Faibussowitsch PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 1530e6b6b59SJacob Faibussowitsch PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 1540e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 1550e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 1560e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 1573ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1580e6b6b59SJacob Faibussowitsch } 1590e6b6b59SJacob Faibussowitsch 160089fb57cSJacob Faibussowitsch static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 1610e6b6b59SJacob Faibussowitsch 162089fb57cSJacob Faibussowitsch static PetscErrorCode finalize_() noexcept 163d71ae5a4SJacob Faibussowitsch { 16417f48955SJacob Faibussowitsch PetscFunctionBegin; 16517f48955SJacob Faibussowitsch for (auto &&handle : blashandles_) { 16617f48955SJacob Faibussowitsch if (handle) { 1679566063dSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 16817f48955SJacob Faibussowitsch handle = nullptr; 16917f48955SJacob Faibussowitsch } 17017f48955SJacob Faibussowitsch } 17117f48955SJacob Faibussowitsch for (auto &&handle : solverhandles_) { 17217f48955SJacob Faibussowitsch if (handle) { 17396a4b4d9SJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 17417f48955SJacob Faibussowitsch handle = nullptr; 17517f48955SJacob Faibussowitsch } 17617f48955SJacob Faibussowitsch } 17717f48955SJacob Faibussowitsch initialized_ = false; 1783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17917f48955SJacob Faibussowitsch } 18017f48955SJacob Faibussowitsch 1810e6b6b59SJacob Faibussowitsch template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 182d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PoolType &default_pool_() noexcept 183d71ae5a4SJacob Faibussowitsch { 1840e6b6b59SJacob Faibussowitsch static PoolType pool; 1850e6b6b59SJacob Faibussowitsch return pool; 1860e6b6b59SJacob Faibussowitsch } 187030f984aSJacob Faibussowitsch 188089fb57cSJacob Faibussowitsch static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 189d71ae5a4SJacob Faibussowitsch { 1900e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1910e6b6b59SJacob Faibussowitsch PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 1923ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1930e6b6b59SJacob Faibussowitsch } 1940e6b6b59SJacob Faibussowitsch 1950e6b6b59SJacob Faibussowitsch public: 196030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they 197030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on 198089fb57cSJacob Faibussowitsch static PetscErrorCode destroy(PetscDeviceContext) noexcept; 199089fb57cSJacob Faibussowitsch static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 200089fb57cSJacob Faibussowitsch static PetscErrorCode setUp(PetscDeviceContext) noexcept; 201089fb57cSJacob Faibussowitsch static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 202089fb57cSJacob Faibussowitsch static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 203089fb57cSJacob Faibussowitsch static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 204a4af0ceeSJacob Faibussowitsch template <typename Handle_t> 205089fb57cSJacob Faibussowitsch static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 20631d47070SJunchao Zhang template <typename Handle_t> 20797cd0981SJacob Faibussowitsch static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept; 208089fb57cSJacob Faibussowitsch static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 209089fb57cSJacob Faibussowitsch static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 210*5268dc8aSHong Zhang static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept; 211*5268dc8aSHong Zhang static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept; 212*5268dc8aSHong Zhang static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept; 213089fb57cSJacob Faibussowitsch static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 214089fb57cSJacob Faibussowitsch static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 215089fb57cSJacob Faibussowitsch static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 216089fb57cSJacob Faibussowitsch static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 217089fb57cSJacob Faibussowitsch static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 218089fb57cSJacob Faibussowitsch static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 219089fb57cSJacob Faibussowitsch static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 2207a101e5eSJacob Faibussowitsch 2217a101e5eSJacob Faibussowitsch // not a PetscDeviceContext method, this registers the class 222089fb57cSJacob Faibussowitsch static PetscErrorCode initialize(PetscDevice) noexcept; 2230e6b6b59SJacob Faibussowitsch 2240e6b6b59SJacob Faibussowitsch // clang-format off 2256ff55be4SJacob Faibussowitsch static constexpr _DeviceContextOps ops = { 2266ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(destroy, destroy), 2276ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(changestreamtype, changeStreamType), 2286ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(setup, setUp), 2296ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(query, query), 2306ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(waitforcontext, waitForContext), 2316ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(synchronize, synchronize), 2326ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 2336ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 23431d47070SJunchao Zhang PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>), 2356ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(begintimer, beginTimer), 2366ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(endtimer, endTimer), 237*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS) 238*5268dc8aSHong Zhang PetscDesignatedInitializer(getpower, getPower), 239*5268dc8aSHong Zhang #else 240*5268dc8aSHong Zhang PetscDesignatedInitializer(getpower, nullptr), 241*5268dc8aSHong Zhang #endif 242*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA) 243*5268dc8aSHong Zhang PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter), 244*5268dc8aSHong Zhang PetscDesignatedInitializer(endenergymeter, endEnergyMeter), 245*5268dc8aSHong Zhang #else 246*5268dc8aSHong Zhang PetscDesignatedInitializer(beginenergymeter, nullptr), 247*5268dc8aSHong Zhang PetscDesignatedInitializer(endenergymeter, nullptr), 248*5268dc8aSHong Zhang #endif 2496ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memalloc, memAlloc), 2506ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memfree, memFree), 2516ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memcopy, memCopy), 2526ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(memset, memSet), 2536ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(createevent, createEvent), 2546ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(recordevent, recordEvent), 2556ff55be4SJacob Faibussowitsch PetscDesignatedInitializer(waitforevent, waitForEvent) 2560e6b6b59SJacob Faibussowitsch }; 2570e6b6b59SJacob Faibussowitsch // clang-format on 258030f984aSJacob Faibussowitsch }; 259030f984aSJacob Faibussowitsch 2600e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS 26117f48955SJacob Faibussowitsch template <DeviceType T> 2626d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 263d71ae5a4SJacob Faibussowitsch { 2647a101e5eSJacob Faibussowitsch PetscFunctionBegin; 2657a101e5eSJacob Faibussowitsch if (PetscUnlikely(!initialized_)) { 2660e6b6b59SJacob Faibussowitsch uint64_t threshold = UINT64_MAX; 2676d54fb17SJacob Faibussowitsch cupmMemPool_t mempool; 2680e6b6b59SJacob Faibussowitsch 2697a101e5eSJacob Faibussowitsch initialized_ = true; 2706d54fb17SJacob Faibussowitsch PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 2710e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 2720e6b6b59SJacob Faibussowitsch blashandles_.fill(nullptr); 2730e6b6b59SJacob Faibussowitsch solverhandles_.fill(nullptr); 2747a101e5eSJacob Faibussowitsch PetscCall(PetscRegisterFinalize(finalize_)); 2757a101e5eSJacob Faibussowitsch } 2763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2777a101e5eSJacob Faibussowitsch } 2787a101e5eSJacob Faibussowitsch 2797a101e5eSJacob Faibussowitsch template <DeviceType T> 2806d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 281d71ae5a4SJacob Faibussowitsch { 282030f984aSJacob Faibussowitsch PetscFunctionBegin; 2830e6b6b59SJacob Faibussowitsch if (const auto dci = impls_cast_(dctx)) { 2840e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 285146a86ebSJacob Faibussowitsch if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 2869566063dSJacob Faibussowitsch if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 2879566063dSJacob Faibussowitsch if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 2880e6b6b59SJacob Faibussowitsch delete dci; 2890e6b6b59SJacob Faibussowitsch dctx->data = nullptr; 2900e6b6b59SJacob Faibussowitsch } 2913ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 292030f984aSJacob Faibussowitsch } 293030f984aSJacob Faibussowitsch 29417f48955SJacob Faibussowitsch template <DeviceType T> 2956d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 296d71ae5a4SJacob Faibussowitsch { 2977a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 298030f984aSJacob Faibussowitsch 299030f984aSJacob Faibussowitsch PetscFunctionBegin; 3000e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 301030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again 302030f984aSJacob Faibussowitsch dci->blas = nullptr; 303030f984aSJacob Faibussowitsch dci->solver = nullptr; 3043ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 305030f984aSJacob Faibussowitsch } 306030f984aSJacob Faibussowitsch 30717f48955SJacob Faibussowitsch template <DeviceType T> 3086d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 309d71ae5a4SJacob Faibussowitsch { 3107a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 3110e6b6b59SJacob Faibussowitsch auto &event = dci->event; 312030f984aSJacob Faibussowitsch 313030f984aSJacob Faibussowitsch PetscFunctionBegin; 3140e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3150e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.change_type(dctx->streamType)); 3160e6b6b59SJacob Faibussowitsch if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 317a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 318a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 319a4af0ceeSJacob Faibussowitsch #endif 3203ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 321030f984aSJacob Faibussowitsch } 322030f984aSJacob Faibussowitsch 32317f48955SJacob Faibussowitsch template <DeviceType T> 3246d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 325d71ae5a4SJacob Faibussowitsch { 326030f984aSJacob Faibussowitsch PetscFunctionBegin; 3270e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3284b955ea4SJacob Faibussowitsch switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 329d71ae5a4SJacob Faibussowitsch case cupmSuccess: 330d71ae5a4SJacob Faibussowitsch *idle = PETSC_TRUE; 331d71ae5a4SJacob Faibussowitsch break; 332d71ae5a4SJacob Faibussowitsch case cupmErrorNotReady: 333d71ae5a4SJacob Faibussowitsch *idle = PETSC_FALSE; 3344b955ea4SJacob Faibussowitsch // reset the error 3354b955ea4SJacob Faibussowitsch cerr = cupmGetLastError(); 3364b955ea4SJacob Faibussowitsch static_cast<void>(cerr); 337d71ae5a4SJacob Faibussowitsch break; 338d71ae5a4SJacob Faibussowitsch default: 339d71ae5a4SJacob Faibussowitsch PetscCallCUPM(cerr); 340d71ae5a4SJacob Faibussowitsch PetscUnreachable(); 341030f984aSJacob Faibussowitsch } 3423ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 343030f984aSJacob Faibussowitsch } 344030f984aSJacob Faibussowitsch 34517f48955SJacob Faibussowitsch template <DeviceType T> 3466d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 347d71ae5a4SJacob Faibussowitsch { 3480e6b6b59SJacob Faibussowitsch const auto dcib = impls_cast_(dctxb); 3490e6b6b59SJacob Faibussowitsch const auto event = dcib->event; 350030f984aSJacob Faibussowitsch 351030f984aSJacob Faibussowitsch PetscFunctionBegin; 3520e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctxa, dctxb)); 3530e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 3540e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 3553ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 356030f984aSJacob Faibussowitsch } 357030f984aSJacob Faibussowitsch 35817f48955SJacob Faibussowitsch template <DeviceType T> 3596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 360d71ae5a4SJacob Faibussowitsch { 3610e6b6b59SJacob Faibussowitsch auto idle = PETSC_TRUE; 362030f984aSJacob Faibussowitsch 363030f984aSJacob Faibussowitsch PetscFunctionBegin; 3640e6b6b59SJacob Faibussowitsch PetscCall(query(dctx, &idle)); 3650e6b6b59SJacob Faibussowitsch if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 3663ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 367030f984aSJacob Faibussowitsch } 368030f984aSJacob Faibussowitsch 36917f48955SJacob Faibussowitsch template <DeviceType T> 37017f48955SJacob Faibussowitsch template <typename handle_t> 3716d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 372d71ae5a4SJacob Faibussowitsch { 373a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3747a101e5eSJacob Faibussowitsch PetscCall(initialize_handle_(handle_t{}, dctx)); 3757a101e5eSJacob Faibussowitsch *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 3763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 377a4af0ceeSJacob Faibussowitsch } 378a4af0ceeSJacob Faibussowitsch 37917f48955SJacob Faibussowitsch template <DeviceType T> 38031d47070SJunchao Zhang template <typename handle_t> 38197cd0981SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept 38231d47070SJunchao Zhang { 38331d47070SJunchao Zhang using handle_type = typename handle_t::type; 38431d47070SJunchao Zhang 38531d47070SJunchao Zhang PetscFunctionBegin; 38631d47070SJunchao Zhang PetscCall(initialize_handle_(handle_t{}, dctx)); 38797cd0981SJacob Faibussowitsch *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{}))); 38831d47070SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 38931d47070SJunchao Zhang } 39031d47070SJunchao Zhang 39131d47070SJunchao Zhang template <DeviceType T> 3926d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 393d71ae5a4SJacob Faibussowitsch { 3940e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 395a4af0ceeSJacob Faibussowitsch 396a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3970e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 398a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3995f80ce2aSJacob Faibussowitsch PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 400a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_TRUE; 401a4af0ceeSJacob Faibussowitsch #endif 40217f48955SJacob Faibussowitsch if (!dci->begin) { 4030e6b6b59SJacob Faibussowitsch PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 4049566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->begin)); 4059566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->end)); 40617f48955SJacob Faibussowitsch } 4070e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 4083ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 409a4af0ceeSJacob Faibussowitsch } 410a4af0ceeSJacob Faibussowitsch 41117f48955SJacob Faibussowitsch template <DeviceType T> 4126d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 413d71ae5a4SJacob Faibussowitsch { 414a4af0ceeSJacob Faibussowitsch float gtime; 4150e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 4160e6b6b59SJacob Faibussowitsch const auto end = dci->end; 417a4af0ceeSJacob Faibussowitsch 418a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 4190e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 420a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 4215f80ce2aSJacob Faibussowitsch PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 422a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 423a4af0ceeSJacob Faibussowitsch #endif 4240e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 4250e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventSynchronize(end)); 4260e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 42717f48955SJacob Faibussowitsch *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 4283ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 429a4af0ceeSJacob Faibussowitsch } 430a4af0ceeSJacob Faibussowitsch 431*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS) 432*5268dc8aSHong Zhang template <DeviceType T> 433*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept 434*5268dc8aSHong Zhang { 435*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx); 436*5268dc8aSHong Zhang nvmlFieldValue_t values[1]; 437*5268dc8aSHong Zhang 438*5268dc8aSHong Zhang PetscFunctionBegin; 439*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx)); 440*5268dc8aSHong Zhang PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream())); 441*5268dc8aSHong Zhang values[0].fieldId = NVML_FI_DEV_POWER_INSTANT; 442*5268dc8aSHong Zhang if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle)); 443*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values)); 444*5268dc8aSHong Zhang *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal); 445*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS); 446*5268dc8aSHong Zhang } 447*5268dc8aSHong Zhang #endif 448*5268dc8aSHong Zhang 449*5268dc8aSHong Zhang #if PetscDefined(HAVE_CUDA) 450*5268dc8aSHong Zhang template <DeviceType T> 451*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept 452*5268dc8aSHong Zhang { 453*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx); 454*5268dc8aSHong Zhang 455*5268dc8aSHong Zhang PetscFunctionBegin; 456*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx)); 457*5268dc8aSHong Zhang #if PetscDefined(USE_DEBUG) 458*5268dc8aSHong Zhang PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?"); 459*5268dc8aSHong Zhang dci->EnergyMeterInUse = PETSC_TRUE; 460*5268dc8aSHong Zhang #endif 461*5268dc8aSHong Zhang if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle)); 462*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin)); 463*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS); 464*5268dc8aSHong Zhang } 465*5268dc8aSHong Zhang 466*5268dc8aSHong Zhang template <DeviceType T> 467*5268dc8aSHong Zhang inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept 468*5268dc8aSHong Zhang { 469*5268dc8aSHong Zhang const auto dci = impls_cast_(dctx); 470*5268dc8aSHong Zhang 471*5268dc8aSHong Zhang PetscFunctionBegin; 472*5268dc8aSHong Zhang PetscCall(check_current_device_(dctx)); 473*5268dc8aSHong Zhang #if PetscDefined(USE_DEBUG) 474*5268dc8aSHong Zhang PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?"); 475*5268dc8aSHong Zhang dci->EnergyMeterInUse = PETSC_FALSE; 476*5268dc8aSHong Zhang #endif 477*5268dc8aSHong Zhang PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream())); 478*5268dc8aSHong Zhang PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend)); 479*5268dc8aSHong Zhang *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule 480*5268dc8aSHong Zhang PetscFunctionReturn(PETSC_SUCCESS); 481*5268dc8aSHong Zhang } 482*5268dc8aSHong Zhang #endif 483*5268dc8aSHong Zhang 4840e6b6b59SJacob Faibussowitsch template <DeviceType T> 4856d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 486d71ae5a4SJacob Faibussowitsch { 4870e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 4880e6b6b59SJacob Faibussowitsch 4890e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4900e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4910e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "allocating")); 4920e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 4936797ed33SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4940e6b6b59SJacob Faibussowitsch } else { 4956797ed33SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4960e6b6b59SJacob Faibussowitsch } 4976797ed33SJacob Faibussowitsch if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 4983ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4990e6b6b59SJacob Faibussowitsch } 5000e6b6b59SJacob Faibussowitsch 5010e6b6b59SJacob Faibussowitsch template <DeviceType T> 5026d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 503d71ae5a4SJacob Faibussowitsch { 5040e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 5050e6b6b59SJacob Faibussowitsch 5060e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5070e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 5080e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "freeing")); 5093ba16761SJacob Faibussowitsch if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 5100e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 5110e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 5120e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 5130e6b6b59SJacob Faibussowitsch if (*ptr) { 5140e6b6b59SJacob Faibussowitsch auto registered = PETSC_FALSE, managed = PETSC_FALSE; 5150e6b6b59SJacob Faibussowitsch 5160e6b6b59SJacob Faibussowitsch PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 5170e6b6b59SJacob Faibussowitsch if (registered) { 5180e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeHost(*ptr)); 5190e6b6b59SJacob Faibussowitsch } else if (managed) { 5200e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 5210e6b6b59SJacob Faibussowitsch } 5220e6b6b59SJacob Faibussowitsch } 5230e6b6b59SJacob Faibussowitsch } else { 5240e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 5256d54fb17SJacob Faibussowitsch // if ptr still exists the pool didn't own it 5260e6b6b59SJacob Faibussowitsch if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 5270e6b6b59SJacob Faibussowitsch } 5283ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5290e6b6b59SJacob Faibussowitsch } 5300e6b6b59SJacob Faibussowitsch 5310e6b6b59SJacob Faibussowitsch template <DeviceType T> 5326d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 533d71ae5a4SJacob Faibussowitsch { 5340e6b6b59SJacob Faibussowitsch const auto stream = impls_cast_(dctx)->stream.get_stream(); 5350e6b6b59SJacob Faibussowitsch 5360e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5370e6b6b59SJacob Faibussowitsch // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 5380e6b6b59SJacob Faibussowitsch if (mode == PETSC_DEVICE_COPY_HTOH) { 5396d54fb17SJacob Faibussowitsch const auto cerr = cupmStreamQuery(stream); 5406d54fb17SJacob Faibussowitsch 5410e6b6b59SJacob Faibussowitsch // yes this is faster 5426d54fb17SJacob Faibussowitsch if (cerr == cupmSuccess) { 5430e6b6b59SJacob Faibussowitsch PetscCall(PetscMemcpy(dest, src, n)); 5443ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5456d54fb17SJacob Faibussowitsch } else if (cerr == cupmErrorNotReady) { 5466d54fb17SJacob Faibussowitsch auto PETSC_UNUSED unused = cupmGetLastError(); 5476d54fb17SJacob Faibussowitsch 5486d54fb17SJacob Faibussowitsch static_cast<void>(unused); 5496d54fb17SJacob Faibussowitsch } else { 5506d54fb17SJacob Faibussowitsch PetscCallCUPM(cerr); 5510e6b6b59SJacob Faibussowitsch } 5520e6b6b59SJacob Faibussowitsch } 5533ba16761SJacob Faibussowitsch PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 5543ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5550e6b6b59SJacob Faibussowitsch } 5560e6b6b59SJacob Faibussowitsch 5570e6b6b59SJacob Faibussowitsch template <DeviceType T> 5586d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 559d71ae5a4SJacob Faibussowitsch { 5600e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5610e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 5620e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "zeroing")); 5636797ed33SJacob Faibussowitsch PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 5643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5650e6b6b59SJacob Faibussowitsch } 5660e6b6b59SJacob Faibussowitsch 5670e6b6b59SJacob Faibussowitsch template <DeviceType T> 5688eb1d50fSPierre Jolivet inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 569d71ae5a4SJacob Faibussowitsch { 5700e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5713048253cSJacob Faibussowitsch PetscCallCXX(event->data = new event_type{}); 5720e6b6b59SJacob Faibussowitsch event->destroy = [](PetscEvent event) { 5730e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5740e6b6b59SJacob Faibussowitsch delete event_cast_(event); 5750e6b6b59SJacob Faibussowitsch event->data = nullptr; 5763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5770e6b6b59SJacob Faibussowitsch }; 5783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5790e6b6b59SJacob Faibussowitsch } 5800e6b6b59SJacob Faibussowitsch 5810e6b6b59SJacob Faibussowitsch template <DeviceType T> 5826d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 583d71ae5a4SJacob Faibussowitsch { 5840e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5850e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 5863ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5870e6b6b59SJacob Faibussowitsch } 5880e6b6b59SJacob Faibussowitsch 5890e6b6b59SJacob Faibussowitsch template <DeviceType T> 5906d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 591d71ae5a4SJacob Faibussowitsch { 5920e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5930e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 5943ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5950e6b6b59SJacob Faibussowitsch } 5960e6b6b59SJacob Faibussowitsch 597030f984aSJacob Faibussowitsch // initialize the static member variables 5989371c9d4SSatish Balay template <DeviceType T> 5999371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false; 600030f984aSJacob Faibussowitsch 60117f48955SJacob Faibussowitsch template <DeviceType T> 60217f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 603030f984aSJacob Faibussowitsch 60417f48955SJacob Faibussowitsch template <DeviceType T> 60517f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 60617f48955SJacob Faibussowitsch 6076ff55be4SJacob Faibussowitsch template <DeviceType T> 6086ff55be4SJacob Faibussowitsch constexpr _DeviceContextOps DeviceContext<T>::ops; 6096ff55be4SJacob Faibussowitsch 6100e6b6b59SJacob Faibussowitsch } // namespace impl 611030f984aSJacob Faibussowitsch 612a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates) 6130e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 6140e6b6b59SJacob Faibussowitsch using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 615030f984aSJacob Faibussowitsch 616030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name 6170e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 618030f984aSJacob Faibussowitsch 6190e6b6b59SJacob Faibussowitsch } // namespace cupm 62017f48955SJacob Faibussowitsch 6210e6b6b59SJacob Faibussowitsch } // namespace device 62217f48955SJacob Faibussowitsch 62317f48955SJacob Faibussowitsch } // namespace Petsc 624