117f48955SJacob Faibussowitsch #ifndef PETSCDEVICECONTEXTCUPM_HPP 2030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP 3030f984aSJacob Faibussowitsch 4a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 517f48955SJacob Faibussowitsch #include <petsc/private/cupmblasinterface.hpp> 67a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h> 7030f984aSJacob Faibussowitsch 80e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp> 9a4af0ceeSJacob Faibussowitsch 100e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp" 110e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp" 120e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp" 130e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp" 140e6b6b59SJacob Faibussowitsch 150e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 166797ed33SJacob Faibussowitsch 17d71ae5a4SJacob Faibussowitsch namespace Petsc 18d71ae5a4SJacob Faibussowitsch { 19a4af0ceeSJacob Faibussowitsch 20d71ae5a4SJacob Faibussowitsch namespace device 21d71ae5a4SJacob Faibussowitsch { 2217f48955SJacob Faibussowitsch 23d71ae5a4SJacob Faibussowitsch namespace cupm 24d71ae5a4SJacob Faibussowitsch { 2517f48955SJacob Faibussowitsch 26d71ae5a4SJacob Faibussowitsch namespace impl 27d71ae5a4SJacob Faibussowitsch { 28030f984aSJacob Faibussowitsch 2917f48955SJacob Faibussowitsch template <DeviceType T> 300e6b6b59SJacob Faibussowitsch class DeviceContext : BlasInterface<T> { 3117f48955SJacob Faibussowitsch public: 3217f48955SJacob Faibussowitsch PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 3317f48955SJacob Faibussowitsch 3417f48955SJacob Faibussowitsch private: 359371c9d4SSatish Balay template <typename H, std::size_t> 369371c9d4SSatish Balay struct HandleTag { 379371c9d4SSatish Balay using type = H; 389371c9d4SSatish Balay }; 390e6b6b59SJacob Faibussowitsch 407a101e5eSJacob Faibussowitsch using stream_tag = HandleTag<cupmStream_t, 0>; 417a101e5eSJacob Faibussowitsch using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 427a101e5eSJacob Faibussowitsch using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 43a4af0ceeSJacob Faibussowitsch 440e6b6b59SJacob Faibussowitsch using stream_type = CUPMStream<T>; 450e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>; 460e6b6b59SJacob Faibussowitsch 47030f984aSJacob Faibussowitsch public: 48030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 49030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of 50030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a 51030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result. 520e6b6b59SJacob Faibussowitsch struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 530e6b6b59SJacob Faibussowitsch stream_type stream{}; 540e6b6b59SJacob Faibussowitsch cupmEvent_t event{}; 550e6b6b59SJacob Faibussowitsch cupmEvent_t begin{}; // timer-only 560e6b6b59SJacob Faibussowitsch cupmEvent_t end{}; // timer-only 57a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 580e6b6b59SJacob Faibussowitsch PetscBool timerInUse{}; 59a4af0ceeSJacob Faibussowitsch #endif 600e6b6b59SJacob Faibussowitsch cupmBlasHandle_t blas{}; 610e6b6b59SJacob Faibussowitsch cupmSolverHandle_t solver{}; 62a4af0ceeSJacob Faibussowitsch 630e6b6b59SJacob Faibussowitsch constexpr PetscDeviceContext_IMPLS() noexcept = default; 640e6b6b59SJacob Faibussowitsch 65d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); } 660e6b6b59SJacob Faibussowitsch 67d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; } 680e6b6b59SJacob Faibussowitsch 69d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; } 70030f984aSJacob Faibussowitsch }; 71030f984aSJacob Faibussowitsch 72030f984aSJacob Faibussowitsch private: 7317f48955SJacob Faibussowitsch static bool initialized_; 746d54fb17SJacob Faibussowitsch 7517f48955SJacob Faibussowitsch static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 7617f48955SJacob Faibussowitsch static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77030f984aSJacob Faibussowitsch 78d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 79a4af0ceeSJacob Faibussowitsch 80d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 810e6b6b59SJacob Faibussowitsch 82d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 837a101e5eSJacob Faibussowitsch 84d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 857a101e5eSJacob Faibussowitsch 867a101e5eSJacob Faibussowitsch // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 877a101e5eSJacob Faibussowitsch // handles 886d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return 0; } 897a101e5eSJacob Faibussowitsch 9047d993e7Ssuyashtn PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept 91d71ae5a4SJacob Faibussowitsch { 927a101e5eSJacob Faibussowitsch PetscLogEvent event; 937a101e5eSJacob Faibussowitsch 94030f984aSJacob Faibussowitsch PetscFunctionBegin; 957a101e5eSJacob Faibussowitsch if (PetscLikely(handle)) PetscFunctionReturn(0); 967a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 977a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 9817f48955SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) { 9917f48955SJacob Faibussowitsch auto cberr = cupmBlasCreate(&handle); 10017f48955SJacob Faibussowitsch if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 1019566063dSJacob Faibussowitsch if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 10217f48955SJacob Faibussowitsch if (i != 2) { 1039566063dSJacob Faibussowitsch PetscCall(PetscSleep(3)); 10417f48955SJacob Faibussowitsch continue; 105a4af0ceeSJacob Faibussowitsch } 1065f80ce2aSJacob Faibussowitsch PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 107a4af0ceeSJacob Faibussowitsch } 1087a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 1097a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 110030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 111030f984aSJacob Faibussowitsch } 112030f984aSJacob Faibussowitsch 11347d993e7Ssuyashtn PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept 114d71ae5a4SJacob Faibussowitsch { 1157a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1167a101e5eSJacob Faibussowitsch auto &handle = blashandles_[dctx->device->deviceId]; 11717f48955SJacob Faibussowitsch 11817f48955SJacob Faibussowitsch PetscFunctionBegin; 11947d993e7Ssuyashtn PetscCall(create_handle_(tag, handle)); 1200e6b6b59SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 1217a101e5eSJacob Faibussowitsch dci->blas = handle; 1227a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 1237a101e5eSJacob Faibussowitsch } 1247a101e5eSJacob Faibussowitsch 1256d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 126d71ae5a4SJacob Faibussowitsch { 1276d54fb17SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1286d54fb17SJacob Faibussowitsch auto &handle = solverhandles_[dctx->device->deviceId]; 1297a101e5eSJacob Faibussowitsch PetscLogEvent event; 1307a101e5eSJacob Faibussowitsch 1317a101e5eSJacob Faibussowitsch PetscFunctionBegin; 1327a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 1337a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1347a101e5eSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 1357a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1367a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 1370e6b6b59SJacob Faibussowitsch PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 1387a101e5eSJacob Faibussowitsch dci->solver = handle; 13917f48955SJacob Faibussowitsch PetscFunctionReturn(0); 14017f48955SJacob Faibussowitsch } 14117f48955SJacob Faibussowitsch 142d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 143d71ae5a4SJacob Faibussowitsch { 1440e6b6b59SJacob Faibussowitsch const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 1450e6b6b59SJacob Faibussowitsch 1460e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1470e6b6b59SJacob Faibussowitsch PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 1480e6b6b59SJacob Faibussowitsch PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 1490e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 1500e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 1510e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 1520e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1530e6b6b59SJacob Faibussowitsch } 1540e6b6b59SJacob Faibussowitsch 155d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 1560e6b6b59SJacob Faibussowitsch 157d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode finalize_() noexcept 158d71ae5a4SJacob Faibussowitsch { 15917f48955SJacob Faibussowitsch PetscFunctionBegin; 16017f48955SJacob Faibussowitsch for (auto &&handle : blashandles_) { 16117f48955SJacob Faibussowitsch if (handle) { 1629566063dSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 16317f48955SJacob Faibussowitsch handle = nullptr; 16417f48955SJacob Faibussowitsch } 16517f48955SJacob Faibussowitsch } 1666d54fb17SJacob Faibussowitsch 16717f48955SJacob Faibussowitsch for (auto &&handle : solverhandles_) { 16817f48955SJacob Faibussowitsch if (handle) { 1699566063dSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 17017f48955SJacob Faibussowitsch handle = nullptr; 17117f48955SJacob Faibussowitsch } 17217f48955SJacob Faibussowitsch } 17317f48955SJacob Faibussowitsch initialized_ = false; 17417f48955SJacob Faibussowitsch PetscFunctionReturn(0); 17517f48955SJacob Faibussowitsch } 17617f48955SJacob Faibussowitsch 1770e6b6b59SJacob Faibussowitsch template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 178d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PoolType &default_pool_() noexcept 179d71ae5a4SJacob Faibussowitsch { 1800e6b6b59SJacob Faibussowitsch static PoolType pool; 1810e6b6b59SJacob Faibussowitsch return pool; 1820e6b6b59SJacob Faibussowitsch } 183030f984aSJacob Faibussowitsch 184d71ae5a4SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 185d71ae5a4SJacob Faibussowitsch { 1860e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1870e6b6b59SJacob Faibussowitsch PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 1880e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1890e6b6b59SJacob Faibussowitsch } 1900e6b6b59SJacob Faibussowitsch 1910e6b6b59SJacob Faibussowitsch public: 192030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they 193030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on 1946d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; 1956d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 1966d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; 1976d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 1986d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 1996d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 200a4af0ceeSJacob Faibussowitsch template <typename Handle_t> 2016d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 2026d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 2036d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 2046d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 2056d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 2066d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 2076d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 2086d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 2096d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 2106d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 2117a101e5eSJacob Faibussowitsch 2127a101e5eSJacob Faibussowitsch // not a PetscDeviceContext method, this registers the class 2136d54fb17SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize(PetscDevice) noexcept; 2140e6b6b59SJacob Faibussowitsch 2150e6b6b59SJacob Faibussowitsch // clang-format off 2160e6b6b59SJacob Faibussowitsch const _DeviceContextOps ops = { 2170e6b6b59SJacob Faibussowitsch destroy, 2180e6b6b59SJacob Faibussowitsch changeStreamType, 2190e6b6b59SJacob Faibussowitsch setUp, 2200e6b6b59SJacob Faibussowitsch query, 2210e6b6b59SJacob Faibussowitsch waitForContext, 2220e6b6b59SJacob Faibussowitsch synchronize, 2230e6b6b59SJacob Faibussowitsch getHandle<blas_tag>, 2240e6b6b59SJacob Faibussowitsch getHandle<solver_tag>, 2250e6b6b59SJacob Faibussowitsch getHandle<stream_tag>, 2260e6b6b59SJacob Faibussowitsch beginTimer, 2270e6b6b59SJacob Faibussowitsch endTimer, 2280e6b6b59SJacob Faibussowitsch memAlloc, 2290e6b6b59SJacob Faibussowitsch memFree, 2300e6b6b59SJacob Faibussowitsch memCopy, 2310e6b6b59SJacob Faibussowitsch memSet, 2320e6b6b59SJacob Faibussowitsch createEvent, 2330e6b6b59SJacob Faibussowitsch recordEvent, 2340e6b6b59SJacob Faibussowitsch waitForEvent 2350e6b6b59SJacob Faibussowitsch }; 2360e6b6b59SJacob Faibussowitsch // clang-format on 237030f984aSJacob Faibussowitsch }; 238030f984aSJacob Faibussowitsch 2390e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS 24017f48955SJacob Faibussowitsch template <DeviceType T> 2416d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 242d71ae5a4SJacob Faibussowitsch { 2437a101e5eSJacob Faibussowitsch PetscFunctionBegin; 2447a101e5eSJacob Faibussowitsch if (PetscUnlikely(!initialized_)) { 2450e6b6b59SJacob Faibussowitsch uint64_t threshold = UINT64_MAX; 2466d54fb17SJacob Faibussowitsch cupmMemPool_t mempool; 2470e6b6b59SJacob Faibussowitsch 2487a101e5eSJacob Faibussowitsch initialized_ = true; 2496d54fb17SJacob Faibussowitsch PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 2500e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 2510e6b6b59SJacob Faibussowitsch blashandles_.fill(nullptr); 2520e6b6b59SJacob Faibussowitsch solverhandles_.fill(nullptr); 2537a101e5eSJacob Faibussowitsch PetscCall(PetscRegisterFinalize(finalize_)); 2547a101e5eSJacob Faibussowitsch } 2557a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 2567a101e5eSJacob Faibussowitsch } 2577a101e5eSJacob Faibussowitsch 2587a101e5eSJacob Faibussowitsch template <DeviceType T> 2596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 260d71ae5a4SJacob Faibussowitsch { 261030f984aSJacob Faibussowitsch PetscFunctionBegin; 2620e6b6b59SJacob Faibussowitsch if (const auto dci = impls_cast_(dctx)) { 2630e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 2640e6b6b59SJacob Faibussowitsch if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 2659566063dSJacob Faibussowitsch if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 2669566063dSJacob Faibussowitsch if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 2670e6b6b59SJacob Faibussowitsch delete dci; 2680e6b6b59SJacob Faibussowitsch dctx->data = nullptr; 2690e6b6b59SJacob Faibussowitsch } 270030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 271030f984aSJacob Faibussowitsch } 272030f984aSJacob Faibussowitsch 27317f48955SJacob Faibussowitsch template <DeviceType T> 2746d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 275d71ae5a4SJacob Faibussowitsch { 2767a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 277030f984aSJacob Faibussowitsch 278030f984aSJacob Faibussowitsch PetscFunctionBegin; 2790e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 280030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again 281030f984aSJacob Faibussowitsch dci->blas = nullptr; 282030f984aSJacob Faibussowitsch dci->solver = nullptr; 283030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 284030f984aSJacob Faibussowitsch } 285030f984aSJacob Faibussowitsch 28617f48955SJacob Faibussowitsch template <DeviceType T> 2876d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 288d71ae5a4SJacob Faibussowitsch { 2897a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 2900e6b6b59SJacob Faibussowitsch auto &event = dci->event; 291030f984aSJacob Faibussowitsch 292030f984aSJacob Faibussowitsch PetscFunctionBegin; 2930e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 2940e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.change_type(dctx->streamType)); 2950e6b6b59SJacob Faibussowitsch if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 296a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 297a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 298a4af0ceeSJacob Faibussowitsch #endif 299030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 300030f984aSJacob Faibussowitsch } 301030f984aSJacob Faibussowitsch 30217f48955SJacob Faibussowitsch template <DeviceType T> 3036d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 304d71ae5a4SJacob Faibussowitsch { 305030f984aSJacob Faibussowitsch PetscFunctionBegin; 3060e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 3074b955ea4SJacob Faibussowitsch switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 308d71ae5a4SJacob Faibussowitsch case cupmSuccess: 309d71ae5a4SJacob Faibussowitsch *idle = PETSC_TRUE; 310d71ae5a4SJacob Faibussowitsch break; 311d71ae5a4SJacob Faibussowitsch case cupmErrorNotReady: 312d71ae5a4SJacob Faibussowitsch *idle = PETSC_FALSE; 3134b955ea4SJacob Faibussowitsch // reset the error 3144b955ea4SJacob Faibussowitsch cerr = cupmGetLastError(); 3154b955ea4SJacob Faibussowitsch static_cast<void>(cerr); 316d71ae5a4SJacob Faibussowitsch break; 317d71ae5a4SJacob Faibussowitsch default: 318d71ae5a4SJacob Faibussowitsch PetscCallCUPM(cerr); 319d71ae5a4SJacob Faibussowitsch PetscUnreachable(); 320030f984aSJacob Faibussowitsch } 321030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 322030f984aSJacob Faibussowitsch } 323030f984aSJacob Faibussowitsch 32417f48955SJacob Faibussowitsch template <DeviceType T> 3256d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 326d71ae5a4SJacob Faibussowitsch { 3270e6b6b59SJacob Faibussowitsch const auto dcib = impls_cast_(dctxb); 3280e6b6b59SJacob Faibussowitsch const auto event = dcib->event; 329030f984aSJacob Faibussowitsch 330030f984aSJacob Faibussowitsch PetscFunctionBegin; 3310e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctxa, dctxb)); 3320e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 3330e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 334030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 335030f984aSJacob Faibussowitsch } 336030f984aSJacob Faibussowitsch 33717f48955SJacob Faibussowitsch template <DeviceType T> 3386d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 339d71ae5a4SJacob Faibussowitsch { 3400e6b6b59SJacob Faibussowitsch auto idle = PETSC_TRUE; 341030f984aSJacob Faibussowitsch 342030f984aSJacob Faibussowitsch PetscFunctionBegin; 3430e6b6b59SJacob Faibussowitsch PetscCall(query(dctx, &idle)); 3440e6b6b59SJacob Faibussowitsch if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 345030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 346030f984aSJacob Faibussowitsch } 347030f984aSJacob Faibussowitsch 34817f48955SJacob Faibussowitsch template <DeviceType T> 34917f48955SJacob Faibussowitsch template <typename handle_t> 3506d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 351d71ae5a4SJacob Faibussowitsch { 352a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3537a101e5eSJacob Faibussowitsch PetscCall(initialize_handle_(handle_t{}, dctx)); 3547a101e5eSJacob Faibussowitsch *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 355a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 356a4af0ceeSJacob Faibussowitsch } 357a4af0ceeSJacob Faibussowitsch 35817f48955SJacob Faibussowitsch template <DeviceType T> 3596d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 360d71ae5a4SJacob Faibussowitsch { 3610e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 362a4af0ceeSJacob Faibussowitsch 363a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3640e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 365a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3665f80ce2aSJacob Faibussowitsch PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 367a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_TRUE; 368a4af0ceeSJacob Faibussowitsch #endif 36917f48955SJacob Faibussowitsch if (!dci->begin) { 3700e6b6b59SJacob Faibussowitsch PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 3719566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->begin)); 3729566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->end)); 37317f48955SJacob Faibussowitsch } 3740e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 375a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 376a4af0ceeSJacob Faibussowitsch } 377a4af0ceeSJacob Faibussowitsch 37817f48955SJacob Faibussowitsch template <DeviceType T> 3796d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 380d71ae5a4SJacob Faibussowitsch { 381a4af0ceeSJacob Faibussowitsch float gtime; 3820e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 3830e6b6b59SJacob Faibussowitsch const auto end = dci->end; 384a4af0ceeSJacob Faibussowitsch 385a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3860e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 387a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3885f80ce2aSJacob Faibussowitsch PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 389a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 390a4af0ceeSJacob Faibussowitsch #endif 3910e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 3920e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventSynchronize(end)); 3930e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 39417f48955SJacob Faibussowitsch *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 395a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 396a4af0ceeSJacob Faibussowitsch } 397a4af0ceeSJacob Faibussowitsch 3980e6b6b59SJacob Faibussowitsch template <DeviceType T> 3996d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 400d71ae5a4SJacob Faibussowitsch { 4010e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 4020e6b6b59SJacob Faibussowitsch 4030e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4040e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4050e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "allocating")); 4060e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 4076797ed33SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4080e6b6b59SJacob Faibussowitsch } else { 4096797ed33SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 4100e6b6b59SJacob Faibussowitsch } 4116797ed33SJacob Faibussowitsch if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 4120e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4130e6b6b59SJacob Faibussowitsch } 4140e6b6b59SJacob Faibussowitsch 4150e6b6b59SJacob Faibussowitsch template <DeviceType T> 4166d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 417d71ae5a4SJacob Faibussowitsch { 4180e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 4190e6b6b59SJacob Faibussowitsch 4200e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4210e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4220e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "freeing")); 4230e6b6b59SJacob Faibussowitsch if (!*ptr) PetscFunctionReturn(0); 4240e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 4250e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 4260e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 4270e6b6b59SJacob Faibussowitsch if (*ptr) { 4280e6b6b59SJacob Faibussowitsch auto registered = PETSC_FALSE, managed = PETSC_FALSE; 4290e6b6b59SJacob Faibussowitsch 4300e6b6b59SJacob Faibussowitsch PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 4310e6b6b59SJacob Faibussowitsch if (registered) { 4320e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeHost(*ptr)); 4330e6b6b59SJacob Faibussowitsch } else if (managed) { 4340e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 4350e6b6b59SJacob Faibussowitsch } 4360e6b6b59SJacob Faibussowitsch } 4370e6b6b59SJacob Faibussowitsch } else { 4380e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 4396d54fb17SJacob Faibussowitsch // if ptr still exists the pool didn't own it 4400e6b6b59SJacob Faibussowitsch if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 4410e6b6b59SJacob Faibussowitsch } 4420e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4430e6b6b59SJacob Faibussowitsch } 4440e6b6b59SJacob Faibussowitsch 4450e6b6b59SJacob Faibussowitsch template <DeviceType T> 4466d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 447d71ae5a4SJacob Faibussowitsch { 4480e6b6b59SJacob Faibussowitsch const auto stream = impls_cast_(dctx)->stream.get_stream(); 4490e6b6b59SJacob Faibussowitsch 4500e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4510e6b6b59SJacob Faibussowitsch // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 4520e6b6b59SJacob Faibussowitsch if (mode == PETSC_DEVICE_COPY_HTOH) { 4536d54fb17SJacob Faibussowitsch const auto cerr = cupmStreamQuery(stream); 4546d54fb17SJacob Faibussowitsch 4550e6b6b59SJacob Faibussowitsch // yes this is faster 4566d54fb17SJacob Faibussowitsch if (cerr == cupmSuccess) { 4570e6b6b59SJacob Faibussowitsch PetscCall(PetscMemcpy(dest, src, n)); 4580e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4596d54fb17SJacob Faibussowitsch } else if (cerr == cupmErrorNotReady) { 4606d54fb17SJacob Faibussowitsch auto PETSC_UNUSED unused = cupmGetLastError(); 4616d54fb17SJacob Faibussowitsch 4626d54fb17SJacob Faibussowitsch static_cast<void>(unused); 4636d54fb17SJacob Faibussowitsch } else { 4646d54fb17SJacob Faibussowitsch PetscCallCUPM(cerr); 4650e6b6b59SJacob Faibussowitsch } 4660e6b6b59SJacob Faibussowitsch } 4670e6b6b59SJacob Faibussowitsch PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 4680e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4690e6b6b59SJacob Faibussowitsch } 4700e6b6b59SJacob Faibussowitsch 4710e6b6b59SJacob Faibussowitsch template <DeviceType T> 4726d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 473d71ae5a4SJacob Faibussowitsch { 4740e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4750e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 4760e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "zeroing")); 4776797ed33SJacob Faibussowitsch PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 4780e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4790e6b6b59SJacob Faibussowitsch } 4800e6b6b59SJacob Faibussowitsch 4810e6b6b59SJacob Faibussowitsch template <DeviceType T> 482*8eb1d50fSPierre Jolivet inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 483d71ae5a4SJacob Faibussowitsch { 4840e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4850e6b6b59SJacob Faibussowitsch PetscCallCXX(event->data = new event_type()); 4860e6b6b59SJacob Faibussowitsch event->destroy = [](PetscEvent event) { 4870e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4880e6b6b59SJacob Faibussowitsch delete event_cast_(event); 4890e6b6b59SJacob Faibussowitsch event->data = nullptr; 4900e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4910e6b6b59SJacob Faibussowitsch }; 4920e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 4930e6b6b59SJacob Faibussowitsch } 4940e6b6b59SJacob Faibussowitsch 4950e6b6b59SJacob Faibussowitsch template <DeviceType T> 4966d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 497d71ae5a4SJacob Faibussowitsch { 4980e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 4990e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 5000e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 5010e6b6b59SJacob Faibussowitsch } 5020e6b6b59SJacob Faibussowitsch 5030e6b6b59SJacob Faibussowitsch template <DeviceType T> 5046d54fb17SJacob Faibussowitsch inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 505d71ae5a4SJacob Faibussowitsch { 5060e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 5070e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 5080e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 5090e6b6b59SJacob Faibussowitsch } 5100e6b6b59SJacob Faibussowitsch 511030f984aSJacob Faibussowitsch // initialize the static member variables 5129371c9d4SSatish Balay template <DeviceType T> 5139371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false; 514030f984aSJacob Faibussowitsch 51517f48955SJacob Faibussowitsch template <DeviceType T> 51617f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 517030f984aSJacob Faibussowitsch 51817f48955SJacob Faibussowitsch template <DeviceType T> 51917f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 52017f48955SJacob Faibussowitsch 5210e6b6b59SJacob Faibussowitsch } // namespace impl 522030f984aSJacob Faibussowitsch 523a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates) 5240e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 5250e6b6b59SJacob Faibussowitsch using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 526030f984aSJacob Faibussowitsch 527030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name 5280e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 529030f984aSJacob Faibussowitsch 5300e6b6b59SJacob Faibussowitsch } // namespace cupm 53117f48955SJacob Faibussowitsch 5320e6b6b59SJacob Faibussowitsch } // namespace device 53317f48955SJacob Faibussowitsch 53417f48955SJacob Faibussowitsch } // namespace Petsc 535030f984aSJacob Faibussowitsch 5360e6b6b59SJacob Faibussowitsch #endif // __cplusplus 5370e6b6b59SJacob Faibussowitsch 538a4af0ceeSJacob Faibussowitsch #endif // PETSCDEVICECONTEXTCUDA_HPP 539