1*030f984aSJacob Faibussowitsch #if !defined(PETSCDEVICECONTEXTCUPM_HPP) 2*030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP 3*030f984aSJacob Faibussowitsch 4*030f984aSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> /*I "petscdevice.h" I*/ 5*030f984aSJacob Faibussowitsch #include <petsc/private/cupminterface.hpp> 6*030f984aSJacob Faibussowitsch 7*030f984aSJacob Faibussowitsch #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) 8*030f984aSJacob Faibussowitsch #error PetscDeviceContext backends for CUDA and HIP requires C++11 9*030f984aSJacob Faibussowitsch #endif 10*030f984aSJacob Faibussowitsch 11*030f984aSJacob Faibussowitsch namespace Petsc { 12*030f984aSJacob Faibussowitsch 13*030f984aSJacob Faibussowitsch // Forward declare 14*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> class CUPMContext; 15*030f984aSJacob Faibussowitsch 16*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 17*030f984aSJacob Faibussowitsch class CUPMContext : CUPMInterface<T> 18*030f984aSJacob Faibussowitsch { 19*030f984aSJacob Faibussowitsch public: 20*030f984aSJacob Faibussowitsch PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T); 21*030f984aSJacob Faibussowitsch 22*030f984aSJacob Faibussowitsch // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 23*030f984aSJacob Faibussowitsch // header, but since we are using the power of templates it must be declared part of 24*030f984aSJacob Faibussowitsch // this class to have easy access the same typedefs. Technically one can make a 25*030f984aSJacob Faibussowitsch // templated struct outside the class but it's more code for the same result. 26*030f984aSJacob Faibussowitsch struct PetscDeviceContext_IMPLS 27*030f984aSJacob Faibussowitsch { 28*030f984aSJacob Faibussowitsch cupmStream_t stream; 29*030f984aSJacob Faibussowitsch cupmEvent_t event; 30*030f984aSJacob Faibussowitsch cupmBlasHandle_t blas; 31*030f984aSJacob Faibussowitsch cupmSolverHandle_t solver; 32*030f984aSJacob Faibussowitsch }; 33*030f984aSJacob Faibussowitsch 34*030f984aSJacob Faibussowitsch private: 35*030f984aSJacob Faibussowitsch static cupmBlasHandle_t _blashandle; 36*030f984aSJacob Faibussowitsch static cupmSolverHandle_t _solverhandle; 37*030f984aSJacob Faibussowitsch 38*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept 39*030f984aSJacob Faibussowitsch { 40*030f984aSJacob Faibussowitsch PetscErrorCode ierr; 41*030f984aSJacob Faibussowitsch 42*030f984aSJacob Faibussowitsch PetscFunctionBegin; 43*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr); 44*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 45*030f984aSJacob Faibussowitsch } 46*030f984aSJacob Faibussowitsch 47*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept 48*030f984aSJacob Faibussowitsch { 49*030f984aSJacob Faibussowitsch PetscErrorCode ierr; 50*030f984aSJacob Faibussowitsch 51*030f984aSJacob Faibussowitsch PetscFunctionBegin; 52*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr); 53*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 54*030f984aSJacob Faibussowitsch } 55*030f984aSJacob Faibussowitsch 56*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept 57*030f984aSJacob Faibussowitsch { 58*030f984aSJacob Faibussowitsch PetscErrorCode ierr; 59*030f984aSJacob Faibussowitsch 60*030f984aSJacob Faibussowitsch PetscFunctionBegin; 61*030f984aSJacob Faibussowitsch if (!_blashandle) { 62*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr); 63*030f984aSJacob Faibussowitsch ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr); 64*030f984aSJacob Faibussowitsch } 65*030f984aSJacob Faibussowitsch if (!_solverhandle) { 66*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr); 67*030f984aSJacob Faibussowitsch ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr); 68*030f984aSJacob Faibussowitsch } 69*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr); 70*030f984aSJacob Faibussowitsch ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr); 71*030f984aSJacob Faibussowitsch dci->blas = _blashandle; 72*030f984aSJacob Faibussowitsch dci->solver = _solverhandle; 73*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 74*030f984aSJacob Faibussowitsch } 75*030f984aSJacob Faibussowitsch 76*030f984aSJacob Faibussowitsch public: 77*030f984aSJacob Faibussowitsch const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize}; 78*030f984aSJacob Faibussowitsch 79*030f984aSJacob Faibussowitsch // default constructor 80*030f984aSJacob Faibussowitsch constexpr CUPMContext() noexcept = default; 81*030f984aSJacob Faibussowitsch 82*030f984aSJacob Faibussowitsch // All of these functions MUST be static in order to be callable from C, otherwise they 83*030f984aSJacob Faibussowitsch // get the implicit 'this' pointer tacked on 84*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; 85*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept; 86*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; 87*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept; 88*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept; 89*030f984aSJacob Faibussowitsch PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 90*030f984aSJacob Faibussowitsch }; 91*030f984aSJacob Faibussowitsch 92*030f984aSJacob Faibussowitsch #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data) 93*030f984aSJacob Faibussowitsch 94*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 95*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept 96*030f984aSJacob Faibussowitsch { 97*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 98*030f984aSJacob Faibussowitsch cupmError_t cerr; 99*030f984aSJacob Faibussowitsch PetscErrorCode ierr; 100*030f984aSJacob Faibussowitsch 101*030f984aSJacob Faibussowitsch PetscFunctionBegin; 102*030f984aSJacob Faibussowitsch if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 103*030f984aSJacob Faibussowitsch if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);} 104*030f984aSJacob Faibussowitsch ierr = PetscFree(dctx->data);CHKERRQ(ierr); 105*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 106*030f984aSJacob Faibussowitsch } 107*030f984aSJacob Faibussowitsch 108*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 109*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept 110*030f984aSJacob Faibussowitsch { 111*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 112*030f984aSJacob Faibussowitsch 113*030f984aSJacob Faibussowitsch PetscFunctionBegin; 114*030f984aSJacob Faibussowitsch if (dci->stream) { 115*030f984aSJacob Faibussowitsch cupmError_t cerr; 116*030f984aSJacob Faibussowitsch 117*030f984aSJacob Faibussowitsch cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 118*030f984aSJacob Faibussowitsch dci->stream = nullptr; 119*030f984aSJacob Faibussowitsch } 120*030f984aSJacob Faibussowitsch // set these to null so they aren't usable until setup is called again 121*030f984aSJacob Faibussowitsch dci->blas = nullptr; 122*030f984aSJacob Faibussowitsch dci->solver = nullptr; 123*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 124*030f984aSJacob Faibussowitsch } 125*030f984aSJacob Faibussowitsch 126*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 127*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept 128*030f984aSJacob Faibussowitsch { 129*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 130*030f984aSJacob Faibussowitsch PetscErrorCode ierr; 131*030f984aSJacob Faibussowitsch cupmError_t cerr; 132*030f984aSJacob Faibussowitsch 133*030f984aSJacob Faibussowitsch PetscFunctionBegin; 134*030f984aSJacob Faibussowitsch if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 135*030f984aSJacob Faibussowitsch switch (dctx->streamType) { 136*030f984aSJacob Faibussowitsch case PETSC_STREAM_GLOBAL_BLOCKING: 137*030f984aSJacob Faibussowitsch // don't create a stream for global blocking 138*030f984aSJacob Faibussowitsch dci->stream = nullptr; 139*030f984aSJacob Faibussowitsch break; 140*030f984aSJacob Faibussowitsch case PETSC_STREAM_DEFAULT_BLOCKING: 141*030f984aSJacob Faibussowitsch cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); 142*030f984aSJacob Faibussowitsch break; 143*030f984aSJacob Faibussowitsch case PETSC_STREAM_GLOBAL_NONBLOCKING: 144*030f984aSJacob Faibussowitsch cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); 145*030f984aSJacob Faibussowitsch break; 146*030f984aSJacob Faibussowitsch default: 147*030f984aSJacob Faibussowitsch SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType); 148*030f984aSJacob Faibussowitsch break; 149*030f984aSJacob Faibussowitsch } 150*030f984aSJacob Faibussowitsch if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);} 151*030f984aSJacob Faibussowitsch ierr = __setupHandles(dci);CHKERRQ(ierr); 152*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 153*030f984aSJacob Faibussowitsch } 154*030f984aSJacob Faibussowitsch 155*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 156*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 157*030f984aSJacob Faibussowitsch { 158*030f984aSJacob Faibussowitsch cupmError_t cerr; 159*030f984aSJacob Faibussowitsch 160*030f984aSJacob Faibussowitsch PetscFunctionBegin; 161*030f984aSJacob Faibussowitsch cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream); 162*030f984aSJacob Faibussowitsch if (cerr == cupmSuccess) 163*030f984aSJacob Faibussowitsch *idle = PETSC_TRUE; 164*030f984aSJacob Faibussowitsch else if (cerr == cupmErrorNotReady) { 165*030f984aSJacob Faibussowitsch *idle = PETSC_FALSE; 166*030f984aSJacob Faibussowitsch } else { 167*030f984aSJacob Faibussowitsch // somethings gone wrong 168*030f984aSJacob Faibussowitsch CHKERRCUPM(cerr); 169*030f984aSJacob Faibussowitsch } 170*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 171*030f984aSJacob Faibussowitsch } 172*030f984aSJacob Faibussowitsch 173*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 174*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 175*030f984aSJacob Faibussowitsch { 176*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa); 177*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb); 178*030f984aSJacob Faibussowitsch cupmError_t cerr; 179*030f984aSJacob Faibussowitsch 180*030f984aSJacob Faibussowitsch PetscFunctionBegin; 181*030f984aSJacob Faibussowitsch cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); 182*030f984aSJacob Faibussowitsch cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr); 183*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 184*030f984aSJacob Faibussowitsch } 185*030f984aSJacob Faibussowitsch 186*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 187*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept 188*030f984aSJacob Faibussowitsch { 189*030f984aSJacob Faibussowitsch PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 190*030f984aSJacob Faibussowitsch cupmError_t cerr; 191*030f984aSJacob Faibussowitsch 192*030f984aSJacob Faibussowitsch PetscFunctionBegin; 193*030f984aSJacob Faibussowitsch // in case anything was queued on the event 194*030f984aSJacob Faibussowitsch cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); 195*030f984aSJacob Faibussowitsch cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); 196*030f984aSJacob Faibussowitsch PetscFunctionReturn(0); 197*030f984aSJacob Faibussowitsch } 198*030f984aSJacob Faibussowitsch 199*030f984aSJacob Faibussowitsch // initialize the static member variables 200*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 201*030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmBlasHandle_t CUPMContext<T>::_blashandle = nullptr; 202*030f984aSJacob Faibussowitsch 203*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> 204*030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr; 205*030f984aSJacob Faibussowitsch 206*030f984aSJacob Faibussowitsch // shorten this one up a bit 207*030f984aSJacob Faibussowitsch using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>; 208*030f984aSJacob Faibussowitsch using CUPMContextHip = CUPMContext<CUPMDeviceKind::HIP>; 209*030f984aSJacob Faibussowitsch 210*030f984aSJacob Faibussowitsch // make sure these doesn't leak out 211*030f984aSJacob Faibussowitsch #undef CHKERRCUPM 212*030f984aSJacob Faibussowitsch #undef IMPLS_RCAST_ 213*030f984aSJacob Faibussowitsch 214*030f984aSJacob Faibussowitsch } // namespace Petsc 215*030f984aSJacob Faibussowitsch 216*030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name 217*030f984aSJacob Faibussowitsch #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS 218*030f984aSJacob Faibussowitsch 219*030f984aSJacob Faibussowitsch // shorthand for casting dctx->data to the appropriate object to access the handles 220*030f984aSJacob Faibussowitsch #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data) 221*030f984aSJacob Faibussowitsch 222*030f984aSJacob Faibussowitsch #endif /* PETSCDEVICECONTEXTCUDA_HPP */ 223