1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmsolverinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 namespace Petsc 16 { 17 18 namespace device 19 { 20 21 namespace cupm 22 { 23 24 namespace impl 25 { 26 27 template <DeviceType T> 28 class DeviceContext : SolverInterface<T> { 29 public: 30 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 31 32 private: 33 template <typename H, std::size_t> 34 struct HandleTag { 35 using type = H; 36 }; 37 38 using stream_tag = HandleTag<cupmStream_t, 0>; 39 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 40 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 41 42 using stream_type = CUPMStream<T>; 43 using event_type = CUPMEvent<T>; 44 45 public: 46 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 47 // header, but since we are using the power of templates it must be declared part of 48 // this class to have easy access the same typedefs. Technically one can make a 49 // templated struct outside the class but it's more code for the same result. 50 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 51 stream_type stream{}; 52 cupmEvent_t event{}; 53 cupmEvent_t begin{}; // timer-only 54 cupmEvent_t end{}; // timer-only 55 #if PetscDefined(USE_DEBUG) 56 PetscBool timerInUse{}; 57 #endif 58 cupmBlasHandle_t blas{}; 59 cupmSolverHandle_t solver{}; 60 61 constexpr PetscDeviceContext_IMPLS() noexcept = default; 62 63 PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); } 64 65 PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; } 66 67 PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; } 68 }; 69 70 private: 71 static bool initialized_; 72 73 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 74 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 75 76 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 77 78 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 79 80 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 81 82 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 83 84 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 85 // handles 86 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 87 88 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 89 { 90 const auto dci = impls_cast_(dctx); 91 auto &handle = blashandles_[dctx->device->deviceId]; 92 93 PetscFunctionBegin; 94 if (!handle) { 95 PetscLogEvent event; 96 97 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 98 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 99 for (auto i = 0; i < 3; ++i) { 100 const auto cberr = cupmBlasCreate(handle.ptr_to()); 101 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 102 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 103 if (i != 2) { 104 PetscCall(PetscSleep(3)); 105 continue; 106 } 107 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 108 } 109 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 110 PetscCall(PetscLogEventResume_Internal(event)); 111 } 112 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 113 dci->blas = handle; 114 PetscFunctionReturn(PETSC_SUCCESS); 115 } 116 117 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 118 { 119 const auto dci = impls_cast_(dctx); 120 auto &handle = solverhandles_[dctx->device->deviceId]; 121 122 PetscFunctionBegin; 123 if (!handle) { 124 PetscLogEvent event; 125 126 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 127 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 128 for (auto i = 0; i < 3; ++i) { 129 const auto cerr = cupmSolverCreate(&handle); 130 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 131 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 132 if (i < 2) { 133 PetscCall(PetscSleep(3)); 134 continue; 135 } 136 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 137 } 138 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 139 PetscCall(PetscLogEventResume_Internal(event)); 140 } 141 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 142 dci->solver = handle; 143 PetscFunctionReturn(PETSC_SUCCESS); 144 } 145 146 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 147 { 148 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 149 150 PetscFunctionBegin; 151 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 152 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 153 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 154 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 155 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 156 PetscFunctionReturn(PETSC_SUCCESS); 157 } 158 159 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 160 161 static PetscErrorCode finalize_() noexcept 162 { 163 PetscFunctionBegin; 164 for (auto &&handle : blashandles_) { 165 if (handle) { 166 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 167 handle = nullptr; 168 } 169 } 170 171 for (auto &&handle : solverhandles_) { 172 if (handle) { 173 PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 174 handle = nullptr; 175 } 176 } 177 initialized_ = false; 178 PetscFunctionReturn(PETSC_SUCCESS); 179 } 180 181 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 182 PETSC_NODISCARD static PoolType &default_pool_() noexcept 183 { 184 static PoolType pool; 185 return pool; 186 } 187 188 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 189 { 190 PetscFunctionBegin; 191 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 192 PetscFunctionReturn(PETSC_SUCCESS); 193 } 194 195 public: 196 // All of these functions MUST be static in order to be callable from C, otherwise they 197 // get the implicit 'this' pointer tacked on 198 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 199 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 200 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 201 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 202 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 203 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 204 template <typename Handle_t> 205 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 206 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 207 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 208 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 209 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 210 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 211 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 212 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 213 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 214 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 215 216 // not a PetscDeviceContext method, this registers the class 217 static PetscErrorCode initialize(PetscDevice) noexcept; 218 219 // clang-format off 220 static constexpr _DeviceContextOps ops = { 221 PetscDesignatedInitializer(destroy, destroy), 222 PetscDesignatedInitializer(changestreamtype, changeStreamType), 223 PetscDesignatedInitializer(setup, setUp), 224 PetscDesignatedInitializer(query, query), 225 PetscDesignatedInitializer(waitforcontext, waitForContext), 226 PetscDesignatedInitializer(synchronize, synchronize), 227 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 228 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 229 PetscDesignatedInitializer(getstreamhandle, getHandle<stream_tag>), 230 PetscDesignatedInitializer(begintimer, beginTimer), 231 PetscDesignatedInitializer(endtimer, endTimer), 232 PetscDesignatedInitializer(memalloc, memAlloc), 233 PetscDesignatedInitializer(memfree, memFree), 234 PetscDesignatedInitializer(memcopy, memCopy), 235 PetscDesignatedInitializer(memset, memSet), 236 PetscDesignatedInitializer(createevent, createEvent), 237 PetscDesignatedInitializer(recordevent, recordEvent), 238 PetscDesignatedInitializer(waitforevent, waitForEvent) 239 }; 240 // clang-format on 241 }; 242 243 // not a PetscDeviceContext method, this initializes the CLASS 244 template <DeviceType T> 245 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 246 { 247 PetscFunctionBegin; 248 if (PetscUnlikely(!initialized_)) { 249 uint64_t threshold = UINT64_MAX; 250 cupmMemPool_t mempool; 251 252 initialized_ = true; 253 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 254 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 255 blashandles_.fill(nullptr); 256 solverhandles_.fill(nullptr); 257 PetscCall(PetscRegisterFinalize(finalize_)); 258 } 259 PetscFunctionReturn(PETSC_SUCCESS); 260 } 261 262 template <DeviceType T> 263 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 264 { 265 PetscFunctionBegin; 266 if (const auto dci = impls_cast_(dctx)) { 267 PetscCall(dci->stream.destroy()); 268 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 269 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 270 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 271 delete dci; 272 dctx->data = nullptr; 273 } 274 PetscFunctionReturn(PETSC_SUCCESS); 275 } 276 277 template <DeviceType T> 278 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 279 { 280 const auto dci = impls_cast_(dctx); 281 282 PetscFunctionBegin; 283 PetscCall(dci->stream.destroy()); 284 // set these to null so they aren't usable until setup is called again 285 dci->blas = nullptr; 286 dci->solver = nullptr; 287 PetscFunctionReturn(PETSC_SUCCESS); 288 } 289 290 template <DeviceType T> 291 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 292 { 293 const auto dci = impls_cast_(dctx); 294 auto &event = dci->event; 295 296 PetscFunctionBegin; 297 PetscCall(check_current_device_(dctx)); 298 PetscCall(dci->stream.change_type(dctx->streamType)); 299 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 300 #if PetscDefined(USE_DEBUG) 301 dci->timerInUse = PETSC_FALSE; 302 #endif 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 template <DeviceType T> 307 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 308 { 309 PetscFunctionBegin; 310 PetscCall(check_current_device_(dctx)); 311 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 312 case cupmSuccess: 313 *idle = PETSC_TRUE; 314 break; 315 case cupmErrorNotReady: 316 *idle = PETSC_FALSE; 317 // reset the error 318 cerr = cupmGetLastError(); 319 static_cast<void>(cerr); 320 break; 321 default: 322 PetscCallCUPM(cerr); 323 PetscUnreachable(); 324 } 325 PetscFunctionReturn(PETSC_SUCCESS); 326 } 327 328 template <DeviceType T> 329 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 330 { 331 const auto dcib = impls_cast_(dctxb); 332 const auto event = dcib->event; 333 334 PetscFunctionBegin; 335 PetscCall(check_current_device_(dctxa, dctxb)); 336 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 337 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 338 PetscFunctionReturn(PETSC_SUCCESS); 339 } 340 341 template <DeviceType T> 342 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 343 { 344 auto idle = PETSC_TRUE; 345 346 PetscFunctionBegin; 347 PetscCall(query(dctx, &idle)); 348 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 349 PetscFunctionReturn(PETSC_SUCCESS); 350 } 351 352 template <DeviceType T> 353 template <typename handle_t> 354 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 355 { 356 PetscFunctionBegin; 357 PetscCall(initialize_handle_(handle_t{}, dctx)); 358 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 359 PetscFunctionReturn(PETSC_SUCCESS); 360 } 361 362 template <DeviceType T> 363 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 364 { 365 const auto dci = impls_cast_(dctx); 366 367 PetscFunctionBegin; 368 PetscCall(check_current_device_(dctx)); 369 #if PetscDefined(USE_DEBUG) 370 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 371 dci->timerInUse = PETSC_TRUE; 372 #endif 373 if (!dci->begin) { 374 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 375 PetscCallCUPM(cupmEventCreate(&dci->begin)); 376 PetscCallCUPM(cupmEventCreate(&dci->end)); 377 } 378 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 379 PetscFunctionReturn(PETSC_SUCCESS); 380 } 381 382 template <DeviceType T> 383 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 384 { 385 float gtime; 386 const auto dci = impls_cast_(dctx); 387 const auto end = dci->end; 388 389 PetscFunctionBegin; 390 PetscCall(check_current_device_(dctx)); 391 #if PetscDefined(USE_DEBUG) 392 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 393 dci->timerInUse = PETSC_FALSE; 394 #endif 395 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 396 PetscCallCUPM(cupmEventSynchronize(end)); 397 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 398 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 399 PetscFunctionReturn(PETSC_SUCCESS); 400 } 401 402 template <DeviceType T> 403 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 404 { 405 const auto &stream = impls_cast_(dctx)->stream; 406 407 PetscFunctionBegin; 408 PetscCall(check_current_device_(dctx)); 409 PetscCall(check_memtype_(mtype, "allocating")); 410 if (PetscMemTypeHost(mtype)) { 411 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 412 } else { 413 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 414 } 415 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 416 PetscFunctionReturn(PETSC_SUCCESS); 417 } 418 419 template <DeviceType T> 420 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 421 { 422 const auto &stream = impls_cast_(dctx)->stream; 423 424 PetscFunctionBegin; 425 PetscCall(check_current_device_(dctx)); 426 PetscCall(check_memtype_(mtype, "freeing")); 427 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 428 if (PetscMemTypeHost(mtype)) { 429 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 430 // if ptr exists still exists the pool didn't own it 431 if (*ptr) { 432 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 433 434 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 435 if (registered) { 436 PetscCallCUPM(cupmFreeHost(*ptr)); 437 } else if (managed) { 438 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 439 } 440 } 441 } else { 442 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 443 // if ptr still exists the pool didn't own it 444 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 445 } 446 PetscFunctionReturn(PETSC_SUCCESS); 447 } 448 449 template <DeviceType T> 450 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 451 { 452 const auto stream = impls_cast_(dctx)->stream.get_stream(); 453 454 PetscFunctionBegin; 455 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 456 if (mode == PETSC_DEVICE_COPY_HTOH) { 457 const auto cerr = cupmStreamQuery(stream); 458 459 // yes this is faster 460 if (cerr == cupmSuccess) { 461 PetscCall(PetscMemcpy(dest, src, n)); 462 PetscFunctionReturn(PETSC_SUCCESS); 463 } else if (cerr == cupmErrorNotReady) { 464 auto PETSC_UNUSED unused = cupmGetLastError(); 465 466 static_cast<void>(unused); 467 } else { 468 PetscCallCUPM(cerr); 469 } 470 } 471 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 472 PetscFunctionReturn(PETSC_SUCCESS); 473 } 474 475 template <DeviceType T> 476 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 477 { 478 PetscFunctionBegin; 479 PetscCall(check_current_device_(dctx)); 480 PetscCall(check_memtype_(mtype, "zeroing")); 481 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 482 PetscFunctionReturn(PETSC_SUCCESS); 483 } 484 485 template <DeviceType T> 486 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 487 { 488 PetscFunctionBegin; 489 PetscCallCXX(event->data = new event_type()); 490 event->destroy = [](PetscEvent event) { 491 PetscFunctionBegin; 492 delete event_cast_(event); 493 event->data = nullptr; 494 PetscFunctionReturn(PETSC_SUCCESS); 495 }; 496 PetscFunctionReturn(PETSC_SUCCESS); 497 } 498 499 template <DeviceType T> 500 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 501 { 502 PetscFunctionBegin; 503 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 504 PetscFunctionReturn(PETSC_SUCCESS); 505 } 506 507 template <DeviceType T> 508 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 509 { 510 PetscFunctionBegin; 511 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 512 PetscFunctionReturn(PETSC_SUCCESS); 513 } 514 515 // initialize the static member variables 516 template <DeviceType T> 517 bool DeviceContext<T>::initialized_ = false; 518 519 template <DeviceType T> 520 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 521 522 template <DeviceType T> 523 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 524 525 template <DeviceType T> 526 constexpr _DeviceContextOps DeviceContext<T>::ops; 527 528 } // namespace impl 529 530 // shorten this one up a bit (and instantiate the templates) 531 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 532 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 533 534 // shorthand for what is an EXTREMELY long name 535 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 536 537 } // namespace cupm 538 539 } // namespace device 540 541 } // namespace Petsc 542 543 #endif // PETSCDEVICECONTEXTCUDA_HPP 544