1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 17 namespace Petsc 18 { 19 20 namespace device 21 { 22 23 namespace cupm 24 { 25 26 namespace impl 27 { 28 29 template <DeviceType T> 30 class DeviceContext : BlasInterface<T> { 31 public: 32 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 33 34 private: 35 template <typename H, std::size_t> 36 struct HandleTag { 37 using type = H; 38 }; 39 40 using stream_tag = HandleTag<cupmStream_t, 0>; 41 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 42 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 43 44 using stream_type = CUPMStream<T>; 45 using event_type = CUPMEvent<T>; 46 47 public: 48 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 49 // header, but since we are using the power of templates it must be declared part of 50 // this class to have easy access the same typedefs. Technically one can make a 51 // templated struct outside the class but it's more code for the same result. 52 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 53 stream_type stream{}; 54 cupmEvent_t event{}; 55 cupmEvent_t begin{}; // timer-only 56 cupmEvent_t end{}; // timer-only 57 #if PetscDefined(USE_DEBUG) 58 PetscBool timerInUse{}; 59 #endif 60 cupmBlasHandle_t blas{}; 61 cupmSolverHandle_t solver{}; 62 63 constexpr PetscDeviceContext_IMPLS() noexcept = default; 64 65 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); } 66 67 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; } 68 69 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; } 70 }; 71 72 private: 73 static bool initialized_; 74 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 75 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 76 77 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 78 79 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 80 81 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 82 83 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 84 85 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 86 // handles 87 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { return 0; } 88 89 PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept 90 { 91 PetscLogEvent event; 92 93 PetscFunctionBegin; 94 if (PetscLikely(handle)) PetscFunctionReturn(0); 95 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 96 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 97 for (auto i = 0; i < 3; ++i) { 98 auto cberr = cupmBlasCreate(&handle); 99 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 100 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 101 if (i != 2) { 102 PetscCall(PetscSleep(3)); 103 continue; 104 } 105 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 106 } 107 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 108 PetscCall(PetscLogEventResume_Internal(event)); 109 PetscFunctionReturn(0); 110 } 111 112 PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept 113 { 114 const auto dci = impls_cast_(dctx); 115 auto &handle = blashandles_[dctx->device->deviceId]; 116 117 PetscFunctionBegin; 118 PetscCall(create_handle_(tag, handle)); 119 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 120 dci->blas = handle; 121 PetscFunctionReturn(0); 122 } 123 124 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(solver_tag, cupmSolverHandle_t &handle)) 125 { 126 PetscLogEvent event; 127 128 PetscFunctionBegin; 129 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 130 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 131 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 132 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 133 PetscCall(PetscLogEventResume_Internal(event)); 134 PetscFunctionReturn(0); 135 } 136 137 PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag tag, PetscDeviceContext dctx) noexcept 138 { 139 const auto dci = impls_cast_(dctx); 140 auto &handle = solverhandles_[dctx->device->deviceId]; 141 142 PetscFunctionBegin; 143 PetscCall(create_handle_(tag, handle)); 144 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 145 dci->solver = handle; 146 PetscFunctionReturn(0); 147 } 148 149 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 150 { 151 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 152 153 PetscFunctionBegin; 154 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 155 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 156 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 157 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 158 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 159 PetscFunctionReturn(0); 160 } 161 162 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 163 164 PETSC_NODISCARD static PetscErrorCode finalize_() noexcept 165 { 166 PetscFunctionBegin; 167 for (auto &&handle : blashandles_) { 168 if (handle) { 169 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 170 handle = nullptr; 171 } 172 } 173 for (auto &&handle : solverhandles_) { 174 if (handle) { 175 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 176 handle = nullptr; 177 } 178 } 179 initialized_ = false; 180 PetscFunctionReturn(0); 181 } 182 183 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 184 PETSC_NODISCARD static PoolType &default_pool_() noexcept 185 { 186 static PoolType pool; 187 return pool; 188 } 189 190 PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 191 { 192 PetscFunctionBegin; 193 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 194 PetscFunctionReturn(0); 195 } 196 197 public: 198 // All of these functions MUST be static in order to be callable from C, otherwise they 199 // get the implicit 'this' pointer tacked on 200 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 201 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 202 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 203 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 204 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 205 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 206 template <typename Handle_t> 207 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 208 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 209 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 210 PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **)); 211 PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **)); 212 PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode)); 213 PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t)); 214 PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent)); 215 PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent)); 216 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent)); 217 218 // not a PetscDeviceContext method, this registers the class 219 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 220 221 // clang-format off 222 const _DeviceContextOps ops = { 223 destroy, 224 changeStreamType, 225 setUp, 226 query, 227 waitForContext, 228 synchronize, 229 getHandle<blas_tag>, 230 getHandle<solver_tag>, 231 getHandle<stream_tag>, 232 beginTimer, 233 endTimer, 234 memAlloc, 235 memFree, 236 memCopy, 237 memSet, 238 createEvent, 239 recordEvent, 240 waitForEvent 241 }; 242 // clang-format on 243 }; 244 245 // not a PetscDeviceContext method, this initializes the CLASS 246 template <DeviceType T> 247 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) 248 { 249 PetscFunctionBegin; 250 if (PetscUnlikely(!initialized_)) { 251 cupmMemPool_t mempool; 252 uint64_t threshold = UINT64_MAX; 253 254 initialized_ = true; 255 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0)); 256 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 257 blashandles_.fill(nullptr); 258 solverhandles_.fill(nullptr); 259 PetscCall(PetscRegisterFinalize(finalize_)); 260 } 261 PetscFunctionReturn(0); 262 } 263 264 template <DeviceType T> 265 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 266 { 267 PetscFunctionBegin; 268 if (const auto dci = impls_cast_(dctx)) { 269 PetscCall(dci->stream.destroy()); 270 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 271 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 272 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 273 delete dci; 274 dctx->data = nullptr; 275 } 276 PetscFunctionReturn(0); 277 } 278 279 template <DeviceType T> 280 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 281 { 282 const auto dci = impls_cast_(dctx); 283 284 PetscFunctionBegin; 285 PetscCall(dci->stream.destroy()); 286 // set these to null so they aren't usable until setup is called again 287 dci->blas = nullptr; 288 dci->solver = nullptr; 289 PetscFunctionReturn(0); 290 } 291 292 template <DeviceType T> 293 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 294 { 295 const auto dci = impls_cast_(dctx); 296 auto &event = dci->event; 297 298 PetscFunctionBegin; 299 PetscCall(check_current_device_(dctx)); 300 PetscCall(dci->stream.change_type(dctx->streamType)); 301 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 302 #if PetscDefined(USE_DEBUG) 303 dci->timerInUse = PETSC_FALSE; 304 #endif 305 PetscFunctionReturn(0); 306 } 307 308 template <DeviceType T> 309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 310 { 311 PetscFunctionBegin; 312 PetscCall(check_current_device_(dctx)); 313 switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 314 case cupmSuccess: 315 *idle = PETSC_TRUE; 316 break; 317 case cupmErrorNotReady: 318 *idle = PETSC_FALSE; 319 break; 320 default: 321 PetscCallCUPM(cerr); 322 PetscUnreachable(); 323 } 324 PetscFunctionReturn(0); 325 } 326 327 template <DeviceType T> 328 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 329 { 330 const auto dcib = impls_cast_(dctxb); 331 const auto event = dcib->event; 332 333 PetscFunctionBegin; 334 PetscCall(check_current_device_(dctxa, dctxb)); 335 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 336 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 337 PetscFunctionReturn(0); 338 } 339 340 template <DeviceType T> 341 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 342 { 343 auto idle = PETSC_TRUE; 344 345 PetscFunctionBegin; 346 PetscCall(query(dctx, &idle)); 347 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 348 PetscFunctionReturn(0); 349 } 350 351 template <DeviceType T> 352 template <typename handle_t> 353 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 354 { 355 PetscFunctionBegin; 356 PetscCall(initialize_handle_(handle_t{}, dctx)); 357 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 358 PetscFunctionReturn(0); 359 } 360 361 template <DeviceType T> 362 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 363 { 364 const auto dci = impls_cast_(dctx); 365 366 PetscFunctionBegin; 367 PetscCall(check_current_device_(dctx)); 368 #if PetscDefined(USE_DEBUG) 369 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 370 dci->timerInUse = PETSC_TRUE; 371 #endif 372 if (!dci->begin) { 373 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 374 PetscCallCUPM(cupmEventCreate(&dci->begin)); 375 PetscCallCUPM(cupmEventCreate(&dci->end)); 376 } 377 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 378 PetscFunctionReturn(0); 379 } 380 381 template <DeviceType T> 382 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 383 { 384 float gtime; 385 const auto dci = impls_cast_(dctx); 386 const auto end = dci->end; 387 388 PetscFunctionBegin; 389 PetscCall(check_current_device_(dctx)); 390 #if PetscDefined(USE_DEBUG) 391 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 392 dci->timerInUse = PETSC_FALSE; 393 #endif 394 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 395 PetscCallCUPM(cupmEventSynchronize(end)); 396 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 397 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 398 PetscFunctionReturn(0); 399 } 400 401 template <DeviceType T> 402 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest)) 403 { 404 const auto &stream = impls_cast_(dctx)->stream; 405 406 PetscFunctionBegin; 407 PetscCall(check_current_device_(dctx)); 408 PetscCall(check_memtype_(mtype, "allocating")); 409 if (PetscMemTypeHost(mtype)) { 410 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 411 } else { 412 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 413 } 414 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 415 PetscFunctionReturn(0); 416 } 417 418 template <DeviceType T> 419 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) 420 { 421 const auto &stream = impls_cast_(dctx)->stream; 422 423 PetscFunctionBegin; 424 PetscCall(check_current_device_(dctx)); 425 PetscCall(check_memtype_(mtype, "freeing")); 426 if (!*ptr) PetscFunctionReturn(0); 427 if (PetscMemTypeHost(mtype)) { 428 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 429 // if ptr exists still exists the pool didn't own it 430 if (*ptr) { 431 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 432 433 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 434 if (registered) { 435 PetscCallCUPM(cupmFreeHost(*ptr)); 436 } else if (managed) { 437 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 438 } 439 } 440 } else { 441 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 442 // if ptr exists still exists the pool didn't own it 443 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 444 } 445 PetscFunctionReturn(0); 446 } 447 448 template <DeviceType T> 449 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) 450 { 451 const auto stream = impls_cast_(dctx)->stream.get_stream(); 452 453 PetscFunctionBegin; 454 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 455 if (mode == PETSC_DEVICE_COPY_HTOH) { 456 // yes this is faster 457 if (cupmStreamQuery(stream) == cupmSuccess) { 458 PetscCall(PetscMemcpy(dest, src, n)); 459 PetscFunctionReturn(0); 460 } 461 // in case cupmStreamQuery() did not return cupmErrorNotReady 462 PetscCallCUPM(cupmGetLastError()); 463 } 464 PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 465 PetscFunctionReturn(0); 466 } 467 468 template <DeviceType T> 469 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) 470 { 471 PetscFunctionBegin; 472 PetscCall(check_current_device_(dctx)); 473 PetscCall(check_memtype_(mtype, "zeroing")); 474 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 475 PetscFunctionReturn(0); 476 } 477 478 template <DeviceType T> 479 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) 480 { 481 PetscFunctionBegin; 482 PetscCallCXX(event->data = new event_type()); 483 event->destroy = [](PetscEvent event) { 484 PetscFunctionBegin; 485 delete event_cast_(event); 486 event->data = nullptr; 487 PetscFunctionReturn(0); 488 }; 489 PetscFunctionReturn(0); 490 } 491 492 template <DeviceType T> 493 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) 494 { 495 PetscFunctionBegin; 496 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 497 PetscFunctionReturn(0); 498 } 499 500 template <DeviceType T> 501 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) 502 { 503 PetscFunctionBegin; 504 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 505 PetscFunctionReturn(0); 506 } 507 508 // initialize the static member variables 509 template <DeviceType T> 510 bool DeviceContext<T>::initialized_ = false; 511 512 template <DeviceType T> 513 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 514 515 template <DeviceType T> 516 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 517 518 } // namespace impl 519 520 // shorten this one up a bit (and instantiate the templates) 521 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 522 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 523 524 // shorthand for what is an EXTREMELY long name 525 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 526 527 } // namespace cupm 528 529 } // namespace device 530 531 } // namespace Petsc 532 533 #endif // __cplusplus 534 535 #endif // PETSCDEVICECONTEXTCUDA_HPP 536