1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 17 namespace Petsc 18 { 19 20 namespace device 21 { 22 23 namespace cupm 24 { 25 26 namespace impl 27 { 28 29 template <DeviceType T> 30 class DeviceContext : BlasInterface<T> { 31 public: 32 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 33 34 private: 35 template <typename H, std::size_t> 36 struct HandleTag { 37 using type = H; 38 }; 39 40 using stream_tag = HandleTag<cupmStream_t, 0>; 41 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 42 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 43 44 using stream_type = CUPMStream<T>; 45 using event_type = CUPMEvent<T>; 46 47 public: 48 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 49 // header, but since we are using the power of templates it must be declared part of 50 // this class to have easy access the same typedefs. Technically one can make a 51 // templated struct outside the class but it's more code for the same result. 52 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 53 stream_type stream{}; 54 cupmEvent_t event{}; 55 cupmEvent_t begin{}; // timer-only 56 cupmEvent_t end{}; // timer-only 57 #if PetscDefined(USE_DEBUG) 58 PetscBool timerInUse{}; 59 #endif 60 cupmBlasHandle_t blas{}; 61 cupmSolverHandle_t solver{}; 62 63 constexpr PetscDeviceContext_IMPLS() noexcept = default; 64 65 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); } 66 67 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; } 68 69 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; } 70 }; 71 72 private: 73 static bool initialized_; 74 75 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 76 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77 78 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 79 80 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 81 82 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 83 84 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 85 86 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 87 // handles 88 PETSC_NODISCARD static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return 0; } 89 90 PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept 91 { 92 PetscLogEvent event; 93 94 PetscFunctionBegin; 95 if (PetscLikely(handle)) PetscFunctionReturn(0); 96 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 97 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 98 for (auto i = 0; i < 3; ++i) { 99 auto cberr = cupmBlasCreate(&handle); 100 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 101 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 102 if (i != 2) { 103 PetscCall(PetscSleep(3)); 104 continue; 105 } 106 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 107 } 108 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 109 PetscCall(PetscLogEventResume_Internal(event)); 110 PetscFunctionReturn(0); 111 } 112 113 PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept 114 { 115 const auto dci = impls_cast_(dctx); 116 auto &handle = blashandles_[dctx->device->deviceId]; 117 118 PetscFunctionBegin; 119 PetscCall(create_handle_(tag, handle)); 120 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 121 dci->blas = handle; 122 PetscFunctionReturn(0); 123 } 124 125 PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 126 { 127 const auto dci = impls_cast_(dctx); 128 auto &handle = solverhandles_[dctx->device->deviceId]; 129 PetscLogEvent event; 130 131 PetscFunctionBegin; 132 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 133 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 134 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 135 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 136 PetscCall(PetscLogEventResume_Internal(event)); 137 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 138 dci->solver = handle; 139 PetscFunctionReturn(0); 140 } 141 142 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 143 { 144 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 145 146 PetscFunctionBegin; 147 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 148 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 149 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 150 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 151 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 152 PetscFunctionReturn(0); 153 } 154 155 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 156 157 PETSC_NODISCARD static PetscErrorCode finalize_() noexcept 158 { 159 PetscFunctionBegin; 160 for (auto &&handle : blashandles_) { 161 if (handle) { 162 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 163 handle = nullptr; 164 } 165 } 166 167 for (auto &&handle : solverhandles_) { 168 if (handle) { 169 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 170 handle = nullptr; 171 } 172 } 173 initialized_ = false; 174 PetscFunctionReturn(0); 175 } 176 177 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 178 PETSC_NODISCARD static PoolType &default_pool_() noexcept 179 { 180 static PoolType pool; 181 return pool; 182 } 183 184 PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 185 { 186 PetscFunctionBegin; 187 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 188 PetscFunctionReturn(0); 189 } 190 191 public: 192 // All of these functions MUST be static in order to be callable from C, otherwise they 193 // get the implicit 'this' pointer tacked on 194 PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept; 195 PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 196 PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept; 197 PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 198 PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 199 PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 200 template <typename Handle_t> 201 PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 202 PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 203 PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 204 PETSC_NODISCARD static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 205 PETSC_NODISCARD static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 206 PETSC_NODISCARD static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 207 PETSC_NODISCARD static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 208 PETSC_NODISCARD static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 209 PETSC_NODISCARD static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 210 PETSC_NODISCARD static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 211 212 // not a PetscDeviceContext method, this registers the class 213 PETSC_NODISCARD static PetscErrorCode initialize(PetscDevice) noexcept; 214 215 // clang-format off 216 const _DeviceContextOps ops = { 217 destroy, 218 changeStreamType, 219 setUp, 220 query, 221 waitForContext, 222 synchronize, 223 getHandle<blas_tag>, 224 getHandle<solver_tag>, 225 getHandle<stream_tag>, 226 beginTimer, 227 endTimer, 228 memAlloc, 229 memFree, 230 memCopy, 231 memSet, 232 createEvent, 233 recordEvent, 234 waitForEvent 235 }; 236 // clang-format on 237 }; 238 239 // not a PetscDeviceContext method, this initializes the CLASS 240 template <DeviceType T> 241 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 242 { 243 PetscFunctionBegin; 244 if (PetscUnlikely(!initialized_)) { 245 uint64_t threshold = UINT64_MAX; 246 cupmMemPool_t mempool; 247 248 initialized_ = true; 249 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 250 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 251 blashandles_.fill(nullptr); 252 solverhandles_.fill(nullptr); 253 PetscCall(PetscRegisterFinalize(finalize_)); 254 } 255 PetscFunctionReturn(0); 256 } 257 258 template <DeviceType T> 259 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 260 { 261 PetscFunctionBegin; 262 if (const auto dci = impls_cast_(dctx)) { 263 PetscCall(dci->stream.destroy()); 264 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 265 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 266 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 267 delete dci; 268 dctx->data = nullptr; 269 } 270 PetscFunctionReturn(0); 271 } 272 273 template <DeviceType T> 274 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 275 { 276 const auto dci = impls_cast_(dctx); 277 278 PetscFunctionBegin; 279 PetscCall(dci->stream.destroy()); 280 // set these to null so they aren't usable until setup is called again 281 dci->blas = nullptr; 282 dci->solver = nullptr; 283 PetscFunctionReturn(0); 284 } 285 286 template <DeviceType T> 287 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 288 { 289 const auto dci = impls_cast_(dctx); 290 auto &event = dci->event; 291 292 PetscFunctionBegin; 293 PetscCall(check_current_device_(dctx)); 294 PetscCall(dci->stream.change_type(dctx->streamType)); 295 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 296 #if PetscDefined(USE_DEBUG) 297 dci->timerInUse = PETSC_FALSE; 298 #endif 299 PetscFunctionReturn(0); 300 } 301 302 template <DeviceType T> 303 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 304 { 305 PetscFunctionBegin; 306 PetscCall(check_current_device_(dctx)); 307 switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 308 case cupmSuccess: 309 *idle = PETSC_TRUE; 310 break; 311 case cupmErrorNotReady: 312 *idle = PETSC_FALSE; 313 break; 314 default: 315 PetscCallCUPM(cerr); 316 PetscUnreachable(); 317 } 318 PetscFunctionReturn(0); 319 } 320 321 template <DeviceType T> 322 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 323 { 324 const auto dcib = impls_cast_(dctxb); 325 const auto event = dcib->event; 326 327 PetscFunctionBegin; 328 PetscCall(check_current_device_(dctxa, dctxb)); 329 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 330 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 331 PetscFunctionReturn(0); 332 } 333 334 template <DeviceType T> 335 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 336 { 337 auto idle = PETSC_TRUE; 338 339 PetscFunctionBegin; 340 PetscCall(query(dctx, &idle)); 341 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 342 PetscFunctionReturn(0); 343 } 344 345 template <DeviceType T> 346 template <typename handle_t> 347 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 348 { 349 PetscFunctionBegin; 350 PetscCall(initialize_handle_(handle_t{}, dctx)); 351 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 352 PetscFunctionReturn(0); 353 } 354 355 template <DeviceType T> 356 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 357 { 358 const auto dci = impls_cast_(dctx); 359 360 PetscFunctionBegin; 361 PetscCall(check_current_device_(dctx)); 362 #if PetscDefined(USE_DEBUG) 363 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 364 dci->timerInUse = PETSC_TRUE; 365 #endif 366 if (!dci->begin) { 367 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 368 PetscCallCUPM(cupmEventCreate(&dci->begin)); 369 PetscCallCUPM(cupmEventCreate(&dci->end)); 370 } 371 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 372 PetscFunctionReturn(0); 373 } 374 375 template <DeviceType T> 376 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 377 { 378 float gtime; 379 const auto dci = impls_cast_(dctx); 380 const auto end = dci->end; 381 382 PetscFunctionBegin; 383 PetscCall(check_current_device_(dctx)); 384 #if PetscDefined(USE_DEBUG) 385 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 386 dci->timerInUse = PETSC_FALSE; 387 #endif 388 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 389 PetscCallCUPM(cupmEventSynchronize(end)); 390 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 391 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 392 PetscFunctionReturn(0); 393 } 394 395 template <DeviceType T> 396 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 397 { 398 const auto &stream = impls_cast_(dctx)->stream; 399 400 PetscFunctionBegin; 401 PetscCall(check_current_device_(dctx)); 402 PetscCall(check_memtype_(mtype, "allocating")); 403 if (PetscMemTypeHost(mtype)) { 404 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 405 } else { 406 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 407 } 408 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 409 PetscFunctionReturn(0); 410 } 411 412 template <DeviceType T> 413 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 414 { 415 const auto &stream = impls_cast_(dctx)->stream; 416 417 PetscFunctionBegin; 418 PetscCall(check_current_device_(dctx)); 419 PetscCall(check_memtype_(mtype, "freeing")); 420 if (!*ptr) PetscFunctionReturn(0); 421 if (PetscMemTypeHost(mtype)) { 422 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 423 // if ptr exists still exists the pool didn't own it 424 if (*ptr) { 425 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 426 427 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 428 if (registered) { 429 PetscCallCUPM(cupmFreeHost(*ptr)); 430 } else if (managed) { 431 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 432 } 433 } 434 } else { 435 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 436 // if ptr still exists the pool didn't own it 437 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 438 } 439 PetscFunctionReturn(0); 440 } 441 442 template <DeviceType T> 443 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 444 { 445 const auto stream = impls_cast_(dctx)->stream.get_stream(); 446 447 PetscFunctionBegin; 448 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 449 if (mode == PETSC_DEVICE_COPY_HTOH) { 450 const auto cerr = cupmStreamQuery(stream); 451 452 // yes this is faster 453 if (cerr == cupmSuccess) { 454 PetscCall(PetscMemcpy(dest, src, n)); 455 PetscFunctionReturn(0); 456 } else if (cerr == cupmErrorNotReady) { 457 auto PETSC_UNUSED unused = cupmGetLastError(); 458 459 static_cast<void>(unused); 460 } else { 461 PetscCallCUPM(cerr); 462 } 463 } 464 PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 465 PetscFunctionReturn(0); 466 } 467 468 template <DeviceType T> 469 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 470 { 471 PetscFunctionBegin; 472 PetscCall(check_current_device_(dctx)); 473 PetscCall(check_memtype_(mtype, "zeroing")); 474 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 475 PetscFunctionReturn(0); 476 } 477 478 template <DeviceType T> 479 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 480 { 481 PetscFunctionBegin; 482 PetscCallCXX(event->data = new event_type()); 483 event->destroy = [](PetscEvent event) { 484 PetscFunctionBegin; 485 delete event_cast_(event); 486 event->data = nullptr; 487 PetscFunctionReturn(0); 488 }; 489 PetscFunctionReturn(0); 490 } 491 492 template <DeviceType T> 493 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 494 { 495 PetscFunctionBegin; 496 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 497 PetscFunctionReturn(0); 498 } 499 500 template <DeviceType T> 501 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 502 { 503 PetscFunctionBegin; 504 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 505 PetscFunctionReturn(0); 506 } 507 508 // initialize the static member variables 509 template <DeviceType T> 510 bool DeviceContext<T>::initialized_ = false; 511 512 template <DeviceType T> 513 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 514 515 template <DeviceType T> 516 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 517 518 } // namespace impl 519 520 // shorten this one up a bit (and instantiate the templates) 521 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 522 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 523 524 // shorthand for what is an EXTREMELY long name 525 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 526 527 } // namespace cupm 528 529 } // namespace device 530 531 } // namespace Petsc 532 533 #endif // __cplusplus 534 535 #endif // PETSCDEVICECONTEXTCUDA_HPP 536