#ifndef PETSCDEVICECONTEXTCUPM_HPP #define PETSCDEVICECONTEXTCUPM_HPP #include #include #include #include namespace Petsc { namespace Device { namespace CUPM { namespace Impl { // Forward declare template class PETSC_VISIBILITY_INTERNAL DeviceContext; template class DeviceContext : Impl::BlasInterface { public: PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); private: // for tag-based dispatch of handle retrieval template struct HandleTag { using type = H; }; using stream_tag = HandleTag; using blas_tag = HandleTag; using solver_tag = HandleTag; public: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls // header, but since we are using the power of templates it must be declared part of // this class to have easy access the same typedefs. Technically one can make a // templated struct outside the class but it's more code for the same result. struct PetscDeviceContext_IMPLS { cupmStream_t stream; cupmEvent_t event; cupmEvent_t begin; // timer-only cupmEvent_t end; // timer-only #if PetscDefined(USE_DEBUG) PetscBool timerInUse; #endif cupmBlasHandle_t blas; cupmSolverHandle_t solver; PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } }; private: static bool initialized_; static std::array blashandles_; static std::array solverhandles_; PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr)) { return static_cast(ptr->data); } PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMBLAS_HANDLE_CREATE()) { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMSOLVER_HANDLE_CREATE()) { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } // this exists purely to satisfy the compiler so the tag-based dispatch works for the other // handles PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { return 0; } PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmBlasHandle_t &handle)) { PetscLogEvent event; PetscFunctionBegin; if (PetscLikely(handle)) PetscFunctionReturn(0); PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); for (auto i = 0; i < 3; ++i) { auto cberr = cupmBlasCreate(&handle); if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); if (i != 2) { PetscCall(PetscSleep(3)); continue; } PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); } PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); PetscCall(PetscLogEventResume_Internal(event)); PetscFunctionReturn(0); } PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx)) { const auto dci = impls_cast_(dctx); auto &handle = blashandles_[dctx->device->deviceId]; PetscFunctionBegin; PetscCall(create_handle_(handle)); PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream)); dci->blas = handle; PetscFunctionReturn(0); } PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { PetscLogEvent event; PetscFunctionBegin; PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); PetscCall(PetscLogEventResume_Internal(event)); PetscFunctionReturn(0); } PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx)) { const auto dci = impls_cast_(dctx); auto &handle = solverhandles_[dctx->device->deviceId]; PetscFunctionBegin; PetscCall(create_handle_(handle)); PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream)); dci->solver = handle; PetscFunctionReturn(0); } PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) { PetscFunctionBegin; for (auto &&handle : blashandles_) { if (handle) { PetscCallCUPMBLAS(cupmBlasDestroy(handle)); handle = nullptr; } } for (auto &&handle : solverhandles_) { if (handle) { PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); handle = nullptr; } } initialized_ = false; PetscFunctionReturn(0); } public: const struct _DeviceContextOps ops = { destroy, changeStreamType, setUp, query, waitForContext, synchronize, getHandle, getHandle, getHandle, beginTimer, endTimer, }; // All of these functions MUST be static in order to be callable from C, otherwise they // get the implicit 'this' pointer tacked on PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); template PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); // not a PetscDeviceContext method, this registers the class PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); }; template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::initialize()) { PetscFunctionBegin; if (PetscUnlikely(!initialized_)) { initialized_ = true; PetscCall(PetscRegisterFinalize(finalize_)); } PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::destroy(PetscDeviceContext dctx)) { const auto dci = impls_cast_(dctx); PetscFunctionBegin; if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream)); if (dci->event) PetscCallCUPM(cupmEventDestroy(dci->event)); if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); PetscCall(PetscFree(dctx->data)); PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { const auto dci = impls_cast_(dctx); PetscFunctionBegin; if (auto &stream = dci->stream) { PetscCallCUPM(cupmStreamDestroy(stream)); stream = nullptr; } // set these to null so they aren't usable until setup is called again dci->blas = nullptr; dci->solver = nullptr; PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::setUp(PetscDeviceContext dctx)) { const auto dci = impls_cast_(dctx); auto &stream = dci->stream; PetscFunctionBegin; if (stream) { PetscCallCUPM(cupmStreamDestroy(stream)); stream = nullptr; } switch (const auto stype = dctx->streamType) { case PETSC_STREAM_GLOBAL_BLOCKING: // don't create a stream for global blocking break; case PETSC_STREAM_DEFAULT_BLOCKING: PetscCallCUPM(cupmStreamCreate(&stream)); break; case PETSC_STREAM_GLOBAL_NONBLOCKING: PetscCallCUPM(cupmStreamCreateWithFlags(&stream, cupmStreamNonBlocking)); break; default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_CORRUPT, "Invalid PetscStreamType %s", PetscStreamTypes[util::integral_value(stype)]); break; } if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event)); #if PetscDefined(USE_DEBUG) dci->timerInUse = PETSC_FALSE; #endif PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::query(PetscDeviceContext dctx, PetscBool *idle)) { cupmError_t cerr; PetscFunctionBegin; cerr = cupmStreamQuery(impls_cast_(dctx)->stream); if (cerr == cupmSuccess) *idle = PETSC_TRUE; else { // somethings gone wrong if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr); *idle = PETSC_FALSE; } PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { auto dcib = impls_cast_(dctxb); PetscFunctionBegin; PetscCallCUPM(cupmEventRecord(dcib->event, dcib->stream)); PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream, dcib->event, 0)); PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::synchronize(PetscDeviceContext dctx)) { auto dci = impls_cast_(dctx); PetscFunctionBegin; // in case anything was queued on the event PetscCallCUPM(cupmStreamWaitEvent(dci->stream, dci->event, 0)); PetscCallCUPM(cupmStreamSynchronize(dci->stream)); PetscFunctionReturn(0); } template template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::getHandle(PetscDeviceContext dctx, void *handle)) { PetscFunctionBegin; PetscCall(initialize_handle_(handle_t{}, dctx)); *static_cast(handle) = impls_cast_(dctx)->get(handle_t{}); PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::beginTimer(PetscDeviceContext dctx)) { auto dci = impls_cast_(dctx); PetscFunctionBegin; #if PetscDefined(USE_DEBUG) PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); dci->timerInUse = PETSC_TRUE; #endif if (!dci->begin) { PetscCallCUPM(cupmEventCreate(&dci->begin)); PetscCallCUPM(cupmEventCreate(&dci->end)); } PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream)); PetscFunctionReturn(0); } template PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { float gtime; auto dci = impls_cast_(dctx); PetscFunctionBegin; #if PetscDefined(USE_DEBUG) PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); dci->timerInUse = PETSC_FALSE; #endif PetscCallCUPM(cupmEventRecord(dci->end, dci->stream)); PetscCallCUPM(cupmEventSynchronize(dci->end)); PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, dci->end)); *elapsed = static_cast>(gtime); PetscFunctionReturn(0); } // initialize the static member variables template bool DeviceContext::initialized_ = false; template std::array::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext::blashandles_ = {}; template std::array::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext::solverhandles_ = {}; } // namespace Impl // shorten this one up a bit (and instantiate the templates) using CUPMContextCuda = Impl::DeviceContext; using CUPMContextHip = Impl::DeviceContext; // shorthand for what is an EXTREMELY long name #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext::PetscDeviceContext_IMPLS } // namespace CUPM } // namespace Device } // namespace Petsc #endif // PETSCDEVICECONTEXTCUDA_HPP