1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmsolverinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 namespace Petsc 16 { 17 18 namespace device 19 { 20 21 namespace cupm 22 { 23 24 namespace impl 25 { 26 27 template <DeviceType T> 28 class DeviceContext : SolverInterface<T> { 29 public: 30 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 31 32 private: 33 template <typename H, std::size_t> 34 struct HandleTag { 35 using type = H; 36 }; 37 38 using stream_tag = HandleTag<cupmStream_t, 0>; 39 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 40 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 41 42 using stream_type = CUPMStream<T>; 43 using event_type = CUPMEvent<T>; 44 45 public: 46 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 47 // header, but since we are using the power of templates it must be declared part of 48 // this class to have easy access the same typedefs. Technically one can make a 49 // templated struct outside the class but it's more code for the same result. 50 struct PetscDeviceContext_IMPLS { 51 stream_type stream{}; 52 cupmEvent_t event{}; 53 cupmEvent_t begin{}; // timer-only 54 cupmEvent_t end{}; // timer-only 55 #if PetscDefined(USE_DEBUG) 56 PetscBool timerInUse{}; 57 #endif 58 cupmBlasHandle_t blas{}; 59 cupmSolverHandle_t solver{}; 60 61 constexpr PetscDeviceContext_IMPLS() noexcept = default; 62 63 PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); } 64 65 PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; } 66 67 PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; } 68 }; 69 70 private: 71 static bool initialized_; 72 73 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 74 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 75 76 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 77 78 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 79 80 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 81 82 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 83 84 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 85 // handles 86 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 87 88 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 89 { 90 const auto dci = impls_cast_(dctx); 91 auto &handle = blashandles_[dctx->device->deviceId]; 92 93 PetscFunctionBegin; 94 if (!handle) { 95 PetscCall(PetscLogEventsPause()); 96 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 97 for (auto i = 0; i < 3; ++i) { 98 const auto cberr = cupmBlasCreate(handle.ptr_to()); 99 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 100 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 101 if (i != 2) { 102 PetscCall(PetscSleep(3)); 103 continue; 104 } 105 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 106 } 107 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 108 PetscCall(PetscLogEventsResume()); 109 } 110 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 111 dci->blas = handle; 112 PetscFunctionReturn(PETSC_SUCCESS); 113 } 114 115 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 116 { 117 const auto dci = impls_cast_(dctx); 118 auto &handle = solverhandles_[dctx->device->deviceId]; 119 120 PetscFunctionBegin; 121 if (!handle) { 122 PetscCall(PetscLogEventsPause()); 123 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 124 for (auto i = 0; i < 3; ++i) { 125 const auto cerr = cupmSolverCreate(&handle); 126 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 127 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 128 if (i < 2) { 129 PetscCall(PetscSleep(3)); 130 continue; 131 } 132 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 133 } 134 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 135 PetscCall(PetscLogEventsResume()); 136 } 137 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 138 dci->solver = handle; 139 PetscFunctionReturn(PETSC_SUCCESS); 140 } 141 142 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 143 { 144 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 145 146 PetscFunctionBegin; 147 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 148 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 149 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 150 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 151 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 152 PetscFunctionReturn(PETSC_SUCCESS); 153 } 154 155 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 156 157 static PetscErrorCode finalize_() noexcept 158 { 159 PetscFunctionBegin; 160 for (auto &&handle : blashandles_) { 161 if (handle) { 162 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 163 handle = nullptr; 164 } 165 } 166 for (auto &&handle : solverhandles_) { 167 if (handle) { 168 PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 169 handle = nullptr; 170 } 171 } 172 initialized_ = false; 173 PetscFunctionReturn(PETSC_SUCCESS); 174 } 175 176 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 177 PETSC_NODISCARD static PoolType &default_pool_() noexcept 178 { 179 static PoolType pool; 180 return pool; 181 } 182 183 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 184 { 185 PetscFunctionBegin; 186 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 187 PetscFunctionReturn(PETSC_SUCCESS); 188 } 189 190 public: 191 // All of these functions MUST be static in order to be callable from C, otherwise they 192 // get the implicit 'this' pointer tacked on 193 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 194 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 195 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 196 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 197 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 198 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 199 template <typename Handle_t> 200 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 201 template <typename Handle_t> 202 static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept; 203 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 204 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 205 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 206 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 207 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 208 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 209 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 210 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 211 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 212 213 // not a PetscDeviceContext method, this registers the class 214 static PetscErrorCode initialize(PetscDevice) noexcept; 215 216 // clang-format off 217 static constexpr _DeviceContextOps ops = { 218 PetscDesignatedInitializer(destroy, destroy), 219 PetscDesignatedInitializer(changestreamtype, changeStreamType), 220 PetscDesignatedInitializer(setup, setUp), 221 PetscDesignatedInitializer(query, query), 222 PetscDesignatedInitializer(waitforcontext, waitForContext), 223 PetscDesignatedInitializer(synchronize, synchronize), 224 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 225 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 226 PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>), 227 PetscDesignatedInitializer(begintimer, beginTimer), 228 PetscDesignatedInitializer(endtimer, endTimer), 229 PetscDesignatedInitializer(memalloc, memAlloc), 230 PetscDesignatedInitializer(memfree, memFree), 231 PetscDesignatedInitializer(memcopy, memCopy), 232 PetscDesignatedInitializer(memset, memSet), 233 PetscDesignatedInitializer(createevent, createEvent), 234 PetscDesignatedInitializer(recordevent, recordEvent), 235 PetscDesignatedInitializer(waitforevent, waitForEvent) 236 }; 237 // clang-format on 238 }; 239 240 // not a PetscDeviceContext method, this initializes the CLASS 241 template <DeviceType T> 242 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 243 { 244 PetscFunctionBegin; 245 if (PetscUnlikely(!initialized_)) { 246 uint64_t threshold = UINT64_MAX; 247 cupmMemPool_t mempool; 248 249 initialized_ = true; 250 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 251 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 252 blashandles_.fill(nullptr); 253 solverhandles_.fill(nullptr); 254 PetscCall(PetscRegisterFinalize(finalize_)); 255 } 256 PetscFunctionReturn(PETSC_SUCCESS); 257 } 258 259 template <DeviceType T> 260 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 261 { 262 PetscFunctionBegin; 263 if (const auto dci = impls_cast_(dctx)) { 264 PetscCall(dci->stream.destroy()); 265 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 266 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 267 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 268 delete dci; 269 dctx->data = nullptr; 270 } 271 PetscFunctionReturn(PETSC_SUCCESS); 272 } 273 274 template <DeviceType T> 275 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 276 { 277 const auto dci = impls_cast_(dctx); 278 279 PetscFunctionBegin; 280 PetscCall(dci->stream.destroy()); 281 // set these to null so they aren't usable until setup is called again 282 dci->blas = nullptr; 283 dci->solver = nullptr; 284 PetscFunctionReturn(PETSC_SUCCESS); 285 } 286 287 template <DeviceType T> 288 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 289 { 290 const auto dci = impls_cast_(dctx); 291 auto &event = dci->event; 292 293 PetscFunctionBegin; 294 PetscCall(check_current_device_(dctx)); 295 PetscCall(dci->stream.change_type(dctx->streamType)); 296 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 297 #if PetscDefined(USE_DEBUG) 298 dci->timerInUse = PETSC_FALSE; 299 #endif 300 PetscFunctionReturn(PETSC_SUCCESS); 301 } 302 303 template <DeviceType T> 304 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 305 { 306 PetscFunctionBegin; 307 PetscCall(check_current_device_(dctx)); 308 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 309 case cupmSuccess: 310 *idle = PETSC_TRUE; 311 break; 312 case cupmErrorNotReady: 313 *idle = PETSC_FALSE; 314 // reset the error 315 cerr = cupmGetLastError(); 316 static_cast<void>(cerr); 317 break; 318 default: 319 PetscCallCUPM(cerr); 320 PetscUnreachable(); 321 } 322 PetscFunctionReturn(PETSC_SUCCESS); 323 } 324 325 template <DeviceType T> 326 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 327 { 328 const auto dcib = impls_cast_(dctxb); 329 const auto event = dcib->event; 330 331 PetscFunctionBegin; 332 PetscCall(check_current_device_(dctxa, dctxb)); 333 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 334 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 335 PetscFunctionReturn(PETSC_SUCCESS); 336 } 337 338 template <DeviceType T> 339 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 340 { 341 auto idle = PETSC_TRUE; 342 343 PetscFunctionBegin; 344 PetscCall(query(dctx, &idle)); 345 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 346 PetscFunctionReturn(PETSC_SUCCESS); 347 } 348 349 template <DeviceType T> 350 template <typename handle_t> 351 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 352 { 353 PetscFunctionBegin; 354 PetscCall(initialize_handle_(handle_t{}, dctx)); 355 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 356 PetscFunctionReturn(PETSC_SUCCESS); 357 } 358 359 template <DeviceType T> 360 template <typename handle_t> 361 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept 362 { 363 using handle_type = typename handle_t::type; 364 365 PetscFunctionBegin; 366 PetscCall(initialize_handle_(handle_t{}, dctx)); 367 *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{}))); 368 PetscFunctionReturn(PETSC_SUCCESS); 369 } 370 371 template <DeviceType T> 372 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 373 { 374 const auto dci = impls_cast_(dctx); 375 376 PetscFunctionBegin; 377 PetscCall(check_current_device_(dctx)); 378 #if PetscDefined(USE_DEBUG) 379 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 380 dci->timerInUse = PETSC_TRUE; 381 #endif 382 if (!dci->begin) { 383 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 384 PetscCallCUPM(cupmEventCreate(&dci->begin)); 385 PetscCallCUPM(cupmEventCreate(&dci->end)); 386 } 387 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 388 PetscFunctionReturn(PETSC_SUCCESS); 389 } 390 391 template <DeviceType T> 392 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 393 { 394 float gtime; 395 const auto dci = impls_cast_(dctx); 396 const auto end = dci->end; 397 398 PetscFunctionBegin; 399 PetscCall(check_current_device_(dctx)); 400 #if PetscDefined(USE_DEBUG) 401 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 402 dci->timerInUse = PETSC_FALSE; 403 #endif 404 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 405 PetscCallCUPM(cupmEventSynchronize(end)); 406 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 407 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 408 PetscFunctionReturn(PETSC_SUCCESS); 409 } 410 411 template <DeviceType T> 412 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 413 { 414 const auto &stream = impls_cast_(dctx)->stream; 415 416 PetscFunctionBegin; 417 PetscCall(check_current_device_(dctx)); 418 PetscCall(check_memtype_(mtype, "allocating")); 419 if (PetscMemTypeHost(mtype)) { 420 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 421 } else { 422 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 423 } 424 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 425 PetscFunctionReturn(PETSC_SUCCESS); 426 } 427 428 template <DeviceType T> 429 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 430 { 431 const auto &stream = impls_cast_(dctx)->stream; 432 433 PetscFunctionBegin; 434 PetscCall(check_current_device_(dctx)); 435 PetscCall(check_memtype_(mtype, "freeing")); 436 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 437 if (PetscMemTypeHost(mtype)) { 438 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 439 // if ptr exists still exists the pool didn't own it 440 if (*ptr) { 441 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 442 443 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 444 if (registered) { 445 PetscCallCUPM(cupmFreeHost(*ptr)); 446 } else if (managed) { 447 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 448 } 449 } 450 } else { 451 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 452 // if ptr still exists the pool didn't own it 453 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 454 } 455 PetscFunctionReturn(PETSC_SUCCESS); 456 } 457 458 template <DeviceType T> 459 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 460 { 461 const auto stream = impls_cast_(dctx)->stream.get_stream(); 462 463 PetscFunctionBegin; 464 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 465 if (mode == PETSC_DEVICE_COPY_HTOH) { 466 const auto cerr = cupmStreamQuery(stream); 467 468 // yes this is faster 469 if (cerr == cupmSuccess) { 470 PetscCall(PetscMemcpy(dest, src, n)); 471 PetscFunctionReturn(PETSC_SUCCESS); 472 } else if (cerr == cupmErrorNotReady) { 473 auto PETSC_UNUSED unused = cupmGetLastError(); 474 475 static_cast<void>(unused); 476 } else { 477 PetscCallCUPM(cerr); 478 } 479 } 480 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 481 PetscFunctionReturn(PETSC_SUCCESS); 482 } 483 484 template <DeviceType T> 485 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 486 { 487 PetscFunctionBegin; 488 PetscCall(check_current_device_(dctx)); 489 PetscCall(check_memtype_(mtype, "zeroing")); 490 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 491 PetscFunctionReturn(PETSC_SUCCESS); 492 } 493 494 template <DeviceType T> 495 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 496 { 497 PetscFunctionBegin; 498 PetscCallCXX(event->data = new event_type{}); 499 event->destroy = [](PetscEvent event) { 500 PetscFunctionBegin; 501 delete event_cast_(event); 502 event->data = nullptr; 503 PetscFunctionReturn(PETSC_SUCCESS); 504 }; 505 PetscFunctionReturn(PETSC_SUCCESS); 506 } 507 508 template <DeviceType T> 509 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 510 { 511 PetscFunctionBegin; 512 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 513 PetscFunctionReturn(PETSC_SUCCESS); 514 } 515 516 template <DeviceType T> 517 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 518 { 519 PetscFunctionBegin; 520 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 521 PetscFunctionReturn(PETSC_SUCCESS); 522 } 523 524 // initialize the static member variables 525 template <DeviceType T> 526 bool DeviceContext<T>::initialized_ = false; 527 528 template <DeviceType T> 529 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 530 531 template <DeviceType T> 532 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 533 534 template <DeviceType T> 535 constexpr _DeviceContextOps DeviceContext<T>::ops; 536 537 } // namespace impl 538 539 // shorten this one up a bit (and instantiate the templates) 540 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 541 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 542 543 // shorthand for what is an EXTREMELY long name 544 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 545 546 } // namespace cupm 547 548 } // namespace device 549 550 } // namespace Petsc 551 552 #endif // PETSCDEVICECONTEXTCUDA_HPP 553