117f48955SJacob Faibussowitsch #ifndef PETSCDEVICECONTEXTCUPM_HPP 2030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP 3030f984aSJacob Faibussowitsch 4a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 517f48955SJacob Faibussowitsch #include <petsc/private/cupmblasinterface.hpp> 67a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h> 7030f984aSJacob Faibussowitsch 80e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp> 9a4af0ceeSJacob Faibussowitsch 100e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp" 110e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp" 120e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp" 130e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp" 140e6b6b59SJacob Faibussowitsch 150e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 16*6797ed33SJacob Faibussowitsch 179371c9d4SSatish Balay namespace Petsc { 18a4af0ceeSJacob Faibussowitsch 190e6b6b59SJacob Faibussowitsch namespace device { 2017f48955SJacob Faibussowitsch 210e6b6b59SJacob Faibussowitsch namespace cupm { 2217f48955SJacob Faibussowitsch 230e6b6b59SJacob Faibussowitsch namespace impl { 24030f984aSJacob Faibussowitsch 2517f48955SJacob Faibussowitsch template <DeviceType T> 260e6b6b59SJacob Faibussowitsch class DeviceContext : BlasInterface<T> { 2717f48955SJacob Faibussowitsch public: 2817f48955SJacob Faibussowitsch PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 2917f48955SJacob Faibussowitsch 3017f48955SJacob Faibussowitsch private: 319371c9d4SSatish Balay template <typename H, std::size_t> 329371c9d4SSatish Balay struct HandleTag { 339371c9d4SSatish Balay using type = H; 349371c9d4SSatish Balay }; 350e6b6b59SJacob Faibussowitsch 367a101e5eSJacob Faibussowitsch using stream_tag = HandleTag<cupmStream_t, 0>; 377a101e5eSJacob Faibussowitsch using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 387a101e5eSJacob Faibussowitsch using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 39a4af0ceeSJacob Faibussowitsch 400e6b6b59SJacob Faibussowitsch using stream_type = CUPMStream<T>; 410e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>; 420e6b6b59SJacob Faibussowitsch 43030f984aSJacob Faibussowitsch public: 44030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 45030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of 46030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a 47030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result. 480e6b6b59SJacob Faibussowitsch struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 490e6b6b59SJacob Faibussowitsch stream_type stream{}; 500e6b6b59SJacob Faibussowitsch cupmEvent_t event{}; 510e6b6b59SJacob Faibussowitsch cupmEvent_t begin{}; // timer-only 520e6b6b59SJacob Faibussowitsch cupmEvent_t end{}; // timer-only 53a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 540e6b6b59SJacob Faibussowitsch PetscBool timerInUse{}; 55a4af0ceeSJacob Faibussowitsch #endif 560e6b6b59SJacob Faibussowitsch cupmBlasHandle_t blas{}; 570e6b6b59SJacob Faibussowitsch cupmSolverHandle_t solver{}; 58a4af0ceeSJacob Faibussowitsch 590e6b6b59SJacob Faibussowitsch constexpr PetscDeviceContext_IMPLS() noexcept = default; 600e6b6b59SJacob Faibussowitsch 610e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { 620e6b6b59SJacob Faibussowitsch return this->stream.get_stream(); 639371c9d4SSatish Balay } 640e6b6b59SJacob Faibussowitsch 650e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { 669371c9d4SSatish Balay return this->blas; 679371c9d4SSatish Balay } 680e6b6b59SJacob Faibussowitsch 690e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { 709371c9d4SSatish Balay return this->solver; 719371c9d4SSatish Balay } 72030f984aSJacob Faibussowitsch }; 73030f984aSJacob Faibussowitsch 74030f984aSJacob Faibussowitsch private: 7517f48955SJacob Faibussowitsch static bool initialized_; 7617f48955SJacob Faibussowitsch static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 7717f48955SJacob Faibussowitsch static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 78030f984aSJacob Faibussowitsch 790e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { 80a4af0ceeSJacob Faibussowitsch return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); 81a4af0ceeSJacob Faibussowitsch } 82a4af0ceeSJacob Faibussowitsch 830e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { 840e6b6b59SJacob Faibussowitsch return static_cast<CUPMEvent<T> *>(event->data); 850e6b6b59SJacob Faibussowitsch } 860e6b6b59SJacob Faibussowitsch 870e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { 887a101e5eSJacob Faibussowitsch return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 897a101e5eSJacob Faibussowitsch } 907a101e5eSJacob Faibussowitsch 910e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { 927a101e5eSJacob Faibussowitsch return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 937a101e5eSJacob Faibussowitsch } 947a101e5eSJacob Faibussowitsch 957a101e5eSJacob Faibussowitsch // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 967a101e5eSJacob Faibussowitsch // handles 979371c9d4SSatish Balay PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { 989371c9d4SSatish Balay return 0; 999371c9d4SSatish Balay } 1007a101e5eSJacob Faibussowitsch 1010e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode create_handle_(cupmBlasHandle_t &handle) noexcept { 1027a101e5eSJacob Faibussowitsch PetscLogEvent event; 1037a101e5eSJacob Faibussowitsch 104030f984aSJacob Faibussowitsch PetscFunctionBegin; 1057a101e5eSJacob Faibussowitsch if (PetscLikely(handle)) PetscFunctionReturn(0); 1067a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 1077a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 10817f48955SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) { 10917f48955SJacob Faibussowitsch auto cberr = cupmBlasCreate(&handle); 11017f48955SJacob Faibussowitsch if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 1119566063dSJacob Faibussowitsch if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 11217f48955SJacob Faibussowitsch if (i != 2) { 1139566063dSJacob Faibussowitsch PetscCall(PetscSleep(3)); 11417f48955SJacob Faibussowitsch continue; 115a4af0ceeSJacob Faibussowitsch } 1165f80ce2aSJacob Faibussowitsch PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 117a4af0ceeSJacob Faibussowitsch } 1187a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 1197a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 120030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 121030f984aSJacob Faibussowitsch } 122030f984aSJacob Faibussowitsch 1230e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept { 1247a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1257a101e5eSJacob Faibussowitsch auto &handle = blashandles_[dctx->device->deviceId]; 12617f48955SJacob Faibussowitsch 12717f48955SJacob Faibussowitsch PetscFunctionBegin; 1287a101e5eSJacob Faibussowitsch PetscCall(create_handle_(handle)); 1290e6b6b59SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 1307a101e5eSJacob Faibussowitsch dci->blas = handle; 1317a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 1327a101e5eSJacob Faibussowitsch } 1337a101e5eSJacob Faibussowitsch 1349371c9d4SSatish Balay PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { 1357a101e5eSJacob Faibussowitsch PetscLogEvent event; 1367a101e5eSJacob Faibussowitsch 1377a101e5eSJacob Faibussowitsch PetscFunctionBegin; 1387a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 1397a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1407a101e5eSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 1417a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1427a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 1437a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 1447a101e5eSJacob Faibussowitsch } 1457a101e5eSJacob Faibussowitsch 1460e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept { 1477a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1487a101e5eSJacob Faibussowitsch auto &handle = solverhandles_[dctx->device->deviceId]; 1497a101e5eSJacob Faibussowitsch 1507a101e5eSJacob Faibussowitsch PetscFunctionBegin; 1517a101e5eSJacob Faibussowitsch PetscCall(create_handle_(handle)); 1520e6b6b59SJacob Faibussowitsch PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 1537a101e5eSJacob Faibussowitsch dci->solver = handle; 15417f48955SJacob Faibussowitsch PetscFunctionReturn(0); 15517f48955SJacob Faibussowitsch } 15617f48955SJacob Faibussowitsch 1570e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept { 1580e6b6b59SJacob Faibussowitsch const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 1590e6b6b59SJacob Faibussowitsch 1600e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1610e6b6b59SJacob Faibussowitsch PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 1620e6b6b59SJacob Faibussowitsch PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 1630e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 1640e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 1650e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 1660e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1670e6b6b59SJacob Faibussowitsch } 1680e6b6b59SJacob Faibussowitsch 1690e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { 1700e6b6b59SJacob Faibussowitsch return check_current_device_(dctx, dctx); 1710e6b6b59SJacob Faibussowitsch } 1720e6b6b59SJacob Faibussowitsch 1730e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode finalize_() noexcept { 17417f48955SJacob Faibussowitsch PetscFunctionBegin; 17517f48955SJacob Faibussowitsch for (auto &&handle : blashandles_) { 17617f48955SJacob Faibussowitsch if (handle) { 1779566063dSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 17817f48955SJacob Faibussowitsch handle = nullptr; 17917f48955SJacob Faibussowitsch } 18017f48955SJacob Faibussowitsch } 18117f48955SJacob Faibussowitsch for (auto &&handle : solverhandles_) { 18217f48955SJacob Faibussowitsch if (handle) { 1839566063dSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 18417f48955SJacob Faibussowitsch handle = nullptr; 18517f48955SJacob Faibussowitsch } 18617f48955SJacob Faibussowitsch } 18717f48955SJacob Faibussowitsch initialized_ = false; 18817f48955SJacob Faibussowitsch PetscFunctionReturn(0); 18917f48955SJacob Faibussowitsch } 19017f48955SJacob Faibussowitsch 1910e6b6b59SJacob Faibussowitsch template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 1920e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PoolType &default_pool_() noexcept { 1930e6b6b59SJacob Faibussowitsch static PoolType pool; 1940e6b6b59SJacob Faibussowitsch return pool; 1950e6b6b59SJacob Faibussowitsch } 196030f984aSJacob Faibussowitsch 1970e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept { 1980e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1990e6b6b59SJacob Faibussowitsch PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 2000e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 2010e6b6b59SJacob Faibussowitsch } 2020e6b6b59SJacob Faibussowitsch 2030e6b6b59SJacob Faibussowitsch public: 204030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they 205030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on 20617f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 20717f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 20817f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 20917f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 21017f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 21117f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 212a4af0ceeSJacob Faibussowitsch template <typename Handle_t> 21317f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 21417f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 21517f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 216*6797ed33SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **)); 2170e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **)); 2180e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode)); 2190e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t)); 2200e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent)); 2210e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent)); 2220e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent)); 2237a101e5eSJacob Faibussowitsch 2247a101e5eSJacob Faibussowitsch // not a PetscDeviceContext method, this registers the class 2257a101e5eSJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 2260e6b6b59SJacob Faibussowitsch 2270e6b6b59SJacob Faibussowitsch // clang-format off 2280e6b6b59SJacob Faibussowitsch const _DeviceContextOps ops = { 2290e6b6b59SJacob Faibussowitsch destroy, 2300e6b6b59SJacob Faibussowitsch changeStreamType, 2310e6b6b59SJacob Faibussowitsch setUp, 2320e6b6b59SJacob Faibussowitsch query, 2330e6b6b59SJacob Faibussowitsch waitForContext, 2340e6b6b59SJacob Faibussowitsch synchronize, 2350e6b6b59SJacob Faibussowitsch getHandle<blas_tag>, 2360e6b6b59SJacob Faibussowitsch getHandle<solver_tag>, 2370e6b6b59SJacob Faibussowitsch getHandle<stream_tag>, 2380e6b6b59SJacob Faibussowitsch beginTimer, 2390e6b6b59SJacob Faibussowitsch endTimer, 2400e6b6b59SJacob Faibussowitsch memAlloc, 2410e6b6b59SJacob Faibussowitsch memFree, 2420e6b6b59SJacob Faibussowitsch memCopy, 2430e6b6b59SJacob Faibussowitsch memSet, 2440e6b6b59SJacob Faibussowitsch createEvent, 2450e6b6b59SJacob Faibussowitsch recordEvent, 2460e6b6b59SJacob Faibussowitsch waitForEvent 2470e6b6b59SJacob Faibussowitsch }; 2480e6b6b59SJacob Faibussowitsch // clang-format on 249030f984aSJacob Faibussowitsch }; 250030f984aSJacob Faibussowitsch 2510e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS 25217f48955SJacob Faibussowitsch template <DeviceType T> 2539371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) { 2547a101e5eSJacob Faibussowitsch PetscFunctionBegin; 2557a101e5eSJacob Faibussowitsch if (PetscUnlikely(!initialized_)) { 2560e6b6b59SJacob Faibussowitsch cupmMemPool_t mempool; 2570e6b6b59SJacob Faibussowitsch uint64_t threshold = UINT64_MAX; 2580e6b6b59SJacob Faibussowitsch 2597a101e5eSJacob Faibussowitsch initialized_ = true; 2600e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0)); 2610e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 2620e6b6b59SJacob Faibussowitsch blashandles_.fill(nullptr); 2630e6b6b59SJacob Faibussowitsch solverhandles_.fill(nullptr); 2647a101e5eSJacob Faibussowitsch PetscCall(PetscRegisterFinalize(finalize_)); 2657a101e5eSJacob Faibussowitsch } 2667a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 2677a101e5eSJacob Faibussowitsch } 2687a101e5eSJacob Faibussowitsch 2697a101e5eSJacob Faibussowitsch template <DeviceType T> 2709371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) { 271030f984aSJacob Faibussowitsch PetscFunctionBegin; 2720e6b6b59SJacob Faibussowitsch if (const auto dci = impls_cast_(dctx)) { 2730e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 2740e6b6b59SJacob Faibussowitsch if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 2759566063dSJacob Faibussowitsch if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 2769566063dSJacob Faibussowitsch if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 2770e6b6b59SJacob Faibussowitsch delete dci; 2780e6b6b59SJacob Faibussowitsch dctx->data = nullptr; 2790e6b6b59SJacob Faibussowitsch } 280030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 281030f984aSJacob Faibussowitsch } 282030f984aSJacob Faibussowitsch 28317f48955SJacob Faibussowitsch template <DeviceType T> 2849371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { 2857a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 286030f984aSJacob Faibussowitsch 287030f984aSJacob Faibussowitsch PetscFunctionBegin; 2880e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 289030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again 290030f984aSJacob Faibussowitsch dci->blas = nullptr; 291030f984aSJacob Faibussowitsch dci->solver = nullptr; 292030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 293030f984aSJacob Faibussowitsch } 294030f984aSJacob Faibussowitsch 29517f48955SJacob Faibussowitsch template <DeviceType T> 2969371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) { 2977a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 2980e6b6b59SJacob Faibussowitsch auto &event = dci->event; 299030f984aSJacob Faibussowitsch 300030f984aSJacob Faibussowitsch PetscFunctionBegin; 3010e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3020e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.change_type(dctx->streamType)); 3030e6b6b59SJacob Faibussowitsch if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 304a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 305a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 306a4af0ceeSJacob Faibussowitsch #endif 307030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 308030f984aSJacob Faibussowitsch } 309030f984aSJacob Faibussowitsch 31017f48955SJacob Faibussowitsch template <DeviceType T> 3119371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) { 312030f984aSJacob Faibussowitsch PetscFunctionBegin; 3130e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3140e6b6b59SJacob Faibussowitsch switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 3150e6b6b59SJacob Faibussowitsch case cupmSuccess: *idle = PETSC_TRUE; break; 3160e6b6b59SJacob Faibussowitsch case cupmErrorNotReady: *idle = PETSC_FALSE; break; 3170e6b6b59SJacob Faibussowitsch default: PetscCallCUPM(cerr); PetscUnreachable(); 318030f984aSJacob Faibussowitsch } 319030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 320030f984aSJacob Faibussowitsch } 321030f984aSJacob Faibussowitsch 32217f48955SJacob Faibussowitsch template <DeviceType T> 3239371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { 3240e6b6b59SJacob Faibussowitsch const auto dcib = impls_cast_(dctxb); 3250e6b6b59SJacob Faibussowitsch const auto event = dcib->event; 326030f984aSJacob Faibussowitsch 327030f984aSJacob Faibussowitsch PetscFunctionBegin; 3280e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctxa, dctxb)); 3290e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 3300e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 331030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 332030f984aSJacob Faibussowitsch } 333030f984aSJacob Faibussowitsch 33417f48955SJacob Faibussowitsch template <DeviceType T> 3359371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) { 3360e6b6b59SJacob Faibussowitsch auto idle = PETSC_TRUE; 337030f984aSJacob Faibussowitsch 338030f984aSJacob Faibussowitsch PetscFunctionBegin; 3390e6b6b59SJacob Faibussowitsch PetscCall(query(dctx, &idle)); 3400e6b6b59SJacob Faibussowitsch if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 341030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 342030f984aSJacob Faibussowitsch } 343030f984aSJacob Faibussowitsch 34417f48955SJacob Faibussowitsch template <DeviceType T> 34517f48955SJacob Faibussowitsch template <typename handle_t> 3469371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) { 347a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3487a101e5eSJacob Faibussowitsch PetscCall(initialize_handle_(handle_t{}, dctx)); 3497a101e5eSJacob Faibussowitsch *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 350a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 351a4af0ceeSJacob Faibussowitsch } 352a4af0ceeSJacob Faibussowitsch 35317f48955SJacob Faibussowitsch template <DeviceType T> 3549371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) { 3550e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 356a4af0ceeSJacob Faibussowitsch 357a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3580e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 359a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3605f80ce2aSJacob Faibussowitsch PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 361a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_TRUE; 362a4af0ceeSJacob Faibussowitsch #endif 36317f48955SJacob Faibussowitsch if (!dci->begin) { 3640e6b6b59SJacob Faibussowitsch PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 3659566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->begin)); 3669566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->end)); 36717f48955SJacob Faibussowitsch } 3680e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 369a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 370a4af0ceeSJacob Faibussowitsch } 371a4af0ceeSJacob Faibussowitsch 37217f48955SJacob Faibussowitsch template <DeviceType T> 3739371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { 374a4af0ceeSJacob Faibussowitsch float gtime; 3750e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 3760e6b6b59SJacob Faibussowitsch const auto end = dci->end; 377a4af0ceeSJacob Faibussowitsch 378a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3790e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 380a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3815f80ce2aSJacob Faibussowitsch PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 382a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 383a4af0ceeSJacob Faibussowitsch #endif 3840e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 3850e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventSynchronize(end)); 3860e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 38717f48955SJacob Faibussowitsch *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 388a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 389a4af0ceeSJacob Faibussowitsch } 390a4af0ceeSJacob Faibussowitsch 3910e6b6b59SJacob Faibussowitsch template <DeviceType T> 392*6797ed33SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest)) { 3930e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 3940e6b6b59SJacob Faibussowitsch 3950e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 3960e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3970e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "allocating")); 3980e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 399*6797ed33SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4000e6b6b59SJacob Faibussowitsch } else { 401*6797ed33SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4020e6b6b59SJacob Faibussowitsch } 403*6797ed33SJacob Faibussowitsch if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 4040e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4050e6b6b59SJacob Faibussowitsch } 4060e6b6b59SJacob Faibussowitsch 4070e6b6b59SJacob Faibussowitsch template <DeviceType T> 4080e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) { 4090e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 4100e6b6b59SJacob Faibussowitsch 4110e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4120e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4130e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "freeing")); 4140e6b6b59SJacob Faibussowitsch if (!*ptr) PetscFunctionReturn(0); 4150e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 4160e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 4170e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 4180e6b6b59SJacob Faibussowitsch if (*ptr) { 4190e6b6b59SJacob Faibussowitsch auto registered = PETSC_FALSE, managed = PETSC_FALSE; 4200e6b6b59SJacob Faibussowitsch 4210e6b6b59SJacob Faibussowitsch PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 4220e6b6b59SJacob Faibussowitsch if (registered) { 4230e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeHost(*ptr)); 4240e6b6b59SJacob Faibussowitsch } else if (managed) { 4250e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 4260e6b6b59SJacob Faibussowitsch } 4270e6b6b59SJacob Faibussowitsch } 4280e6b6b59SJacob Faibussowitsch } else { 4290e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 4300e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 4310e6b6b59SJacob Faibussowitsch if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 4320e6b6b59SJacob Faibussowitsch } 4330e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4340e6b6b59SJacob Faibussowitsch } 4350e6b6b59SJacob Faibussowitsch 4360e6b6b59SJacob Faibussowitsch template <DeviceType T> 4370e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) { 4380e6b6b59SJacob Faibussowitsch const auto stream = impls_cast_(dctx)->stream.get_stream(); 4390e6b6b59SJacob Faibussowitsch 4400e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4410e6b6b59SJacob Faibussowitsch // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 4420e6b6b59SJacob Faibussowitsch if (mode == PETSC_DEVICE_COPY_HTOH) { 4430e6b6b59SJacob Faibussowitsch // yes this is faster 4440e6b6b59SJacob Faibussowitsch if (cupmStreamQuery(stream) == cupmSuccess) { 4450e6b6b59SJacob Faibussowitsch PetscCall(PetscMemcpy(dest, src, n)); 4460e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4470e6b6b59SJacob Faibussowitsch } 4480e6b6b59SJacob Faibussowitsch // in case cupmStreamQuery() did not return cupmErrorNotReady 4490e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmGetLastError()); 4500e6b6b59SJacob Faibussowitsch } 4510e6b6b59SJacob Faibussowitsch PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 4520e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4530e6b6b59SJacob Faibussowitsch } 4540e6b6b59SJacob Faibussowitsch 4550e6b6b59SJacob Faibussowitsch template <DeviceType T> 4560e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) { 4570e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4580e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4590e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "zeroing")); 460*6797ed33SJacob Faibussowitsch PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 4610e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4620e6b6b59SJacob Faibussowitsch } 4630e6b6b59SJacob Faibussowitsch 4640e6b6b59SJacob Faibussowitsch template <DeviceType T> 4650e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) { 4660e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4670e6b6b59SJacob Faibussowitsch PetscCallCXX(event->data = new event_type()); 4680e6b6b59SJacob Faibussowitsch event->destroy = [](PetscEvent event) { 4690e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4700e6b6b59SJacob Faibussowitsch delete event_cast_(event); 4710e6b6b59SJacob Faibussowitsch event->data = nullptr; 4720e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4730e6b6b59SJacob Faibussowitsch }; 4740e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4750e6b6b59SJacob Faibussowitsch } 4760e6b6b59SJacob Faibussowitsch 4770e6b6b59SJacob Faibussowitsch template <DeviceType T> 4780e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) { 4790e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4800e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 4810e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4820e6b6b59SJacob Faibussowitsch } 4830e6b6b59SJacob Faibussowitsch 4840e6b6b59SJacob Faibussowitsch template <DeviceType T> 4850e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) { 4860e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4870e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 4880e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4890e6b6b59SJacob Faibussowitsch } 4900e6b6b59SJacob Faibussowitsch 491030f984aSJacob Faibussowitsch // initialize the static member variables 4929371c9d4SSatish Balay template <DeviceType T> 4939371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false; 494030f984aSJacob Faibussowitsch 49517f48955SJacob Faibussowitsch template <DeviceType T> 49617f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 497030f984aSJacob Faibussowitsch 49817f48955SJacob Faibussowitsch template <DeviceType T> 49917f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 50017f48955SJacob Faibussowitsch 5010e6b6b59SJacob Faibussowitsch } // namespace impl 502030f984aSJacob Faibussowitsch 503a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates) 5040e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 5050e6b6b59SJacob Faibussowitsch using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 506030f984aSJacob Faibussowitsch 507030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name 5080e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 509030f984aSJacob Faibussowitsch 5100e6b6b59SJacob Faibussowitsch } // namespace cupm 51117f48955SJacob Faibussowitsch 5120e6b6b59SJacob Faibussowitsch } // namespace device 51317f48955SJacob Faibussowitsch 51417f48955SJacob Faibussowitsch } // namespace Petsc 515030f984aSJacob Faibussowitsch 5160e6b6b59SJacob Faibussowitsch #endif // __cplusplus 5170e6b6b59SJacob Faibussowitsch 518a4af0ceeSJacob Faibussowitsch #endif // PETSCDEVICECONTEXTCUDA_HPP 519