1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <array> 9 10 namespace Petsc { 11 12 namespace Device { 13 14 namespace CUPM { 15 16 namespace Impl { 17 18 // Forward declare 19 template <DeviceType T> 20 class PETSC_VISIBILITY_INTERNAL DeviceContext; 21 22 template <DeviceType T> 23 class DeviceContext : Impl::BlasInterface<T> { 24 public: 25 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 26 27 private: 28 // for tag-based dispatch of handle retrieval 29 template <typename H, std::size_t> 30 struct HandleTag { 31 using type = H; 32 }; 33 using stream_tag = HandleTag<cupmStream_t, 0>; 34 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 35 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 36 37 public: 38 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 39 // header, but since we are using the power of templates it must be declared part of 40 // this class to have easy access the same typedefs. Technically one can make a 41 // templated struct outside the class but it's more code for the same result. 42 struct PetscDeviceContext_IMPLS { 43 cupmStream_t stream; 44 cupmEvent_t event; 45 cupmEvent_t begin; // timer-only 46 cupmEvent_t end; // timer-only 47 #if PetscDefined(USE_DEBUG) 48 PetscBool timerInUse; 49 #endif 50 cupmBlasHandle_t blas; 51 cupmSolverHandle_t solver; 52 53 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { 54 return this->stream; 55 } 56 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { 57 return this->blas; 58 } 59 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { 60 return this->solver; 61 } 62 }; 63 64 private: 65 static bool initialized_; 66 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 67 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 68 69 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr)) { 70 return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); 71 } 72 73 PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMBLAS_HANDLE_CREATE()) { 74 return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 75 } 76 77 PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMSOLVER_HANDLE_CREATE()) { 78 return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 79 } 80 81 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 82 // handles 83 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { 84 return 0; 85 } 86 87 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmBlasHandle_t &handle)) { 88 PetscLogEvent event; 89 90 PetscFunctionBegin; 91 if (PetscLikely(handle)) PetscFunctionReturn(0); 92 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 93 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 94 for (auto i = 0; i < 3; ++i) { 95 auto cberr = cupmBlasCreate(&handle); 96 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 97 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 98 if (i != 2) { 99 PetscCall(PetscSleep(3)); 100 continue; 101 } 102 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 103 } 104 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 105 PetscCall(PetscLogEventResume_Internal(event)); 106 PetscFunctionReturn(0); 107 } 108 109 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx)) { 110 const auto dci = impls_cast_(dctx); 111 auto &handle = blashandles_[dctx->device->deviceId]; 112 113 PetscFunctionBegin; 114 PetscCall(create_handle_(handle)); 115 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream)); 116 dci->blas = handle; 117 PetscFunctionReturn(0); 118 } 119 120 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { 121 PetscLogEvent event; 122 123 PetscFunctionBegin; 124 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 125 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 126 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 127 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 128 PetscCall(PetscLogEventResume_Internal(event)); 129 PetscFunctionReturn(0); 130 } 131 132 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx)) { 133 const auto dci = impls_cast_(dctx); 134 auto &handle = solverhandles_[dctx->device->deviceId]; 135 136 PetscFunctionBegin; 137 PetscCall(create_handle_(handle)); 138 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream)); 139 dci->solver = handle; 140 PetscFunctionReturn(0); 141 } 142 143 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) { 144 PetscFunctionBegin; 145 for (auto &&handle : blashandles_) { 146 if (handle) { 147 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 148 handle = nullptr; 149 } 150 } 151 for (auto &&handle : solverhandles_) { 152 if (handle) { 153 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 154 handle = nullptr; 155 } 156 } 157 initialized_ = false; 158 PetscFunctionReturn(0); 159 } 160 161 public: 162 const struct _DeviceContextOps ops = { 163 destroy, changeStreamType, setUp, query, waitForContext, synchronize, getHandle<blas_tag>, getHandle<solver_tag>, getHandle<stream_tag>, beginTimer, endTimer, 164 }; 165 166 // All of these functions MUST be static in order to be callable from C, otherwise they 167 // get the implicit 'this' pointer tacked on 168 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 169 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 170 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 171 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 172 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 173 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 174 template <typename Handle_t> 175 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 176 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 177 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 178 179 // not a PetscDeviceContext method, this registers the class 180 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 181 }; 182 183 template <DeviceType T> 184 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) { 185 PetscFunctionBegin; 186 if (PetscUnlikely(!initialized_)) { 187 initialized_ = true; 188 PetscCall(PetscRegisterFinalize(finalize_)); 189 } 190 PetscFunctionReturn(0); 191 } 192 193 template <DeviceType T> 194 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) { 195 const auto dci = impls_cast_(dctx); 196 197 PetscFunctionBegin; 198 if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream)); 199 if (dci->event) PetscCallCUPM(cupmEventDestroy(dci->event)); 200 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 201 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 202 PetscCall(PetscFree(dctx->data)); 203 PetscFunctionReturn(0); 204 } 205 206 template <DeviceType T> 207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { 208 const auto dci = impls_cast_(dctx); 209 210 PetscFunctionBegin; 211 if (auto &stream = dci->stream) { 212 PetscCallCUPM(cupmStreamDestroy(stream)); 213 stream = nullptr; 214 } 215 // set these to null so they aren't usable until setup is called again 216 dci->blas = nullptr; 217 dci->solver = nullptr; 218 PetscFunctionReturn(0); 219 } 220 221 template <DeviceType T> 222 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) { 223 const auto dci = impls_cast_(dctx); 224 auto &stream = dci->stream; 225 226 PetscFunctionBegin; 227 if (stream) { 228 PetscCallCUPM(cupmStreamDestroy(stream)); 229 stream = nullptr; 230 } 231 switch (const auto stype = dctx->streamType) { 232 case PETSC_STREAM_GLOBAL_BLOCKING: 233 // don't create a stream for global blocking 234 break; 235 case PETSC_STREAM_DEFAULT_BLOCKING: PetscCallCUPM(cupmStreamCreate(&stream)); break; 236 case PETSC_STREAM_GLOBAL_NONBLOCKING: PetscCallCUPM(cupmStreamCreateWithFlags(&stream, cupmStreamNonBlocking)); break; 237 default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_CORRUPT, "Invalid PetscStreamType %s", PetscStreamTypes[util::integral_value(stype)]); break; 238 } 239 if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event)); 240 #if PetscDefined(USE_DEBUG) 241 dci->timerInUse = PETSC_FALSE; 242 #endif 243 PetscFunctionReturn(0); 244 } 245 246 template <DeviceType T> 247 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) { 248 cupmError_t cerr; 249 250 PetscFunctionBegin; 251 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 252 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 253 else { 254 // somethings gone wrong 255 if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr); 256 *idle = PETSC_FALSE; 257 } 258 PetscFunctionReturn(0); 259 } 260 261 template <DeviceType T> 262 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { 263 auto dcib = impls_cast_(dctxb); 264 265 PetscFunctionBegin; 266 PetscCallCUPM(cupmEventRecord(dcib->event, dcib->stream)); 267 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream, dcib->event, 0)); 268 PetscFunctionReturn(0); 269 } 270 271 template <DeviceType T> 272 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) { 273 auto dci = impls_cast_(dctx); 274 275 PetscFunctionBegin; 276 // in case anything was queued on the event 277 PetscCallCUPM(cupmStreamWaitEvent(dci->stream, dci->event, 0)); 278 PetscCallCUPM(cupmStreamSynchronize(dci->stream)); 279 PetscFunctionReturn(0); 280 } 281 282 template <DeviceType T> 283 template <typename handle_t> 284 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) { 285 PetscFunctionBegin; 286 PetscCall(initialize_handle_(handle_t{}, dctx)); 287 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 288 PetscFunctionReturn(0); 289 } 290 291 template <DeviceType T> 292 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) { 293 auto dci = impls_cast_(dctx); 294 295 PetscFunctionBegin; 296 #if PetscDefined(USE_DEBUG) 297 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 298 dci->timerInUse = PETSC_TRUE; 299 #endif 300 if (!dci->begin) { 301 PetscCallCUPM(cupmEventCreate(&dci->begin)); 302 PetscCallCUPM(cupmEventCreate(&dci->end)); 303 } 304 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream)); 305 PetscFunctionReturn(0); 306 } 307 308 template <DeviceType T> 309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { 310 float gtime; 311 auto dci = impls_cast_(dctx); 312 313 PetscFunctionBegin; 314 #if PetscDefined(USE_DEBUG) 315 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 316 dci->timerInUse = PETSC_FALSE; 317 #endif 318 PetscCallCUPM(cupmEventRecord(dci->end, dci->stream)); 319 PetscCallCUPM(cupmEventSynchronize(dci->end)); 320 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, dci->end)); 321 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 322 PetscFunctionReturn(0); 323 } 324 325 // initialize the static member variables 326 template <DeviceType T> 327 bool DeviceContext<T>::initialized_ = false; 328 329 template <DeviceType T> 330 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 331 332 template <DeviceType T> 333 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 334 335 } // namespace Impl 336 337 // shorten this one up a bit (and instantiate the templates) 338 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 339 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 340 341 // shorthand for what is an EXTREMELY long name 342 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 343 344 } // namespace CUPM 345 346 } // namespace Device 347 348 } // namespace Petsc 349 350 #endif // PETSCDEVICECONTEXTCUDA_HPP 351