1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 17 namespace Petsc { 18 19 namespace device { 20 21 namespace cupm { 22 23 namespace impl { 24 25 template <DeviceType T> 26 class DeviceContext : BlasInterface<T> { 27 public: 28 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 29 30 private: 31 template <typename H, std::size_t> 32 struct HandleTag { 33 using type = H; 34 }; 35 36 using stream_tag = HandleTag<cupmStream_t, 0>; 37 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 38 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 39 40 using stream_type = CUPMStream<T>; 41 using event_type = CUPMEvent<T>; 42 43 public: 44 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 45 // header, but since we are using the power of templates it must be declared part of 46 // this class to have easy access the same typedefs. Technically one can make a 47 // templated struct outside the class but it's more code for the same result. 48 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 49 stream_type stream{}; 50 cupmEvent_t event{}; 51 cupmEvent_t begin{}; // timer-only 52 cupmEvent_t end{}; // timer-only 53 #if PetscDefined(USE_DEBUG) 54 PetscBool timerInUse{}; 55 #endif 56 cupmBlasHandle_t blas{}; 57 cupmSolverHandle_t solver{}; 58 59 constexpr PetscDeviceContext_IMPLS() noexcept = default; 60 61 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { 62 return this->stream.get_stream(); 63 } 64 65 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { 66 return this->blas; 67 } 68 69 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { 70 return this->solver; 71 } 72 }; 73 74 private: 75 static bool initialized_; 76 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 77 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 78 79 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { 80 return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); 81 } 82 83 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { 84 return static_cast<CUPMEvent<T> *>(event->data); 85 } 86 87 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { 88 return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 89 } 90 91 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { 92 return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 93 } 94 95 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 96 // handles 97 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { 98 return 0; 99 } 100 101 PETSC_NODISCARD static PetscErrorCode create_handle_(cupmBlasHandle_t &handle) noexcept { 102 PetscLogEvent event; 103 104 PetscFunctionBegin; 105 if (PetscLikely(handle)) PetscFunctionReturn(0); 106 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 107 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 108 for (auto i = 0; i < 3; ++i) { 109 auto cberr = cupmBlasCreate(&handle); 110 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 111 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 112 if (i != 2) { 113 PetscCall(PetscSleep(3)); 114 continue; 115 } 116 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 117 } 118 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 119 PetscCall(PetscLogEventResume_Internal(event)); 120 PetscFunctionReturn(0); 121 } 122 123 PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept { 124 const auto dci = impls_cast_(dctx); 125 auto &handle = blashandles_[dctx->device->deviceId]; 126 127 PetscFunctionBegin; 128 PetscCall(create_handle_(handle)); 129 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 130 dci->blas = handle; 131 PetscFunctionReturn(0); 132 } 133 134 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { 135 PetscLogEvent event; 136 137 PetscFunctionBegin; 138 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 139 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 140 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 141 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 142 PetscCall(PetscLogEventResume_Internal(event)); 143 PetscFunctionReturn(0); 144 } 145 146 PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept { 147 const auto dci = impls_cast_(dctx); 148 auto &handle = solverhandles_[dctx->device->deviceId]; 149 150 PetscFunctionBegin; 151 PetscCall(create_handle_(handle)); 152 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 153 dci->solver = handle; 154 PetscFunctionReturn(0); 155 } 156 157 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept { 158 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 159 160 PetscFunctionBegin; 161 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 162 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 163 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 164 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 165 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 166 PetscFunctionReturn(0); 167 } 168 169 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { 170 return check_current_device_(dctx, dctx); 171 } 172 173 PETSC_NODISCARD static PetscErrorCode finalize_() noexcept { 174 PetscFunctionBegin; 175 for (auto &&handle : blashandles_) { 176 if (handle) { 177 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 178 handle = nullptr; 179 } 180 } 181 for (auto &&handle : solverhandles_) { 182 if (handle) { 183 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 184 handle = nullptr; 185 } 186 } 187 initialized_ = false; 188 PetscFunctionReturn(0); 189 } 190 191 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 192 PETSC_NODISCARD static PoolType &default_pool_() noexcept { 193 static PoolType pool; 194 return pool; 195 } 196 197 PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept { 198 PetscFunctionBegin; 199 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 200 PetscFunctionReturn(0); 201 } 202 203 public: 204 // All of these functions MUST be static in order to be callable from C, otherwise they 205 // get the implicit 'this' pointer tacked on 206 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 207 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 208 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 209 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 210 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 211 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 212 template <typename Handle_t> 213 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 214 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 215 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 216 PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **)); 217 PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **)); 218 PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode)); 219 PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t)); 220 PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent)); 221 PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent)); 222 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent)); 223 224 // not a PetscDeviceContext method, this registers the class 225 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 226 227 // clang-format off 228 const _DeviceContextOps ops = { 229 destroy, 230 changeStreamType, 231 setUp, 232 query, 233 waitForContext, 234 synchronize, 235 getHandle<blas_tag>, 236 getHandle<solver_tag>, 237 getHandle<stream_tag>, 238 beginTimer, 239 endTimer, 240 memAlloc, 241 memFree, 242 memCopy, 243 memSet, 244 createEvent, 245 recordEvent, 246 waitForEvent 247 }; 248 // clang-format on 249 }; 250 251 // not a PetscDeviceContext method, this initializes the CLASS 252 template <DeviceType T> 253 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) { 254 PetscFunctionBegin; 255 if (PetscUnlikely(!initialized_)) { 256 cupmMemPool_t mempool; 257 uint64_t threshold = UINT64_MAX; 258 259 initialized_ = true; 260 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0)); 261 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 262 blashandles_.fill(nullptr); 263 solverhandles_.fill(nullptr); 264 PetscCall(PetscRegisterFinalize(finalize_)); 265 } 266 PetscFunctionReturn(0); 267 } 268 269 template <DeviceType T> 270 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) { 271 PetscFunctionBegin; 272 if (const auto dci = impls_cast_(dctx)) { 273 PetscCall(dci->stream.destroy()); 274 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 275 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 276 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 277 delete dci; 278 dctx->data = nullptr; 279 } 280 PetscFunctionReturn(0); 281 } 282 283 template <DeviceType T> 284 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { 285 const auto dci = impls_cast_(dctx); 286 287 PetscFunctionBegin; 288 PetscCall(dci->stream.destroy()); 289 // set these to null so they aren't usable until setup is called again 290 dci->blas = nullptr; 291 dci->solver = nullptr; 292 PetscFunctionReturn(0); 293 } 294 295 template <DeviceType T> 296 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) { 297 const auto dci = impls_cast_(dctx); 298 auto &event = dci->event; 299 300 PetscFunctionBegin; 301 PetscCall(check_current_device_(dctx)); 302 PetscCall(dci->stream.change_type(dctx->streamType)); 303 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 304 #if PetscDefined(USE_DEBUG) 305 dci->timerInUse = PETSC_FALSE; 306 #endif 307 PetscFunctionReturn(0); 308 } 309 310 template <DeviceType T> 311 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) { 312 PetscFunctionBegin; 313 PetscCall(check_current_device_(dctx)); 314 switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 315 case cupmSuccess: *idle = PETSC_TRUE; break; 316 case cupmErrorNotReady: *idle = PETSC_FALSE; break; 317 default: PetscCallCUPM(cerr); PetscUnreachable(); 318 } 319 PetscFunctionReturn(0); 320 } 321 322 template <DeviceType T> 323 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { 324 const auto dcib = impls_cast_(dctxb); 325 const auto event = dcib->event; 326 327 PetscFunctionBegin; 328 PetscCall(check_current_device_(dctxa, dctxb)); 329 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 330 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 331 PetscFunctionReturn(0); 332 } 333 334 template <DeviceType T> 335 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) { 336 auto idle = PETSC_TRUE; 337 338 PetscFunctionBegin; 339 PetscCall(query(dctx, &idle)); 340 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 341 PetscFunctionReturn(0); 342 } 343 344 template <DeviceType T> 345 template <typename handle_t> 346 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) { 347 PetscFunctionBegin; 348 PetscCall(initialize_handle_(handle_t{}, dctx)); 349 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 350 PetscFunctionReturn(0); 351 } 352 353 template <DeviceType T> 354 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) { 355 const auto dci = impls_cast_(dctx); 356 357 PetscFunctionBegin; 358 PetscCall(check_current_device_(dctx)); 359 #if PetscDefined(USE_DEBUG) 360 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 361 dci->timerInUse = PETSC_TRUE; 362 #endif 363 if (!dci->begin) { 364 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 365 PetscCallCUPM(cupmEventCreate(&dci->begin)); 366 PetscCallCUPM(cupmEventCreate(&dci->end)); 367 } 368 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 369 PetscFunctionReturn(0); 370 } 371 372 template <DeviceType T> 373 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { 374 float gtime; 375 const auto dci = impls_cast_(dctx); 376 const auto end = dci->end; 377 378 PetscFunctionBegin; 379 PetscCall(check_current_device_(dctx)); 380 #if PetscDefined(USE_DEBUG) 381 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 382 dci->timerInUse = PETSC_FALSE; 383 #endif 384 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 385 PetscCallCUPM(cupmEventSynchronize(end)); 386 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 387 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 388 PetscFunctionReturn(0); 389 } 390 391 template <DeviceType T> 392 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest)) { 393 const auto &stream = impls_cast_(dctx)->stream; 394 395 PetscFunctionBegin; 396 PetscCall(check_current_device_(dctx)); 397 PetscCall(check_memtype_(mtype, "allocating")); 398 if (PetscMemTypeHost(mtype)) { 399 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 400 } else { 401 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 402 } 403 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 404 PetscFunctionReturn(0); 405 } 406 407 template <DeviceType T> 408 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) { 409 const auto &stream = impls_cast_(dctx)->stream; 410 411 PetscFunctionBegin; 412 PetscCall(check_current_device_(dctx)); 413 PetscCall(check_memtype_(mtype, "freeing")); 414 if (!*ptr) PetscFunctionReturn(0); 415 if (PetscMemTypeHost(mtype)) { 416 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 417 // if ptr exists still exists the pool didn't own it 418 if (*ptr) { 419 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 420 421 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 422 if (registered) { 423 PetscCallCUPM(cupmFreeHost(*ptr)); 424 } else if (managed) { 425 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 426 } 427 } 428 } else { 429 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 430 // if ptr exists still exists the pool didn't own it 431 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 432 } 433 PetscFunctionReturn(0); 434 } 435 436 template <DeviceType T> 437 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) { 438 const auto stream = impls_cast_(dctx)->stream.get_stream(); 439 440 PetscFunctionBegin; 441 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 442 if (mode == PETSC_DEVICE_COPY_HTOH) { 443 // yes this is faster 444 if (cupmStreamQuery(stream) == cupmSuccess) { 445 PetscCall(PetscMemcpy(dest, src, n)); 446 PetscFunctionReturn(0); 447 } 448 // in case cupmStreamQuery() did not return cupmErrorNotReady 449 PetscCallCUPM(cupmGetLastError()); 450 } 451 PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 452 PetscFunctionReturn(0); 453 } 454 455 template <DeviceType T> 456 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) { 457 PetscFunctionBegin; 458 PetscCall(check_current_device_(dctx)); 459 PetscCall(check_memtype_(mtype, "zeroing")); 460 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 461 PetscFunctionReturn(0); 462 } 463 464 template <DeviceType T> 465 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) { 466 PetscFunctionBegin; 467 PetscCallCXX(event->data = new event_type()); 468 event->destroy = [](PetscEvent event) { 469 PetscFunctionBegin; 470 delete event_cast_(event); 471 event->data = nullptr; 472 PetscFunctionReturn(0); 473 }; 474 PetscFunctionReturn(0); 475 } 476 477 template <DeviceType T> 478 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) { 479 PetscFunctionBegin; 480 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 481 PetscFunctionReturn(0); 482 } 483 484 template <DeviceType T> 485 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) { 486 PetscFunctionBegin; 487 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 488 PetscFunctionReturn(0); 489 } 490 491 // initialize the static member variables 492 template <DeviceType T> 493 bool DeviceContext<T>::initialized_ = false; 494 495 template <DeviceType T> 496 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 497 498 template <DeviceType T> 499 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 500 501 } // namespace impl 502 503 // shorten this one up a bit (and instantiate the templates) 504 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 505 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 506 507 // shorthand for what is an EXTREMELY long name 508 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 509 510 } // namespace cupm 511 512 } // namespace device 513 514 } // namespace Petsc 515 516 #endif // __cplusplus 517 518 #endif // PETSCDEVICECONTEXTCUDA_HPP 519