#if !defined(PETSCDEVICECONTEXTCUPM_HPP) #define PETSCDEVICECONTEXTCUPM_HPP #include #include #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) #error PetscDeviceContext backends for CUDA and HIP requires C++11 #endif #include namespace Petsc { namespace detail { // for tag-based dispatch of handle retrieval template struct HandleTag { }; } // namespace detail // Forward declare template class CUPMContext; template class CUPMContext : CUPMInterface { template using HandleTag = typename detail::HandleTag; public: PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T) // This is the canonical PETSc "impls" struct that normally resides in a standalone impls // header, but since we are using the power of templates it must be declared part of // this class to have easy access the same typedefs. Technically one can make a // templated struct outside the class but it's more code for the same result. struct PetscDeviceContext_IMPLS { cupmStream_t stream; cupmEvent_t event; cupmEvent_t begin; // timer-only cupmEvent_t end; // timer-only #if PetscDefined(USE_DEBUG) PetscBool timerInUse; #endif cupmBlasHandle_t blas; cupmSolverHandle_t solver; PETSC_NODISCARD cupmBlasHandle_t handle(HandleTag) { return blas; } PETSC_NODISCARD cupmSolverHandle_t handle(HandleTag) { return solver; } }; private: static bool _initialized; static std::array _blashandles; static std::array _solverhandles; PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS* __impls_cast(PetscDeviceContext ptr) noexcept { return static_cast(ptr->data); } PETSC_NODISCARD static PetscErrorCode __finalize() noexcept { PetscErrorCode ierr; PetscFunctionBegin; for (auto&& handle : _blashandles) { if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);} } for (auto&& handle : _solverhandles) { if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);} } _initialized = false; PetscFunctionReturn(0); } PETSC_NODISCARD static PetscErrorCode __initialize(PetscInt id, PetscDeviceContext_IMPLS *dci) noexcept { PetscErrorCode ierr; PetscFunctionBegin; ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr); if (!_initialized) { _initialized = true; ierr = PetscRegisterFinalize(__finalize);CHKERRQ(ierr); } // use the blashandle as a canary if (!_blashandles[id]) { ierr = cupmInterface_t::InitializeHandle(_blashandles[id]);CHKERRQ(ierr); ierr = cupmInterface_t::InitializeHandle(_solverhandles[id]);CHKERRQ(ierr); } ierr = cupmInterface_t::SetHandleStream(_blashandles[id],dci->stream);CHKERRQ(ierr); ierr = cupmInterface_t::SetHandleStream(_solverhandles[id],dci->stream);CHKERRQ(ierr); dci->blas = _blashandles[id]; dci->solver = _solverhandles[id]; PetscFunctionReturn(0); } public: const struct _DeviceContextOps ops = { destroy, changeStreamType, setUp, query, waitForContext, synchronize, getHandle, getHandle, beginTimer, endTimer }; // default constructor constexpr CUPMContext() noexcept = default; // All of these functions MUST be static in order to be callable from C, otherwise they // get the implicit 'this' pointer tacked on PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept; PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept; PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; template PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext,void*) noexcept; PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*) noexcept; }; template inline PetscErrorCode CUPMContext::destroy(PetscDeviceContext dctx) noexcept { cupmError_t cerr; PetscErrorCode ierr; auto dci = __impls_cast(dctx); PetscFunctionBegin; if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} if (dci->event) { cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr); cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr); cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr); } ierr = PetscFree(dctx->data);CHKERRQ(ierr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept { auto dci = __impls_cast(dctx); PetscFunctionBegin; if (dci->stream) { cupmError_t cerr; cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); dci->stream = nullptr; } // set these to null so they aren't usable until setup is called again dci->blas = nullptr; dci->solver = nullptr; PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::setUp(PetscDeviceContext dctx) noexcept { PetscErrorCode ierr; cupmError_t cerr; auto dci = __impls_cast(dctx); PetscFunctionBegin; if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} switch (dctx->streamType) { case PETSC_STREAM_GLOBAL_BLOCKING: // don't create a stream for global blocking dci->stream = nullptr; break; case PETSC_STREAM_DEFAULT_BLOCKING: cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); break; case PETSC_STREAM_GLOBAL_NONBLOCKING: cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); break; default: SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[static_cast(dctx->streamType)]); break; } if (!dci->event) { cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr); cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr); cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr); } #if PetscDefined(USE_DEBUG) dci->timerInUse = PETSC_FALSE; #endif ierr = __initialize(dctx->device->deviceId,dci);CHKERRQ(ierr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::query(PetscDeviceContext dctx, PetscBool *idle) noexcept { cupmError_t cerr; PetscFunctionBegin; cerr = cupmStreamQuery(__impls_cast(dctx)->stream); if (cerr == cupmSuccess) *idle = PETSC_TRUE; else { // somethings gone wrong if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr); *idle = PETSC_FALSE; } PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept { cupmError_t cerr; auto dcia = __impls_cast(dctxa),dcib = __impls_cast(dctxb); PetscFunctionBegin; cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::synchronize(PetscDeviceContext dctx) noexcept { cupmError_t cerr; auto dci = __impls_cast(dctx); PetscFunctionBegin; // in case anything was queued on the event cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); PetscFunctionReturn(0); } template template inline PetscErrorCode CUPMContext::getHandle(PetscDeviceContext dctx, void *handle) noexcept { PetscFunctionBegin; *static_cast(handle) = __impls_cast(dctx)->handle(HandleTag()); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::beginTimer(PetscDeviceContext dctx) noexcept { auto dci = __impls_cast(dctx); cupmError_t cerr; PetscFunctionBegin; #if PetscDefined(USE_DEBUG) if (PetscUnlikely(dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); dci->timerInUse = PETSC_TRUE; #endif cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept { cupmError_t cerr; float gtime; auto dci = __impls_cast(dctx); PetscFunctionBegin; #if PetscDefined(USE_DEBUG) if (PetscUnlikely(!dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); dci->timerInUse = PETSC_FALSE; #endif cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr); cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr); cerr = cupmEventElapsedTime(>ime,dci->begin,dci->end);CHKERRCUPM(cerr); *elapsed = static_cast(gtime); PetscFunctionReturn(0); } // initialize the static member variables template bool CUPMContext::_initialized = false; template std::array::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> CUPMContext::_blashandles = {}; template std::array::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> CUPMContext::_solverhandles = {}; // shorten this one up a bit (and instantiate the templates) using CUPMContextCuda = CUPMContext; using CUPMContextHip = CUPMContext; } // namespace Petsc // shorthand for what is an EXTREMELY long name #define PetscDeviceContext_(IMPLS) Petsc::CUPMContext::PetscDeviceContext_IMPLS // shorthand for casting dctx->data to the appropriate object to access the handles #define PDC_IMPLS_STATIC_CAST(IMPLS,obj) static_cast((obj)->data) #endif // PETSCDEVICECONTEXTCUDA_HPP