1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <array> 9 10 namespace Petsc 11 { 12 13 namespace Device 14 { 15 16 namespace CUPM 17 { 18 19 namespace Impl 20 { 21 22 // Forward declare 23 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext; 24 25 template <DeviceType T> 26 class DeviceContext : Impl::BlasInterface<T> 27 { 28 public: 29 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T); 30 31 private: 32 // for tag-based dispatch of handle retrieval 33 template <typename H,std::size_t> struct HandleTag { using type = H; }; 34 using stream_tag = HandleTag<cupmStream_t,0>; 35 using blas_tag = HandleTag<cupmBlasHandle_t,1>; 36 using solver_tag = HandleTag<cupmSolverHandle_t,2>; 37 38 public: 39 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 40 // header, but since we are using the power of templates it must be declared part of 41 // this class to have easy access the same typedefs. Technically one can make a 42 // templated struct outside the class but it's more code for the same result. 43 struct PetscDeviceContext_IMPLS 44 { 45 cupmStream_t stream; 46 cupmEvent_t event; 47 cupmEvent_t begin; // timer-only 48 cupmEvent_t end; // timer-only 49 #if PetscDefined(USE_DEBUG) 50 PetscBool timerInUse; 51 #endif 52 cupmBlasHandle_t blas; 53 cupmSolverHandle_t solver; 54 55 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } 56 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } 57 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } 58 }; 59 60 private: 61 static bool initialized_; 62 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_; 63 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_; 64 65 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr)) 66 { 67 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 68 } 69 70 PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMBLAS_HANDLE_CREATE()) 71 { 72 return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 73 } 74 75 PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMSOLVER_HANDLE_CREATE()) 76 { 77 return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 78 } 79 80 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 81 // handles 82 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag,PetscDeviceContext)) { return 0; } 83 84 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmBlasHandle_t &handle)) 85 { 86 PetscLogEvent event; 87 88 PetscFunctionBegin; 89 if (PetscLikely(handle)) PetscFunctionReturn(0); 90 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 91 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(),0,0,0,0)); 92 for (auto i = 0; i < 3; ++i) { 93 auto cberr = cupmBlasCreate(&handle); 94 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 95 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 96 if (i != 2) { 97 PetscCall(PetscSleep(3)); 98 continue; 99 } 100 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName()); 101 } 102 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(),0,0,0,0)); 103 PetscCall(PetscLogEventResume_Internal(event)); 104 PetscFunctionReturn(0); 105 } 106 107 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx)) 108 { 109 const auto dci = impls_cast_(dctx); 110 auto& handle = blashandles_[dctx->device->deviceId]; 111 112 PetscFunctionBegin; 113 PetscCall(create_handle_(handle)); 114 PetscCallCUPMBLAS(cupmBlasSetStream(handle,dci->stream)); 115 dci->blas = handle; 116 PetscFunctionReturn(0); 117 } 118 119 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) 120 { 121 PetscLogEvent event; 122 123 PetscFunctionBegin; 124 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 125 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(),0,0,0,0)); 126 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 127 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(),0,0,0,0)); 128 PetscCall(PetscLogEventResume_Internal(event)); 129 PetscFunctionReturn(0); 130 } 131 132 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx)) 133 { 134 const auto dci = impls_cast_(dctx); 135 auto& handle = solverhandles_[dctx->device->deviceId]; 136 137 PetscFunctionBegin; 138 PetscCall(create_handle_(handle)); 139 PetscCall(cupmBlasInterface_t::SetHandleStream(handle,dci->stream)); 140 dci->solver = handle; 141 PetscFunctionReturn(0); 142 } 143 144 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) 145 { 146 PetscFunctionBegin; 147 for (auto&& handle : blashandles_) { 148 if (handle) { 149 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 150 handle = nullptr; 151 } 152 } 153 for (auto&& handle : solverhandles_) { 154 if (handle) { 155 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 156 handle = nullptr; 157 } 158 } 159 initialized_ = false; 160 PetscFunctionReturn(0); 161 } 162 163 public: 164 const struct _DeviceContextOps ops = { 165 destroy, 166 changeStreamType, 167 setUp, 168 query, 169 waitForContext, 170 synchronize, 171 getHandle<blas_tag>, 172 getHandle<solver_tag>, 173 getHandle<stream_tag>, 174 beginTimer, 175 endTimer, 176 }; 177 178 // All of these functions MUST be static in order to be callable from C, otherwise they 179 // get the implicit 'this' pointer tacked on 180 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 181 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType)); 182 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 183 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*)); 184 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext)); 185 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 186 template <typename Handle_t> 187 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*)); 188 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 189 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*)); 190 191 // not a PetscDeviceContext method, this registers the class 192 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 193 }; 194 195 template <DeviceType T> 196 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) 197 { 198 PetscFunctionBegin; 199 if (PetscUnlikely(!initialized_)) { 200 initialized_ = true; 201 PetscCall(PetscRegisterFinalize(finalize_)); 202 } 203 PetscFunctionReturn(0); 204 } 205 206 template <DeviceType T> 207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 208 { 209 const auto dci = impls_cast_(dctx); 210 211 PetscFunctionBegin; 212 if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream)); 213 if (dci->event) PetscCallCUPM(cupmEventDestroy(dci->event)); 214 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 215 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 216 PetscCall(PetscFree(dctx->data)); 217 PetscFunctionReturn(0); 218 } 219 220 template <DeviceType T> 221 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 222 { 223 const auto dci = impls_cast_(dctx); 224 225 PetscFunctionBegin; 226 if (auto& stream = dci->stream) { 227 PetscCallCUPM(cupmStreamDestroy(stream)); 228 stream = nullptr; 229 } 230 // set these to null so they aren't usable until setup is called again 231 dci->blas = nullptr; 232 dci->solver = nullptr; 233 PetscFunctionReturn(0); 234 } 235 236 template <DeviceType T> 237 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 238 { 239 const auto dci = impls_cast_(dctx); 240 auto& stream = dci->stream; 241 242 PetscFunctionBegin; 243 if (stream) { 244 PetscCallCUPM(cupmStreamDestroy(stream)); 245 stream = nullptr; 246 } 247 switch (const auto stype = dctx->streamType) { 248 case PETSC_STREAM_GLOBAL_BLOCKING: 249 // don't create a stream for global blocking 250 break; 251 case PETSC_STREAM_DEFAULT_BLOCKING: 252 PetscCallCUPM(cupmStreamCreate(&stream)); 253 break; 254 case PETSC_STREAM_GLOBAL_NONBLOCKING: 255 PetscCallCUPM(cupmStreamCreateWithFlags(&stream,cupmStreamNonBlocking)); 256 break; 257 default: 258 SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(stype)]); 259 break; 260 } 261 if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event)); 262 #if PetscDefined(USE_DEBUG) 263 dci->timerInUse = PETSC_FALSE; 264 #endif 265 PetscFunctionReturn(0); 266 } 267 268 template <DeviceType T> 269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 270 { 271 cupmError_t cerr; 272 273 PetscFunctionBegin; 274 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 275 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 276 else { 277 // somethings gone wrong 278 if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr); 279 *idle = PETSC_FALSE; 280 } 281 PetscFunctionReturn(0); 282 } 283 284 template <DeviceType T> 285 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 286 { 287 auto dcib = impls_cast_(dctxb); 288 289 PetscFunctionBegin; 290 PetscCallCUPM(cupmEventRecord(dcib->event,dcib->stream)); 291 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0)); 292 PetscFunctionReturn(0); 293 } 294 295 template <DeviceType T> 296 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 297 { 298 auto dci = impls_cast_(dctx); 299 300 PetscFunctionBegin; 301 // in case anything was queued on the event 302 PetscCallCUPM(cupmStreamWaitEvent(dci->stream,dci->event,0)); 303 PetscCallCUPM(cupmStreamSynchronize(dci->stream)); 304 PetscFunctionReturn(0); 305 } 306 307 template <DeviceType T> 308 template <typename handle_t> 309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 310 { 311 PetscFunctionBegin; 312 PetscCall(initialize_handle_(handle_t{},dctx)); 313 *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t{}); 314 PetscFunctionReturn(0); 315 } 316 317 template <DeviceType T> 318 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 319 { 320 auto dci = impls_cast_(dctx); 321 322 PetscFunctionBegin; 323 #if PetscDefined(USE_DEBUG) 324 PetscCheck(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 325 dci->timerInUse = PETSC_TRUE; 326 #endif 327 if (!dci->begin) { 328 PetscCallCUPM(cupmEventCreate(&dci->begin)); 329 PetscCallCUPM(cupmEventCreate(&dci->end)); 330 } 331 PetscCallCUPM(cupmEventRecord(dci->begin,dci->stream)); 332 PetscFunctionReturn(0); 333 } 334 335 template <DeviceType T> 336 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 337 { 338 float gtime; 339 auto dci = impls_cast_(dctx); 340 341 PetscFunctionBegin; 342 #if PetscDefined(USE_DEBUG) 343 PetscCheck(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 344 dci->timerInUse = PETSC_FALSE; 345 #endif 346 PetscCallCUPM(cupmEventRecord(dci->end,dci->stream)); 347 PetscCallCUPM(cupmEventSynchronize(dci->end)); 348 PetscCallCUPM(cupmEventElapsedTime(>ime,dci->begin,dci->end)); 349 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 350 PetscFunctionReturn(0); 351 } 352 353 // initialize the static member variables 354 template <DeviceType T> bool DeviceContext<T>::initialized_ = false; 355 356 template <DeviceType T> 357 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 358 359 template <DeviceType T> 360 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 361 362 } // namespace Impl 363 364 // shorten this one up a bit (and instantiate the templates) 365 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 366 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 367 368 // shorthand for what is an EXTREMELY long name 369 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 370 371 } // namespace CUPM 372 373 } // namespace Device 374 375 } // namespace Petsc 376 377 #endif // PETSCDEVICECONTEXTCUDA_HPP 378