1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 17 namespace Petsc 18 { 19 20 namespace device 21 { 22 23 namespace cupm 24 { 25 26 namespace impl 27 { 28 29 template <DeviceType T> 30 class DeviceContext : BlasInterface<T> { 31 public: 32 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 33 34 private: 35 template <typename H, std::size_t> 36 struct HandleTag { 37 using type = H; 38 }; 39 40 using stream_tag = HandleTag<cupmStream_t, 0>; 41 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 42 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 43 44 using stream_type = CUPMStream<T>; 45 using event_type = CUPMEvent<T>; 46 47 public: 48 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 49 // header, but since we are using the power of templates it must be declared part of 50 // this class to have easy access the same typedefs. Technically one can make a 51 // templated struct outside the class but it's more code for the same result. 52 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 53 stream_type stream{}; 54 cupmEvent_t event{}; 55 cupmEvent_t begin{}; // timer-only 56 cupmEvent_t end{}; // timer-only 57 #if PetscDefined(USE_DEBUG) 58 PetscBool timerInUse{}; 59 #endif 60 cupmBlasHandle_t blas{}; 61 cupmSolverHandle_t solver{}; 62 63 constexpr PetscDeviceContext_IMPLS() noexcept = default; 64 65 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); } 66 67 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; } 68 69 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; } 70 }; 71 72 private: 73 static bool initialized_; 74 75 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 76 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77 78 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 79 80 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 81 82 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 83 84 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 85 86 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 87 // handles 88 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 89 90 static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept 91 { 92 PetscLogEvent event; 93 94 PetscFunctionBegin; 95 if (PetscLikely(handle)) PetscFunctionReturn(PETSC_SUCCESS); 96 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 97 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 98 for (auto i = 0; i < 3; ++i) { 99 auto cberr = cupmBlasCreate(&handle); 100 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 101 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 102 if (i != 2) { 103 PetscCall(PetscSleep(3)); 104 continue; 105 } 106 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 107 } 108 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 109 PetscCall(PetscLogEventResume_Internal(event)); 110 PetscFunctionReturn(PETSC_SUCCESS); 111 } 112 113 static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept 114 { 115 const auto dci = impls_cast_(dctx); 116 auto &handle = blashandles_[dctx->device->deviceId]; 117 118 PetscFunctionBegin; 119 PetscCall(create_handle_(tag, handle)); 120 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 121 dci->blas = handle; 122 PetscFunctionReturn(PETSC_SUCCESS); 123 } 124 125 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 126 { 127 const auto dci = impls_cast_(dctx); 128 auto &handle = solverhandles_[dctx->device->deviceId]; 129 PetscLogEvent event; 130 131 PetscFunctionBegin; 132 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 133 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 134 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 135 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 136 PetscCall(PetscLogEventResume_Internal(event)); 137 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 138 dci->solver = handle; 139 PetscFunctionReturn(PETSC_SUCCESS); 140 } 141 142 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 143 { 144 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 145 146 PetscFunctionBegin; 147 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 148 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 149 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 150 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 151 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 152 PetscFunctionReturn(PETSC_SUCCESS); 153 } 154 155 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 156 157 static PetscErrorCode finalize_() noexcept 158 { 159 PetscFunctionBegin; 160 for (auto &&handle : blashandles_) { 161 if (handle) { 162 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 163 handle = nullptr; 164 } 165 } 166 167 for (auto &&handle : solverhandles_) { 168 if (handle) { 169 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 170 handle = nullptr; 171 } 172 } 173 initialized_ = false; 174 PetscFunctionReturn(PETSC_SUCCESS); 175 } 176 177 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 178 PETSC_NODISCARD static PoolType &default_pool_() noexcept 179 { 180 static PoolType pool; 181 return pool; 182 } 183 184 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 185 { 186 PetscFunctionBegin; 187 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 188 PetscFunctionReturn(PETSC_SUCCESS); 189 } 190 191 public: 192 // All of these functions MUST be static in order to be callable from C, otherwise they 193 // get the implicit 'this' pointer tacked on 194 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 195 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 196 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 197 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 198 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 199 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 200 template <typename Handle_t> 201 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 202 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 203 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 204 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 205 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 206 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 207 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 208 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 209 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 210 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 211 212 // not a PetscDeviceContext method, this registers the class 213 static PetscErrorCode initialize(PetscDevice) noexcept; 214 215 // clang-format off 216 const _DeviceContextOps ops = { 217 destroy, 218 changeStreamType, 219 setUp, 220 query, 221 waitForContext, 222 synchronize, 223 getHandle<blas_tag>, 224 getHandle<solver_tag>, 225 getHandle<stream_tag>, 226 beginTimer, 227 endTimer, 228 memAlloc, 229 memFree, 230 memCopy, 231 memSet, 232 createEvent, 233 recordEvent, 234 waitForEvent 235 }; 236 // clang-format on 237 }; 238 239 // not a PetscDeviceContext method, this initializes the CLASS 240 template <DeviceType T> 241 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 242 { 243 PetscFunctionBegin; 244 if (PetscUnlikely(!initialized_)) { 245 uint64_t threshold = UINT64_MAX; 246 cupmMemPool_t mempool; 247 248 initialized_ = true; 249 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 250 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 251 blashandles_.fill(nullptr); 252 solverhandles_.fill(nullptr); 253 PetscCall(PetscRegisterFinalize(finalize_)); 254 } 255 PetscFunctionReturn(PETSC_SUCCESS); 256 } 257 258 template <DeviceType T> 259 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 260 { 261 PetscFunctionBegin; 262 if (const auto dci = impls_cast_(dctx)) { 263 PetscCall(dci->stream.destroy()); 264 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 265 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 266 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 267 delete dci; 268 dctx->data = nullptr; 269 } 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 template <DeviceType T> 274 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 275 { 276 const auto dci = impls_cast_(dctx); 277 278 PetscFunctionBegin; 279 PetscCall(dci->stream.destroy()); 280 // set these to null so they aren't usable until setup is called again 281 dci->blas = nullptr; 282 dci->solver = nullptr; 283 PetscFunctionReturn(PETSC_SUCCESS); 284 } 285 286 template <DeviceType T> 287 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 288 { 289 const auto dci = impls_cast_(dctx); 290 auto &event = dci->event; 291 292 PetscFunctionBegin; 293 PetscCall(check_current_device_(dctx)); 294 PetscCall(dci->stream.change_type(dctx->streamType)); 295 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 296 #if PetscDefined(USE_DEBUG) 297 dci->timerInUse = PETSC_FALSE; 298 #endif 299 PetscFunctionReturn(PETSC_SUCCESS); 300 } 301 302 template <DeviceType T> 303 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 304 { 305 PetscFunctionBegin; 306 PetscCall(check_current_device_(dctx)); 307 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 308 case cupmSuccess: 309 *idle = PETSC_TRUE; 310 break; 311 case cupmErrorNotReady: 312 *idle = PETSC_FALSE; 313 // reset the error 314 cerr = cupmGetLastError(); 315 static_cast<void>(cerr); 316 break; 317 default: 318 PetscCallCUPM(cerr); 319 PetscUnreachable(); 320 } 321 PetscFunctionReturn(PETSC_SUCCESS); 322 } 323 324 template <DeviceType T> 325 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 326 { 327 const auto dcib = impls_cast_(dctxb); 328 const auto event = dcib->event; 329 330 PetscFunctionBegin; 331 PetscCall(check_current_device_(dctxa, dctxb)); 332 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 333 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 334 PetscFunctionReturn(PETSC_SUCCESS); 335 } 336 337 template <DeviceType T> 338 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 339 { 340 auto idle = PETSC_TRUE; 341 342 PetscFunctionBegin; 343 PetscCall(query(dctx, &idle)); 344 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 345 PetscFunctionReturn(PETSC_SUCCESS); 346 } 347 348 template <DeviceType T> 349 template <typename handle_t> 350 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 351 { 352 PetscFunctionBegin; 353 PetscCall(initialize_handle_(handle_t{}, dctx)); 354 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 355 PetscFunctionReturn(PETSC_SUCCESS); 356 } 357 358 template <DeviceType T> 359 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 360 { 361 const auto dci = impls_cast_(dctx); 362 363 PetscFunctionBegin; 364 PetscCall(check_current_device_(dctx)); 365 #if PetscDefined(USE_DEBUG) 366 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 367 dci->timerInUse = PETSC_TRUE; 368 #endif 369 if (!dci->begin) { 370 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 371 PetscCallCUPM(cupmEventCreate(&dci->begin)); 372 PetscCallCUPM(cupmEventCreate(&dci->end)); 373 } 374 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 375 PetscFunctionReturn(PETSC_SUCCESS); 376 } 377 378 template <DeviceType T> 379 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 380 { 381 float gtime; 382 const auto dci = impls_cast_(dctx); 383 const auto end = dci->end; 384 385 PetscFunctionBegin; 386 PetscCall(check_current_device_(dctx)); 387 #if PetscDefined(USE_DEBUG) 388 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 389 dci->timerInUse = PETSC_FALSE; 390 #endif 391 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 392 PetscCallCUPM(cupmEventSynchronize(end)); 393 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 394 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 395 PetscFunctionReturn(PETSC_SUCCESS); 396 } 397 398 template <DeviceType T> 399 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 400 { 401 const auto &stream = impls_cast_(dctx)->stream; 402 403 PetscFunctionBegin; 404 PetscCall(check_current_device_(dctx)); 405 PetscCall(check_memtype_(mtype, "allocating")); 406 if (PetscMemTypeHost(mtype)) { 407 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 408 } else { 409 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 410 } 411 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 412 PetscFunctionReturn(PETSC_SUCCESS); 413 } 414 415 template <DeviceType T> 416 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 417 { 418 const auto &stream = impls_cast_(dctx)->stream; 419 420 PetscFunctionBegin; 421 PetscCall(check_current_device_(dctx)); 422 PetscCall(check_memtype_(mtype, "freeing")); 423 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 424 if (PetscMemTypeHost(mtype)) { 425 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 426 // if ptr exists still exists the pool didn't own it 427 if (*ptr) { 428 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 429 430 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 431 if (registered) { 432 PetscCallCUPM(cupmFreeHost(*ptr)); 433 } else if (managed) { 434 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 435 } 436 } 437 } else { 438 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 439 // if ptr still exists the pool didn't own it 440 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 441 } 442 PetscFunctionReturn(PETSC_SUCCESS); 443 } 444 445 template <DeviceType T> 446 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 447 { 448 const auto stream = impls_cast_(dctx)->stream.get_stream(); 449 450 PetscFunctionBegin; 451 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 452 if (mode == PETSC_DEVICE_COPY_HTOH) { 453 const auto cerr = cupmStreamQuery(stream); 454 455 // yes this is faster 456 if (cerr == cupmSuccess) { 457 PetscCall(PetscMemcpy(dest, src, n)); 458 PetscFunctionReturn(PETSC_SUCCESS); 459 } else if (cerr == cupmErrorNotReady) { 460 auto PETSC_UNUSED unused = cupmGetLastError(); 461 462 static_cast<void>(unused); 463 } else { 464 PetscCallCUPM(cerr); 465 } 466 } 467 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 468 PetscFunctionReturn(PETSC_SUCCESS); 469 } 470 471 template <DeviceType T> 472 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 473 { 474 PetscFunctionBegin; 475 PetscCall(check_current_device_(dctx)); 476 PetscCall(check_memtype_(mtype, "zeroing")); 477 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 478 PetscFunctionReturn(PETSC_SUCCESS); 479 } 480 481 template <DeviceType T> 482 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 483 { 484 PetscFunctionBegin; 485 PetscCallCXX(event->data = new event_type()); 486 event->destroy = [](PetscEvent event) { 487 PetscFunctionBegin; 488 delete event_cast_(event); 489 event->data = nullptr; 490 PetscFunctionReturn(PETSC_SUCCESS); 491 }; 492 PetscFunctionReturn(PETSC_SUCCESS); 493 } 494 495 template <DeviceType T> 496 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 497 { 498 PetscFunctionBegin; 499 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 500 PetscFunctionReturn(PETSC_SUCCESS); 501 } 502 503 template <DeviceType T> 504 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 505 { 506 PetscFunctionBegin; 507 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 508 PetscFunctionReturn(PETSC_SUCCESS); 509 } 510 511 // initialize the static member variables 512 template <DeviceType T> 513 bool DeviceContext<T>::initialized_ = false; 514 515 template <DeviceType T> 516 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 517 518 template <DeviceType T> 519 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 520 521 } // namespace impl 522 523 // shorten this one up a bit (and instantiate the templates) 524 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 525 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 526 527 // shorthand for what is an EXTREMELY long name 528 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 529 530 } // namespace cupm 531 532 } // namespace device 533 534 } // namespace Petsc 535 536 #endif // __cplusplus 537 538 #endif // PETSCDEVICECONTEXTCUDA_HPP 539