1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 namespace Petsc { 17 18 namespace device { 19 20 namespace cupm { 21 22 namespace impl { 23 24 template <DeviceType T> 25 class DeviceContext : BlasInterface<T> { 26 public: 27 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T); 28 29 private: 30 template <typename H, std::size_t> 31 struct HandleTag { 32 using type = H; 33 }; 34 35 using stream_tag = HandleTag<cupmStream_t, 0>; 36 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 37 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 38 39 using stream_type = CUPMStream<T>; 40 using event_type = CUPMEvent<T>; 41 42 public: 43 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 44 // header, but since we are using the power of templates it must be declared part of 45 // this class to have easy access the same typedefs. Technically one can make a 46 // templated struct outside the class but it's more code for the same result. 47 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 48 stream_type stream{}; 49 cupmEvent_t event{}; 50 cupmEvent_t begin{}; // timer-only 51 cupmEvent_t end{}; // timer-only 52 #if PetscDefined(USE_DEBUG) 53 PetscBool timerInUse{}; 54 #endif 55 cupmBlasHandle_t blas{}; 56 cupmSolverHandle_t solver{}; 57 58 constexpr PetscDeviceContext_IMPLS() noexcept = default; 59 60 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { 61 return this->stream.get_stream(); 62 } 63 64 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { 65 return this->blas; 66 } 67 68 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { 69 return this->solver; 70 } 71 }; 72 73 private: 74 static bool initialized_; 75 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 76 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77 78 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { 79 return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); 80 } 81 82 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { 83 return static_cast<CUPMEvent<T> *>(event->data); 84 } 85 86 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { 87 return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; 88 } 89 90 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { 91 return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; 92 } 93 94 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 95 // handles 96 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { 97 return 0; 98 } 99 100 PETSC_NODISCARD static PetscErrorCode create_handle_(cupmBlasHandle_t &handle) noexcept { 101 PetscLogEvent event; 102 103 PetscFunctionBegin; 104 if (PetscLikely(handle)) PetscFunctionReturn(0); 105 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 106 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 107 for (auto i = 0; i < 3; ++i) { 108 auto cberr = cupmBlasCreate(&handle); 109 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 110 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 111 if (i != 2) { 112 PetscCall(PetscSleep(3)); 113 continue; 114 } 115 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 116 } 117 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 118 PetscCall(PetscLogEventResume_Internal(event)); 119 PetscFunctionReturn(0); 120 } 121 122 PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept { 123 const auto dci = impls_cast_(dctx); 124 auto &handle = blashandles_[dctx->device->deviceId]; 125 126 PetscFunctionBegin; 127 PetscCall(create_handle_(handle)); 128 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 129 dci->blas = handle; 130 PetscFunctionReturn(0); 131 } 132 133 PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle)) { 134 PetscLogEvent event; 135 136 PetscFunctionBegin; 137 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 138 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 139 PetscCall(cupmBlasInterface_t::InitializeHandle(handle)); 140 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 141 PetscCall(PetscLogEventResume_Internal(event)); 142 PetscFunctionReturn(0); 143 } 144 145 PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept { 146 const auto dci = impls_cast_(dctx); 147 auto &handle = solverhandles_[dctx->device->deviceId]; 148 149 PetscFunctionBegin; 150 PetscCall(create_handle_(handle)); 151 PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream())); 152 dci->solver = handle; 153 PetscFunctionReturn(0); 154 } 155 156 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept { 157 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 158 159 PetscFunctionBegin; 160 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 161 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 162 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 163 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 164 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 165 PetscFunctionReturn(0); 166 } 167 168 PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { 169 return check_current_device_(dctx, dctx); 170 } 171 172 PETSC_NODISCARD static PetscErrorCode finalize_() noexcept { 173 PetscFunctionBegin; 174 for (auto &&handle : blashandles_) { 175 if (handle) { 176 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 177 handle = nullptr; 178 } 179 } 180 for (auto &&handle : solverhandles_) { 181 if (handle) { 182 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 183 handle = nullptr; 184 } 185 } 186 initialized_ = false; 187 PetscFunctionReturn(0); 188 } 189 190 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 191 PETSC_NODISCARD static PoolType &default_pool_() noexcept { 192 static PoolType pool; 193 return pool; 194 } 195 196 PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept { 197 PetscFunctionBegin; 198 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 199 PetscFunctionReturn(0); 200 } 201 202 public: 203 // All of these functions MUST be static in order to be callable from C, otherwise they 204 // get the implicit 'this' pointer tacked on 205 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 206 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType)); 207 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 208 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *)); 209 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext)); 210 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 211 template <typename Handle_t> 212 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *)); 213 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 214 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *)); 215 PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, void **)); 216 PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **)); 217 PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode)); 218 PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t)); 219 PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent)); 220 PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent)); 221 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent)); 222 223 // not a PetscDeviceContext method, this registers the class 224 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize()); 225 226 // clang-format off 227 const _DeviceContextOps ops = { 228 destroy, 229 changeStreamType, 230 setUp, 231 query, 232 waitForContext, 233 synchronize, 234 getHandle<blas_tag>, 235 getHandle<solver_tag>, 236 getHandle<stream_tag>, 237 beginTimer, 238 endTimer, 239 memAlloc, 240 memFree, 241 memCopy, 242 memSet, 243 createEvent, 244 recordEvent, 245 waitForEvent 246 }; 247 // clang-format on 248 }; 249 250 // not a PetscDeviceContext method, this initializes the CLASS 251 template <DeviceType T> 252 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize()) { 253 PetscFunctionBegin; 254 if (PetscUnlikely(!initialized_)) { 255 cupmMemPool_t mempool; 256 uint64_t threshold = UINT64_MAX; 257 258 initialized_ = true; 259 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, 0)); 260 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 261 blashandles_.fill(nullptr); 262 solverhandles_.fill(nullptr); 263 PetscCall(PetscRegisterFinalize(finalize_)); 264 } 265 PetscFunctionReturn(0); 266 } 267 268 template <DeviceType T> 269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) { 270 PetscFunctionBegin; 271 if (const auto dci = impls_cast_(dctx)) { 272 PetscCall(dci->stream.destroy()); 273 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(std::move(dci->event))); 274 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 275 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 276 delete dci; 277 dctx->data = nullptr; 278 } 279 PetscFunctionReturn(0); 280 } 281 282 template <DeviceType T> 283 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) { 284 const auto dci = impls_cast_(dctx); 285 286 PetscFunctionBegin; 287 PetscCall(dci->stream.destroy()); 288 // set these to null so they aren't usable until setup is called again 289 dci->blas = nullptr; 290 dci->solver = nullptr; 291 PetscFunctionReturn(0); 292 } 293 294 template <DeviceType T> 295 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) { 296 const auto dci = impls_cast_(dctx); 297 auto &event = dci->event; 298 299 PetscFunctionBegin; 300 PetscCall(check_current_device_(dctx)); 301 PetscCall(dci->stream.change_type(dctx->streamType)); 302 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 303 #if PetscDefined(USE_DEBUG) 304 dci->timerInUse = PETSC_FALSE; 305 #endif 306 PetscFunctionReturn(0); 307 } 308 309 template <DeviceType T> 310 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) { 311 PetscFunctionBegin; 312 PetscCall(check_current_device_(dctx)); 313 switch (const auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 314 case cupmSuccess: *idle = PETSC_TRUE; break; 315 case cupmErrorNotReady: *idle = PETSC_FALSE; break; 316 default: PetscCallCUPM(cerr); PetscUnreachable(); 317 } 318 PetscFunctionReturn(0); 319 } 320 321 template <DeviceType T> 322 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) { 323 const auto dcib = impls_cast_(dctxb); 324 const auto event = dcib->event; 325 326 PetscFunctionBegin; 327 PetscCall(check_current_device_(dctxa, dctxb)); 328 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 329 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 330 PetscFunctionReturn(0); 331 } 332 333 template <DeviceType T> 334 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) { 335 auto idle = PETSC_TRUE; 336 337 PetscFunctionBegin; 338 PetscCall(query(dctx, &idle)); 339 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 340 PetscFunctionReturn(0); 341 } 342 343 template <DeviceType T> 344 template <typename handle_t> 345 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) { 346 PetscFunctionBegin; 347 PetscCall(initialize_handle_(handle_t{}, dctx)); 348 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 349 PetscFunctionReturn(0); 350 } 351 352 template <DeviceType T> 353 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) { 354 const auto dci = impls_cast_(dctx); 355 356 PetscFunctionBegin; 357 PetscCall(check_current_device_(dctx)); 358 #if PetscDefined(USE_DEBUG) 359 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 360 dci->timerInUse = PETSC_TRUE; 361 #endif 362 if (!dci->begin) { 363 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 364 PetscCallCUPM(cupmEventCreate(&dci->begin)); 365 PetscCallCUPM(cupmEventCreate(&dci->end)); 366 } 367 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 368 PetscFunctionReturn(0); 369 } 370 371 template <DeviceType T> 372 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) { 373 float gtime; 374 const auto dci = impls_cast_(dctx); 375 const auto end = dci->end; 376 377 PetscFunctionBegin; 378 PetscCall(check_current_device_(dctx)); 379 #if PetscDefined(USE_DEBUG) 380 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 381 dci->timerInUse = PETSC_FALSE; 382 #endif 383 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 384 PetscCallCUPM(cupmEventSynchronize(end)); 385 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 386 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 387 PetscFunctionReturn(0); 388 } 389 390 template <DeviceType T> 391 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, void **dest)) { 392 const auto &stream = impls_cast_(dctx)->stream; 393 394 PetscFunctionBegin; 395 PetscCall(check_current_device_(dctx)); 396 PetscCall(check_memtype_(mtype, "allocating")); 397 if (PetscMemTypeHost(mtype)) { 398 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream)); 399 if (clear) std::memset(*dest, 0, n); 400 } else { 401 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream)); 402 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 403 } 404 PetscFunctionReturn(0); 405 } 406 407 template <DeviceType T> 408 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr)) { 409 const auto &stream = impls_cast_(dctx)->stream; 410 411 PetscFunctionBegin; 412 PetscCall(check_current_device_(dctx)); 413 PetscCall(check_memtype_(mtype, "freeing")); 414 if (!*ptr) PetscFunctionReturn(0); 415 if (PetscMemTypeHost(mtype)) { 416 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 417 // if ptr exists still exists the pool didn't own it 418 if (*ptr) { 419 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 420 421 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 422 if (registered) { 423 PetscCallCUPM(cupmFreeHost(*ptr)); 424 } else if (managed) { 425 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 426 } 427 } 428 } else { 429 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 430 // if ptr exists still exists the pool didn't own it 431 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 432 } 433 PetscFunctionReturn(0); 434 } 435 436 template <DeviceType T> 437 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode)) { 438 const auto stream = impls_cast_(dctx)->stream.get_stream(); 439 440 PetscFunctionBegin; 441 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 442 if (mode == PETSC_DEVICE_COPY_HTOH) { 443 // yes this is faster 444 if (cupmStreamQuery(stream) == cupmSuccess) { 445 PetscCall(PetscMemcpy(dest, src, n)); 446 PetscFunctionReturn(0); 447 } 448 // in case cupmStreamQuery() did not return cupmErrorNotReady 449 PetscCallCUPM(cupmGetLastError()); 450 } 451 PetscCall(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 452 PetscFunctionReturn(0); 453 } 454 455 template <DeviceType T> 456 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n)) { 457 auto vint = static_cast<int>(v); 458 459 PetscFunctionBegin; 460 PetscCall(check_current_device_(dctx)); 461 PetscCall(check_memtype_(mtype, "zeroing")); 462 if (PetscMemTypeHost(mtype)) { 463 // must call public sync to prune the dependency graph 464 PetscCall(PetscDeviceContextSynchronize(dctx)); 465 std::memset(ptr, vint, n); 466 } else { 467 PetscCallCUPM(cupmMemsetAsync(ptr, vint, n, impls_cast_(dctx)->stream.get_stream())); 468 } 469 PetscFunctionReturn(0); 470 } 471 472 template <DeviceType T> 473 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event)) { 474 PetscFunctionBegin; 475 PetscCallCXX(event->data = new event_type()); 476 event->destroy = [](PetscEvent event) { 477 PetscFunctionBegin; 478 delete event_cast_(event); 479 event->data = nullptr; 480 PetscFunctionReturn(0); 481 }; 482 PetscFunctionReturn(0); 483 } 484 485 template <DeviceType T> 486 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event)) { 487 PetscFunctionBegin; 488 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 489 PetscFunctionReturn(0); 490 } 491 492 template <DeviceType T> 493 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event)) { 494 PetscFunctionBegin; 495 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 496 PetscFunctionReturn(0); 497 } 498 499 // initialize the static member variables 500 template <DeviceType T> 501 bool DeviceContext<T>::initialized_ = false; 502 503 template <DeviceType T> 504 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 505 506 template <DeviceType T> 507 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 508 509 } // namespace impl 510 511 // shorten this one up a bit (and instantiate the templates) 512 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 513 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 514 515 // shorthand for what is an EXTREMELY long name 516 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 517 518 } // namespace cupm 519 520 } // namespace device 521 522 } // namespace Petsc 523 524 #endif // __cplusplus 525 526 #endif // PETSCDEVICECONTEXTCUDA_HPP 527