1 #if !defined(PETSCDEVICECONTEXTCUPM_HPP) 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupminterface.hpp> 6 7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) 8 #error PetscDeviceContext backends for CUDA and HIP requires C++11 9 #endif 10 11 #include <array> 12 13 namespace Petsc 14 { 15 16 namespace detail 17 { 18 19 // for tag-based dispatch of handle retrieval 20 template <typename HT> struct HandleTag { }; 21 22 } // namespace detail 23 24 // Forward declare 25 template <CUPMDeviceType T> class CUPMContext; 26 27 template <CUPMDeviceType T> 28 class CUPMContext : CUPMInterface<T> 29 { 30 template <typename H> 31 using HandleTag = typename detail::HandleTag<H>; 32 33 public: 34 PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T) 35 36 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 37 // header, but since we are using the power of templates it must be declared part of 38 // this class to have easy access the same typedefs. Technically one can make a 39 // templated struct outside the class but it's more code for the same result. 40 struct PetscDeviceContext_IMPLS 41 { 42 cupmStream_t stream; 43 cupmEvent_t event; 44 cupmEvent_t begin; // timer-only 45 cupmEvent_t end; // timer-only 46 #if PetscDefined(USE_DEBUG) 47 PetscBool timerInUse; 48 #endif 49 cupmBlasHandle_t blas; 50 cupmSolverHandle_t solver; 51 52 PETSC_NODISCARD cupmBlasHandle_t handle(HandleTag<cupmBlasHandle_t>) { return blas; } 53 PETSC_NODISCARD cupmSolverHandle_t handle(HandleTag<cupmSolverHandle_t>) { return solver; } 54 }; 55 56 private: 57 static bool _initialized; 58 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> _blashandles; 59 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> _solverhandles; 60 61 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS* __impls_cast(PetscDeviceContext ptr) noexcept 62 { 63 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 64 } 65 66 PETSC_NODISCARD static PetscErrorCode __finalize() noexcept 67 { 68 PetscErrorCode ierr; 69 70 PetscFunctionBegin; 71 for (auto&& handle : _blashandles) { 72 if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);} 73 } 74 for (auto&& handle : _solverhandles) { 75 if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);} 76 } 77 _initialized = false; 78 PetscFunctionReturn(0); 79 } 80 81 PETSC_NODISCARD static PetscErrorCode __initialize(PetscInt id, PetscDeviceContext_IMPLS *dci) noexcept 82 { 83 PetscErrorCode ierr; 84 85 PetscFunctionBegin; 86 ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr); 87 if (!_initialized) { 88 _initialized = true; 89 ierr = PetscRegisterFinalize(__finalize);CHKERRQ(ierr); 90 } 91 // use the blashandle as a canary 92 if (!_blashandles[id]) { 93 ierr = cupmInterface_t::InitializeHandle(_blashandles[id]);CHKERRQ(ierr); 94 ierr = cupmInterface_t::InitializeHandle(_solverhandles[id]);CHKERRQ(ierr); 95 } 96 ierr = cupmInterface_t::SetHandleStream(_blashandles[id],dci->stream);CHKERRQ(ierr); 97 ierr = cupmInterface_t::SetHandleStream(_solverhandles[id],dci->stream);CHKERRQ(ierr); 98 dci->blas = _blashandles[id]; 99 dci->solver = _solverhandles[id]; 100 PetscFunctionReturn(0); 101 } 102 103 public: 104 const struct _DeviceContextOps ops = { 105 destroy, 106 changeStreamType, 107 setUp, 108 query, 109 waitForContext, 110 synchronize, 111 getHandle<cupmBlasHandle_t>, 112 getHandle<cupmSolverHandle_t>, 113 beginTimer, 114 endTimer 115 }; 116 117 // default constructor 118 constexpr CUPMContext() noexcept = default; 119 120 // All of these functions MUST be static in order to be callable from C, otherwise they 121 // get the implicit 'this' pointer tacked on 122 PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; 123 PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept; 124 PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; 125 PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept; 126 PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept; 127 PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 128 template <typename Handle_t> 129 PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext,void*) noexcept; 130 PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 131 PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*) noexcept; 132 }; 133 134 template <CUPMDeviceType T> 135 inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept 136 { 137 cupmError_t cerr; 138 PetscErrorCode ierr; 139 auto dci = __impls_cast(dctx); 140 141 PetscFunctionBegin; 142 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 143 if (dci->event) { 144 cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr); 145 cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr); 146 cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr); 147 } 148 ierr = PetscFree(dctx->data);CHKERRQ(ierr); 149 PetscFunctionReturn(0); 150 } 151 152 template <CUPMDeviceType T> 153 inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 154 { 155 auto dci = __impls_cast(dctx); 156 157 PetscFunctionBegin; 158 if (dci->stream) { 159 cupmError_t cerr; 160 161 cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 162 dci->stream = nullptr; 163 } 164 // set these to null so they aren't usable until setup is called again 165 dci->blas = nullptr; 166 dci->solver = nullptr; 167 PetscFunctionReturn(0); 168 } 169 170 template <CUPMDeviceType T> 171 inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept 172 { 173 PetscErrorCode ierr; 174 cupmError_t cerr; 175 auto dci = __impls_cast(dctx); 176 177 PetscFunctionBegin; 178 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 179 switch (dctx->streamType) { 180 case PETSC_STREAM_GLOBAL_BLOCKING: 181 // don't create a stream for global blocking 182 dci->stream = nullptr; 183 break; 184 case PETSC_STREAM_DEFAULT_BLOCKING: 185 cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); 186 break; 187 case PETSC_STREAM_GLOBAL_NONBLOCKING: 188 cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); 189 break; 190 default: 191 SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[static_cast<int>(dctx->streamType)]); 192 break; 193 } 194 if (!dci->event) { 195 cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr); 196 cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr); 197 cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr); 198 } 199 #if PetscDefined(USE_DEBUG) 200 dci->timerInUse = PETSC_FALSE; 201 #endif 202 ierr = __initialize(dctx->device->deviceId,dci);CHKERRQ(ierr); 203 PetscFunctionReturn(0); 204 } 205 206 template <CUPMDeviceType T> 207 inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 208 { 209 cupmError_t cerr; 210 211 PetscFunctionBegin; 212 cerr = cupmStreamQuery(__impls_cast(dctx)->stream); 213 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 214 else { 215 // somethings gone wrong 216 if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr); 217 *idle = PETSC_FALSE; 218 } 219 PetscFunctionReturn(0); 220 } 221 222 template <CUPMDeviceType T> 223 inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 224 { 225 cupmError_t cerr; 226 auto dcia = __impls_cast(dctxa),dcib = __impls_cast(dctxb); 227 228 PetscFunctionBegin; 229 cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); 230 cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr); 231 PetscFunctionReturn(0); 232 } 233 234 template <CUPMDeviceType T> 235 inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept 236 { 237 cupmError_t cerr; 238 auto dci = __impls_cast(dctx); 239 240 PetscFunctionBegin; 241 // in case anything was queued on the event 242 cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); 243 cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); 244 PetscFunctionReturn(0); 245 } 246 247 template <CUPMDeviceType T> 248 template <typename Handle_T> 249 inline PetscErrorCode CUPMContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 250 { 251 PetscFunctionBegin; 252 *static_cast<Handle_T*>(handle) = __impls_cast(dctx)->handle(HandleTag<Handle_T>()); 253 PetscFunctionReturn(0); 254 } 255 256 template <CUPMDeviceType T> 257 inline PetscErrorCode CUPMContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 258 { 259 auto dci = __impls_cast(dctx); 260 cupmError_t cerr; 261 262 PetscFunctionBegin; 263 #if PetscDefined(USE_DEBUG) 264 if (PetscUnlikely(dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 265 dci->timerInUse = PETSC_TRUE; 266 #endif 267 cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr); 268 PetscFunctionReturn(0); 269 } 270 271 template <CUPMDeviceType T> 272 inline PetscErrorCode CUPMContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 273 { 274 cupmError_t cerr; 275 float gtime; 276 auto dci = __impls_cast(dctx); 277 278 PetscFunctionBegin; 279 #if PetscDefined(USE_DEBUG) 280 if (PetscUnlikely(!dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 281 dci->timerInUse = PETSC_FALSE; 282 #endif 283 cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr); 284 cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr); 285 cerr = cupmEventElapsedTime(>ime,dci->begin,dci->end);CHKERRCUPM(cerr); 286 *elapsed = static_cast<PetscLogDouble>(gtime); 287 PetscFunctionReturn(0); 288 } 289 290 // initialize the static member variables 291 template <CUPMDeviceType T> bool CUPMContext<T>::_initialized = false; 292 293 template <CUPMDeviceType T> 294 std::array<typename CUPMContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> CUPMContext<T>::_blashandles = {}; 295 296 template <CUPMDeviceType T> 297 std::array<typename CUPMContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> CUPMContext<T>::_solverhandles = {}; 298 299 // shorten this one up a bit (and instantiate the templates) 300 using CUPMContextCuda = CUPMContext<CUPMDeviceType::CUDA>; 301 using CUPMContextHip = CUPMContext<CUPMDeviceType::HIP>; 302 303 } // namespace Petsc 304 305 // shorthand for what is an EXTREMELY long name 306 #define PetscDeviceContext_(IMPLS) Petsc::CUPMContext<Petsc::CUPMDeviceType::IMPLS>::PetscDeviceContext_IMPLS 307 308 // shorthand for casting dctx->data to the appropriate object to access the handles 309 #define PDC_IMPLS_STATIC_CAST(IMPLS,obj) static_cast<PetscDeviceContext_(IMPLS) *>((obj)->data) 310 311 #endif // PETSCDEVICECONTEXTCUDA_HPP 312