117f48955SJacob Faibussowitsch #ifndef PETSCDEVICECONTEXTCUPM_HPP 2030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP 3030f984aSJacob Faibussowitsch 4a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 517f48955SJacob Faibussowitsch #include <petsc/private/cupmblasinterface.hpp> 67a101e5eSJacob Faibussowitsch #include <petsc/private/logimpl.h> 7030f984aSJacob Faibussowitsch 8*0e6b6b59SJacob Faibussowitsch #include <petsc/private/cpp/array.hpp> 9a4af0ceeSJacob Faibussowitsch 10*0e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp" 11*0e6b6b59SJacob Faibussowitsch #include "cupmallocator.hpp" 12*0e6b6b59SJacob Faibussowitsch #include "cupmstream.hpp" 13*0e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp" 14*0e6b6b59SJacob Faibussowitsch 15*0e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 169371c9d4SSatish Balay namespace Petsc { 17a4af0ceeSJacob Faibussowitsch 18*0e6b6b59SJacob Faibussowitsch namespace device { 1917f48955SJacob Faibussowitsch 20*0e6b6b59SJacob Faibussowitsch namespace cupm { 2117f48955SJacob Faibussowitsch 22*0e6b6b59SJacob Faibussowitsch namespace impl { 23030f984aSJacob Faibussowitsch 2417f48955SJacob Faibussowitsch template <DeviceType T> 25*0e6b6b59SJacob Faibussowitsch class DeviceContext : BlasInterface<T> { 2617f48955SJacob Faibussowitsch public: 2717f48955SJacob Faibussowitsch PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 2817f48955SJacob Faibussowitsch 2917f48955SJacob Faibussowitsch private: 309371c9d4SSatish Balay template <typename H, std::size_t> 319371c9d4SSatish Balay struct HandleTag { 329371c9d4SSatish Balay using type = H; 339371c9d4SSatish Balay }; 34*0e6b6b59SJacob Faibussowitsch 357a101e5eSJacob Faibussowitsch using stream_tag = HandleTag<cupmStream_t, 0>; 367a101e5eSJacob Faibussowitsch using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 377a101e5eSJacob Faibussowitsch using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 38a4af0ceeSJacob Faibussowitsch 39*0e6b6b59SJacob Faibussowitsch using stream_type = CUPMStream<T>; 40*0e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>; 41*0e6b6b59SJacob Faibussowitsch 42030f984aSJacob Faibussowitsch public: 43030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 44030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of 45030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a 46030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result. 47*0e6b6b59SJacob Faibussowitsch struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 48*0e6b6b59SJacob Faibussowitsch stream_type stream{}; 49*0e6b6b59SJacob Faibussowitsch cupmEvent_t event{}; 50*0e6b6b59SJacob Faibussowitsch cupmEvent_t begin{}; // timer-only 51*0e6b6b59SJacob Faibussowitsch cupmEvent_t end{}; // timer-only 52a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 53*0e6b6b59SJacob Faibussowitsch PetscBool timerInUse{}; 54a4af0ceeSJacob Faibussowitsch #endif 55*0e6b6b59SJacob Faibussowitsch cupmBlasHandle_t blas{}; 56*0e6b6b59SJacob Faibussowitsch cupmSolverHandle_t solver{}; 57a4af0ceeSJacob Faibussowitsch 58*0e6b6b59SJacob Faibussowitsch constexpr PetscDeviceContext_IMPLS() noexcept = default; 59*0e6b6b59SJacob Faibussowitsch 60*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { 61*0e6b6b59SJacob Faibussowitsch return this->stream.get_stream(); 629371c9d4SSatish Balay } 63*0e6b6b59SJacob Faibussowitsch 64*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { 659371c9d4SSatish Balay return this->blas; 669371c9d4SSatish Balay } 67*0e6b6b59SJacob Faibussowitsch 68*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { 699371c9d4SSatish Balay return this->solver; 709371c9d4SSatish Balay } 71030f984aSJacob Faibussowitsch }; 72030f984aSJacob Faibussowitsch 73030f984aSJacob Faibussowitsch private: 7417f48955SJacob Faibussowitsch static bool initialized_; 7517f48955SJacob Faibussowitsch static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 7617f48955SJacob Faibussowitsch static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77030f984aSJacob Faibussowitsch 78*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { 79a4af0ceeSJacob Faibussowitsch return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); 80a4af0ceeSJacob Faibussowitsch } 81a4af0ceeSJacob Faibussowitsch 82*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { 83*0e6b6b59SJacob Faibussowitsch return static_cast<CUPMEvent<T> *>(event->data); 84*0e6b6b59SJacob Faibussowitsch } 85*0e6b6b59SJacob Faibussowitsch 86*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { 877a101e5eSJacob Faibussowitsch return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 887a101e5eSJacob Faibussowitsch } 897a101e5eSJacob Faibussowitsch 90*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { 917a101e5eSJacob Faibussowitsch return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 927a101e5eSJacob Faibussowitsch } 937a101e5eSJacob Faibussowitsch 947a101e5eSJacob Faibussowitsch // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 957a101e5eSJacob Faibussowitsch // handles 969371c9d4SSatish Balay PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { 979371c9d4SSatish Balay return 0; 989371c9d4SSatish Balay } 997a101e5eSJacob Faibussowitsch 100*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode create_handle_(cupmBlasHandle_t &handle) noexcept { 1017a101e5eSJacob Faibussowitsch PetscLogEvent event; 1027a101e5eSJacob Faibussowitsch 103030f984aSJacob Faibussowitsch PetscFunctionBegin; 1047a101e5eSJacob Faibussowitsch if (PetscLikely(handle)) PetscFunctionReturn(0); 1057a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 1067a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 10717f48955SJacob Faibussowitsch for (auto i = 0; i < 3; ++i) { 10817f48955SJacob Faibussowitsch auto cberr = cupmBlasCreate(&handle); 10917f48955SJacob Faibussowitsch if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 1109566063dSJacob Faibussowitsch if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 11117f48955SJacob Faibussowitsch if (i != 2) { 1129566063dSJacob Faibussowitsch PetscCall(PetscSleep(3)); 11317f48955SJacob Faibussowitsch continue; 114a4af0ceeSJacob Faibussowitsch } 1155f80ce2aSJacob Faibussowitsch PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 116a4af0ceeSJacob Faibussowitsch } 1177a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 1187a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 119030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 120030f984aSJacob Faibussowitsch } 121030f984aSJacob Faibussowitsch 122*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept { 1237a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1247a101e5eSJacob Faibussowitsch auto &handle = blashandles_[dctx->device->deviceId]; 12517f48955SJacob Faibussowitsch 12617f48955SJacob Faibussowitsch PetscFunctionBegin; 1277a101e5eSJacob Faibussowitsch PetscCall(create_handle_(handle)); 128*0e6b6b59SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 1297a101e5eSJacob Faibussowitsch dci->blas = handle; 1307a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 1317a101e5eSJacob Faibussowitsch } 1327a101e5eSJacob Faibussowitsch 1339371c9d4SSatish Balay PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { 1347a101e5eSJacob Faibussowitsch PetscLogEvent event; 1357a101e5eSJacob Faibussowitsch 1367a101e5eSJacob Faibussowitsch PetscFunctionBegin; 1377a101e5eSJacob Faibussowitsch PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 1387a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1397a101e5eSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 1407a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 1417a101e5eSJacob Faibussowitsch PetscCall(PetscLogEventResume_Internal(event)); 1427a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 1437a101e5eSJacob Faibussowitsch } 1447a101e5eSJacob Faibussowitsch 145*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept { 1467a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 1477a101e5eSJacob Faibussowitsch auto &handle = solverhandles_[dctx->device->deviceId]; 1487a101e5eSJacob Faibussowitsch 1497a101e5eSJacob Faibussowitsch PetscFunctionBegin; 1507a101e5eSJacob Faibussowitsch PetscCall(create_handle_(handle)); 151*0e6b6b59SJacob Faibussowitsch PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 1527a101e5eSJacob Faibussowitsch dci->solver = handle; 15317f48955SJacob Faibussowitsch PetscFunctionReturn(0); 15417f48955SJacob Faibussowitsch } 15517f48955SJacob Faibussowitsch 156*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept { 157*0e6b6b59SJacob Faibussowitsch const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 158*0e6b6b59SJacob Faibussowitsch 159*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 160*0e6b6b59SJacob Faibussowitsch PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 161*0e6b6b59SJacob Faibussowitsch PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 162*0e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 163*0e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 164*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 165*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 166*0e6b6b59SJacob Faibussowitsch } 167*0e6b6b59SJacob Faibussowitsch 168*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { 169*0e6b6b59SJacob Faibussowitsch return check_current_device_(dctx, dctx); 170*0e6b6b59SJacob Faibussowitsch } 171*0e6b6b59SJacob Faibussowitsch 172*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode finalize_() noexcept { 17317f48955SJacob Faibussowitsch PetscFunctionBegin; 17417f48955SJacob Faibussowitsch for (auto &&handle : blashandles_) { 17517f48955SJacob Faibussowitsch if (handle) { 1769566063dSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 17717f48955SJacob Faibussowitsch handle = nullptr; 17817f48955SJacob Faibussowitsch } 17917f48955SJacob Faibussowitsch } 18017f48955SJacob Faibussowitsch for (auto &&handle : solverhandles_) { 18117f48955SJacob Faibussowitsch if (handle) { 1829566063dSJacob Faibussowitsch PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 18317f48955SJacob Faibussowitsch handle = nullptr; 18417f48955SJacob Faibussowitsch } 18517f48955SJacob Faibussowitsch } 18617f48955SJacob Faibussowitsch initialized_ = false; 18717f48955SJacob Faibussowitsch PetscFunctionReturn(0); 18817f48955SJacob Faibussowitsch } 18917f48955SJacob Faibussowitsch 190*0e6b6b59SJacob Faibussowitsch template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 191*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PoolType &default_pool_() noexcept { 192*0e6b6b59SJacob Faibussowitsch static PoolType pool; 193*0e6b6b59SJacob Faibussowitsch return pool; 194*0e6b6b59SJacob Faibussowitsch } 195030f984aSJacob Faibussowitsch 196*0e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept { 197*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 198*0e6b6b59SJacob Faibussowitsch PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 199*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 200*0e6b6b59SJacob Faibussowitsch } 201*0e6b6b59SJacob Faibussowitsch 202*0e6b6b59SJacob Faibussowitsch public: 203030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they 204030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on 20517f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 20617f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 20717f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 20817f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 20917f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 21017f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 211a4af0ceeSJacob Faibussowitsch template <typename Handle_t> 21217f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 21317f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 21417f48955SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 215*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, void **)); 216*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **)); 217*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode)); 218*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t)); 219*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent)); 220*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent)); 221*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent)); 2227a101e5eSJacob Faibussowitsch 2237a101e5eSJacob Faibussowitsch // not a PetscDeviceContext method, this registers the class 2247a101e5eSJacob Faibussowitsch PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 225*0e6b6b59SJacob Faibussowitsch 226*0e6b6b59SJacob Faibussowitsch // clang-format off 227*0e6b6b59SJacob Faibussowitsch const _DeviceContextOps ops = { 228*0e6b6b59SJacob Faibussowitsch destroy, 229*0e6b6b59SJacob Faibussowitsch changeStreamType, 230*0e6b6b59SJacob Faibussowitsch setUp, 231*0e6b6b59SJacob Faibussowitsch query, 232*0e6b6b59SJacob Faibussowitsch waitForContext, 233*0e6b6b59SJacob Faibussowitsch synchronize, 234*0e6b6b59SJacob Faibussowitsch getHandle<blas_tag>, 235*0e6b6b59SJacob Faibussowitsch getHandle<solver_tag>, 236*0e6b6b59SJacob Faibussowitsch getHandle<stream_tag>, 237*0e6b6b59SJacob Faibussowitsch beginTimer, 238*0e6b6b59SJacob Faibussowitsch endTimer, 239*0e6b6b59SJacob Faibussowitsch memAlloc, 240*0e6b6b59SJacob Faibussowitsch memFree, 241*0e6b6b59SJacob Faibussowitsch memCopy, 242*0e6b6b59SJacob Faibussowitsch memSet, 243*0e6b6b59SJacob Faibussowitsch createEvent, 244*0e6b6b59SJacob Faibussowitsch recordEvent, 245*0e6b6b59SJacob Faibussowitsch waitForEvent 246*0e6b6b59SJacob Faibussowitsch }; 247*0e6b6b59SJacob Faibussowitsch // clang-format on 248030f984aSJacob Faibussowitsch }; 249030f984aSJacob Faibussowitsch 250*0e6b6b59SJacob Faibussowitsch // not a PetscDeviceContext method, this initializes the CLASS 25117f48955SJacob Faibussowitsch template <DeviceType T> 2529371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) { 2537a101e5eSJacob Faibussowitsch PetscFunctionBegin; 2547a101e5eSJacob Faibussowitsch if (PetscUnlikely(!initialized_)) { 255*0e6b6b59SJacob Faibussowitsch cupmMemPool_t mempool; 256*0e6b6b59SJacob Faibussowitsch uint64_t threshold = UINT64_MAX; 257*0e6b6b59SJacob Faibussowitsch 2587a101e5eSJacob Faibussowitsch initialized_ = true; 259*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0)); 260*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 261*0e6b6b59SJacob Faibussowitsch blashandles_.fill(nullptr); 262*0e6b6b59SJacob Faibussowitsch solverhandles_.fill(nullptr); 2637a101e5eSJacob Faibussowitsch PetscCall(PetscRegisterFinalize(finalize_)); 2647a101e5eSJacob Faibussowitsch } 2657a101e5eSJacob Faibussowitsch PetscFunctionReturn(0); 2667a101e5eSJacob Faibussowitsch } 2677a101e5eSJacob Faibussowitsch 2687a101e5eSJacob Faibussowitsch template <DeviceType T> 2699371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) { 270030f984aSJacob Faibussowitsch PetscFunctionBegin; 271*0e6b6b59SJacob Faibussowitsch if (const auto dci = impls_cast_(dctx)) { 272*0e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 273*0e6b6b59SJacob Faibussowitsch if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 2749566063dSJacob Faibussowitsch if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 2759566063dSJacob Faibussowitsch if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 276*0e6b6b59SJacob Faibussowitsch delete dci; 277*0e6b6b59SJacob Faibussowitsch dctx->data = nullptr; 278*0e6b6b59SJacob Faibussowitsch } 279030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 280030f984aSJacob Faibussowitsch } 281030f984aSJacob Faibussowitsch 28217f48955SJacob Faibussowitsch template <DeviceType T> 2839371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { 2847a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 285030f984aSJacob Faibussowitsch 286030f984aSJacob Faibussowitsch PetscFunctionBegin; 287*0e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.destroy()); 288030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again 289030f984aSJacob Faibussowitsch dci->blas = nullptr; 290030f984aSJacob Faibussowitsch dci->solver = nullptr; 291030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 292030f984aSJacob Faibussowitsch } 293030f984aSJacob Faibussowitsch 29417f48955SJacob Faibussowitsch template <DeviceType T> 2959371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) { 2967a101e5eSJacob Faibussowitsch const auto dci = impls_cast_(dctx); 297*0e6b6b59SJacob Faibussowitsch auto &event = dci->event; 298030f984aSJacob Faibussowitsch 299030f984aSJacob Faibussowitsch PetscFunctionBegin; 300*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 301*0e6b6b59SJacob Faibussowitsch PetscCall(dci->stream.change_type(dctx->streamType)); 302*0e6b6b59SJacob Faibussowitsch if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 303a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 304a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 305a4af0ceeSJacob Faibussowitsch #endif 306030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 307030f984aSJacob Faibussowitsch } 308030f984aSJacob Faibussowitsch 30917f48955SJacob Faibussowitsch template <DeviceType T> 3109371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) { 311030f984aSJacob Faibussowitsch PetscFunctionBegin; 312*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 313*0e6b6b59SJacob Faibussowitsch switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 314*0e6b6b59SJacob Faibussowitsch case cupmSuccess: *idle = PETSC_TRUE; break; 315*0e6b6b59SJacob Faibussowitsch case cupmErrorNotReady: *idle = PETSC_FALSE; break; 316*0e6b6b59SJacob Faibussowitsch default: PetscCallCUPM(cerr); PetscUnreachable(); 317030f984aSJacob Faibussowitsch } 318030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 319030f984aSJacob Faibussowitsch } 320030f984aSJacob Faibussowitsch 32117f48955SJacob Faibussowitsch template <DeviceType T> 3229371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { 323*0e6b6b59SJacob Faibussowitsch const auto dcib = impls_cast_(dctxb); 324*0e6b6b59SJacob Faibussowitsch const auto event = dcib->event; 325030f984aSJacob Faibussowitsch 326030f984aSJacob Faibussowitsch PetscFunctionBegin; 327*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctxa, dctxb)); 328*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 329*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 330030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 331030f984aSJacob Faibussowitsch } 332030f984aSJacob Faibussowitsch 33317f48955SJacob Faibussowitsch template <DeviceType T> 3349371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) { 335*0e6b6b59SJacob Faibussowitsch auto idle = PETSC_TRUE; 336030f984aSJacob Faibussowitsch 337030f984aSJacob Faibussowitsch PetscFunctionBegin; 338*0e6b6b59SJacob Faibussowitsch PetscCall(query(dctx, &idle)); 339*0e6b6b59SJacob Faibussowitsch if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 340030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 341030f984aSJacob Faibussowitsch } 342030f984aSJacob Faibussowitsch 34317f48955SJacob Faibussowitsch template <DeviceType T> 34417f48955SJacob Faibussowitsch template <typename handle_t> 3459371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) { 346a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 3477a101e5eSJacob Faibussowitsch PetscCall(initialize_handle_(handle_t{}, dctx)); 3487a101e5eSJacob Faibussowitsch *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 349a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 350a4af0ceeSJacob Faibussowitsch } 351a4af0ceeSJacob Faibussowitsch 35217f48955SJacob Faibussowitsch template <DeviceType T> 3539371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) { 354*0e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 355a4af0ceeSJacob Faibussowitsch 356a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 357*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 358a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3595f80ce2aSJacob Faibussowitsch PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 360a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_TRUE; 361a4af0ceeSJacob Faibussowitsch #endif 36217f48955SJacob Faibussowitsch if (!dci->begin) { 363*0e6b6b59SJacob Faibussowitsch PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 3649566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->begin)); 3659566063dSJacob Faibussowitsch PetscCallCUPM(cupmEventCreate(&dci->end)); 36617f48955SJacob Faibussowitsch } 367*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 368a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 369a4af0ceeSJacob Faibussowitsch } 370a4af0ceeSJacob Faibussowitsch 37117f48955SJacob Faibussowitsch template <DeviceType T> 3729371c9d4SSatish Balay PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { 373a4af0ceeSJacob Faibussowitsch float gtime; 374*0e6b6b59SJacob Faibussowitsch const auto dci = impls_cast_(dctx); 375*0e6b6b59SJacob Faibussowitsch const auto end = dci->end; 376a4af0ceeSJacob Faibussowitsch 377a4af0ceeSJacob Faibussowitsch PetscFunctionBegin; 378*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 379a4af0ceeSJacob Faibussowitsch #if PetscDefined(USE_DEBUG) 3805f80ce2aSJacob Faibussowitsch PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 381a4af0ceeSJacob Faibussowitsch dci->timerInUse = PETSC_FALSE; 382a4af0ceeSJacob Faibussowitsch #endif 383*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 384*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventSynchronize(end)); 385*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 38617f48955SJacob Faibussowitsch *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 387a4af0ceeSJacob Faibussowitsch PetscFunctionReturn(0); 388a4af0ceeSJacob Faibussowitsch } 389a4af0ceeSJacob Faibussowitsch 390*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 391*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, void **dest)) { 392*0e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 393*0e6b6b59SJacob Faibussowitsch 394*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 395*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 396*0e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "allocating")); 397*0e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 398*0e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream)); 399*0e6b6b59SJacob Faibussowitsch if (clear) std::memset(*dest, 0, n); 400*0e6b6b59SJacob Faibussowitsch } else { 401*0e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream)); 402*0e6b6b59SJacob Faibussowitsch if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 403*0e6b6b59SJacob Faibussowitsch } 404*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 405*0e6b6b59SJacob Faibussowitsch } 406*0e6b6b59SJacob Faibussowitsch 407*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 408*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) { 409*0e6b6b59SJacob Faibussowitsch const auto &stream = impls_cast_(dctx)->stream; 410*0e6b6b59SJacob Faibussowitsch 411*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 412*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 413*0e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "freeing")); 414*0e6b6b59SJacob Faibussowitsch if (!*ptr) PetscFunctionReturn(0); 415*0e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 416*0e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 417*0e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 418*0e6b6b59SJacob Faibussowitsch if (*ptr) { 419*0e6b6b59SJacob Faibussowitsch auto registered = PETSC_FALSE, managed = PETSC_FALSE; 420*0e6b6b59SJacob Faibussowitsch 421*0e6b6b59SJacob Faibussowitsch PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 422*0e6b6b59SJacob Faibussowitsch if (registered) { 423*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeHost(*ptr)); 424*0e6b6b59SJacob Faibussowitsch } else if (managed) { 425*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 426*0e6b6b59SJacob Faibussowitsch } 427*0e6b6b59SJacob Faibussowitsch } 428*0e6b6b59SJacob Faibussowitsch } else { 429*0e6b6b59SJacob Faibussowitsch PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 430*0e6b6b59SJacob Faibussowitsch // if ptr exists still exists the pool didn't own it 431*0e6b6b59SJacob Faibussowitsch if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 432*0e6b6b59SJacob Faibussowitsch } 433*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 434*0e6b6b59SJacob Faibussowitsch } 435*0e6b6b59SJacob Faibussowitsch 436*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 437*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) { 438*0e6b6b59SJacob Faibussowitsch const auto stream = impls_cast_(dctx)->stream.get_stream(); 439*0e6b6b59SJacob Faibussowitsch 440*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 441*0e6b6b59SJacob Faibussowitsch // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 442*0e6b6b59SJacob Faibussowitsch if (mode == PETSC_DEVICE_COPY_HTOH) { 443*0e6b6b59SJacob Faibussowitsch // yes this is faster 444*0e6b6b59SJacob Faibussowitsch if (cupmStreamQuery(stream) == cupmSuccess) { 445*0e6b6b59SJacob Faibussowitsch PetscCall(PetscMemcpy(dest, src, n)); 446*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 447*0e6b6b59SJacob Faibussowitsch } 448*0e6b6b59SJacob Faibussowitsch // in case cupmStreamQuery() did not return cupmErrorNotReady 449*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmGetLastError()); 450*0e6b6b59SJacob Faibussowitsch } 451*0e6b6b59SJacob Faibussowitsch PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 452*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 453*0e6b6b59SJacob Faibussowitsch } 454*0e6b6b59SJacob Faibussowitsch 455*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 456*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) { 457*0e6b6b59SJacob Faibussowitsch auto vint = static_cast<int>(v); 458*0e6b6b59SJacob Faibussowitsch 459*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 460*0e6b6b59SJacob Faibussowitsch PetscCall(check_current_device_(dctx)); 461*0e6b6b59SJacob Faibussowitsch PetscCall(check_memtype_(mtype, "zeroing")); 462*0e6b6b59SJacob Faibussowitsch if (PetscMemTypeHost(mtype)) { 463*0e6b6b59SJacob Faibussowitsch // must call public sync to prune the dependency graph 464*0e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceContextSynchronize(dctx)); 465*0e6b6b59SJacob Faibussowitsch std::memset(ptr, vint, n); 466*0e6b6b59SJacob Faibussowitsch } else { 467*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmMemsetAsync(ptr, vint, n, impls_cast_(dctx)->stream.get_stream())); 468*0e6b6b59SJacob Faibussowitsch } 469*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 470*0e6b6b59SJacob Faibussowitsch } 471*0e6b6b59SJacob Faibussowitsch 472*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 473*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) { 474*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 475*0e6b6b59SJacob Faibussowitsch PetscCallCXX(event->data = new event_type()); 476*0e6b6b59SJacob Faibussowitsch event->destroy = [](PetscEvent event) { 477*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 478*0e6b6b59SJacob Faibussowitsch delete event_cast_(event); 479*0e6b6b59SJacob Faibussowitsch event->data = nullptr; 480*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 481*0e6b6b59SJacob Faibussowitsch }; 482*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 483*0e6b6b59SJacob Faibussowitsch } 484*0e6b6b59SJacob Faibussowitsch 485*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 486*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) { 487*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 488*0e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 489*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 490*0e6b6b59SJacob Faibussowitsch } 491*0e6b6b59SJacob Faibussowitsch 492*0e6b6b59SJacob Faibussowitsch template <DeviceType T> 493*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) { 494*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 495*0e6b6b59SJacob Faibussowitsch PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 496*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 497*0e6b6b59SJacob Faibussowitsch } 498*0e6b6b59SJacob Faibussowitsch 499030f984aSJacob Faibussowitsch // initialize the static member variables 5009371c9d4SSatish Balay template <DeviceType T> 5019371c9d4SSatish Balay bool DeviceContext<T>::initialized_ = false; 502030f984aSJacob Faibussowitsch 50317f48955SJacob Faibussowitsch template <DeviceType T> 50417f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 505030f984aSJacob Faibussowitsch 50617f48955SJacob Faibussowitsch template <DeviceType T> 50717f48955SJacob Faibussowitsch std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 50817f48955SJacob Faibussowitsch 509*0e6b6b59SJacob Faibussowitsch } // namespace impl 510030f984aSJacob Faibussowitsch 511a4af0ceeSJacob Faibussowitsch // shorten this one up a bit (and instantiate the templates) 512*0e6b6b59SJacob Faibussowitsch using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 513*0e6b6b59SJacob Faibussowitsch using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 514030f984aSJacob Faibussowitsch 515030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name 516*0e6b6b59SJacob Faibussowitsch #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 517030f984aSJacob Faibussowitsch 518*0e6b6b59SJacob Faibussowitsch } // namespace cupm 51917f48955SJacob Faibussowitsch 520*0e6b6b59SJacob Faibussowitsch } // namespace device 52117f48955SJacob Faibussowitsch 52217f48955SJacob Faibussowitsch } // namespace Petsc 523030f984aSJacob Faibussowitsch 524*0e6b6b59SJacob Faibussowitsch #endif // __cplusplus 525*0e6b6b59SJacob Faibussowitsch 526a4af0ceeSJacob Faibussowitsch #endif // PETSCDEVICECONTEXTCUDA_HPP 527