#if !defined(PETSCDEVICECONTEXTCUPM_HPP) #define PETSCDEVICECONTEXTCUPM_HPP #include /*I "petscdevice.h" I*/ #include #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) #error PetscDeviceContext backends for CUDA and HIP requires C++11 #endif namespace Petsc { // Forward declare template class CUPMContext; template class CUPMContext : CUPMInterface { public: PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T); // This is the canonical PETSc "impls" struct that normally resides in a standalone impls // header, but since we are using the power of templates it must be declared part of // this class to have easy access the same typedefs. Technically one can make a // templated struct outside the class but it's more code for the same result. struct PetscDeviceContext_IMPLS { cupmStream_t stream; cupmEvent_t event; cupmBlasHandle_t blas; cupmSolverHandle_t solver; }; private: static cupmBlasHandle_t _blashandle; static cupmSolverHandle_t _solverhandle; PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept { PetscErrorCode ierr; PetscFunctionBegin; ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr); PetscFunctionReturn(0); } PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept { PetscErrorCode ierr; PetscFunctionBegin; ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr); PetscFunctionReturn(0); } PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept { PetscErrorCode ierr; PetscFunctionBegin; if (!_blashandle) { ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr); ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr); } if (!_solverhandle) { ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr); ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr); } ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr); ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr); dci->blas = _blashandle; dci->solver = _solverhandle; PetscFunctionReturn(0); } public: const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize}; // default constructor constexpr CUPMContext() noexcept = default; // All of these functions MUST be static in order to be callable from C, otherwise they // get the implicit 'this' pointer tacked on PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept; PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept; PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept; PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; }; #define IMPLS_RCAST_(obj_) static_cast((obj_)->data) template inline PetscErrorCode CUPMContext::destroy(PetscDeviceContext dctx) noexcept { PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); cupmError_t cerr; PetscErrorCode ierr; PetscFunctionBegin; if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);} ierr = PetscFree(dctx->data);CHKERRQ(ierr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept { PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); PetscFunctionBegin; if (dci->stream) { cupmError_t cerr; cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); dci->stream = nullptr; } // set these to null so they aren't usable until setup is called again dci->blas = nullptr; dci->solver = nullptr; PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::setUp(PetscDeviceContext dctx) noexcept { PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); PetscErrorCode ierr; cupmError_t cerr; PetscFunctionBegin; if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} switch (dctx->streamType) { case PETSC_STREAM_GLOBAL_BLOCKING: // don't create a stream for global blocking dci->stream = nullptr; break; case PETSC_STREAM_DEFAULT_BLOCKING: cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); break; case PETSC_STREAM_GLOBAL_NONBLOCKING: cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); break; default: SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType); break; } if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);} ierr = __setupHandles(dci);CHKERRQ(ierr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::query(PetscDeviceContext dctx, PetscBool *idle) noexcept { cupmError_t cerr; PetscFunctionBegin; cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream); if (cerr == cupmSuccess) *idle = PETSC_TRUE; else if (cerr == cupmErrorNotReady) { *idle = PETSC_FALSE; } else { // somethings gone wrong CHKERRCUPM(cerr); } PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept { PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa); PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb); cupmError_t cerr; PetscFunctionBegin; cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr); PetscFunctionReturn(0); } template inline PetscErrorCode CUPMContext::synchronize(PetscDeviceContext dctx) noexcept { PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); cupmError_t cerr; PetscFunctionBegin; // in case anything was queued on the event cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); PetscFunctionReturn(0); } // initialize the static member variables template typename CUPMContext::cupmBlasHandle_t CUPMContext::_blashandle = nullptr; template typename CUPMContext::cupmSolverHandle_t CUPMContext::_solverhandle = nullptr; // shorten this one up a bit using CUPMContextCuda = CUPMContext; using CUPMContextHip = CUPMContext; // make sure these doesn't leak out #undef CHKERRCUPM #undef IMPLS_RCAST_ } // namespace Petsc // shorthand for what is an EXTREMELY long name #define PetscDeviceContext_(impls_) Petsc::CUPMContext::PetscDeviceContext_IMPLS // shorthand for casting dctx->data to the appropriate object to access the handles #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast((obj_)->data) #endif /* PETSCDEVICECONTEXTCUDA_HPP */