1 #if !defined(PETSCDEVICECONTEXTCUPM_HPP) 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> /*I "petscdevice.h" I*/ 5 #include <petsc/private/cupminterface.hpp> 6 7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) 8 #error PetscDeviceContext backends for CUDA and HIP requires C++11 9 #endif 10 11 namespace Petsc { 12 13 // Forward declare 14 template <CUPMDeviceKind T> class CUPMContext; 15 16 template <CUPMDeviceKind T> 17 class CUPMContext : CUPMInterface<T> 18 { 19 public: 20 PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T); 21 22 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 23 // header, but since we are using the power of templates it must be declared part of 24 // this class to have easy access the same typedefs. Technically one can make a 25 // templated struct outside the class but it's more code for the same result. 26 struct PetscDeviceContext_IMPLS 27 { 28 cupmStream_t stream; 29 cupmEvent_t event; 30 cupmBlasHandle_t blas; 31 cupmSolverHandle_t solver; 32 }; 33 34 private: 35 static cupmBlasHandle_t _blashandle; 36 static cupmSolverHandle_t _solverhandle; 37 38 PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept 39 { 40 PetscErrorCode ierr; 41 42 PetscFunctionBegin; 43 ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr); 44 PetscFunctionReturn(0); 45 } 46 47 PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept 48 { 49 PetscErrorCode ierr; 50 51 PetscFunctionBegin; 52 ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr); 53 PetscFunctionReturn(0); 54 } 55 56 PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept 57 { 58 PetscErrorCode ierr; 59 60 PetscFunctionBegin; 61 if (!_blashandle) { 62 ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr); 63 ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr); 64 } 65 if (!_solverhandle) { 66 ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr); 67 ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr); 68 } 69 ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr); 70 ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr); 71 dci->blas = _blashandle; 72 dci->solver = _solverhandle; 73 PetscFunctionReturn(0); 74 } 75 76 public: 77 const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize}; 78 79 // default constructor 80 constexpr CUPMContext() noexcept = default; 81 82 // All of these functions MUST be static in order to be callable from C, otherwise they 83 // get the implicit 'this' pointer tacked on 84 PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; 85 PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept; 86 PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; 87 PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept; 88 PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept; 89 PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 90 }; 91 92 #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data) 93 94 template <CUPMDeviceKind T> 95 inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept 96 { 97 PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 98 cupmError_t cerr; 99 PetscErrorCode ierr; 100 101 PetscFunctionBegin; 102 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 103 if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);} 104 ierr = PetscFree(dctx->data);CHKERRQ(ierr); 105 PetscFunctionReturn(0); 106 } 107 108 template <CUPMDeviceKind T> 109 inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept 110 { 111 PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 112 113 PetscFunctionBegin; 114 if (dci->stream) { 115 cupmError_t cerr; 116 117 cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 118 dci->stream = nullptr; 119 } 120 // set these to null so they aren't usable until setup is called again 121 dci->blas = nullptr; 122 dci->solver = nullptr; 123 PetscFunctionReturn(0); 124 } 125 126 template <CUPMDeviceKind T> 127 inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept 128 { 129 PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 130 PetscErrorCode ierr; 131 cupmError_t cerr; 132 133 PetscFunctionBegin; 134 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 135 switch (dctx->streamType) { 136 case PETSC_STREAM_GLOBAL_BLOCKING: 137 // don't create a stream for global blocking 138 dci->stream = nullptr; 139 break; 140 case PETSC_STREAM_DEFAULT_BLOCKING: 141 cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); 142 break; 143 case PETSC_STREAM_GLOBAL_NONBLOCKING: 144 cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); 145 break; 146 default: 147 SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType); 148 break; 149 } 150 if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);} 151 ierr = __setupHandles(dci);CHKERRQ(ierr); 152 PetscFunctionReturn(0); 153 } 154 155 template <CUPMDeviceKind T> 156 inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 157 { 158 cupmError_t cerr; 159 160 PetscFunctionBegin; 161 cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream); 162 if (cerr == cupmSuccess) 163 *idle = PETSC_TRUE; 164 else if (cerr == cupmErrorNotReady) { 165 *idle = PETSC_FALSE; 166 } else { 167 // somethings gone wrong 168 CHKERRCUPM(cerr); 169 } 170 PetscFunctionReturn(0); 171 } 172 173 template <CUPMDeviceKind T> 174 inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 175 { 176 PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa); 177 PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb); 178 cupmError_t cerr; 179 180 PetscFunctionBegin; 181 cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); 182 cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr); 183 PetscFunctionReturn(0); 184 } 185 186 template <CUPMDeviceKind T> 187 inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept 188 { 189 PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx); 190 cupmError_t cerr; 191 192 PetscFunctionBegin; 193 // in case anything was queued on the event 194 cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); 195 cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); 196 PetscFunctionReturn(0); 197 } 198 199 // initialize the static member variables 200 template <CUPMDeviceKind T> 201 typename CUPMContext<T>::cupmBlasHandle_t CUPMContext<T>::_blashandle = nullptr; 202 203 template <CUPMDeviceKind T> 204 typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr; 205 206 // shorten this one up a bit 207 using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>; 208 using CUPMContextHip = CUPMContext<CUPMDeviceKind::HIP>; 209 210 // make sure these doesn't leak out 211 #undef CHKERRCUPM 212 #undef IMPLS_RCAST_ 213 214 } // namespace Petsc 215 216 // shorthand for what is an EXTREMELY long name 217 #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS 218 219 // shorthand for casting dctx->data to the appropriate object to access the handles 220 #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data) 221 222 #endif /* PETSCDEVICECONTEXTCUDA_HPP */ 223