1 #pragma once 2 3 #include <petsc/private/deviceimpl.h> 4 #include <petsc/private/cupmsolverinterface.hpp> 5 #include <petsc/private/logimpl.h> 6 7 #include <petsc/private/cpp/array.hpp> 8 9 #include "../segmentedmempool.hpp" 10 #include "cupmallocator.hpp" 11 #include "cupmstream.hpp" 12 #include "cupmevent.hpp" 13 14 namespace Petsc 15 { 16 17 namespace device 18 { 19 20 namespace cupm 21 { 22 23 namespace impl 24 { 25 26 template <DeviceType T> 27 class DeviceContext : SolverInterface<T> { 28 public: 29 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 30 31 private: 32 template <typename H, std::size_t> 33 struct HandleTag { 34 using type = H; 35 }; 36 37 using stream_tag = HandleTag<cupmStream_t, 0>; 38 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 39 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 40 41 using stream_type = CUPMStream<T>; 42 using event_type = CUPMEvent<T>; 43 44 public: 45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46 // header, but since we are using the power of templates it must be declared part of 47 // this class to have easy access the same typedefs. Technically one can make a 48 // templated struct outside the class but it's more code for the same result. 49 struct PetscDeviceContext_IMPLS { 50 stream_type stream{}; 51 cupmEvent_t event{}; 52 cupmEvent_t begin{}; // timer-only 53 cupmEvent_t end{}; // timer-only 54 #if PetscDefined(USE_DEBUG) 55 PetscBool timerInUse{}; 56 #endif 57 cupmBlasHandle_t blas{}; 58 cupmSolverHandle_t solver{}; 59 60 constexpr PetscDeviceContext_IMPLS() noexcept = default; 61 62 PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); } 63 64 PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; } 65 66 PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; } 67 }; 68 69 private: 70 static bool initialized_; 71 72 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 73 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 74 75 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 76 77 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 78 79 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 80 81 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 82 83 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 84 // handles 85 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 86 87 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 88 { 89 const auto dci = impls_cast_(dctx); 90 auto &handle = blashandles_[dctx->device->deviceId]; 91 92 PetscFunctionBegin; 93 if (!handle) { 94 PetscCall(PetscLogEventsPause()); 95 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 96 for (auto i = 0; i < 3; ++i) { 97 const auto cberr = cupmBlasCreate(handle.ptr_to()); 98 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 99 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 100 if (i != 2) { 101 PetscCall(PetscSleep(3)); 102 continue; 103 } 104 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 105 } 106 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 107 PetscCall(PetscLogEventsResume()); 108 } 109 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 110 dci->blas = handle; 111 PetscFunctionReturn(PETSC_SUCCESS); 112 } 113 114 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 115 { 116 const auto dci = impls_cast_(dctx); 117 auto &handle = solverhandles_[dctx->device->deviceId]; 118 119 PetscFunctionBegin; 120 if (!handle) { 121 PetscCall(PetscLogEventsPause()); 122 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 123 for (auto i = 0; i < 3; ++i) { 124 const auto cerr = cupmSolverCreate(&handle); 125 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 126 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 127 if (i < 2) { 128 PetscCall(PetscSleep(3)); 129 continue; 130 } 131 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 132 } 133 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 134 PetscCall(PetscLogEventsResume()); 135 } 136 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 137 dci->solver = handle; 138 PetscFunctionReturn(PETSC_SUCCESS); 139 } 140 141 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 142 { 143 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 144 145 PetscFunctionBegin; 146 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 147 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 148 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 149 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 150 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 151 PetscFunctionReturn(PETSC_SUCCESS); 152 } 153 154 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 155 156 static PetscErrorCode finalize_() noexcept 157 { 158 PetscFunctionBegin; 159 for (auto &&handle : blashandles_) { 160 if (handle) { 161 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 162 handle = nullptr; 163 } 164 } 165 for (auto &&handle : solverhandles_) { 166 if (handle) { 167 PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 168 handle = nullptr; 169 } 170 } 171 initialized_ = false; 172 PetscFunctionReturn(PETSC_SUCCESS); 173 } 174 175 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 176 PETSC_NODISCARD static PoolType &default_pool_() noexcept 177 { 178 static PoolType pool; 179 return pool; 180 } 181 182 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 183 { 184 PetscFunctionBegin; 185 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 186 PetscFunctionReturn(PETSC_SUCCESS); 187 } 188 189 public: 190 // All of these functions MUST be static in order to be callable from C, otherwise they 191 // get the implicit 'this' pointer tacked on 192 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 193 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 194 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 195 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 196 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 197 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 198 template <typename Handle_t> 199 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 200 template <typename Handle_t> 201 static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept; 202 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 203 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 204 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 205 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 206 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 207 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 208 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 209 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 210 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 211 212 // not a PetscDeviceContext method, this registers the class 213 static PetscErrorCode initialize(PetscDevice) noexcept; 214 215 // clang-format off 216 static constexpr _DeviceContextOps ops = { 217 PetscDesignatedInitializer(destroy, destroy), 218 PetscDesignatedInitializer(changestreamtype, changeStreamType), 219 PetscDesignatedInitializer(setup, setUp), 220 PetscDesignatedInitializer(query, query), 221 PetscDesignatedInitializer(waitforcontext, waitForContext), 222 PetscDesignatedInitializer(synchronize, synchronize), 223 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 224 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 225 PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>), 226 PetscDesignatedInitializer(begintimer, beginTimer), 227 PetscDesignatedInitializer(endtimer, endTimer), 228 PetscDesignatedInitializer(memalloc, memAlloc), 229 PetscDesignatedInitializer(memfree, memFree), 230 PetscDesignatedInitializer(memcopy, memCopy), 231 PetscDesignatedInitializer(memset, memSet), 232 PetscDesignatedInitializer(createevent, createEvent), 233 PetscDesignatedInitializer(recordevent, recordEvent), 234 PetscDesignatedInitializer(waitforevent, waitForEvent) 235 }; 236 // clang-format on 237 }; 238 239 // not a PetscDeviceContext method, this initializes the CLASS 240 template <DeviceType T> 241 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 242 { 243 PetscFunctionBegin; 244 if (PetscUnlikely(!initialized_)) { 245 uint64_t threshold = UINT64_MAX; 246 cupmMemPool_t mempool; 247 248 initialized_ = true; 249 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 250 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 251 blashandles_.fill(nullptr); 252 solverhandles_.fill(nullptr); 253 PetscCall(PetscRegisterFinalize(finalize_)); 254 } 255 PetscFunctionReturn(PETSC_SUCCESS); 256 } 257 258 template <DeviceType T> 259 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 260 { 261 PetscFunctionBegin; 262 if (const auto dci = impls_cast_(dctx)) { 263 PetscCall(dci->stream.destroy()); 264 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 265 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 266 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 267 delete dci; 268 dctx->data = nullptr; 269 } 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 template <DeviceType T> 274 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 275 { 276 const auto dci = impls_cast_(dctx); 277 278 PetscFunctionBegin; 279 PetscCall(dci->stream.destroy()); 280 // set these to null so they aren't usable until setup is called again 281 dci->blas = nullptr; 282 dci->solver = nullptr; 283 PetscFunctionReturn(PETSC_SUCCESS); 284 } 285 286 template <DeviceType T> 287 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 288 { 289 const auto dci = impls_cast_(dctx); 290 auto &event = dci->event; 291 292 PetscFunctionBegin; 293 PetscCall(check_current_device_(dctx)); 294 PetscCall(dci->stream.change_type(dctx->streamType)); 295 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 296 #if PetscDefined(USE_DEBUG) 297 dci->timerInUse = PETSC_FALSE; 298 #endif 299 PetscFunctionReturn(PETSC_SUCCESS); 300 } 301 302 template <DeviceType T> 303 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 304 { 305 PetscFunctionBegin; 306 PetscCall(check_current_device_(dctx)); 307 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 308 case cupmSuccess: 309 *idle = PETSC_TRUE; 310 break; 311 case cupmErrorNotReady: 312 *idle = PETSC_FALSE; 313 // reset the error 314 cerr = cupmGetLastError(); 315 static_cast<void>(cerr); 316 break; 317 default: 318 PetscCallCUPM(cerr); 319 PetscUnreachable(); 320 } 321 PetscFunctionReturn(PETSC_SUCCESS); 322 } 323 324 template <DeviceType T> 325 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 326 { 327 const auto dcib = impls_cast_(dctxb); 328 const auto event = dcib->event; 329 330 PetscFunctionBegin; 331 PetscCall(check_current_device_(dctxa, dctxb)); 332 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 333 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 334 PetscFunctionReturn(PETSC_SUCCESS); 335 } 336 337 template <DeviceType T> 338 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 339 { 340 auto idle = PETSC_TRUE; 341 342 PetscFunctionBegin; 343 PetscCall(query(dctx, &idle)); 344 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 345 PetscFunctionReturn(PETSC_SUCCESS); 346 } 347 348 template <DeviceType T> 349 template <typename handle_t> 350 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 351 { 352 PetscFunctionBegin; 353 PetscCall(initialize_handle_(handle_t{}, dctx)); 354 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 355 PetscFunctionReturn(PETSC_SUCCESS); 356 } 357 358 template <DeviceType T> 359 template <typename handle_t> 360 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept 361 { 362 using handle_type = typename handle_t::type; 363 364 PetscFunctionBegin; 365 PetscCall(initialize_handle_(handle_t{}, dctx)); 366 *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{}))); 367 PetscFunctionReturn(PETSC_SUCCESS); 368 } 369 370 template <DeviceType T> 371 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 372 { 373 const auto dci = impls_cast_(dctx); 374 375 PetscFunctionBegin; 376 PetscCall(check_current_device_(dctx)); 377 #if PetscDefined(USE_DEBUG) 378 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 379 dci->timerInUse = PETSC_TRUE; 380 #endif 381 if (!dci->begin) { 382 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 383 PetscCallCUPM(cupmEventCreate(&dci->begin)); 384 PetscCallCUPM(cupmEventCreate(&dci->end)); 385 } 386 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 387 PetscFunctionReturn(PETSC_SUCCESS); 388 } 389 390 template <DeviceType T> 391 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 392 { 393 float gtime; 394 const auto dci = impls_cast_(dctx); 395 const auto end = dci->end; 396 397 PetscFunctionBegin; 398 PetscCall(check_current_device_(dctx)); 399 #if PetscDefined(USE_DEBUG) 400 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 401 dci->timerInUse = PETSC_FALSE; 402 #endif 403 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 404 PetscCallCUPM(cupmEventSynchronize(end)); 405 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 406 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 407 PetscFunctionReturn(PETSC_SUCCESS); 408 } 409 410 template <DeviceType T> 411 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 412 { 413 const auto &stream = impls_cast_(dctx)->stream; 414 415 PetscFunctionBegin; 416 PetscCall(check_current_device_(dctx)); 417 PetscCall(check_memtype_(mtype, "allocating")); 418 if (PetscMemTypeHost(mtype)) { 419 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 420 } else { 421 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 422 } 423 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 424 PetscFunctionReturn(PETSC_SUCCESS); 425 } 426 427 template <DeviceType T> 428 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 429 { 430 const auto &stream = impls_cast_(dctx)->stream; 431 432 PetscFunctionBegin; 433 PetscCall(check_current_device_(dctx)); 434 PetscCall(check_memtype_(mtype, "freeing")); 435 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 436 if (PetscMemTypeHost(mtype)) { 437 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 438 // if ptr exists still exists the pool didn't own it 439 if (*ptr) { 440 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 441 442 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 443 if (registered) { 444 PetscCallCUPM(cupmFreeHost(*ptr)); 445 } else if (managed) { 446 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 447 } 448 } 449 } else { 450 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 451 // if ptr still exists the pool didn't own it 452 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 453 } 454 PetscFunctionReturn(PETSC_SUCCESS); 455 } 456 457 template <DeviceType T> 458 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 459 { 460 const auto stream = impls_cast_(dctx)->stream.get_stream(); 461 462 PetscFunctionBegin; 463 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 464 if (mode == PETSC_DEVICE_COPY_HTOH) { 465 const auto cerr = cupmStreamQuery(stream); 466 467 // yes this is faster 468 if (cerr == cupmSuccess) { 469 PetscCall(PetscMemcpy(dest, src, n)); 470 PetscFunctionReturn(PETSC_SUCCESS); 471 } else if (cerr == cupmErrorNotReady) { 472 auto PETSC_UNUSED unused = cupmGetLastError(); 473 474 static_cast<void>(unused); 475 } else { 476 PetscCallCUPM(cerr); 477 } 478 } 479 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 480 PetscFunctionReturn(PETSC_SUCCESS); 481 } 482 483 template <DeviceType T> 484 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 485 { 486 PetscFunctionBegin; 487 PetscCall(check_current_device_(dctx)); 488 PetscCall(check_memtype_(mtype, "zeroing")); 489 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 490 PetscFunctionReturn(PETSC_SUCCESS); 491 } 492 493 template <DeviceType T> 494 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 495 { 496 PetscFunctionBegin; 497 PetscCallCXX(event->data = new event_type{}); 498 event->destroy = [](PetscEvent event) { 499 PetscFunctionBegin; 500 delete event_cast_(event); 501 event->data = nullptr; 502 PetscFunctionReturn(PETSC_SUCCESS); 503 }; 504 PetscFunctionReturn(PETSC_SUCCESS); 505 } 506 507 template <DeviceType T> 508 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 509 { 510 PetscFunctionBegin; 511 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 512 PetscFunctionReturn(PETSC_SUCCESS); 513 } 514 515 template <DeviceType T> 516 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 517 { 518 PetscFunctionBegin; 519 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 520 PetscFunctionReturn(PETSC_SUCCESS); 521 } 522 523 // initialize the static member variables 524 template <DeviceType T> 525 bool DeviceContext<T>::initialized_ = false; 526 527 template <DeviceType T> 528 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 529 530 template <DeviceType T> 531 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 532 533 template <DeviceType T> 534 constexpr _DeviceContextOps DeviceContext<T>::ops; 535 536 } // namespace impl 537 538 // shorten this one up a bit (and instantiate the templates) 539 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 540 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 541 542 // shorthand for what is an EXTREMELY long name 543 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 544 545 } // namespace cupm 546 547 } // namespace device 548 549 } // namespace Petsc 550