1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmsolverinterface.hpp> 6 #include <petsc/private/logimpl.h> 7 8 #include <petsc/private/cpp/array.hpp> 9 10 #include "../segmentedmempool.hpp" 11 #include "cupmallocator.hpp" 12 #include "cupmstream.hpp" 13 #include "cupmevent.hpp" 14 15 #if defined(__cplusplus) 16 17 namespace Petsc 18 { 19 20 namespace device 21 { 22 23 namespace cupm 24 { 25 26 namespace impl 27 { 28 29 template <DeviceType T> 30 class DeviceContext : SolverInterface<T> { 31 public: 32 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 33 34 private: 35 template <typename H, std::size_t> 36 struct HandleTag { 37 using type = H; 38 }; 39 40 using stream_tag = HandleTag<cupmStream_t, 0>; 41 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 42 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 43 44 using stream_type = CUPMStream<T>; 45 using event_type = CUPMEvent<T>; 46 47 public: 48 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 49 // header, but since we are using the power of templates it must be declared part of 50 // this class to have easy access the same typedefs. Technically one can make a 51 // templated struct outside the class but it's more code for the same result. 52 struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> { 53 stream_type stream{}; 54 cupmEvent_t event{}; 55 cupmEvent_t begin{}; // timer-only 56 cupmEvent_t end{}; // timer-only 57 #if PetscDefined(USE_DEBUG) 58 PetscBool timerInUse{}; 59 #endif 60 cupmBlasHandle_t blas{}; 61 cupmSolverHandle_t solver{}; 62 63 constexpr PetscDeviceContext_IMPLS() noexcept = default; 64 65 PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); } 66 67 PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; } 68 69 PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; } 70 }; 71 72 private: 73 static bool initialized_; 74 75 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 76 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 77 78 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 79 80 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 81 82 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 83 84 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 85 86 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 87 // handles 88 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 89 90 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 91 { 92 const auto dci = impls_cast_(dctx); 93 auto &handle = blashandles_[dctx->device->deviceId]; 94 95 PetscFunctionBegin; 96 if (!handle) { 97 PetscLogEvent event; 98 99 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 100 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 101 for (auto i = 0; i < 3; ++i) { 102 const auto cberr = cupmBlasCreate(handle.ptr_to()); 103 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 104 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 105 if (i != 2) { 106 PetscCall(PetscSleep(3)); 107 continue; 108 } 109 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 110 } 111 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 112 PetscCall(PetscLogEventResume_Internal(event)); 113 } 114 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 115 dci->blas = handle; 116 PetscFunctionReturn(PETSC_SUCCESS); 117 } 118 119 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 120 { 121 const auto dci = impls_cast_(dctx); 122 auto &handle = solverhandles_[dctx->device->deviceId]; 123 124 PetscFunctionBegin; 125 if (!handle) { 126 PetscLogEvent event; 127 128 PetscCall(PetscLogPauseCurrentEvent_Internal(&event)); 129 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 130 for (auto i = 0; i < 3; ++i) { 131 const auto cerr = cupmSolverCreate(&handle); 132 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 133 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 134 if (i < 2) { 135 PetscCall(PetscSleep(3)); 136 continue; 137 } 138 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 139 } 140 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 141 PetscCall(PetscLogEventResume_Internal(event)); 142 } 143 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 144 dci->solver = handle; 145 PetscFunctionReturn(PETSC_SUCCESS); 146 } 147 148 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 149 { 150 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 151 152 PetscFunctionBegin; 153 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 154 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 155 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 156 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 157 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 158 PetscFunctionReturn(PETSC_SUCCESS); 159 } 160 161 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 162 163 static PetscErrorCode finalize_() noexcept 164 { 165 PetscFunctionBegin; 166 for (auto &&handle : blashandles_) { 167 if (handle) { 168 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 169 handle = nullptr; 170 } 171 } 172 173 for (auto &&handle : solverhandles_) { 174 if (handle) { 175 PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 176 handle = nullptr; 177 } 178 } 179 initialized_ = false; 180 PetscFunctionReturn(PETSC_SUCCESS); 181 } 182 183 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 184 PETSC_NODISCARD static PoolType &default_pool_() noexcept 185 { 186 static PoolType pool; 187 return pool; 188 } 189 190 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 191 { 192 PetscFunctionBegin; 193 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 194 PetscFunctionReturn(PETSC_SUCCESS); 195 } 196 197 public: 198 // All of these functions MUST be static in order to be callable from C, otherwise they 199 // get the implicit 'this' pointer tacked on 200 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 201 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 202 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 203 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 204 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 205 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 206 template <typename Handle_t> 207 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 208 template <typename Handle_t> 209 static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept; 210 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 211 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 212 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 213 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 214 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 215 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 216 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 217 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 218 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 219 220 // not a PetscDeviceContext method, this registers the class 221 static PetscErrorCode initialize(PetscDevice) noexcept; 222 223 // clang-format off 224 static constexpr _DeviceContextOps ops = { 225 PetscDesignatedInitializer(destroy, destroy), 226 PetscDesignatedInitializer(changestreamtype, changeStreamType), 227 PetscDesignatedInitializer(setup, setUp), 228 PetscDesignatedInitializer(query, query), 229 PetscDesignatedInitializer(waitforcontext, waitForContext), 230 PetscDesignatedInitializer(synchronize, synchronize), 231 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 232 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 233 PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>), 234 PetscDesignatedInitializer(begintimer, beginTimer), 235 PetscDesignatedInitializer(endtimer, endTimer), 236 PetscDesignatedInitializer(memalloc, memAlloc), 237 PetscDesignatedInitializer(memfree, memFree), 238 PetscDesignatedInitializer(memcopy, memCopy), 239 PetscDesignatedInitializer(memset, memSet), 240 PetscDesignatedInitializer(createevent, createEvent), 241 PetscDesignatedInitializer(recordevent, recordEvent), 242 PetscDesignatedInitializer(waitforevent, waitForEvent) 243 }; 244 // clang-format on 245 }; 246 247 // not a PetscDeviceContext method, this initializes the CLASS 248 template <DeviceType T> 249 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 250 { 251 PetscFunctionBegin; 252 if (PetscUnlikely(!initialized_)) { 253 uint64_t threshold = UINT64_MAX; 254 cupmMemPool_t mempool; 255 256 initialized_ = true; 257 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 258 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 259 blashandles_.fill(nullptr); 260 solverhandles_.fill(nullptr); 261 PetscCall(PetscRegisterFinalize(finalize_)); 262 } 263 PetscFunctionReturn(PETSC_SUCCESS); 264 } 265 266 template <DeviceType T> 267 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 268 { 269 PetscFunctionBegin; 270 if (const auto dci = impls_cast_(dctx)) { 271 PetscCall(dci->stream.destroy()); 272 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 273 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 274 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 275 delete dci; 276 dctx->data = nullptr; 277 } 278 PetscFunctionReturn(PETSC_SUCCESS); 279 } 280 281 template <DeviceType T> 282 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 283 { 284 const auto dci = impls_cast_(dctx); 285 286 PetscFunctionBegin; 287 PetscCall(dci->stream.destroy()); 288 // set these to null so they aren't usable until setup is called again 289 dci->blas = nullptr; 290 dci->solver = nullptr; 291 PetscFunctionReturn(PETSC_SUCCESS); 292 } 293 294 template <DeviceType T> 295 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 296 { 297 const auto dci = impls_cast_(dctx); 298 auto &event = dci->event; 299 300 PetscFunctionBegin; 301 PetscCall(check_current_device_(dctx)); 302 PetscCall(dci->stream.change_type(dctx->streamType)); 303 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 304 #if PetscDefined(USE_DEBUG) 305 dci->timerInUse = PETSC_FALSE; 306 #endif 307 PetscFunctionReturn(PETSC_SUCCESS); 308 } 309 310 template <DeviceType T> 311 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 312 { 313 PetscFunctionBegin; 314 PetscCall(check_current_device_(dctx)); 315 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 316 case cupmSuccess: 317 *idle = PETSC_TRUE; 318 break; 319 case cupmErrorNotReady: 320 *idle = PETSC_FALSE; 321 // reset the error 322 cerr = cupmGetLastError(); 323 static_cast<void>(cerr); 324 break; 325 default: 326 PetscCallCUPM(cerr); 327 PetscUnreachable(); 328 } 329 PetscFunctionReturn(PETSC_SUCCESS); 330 } 331 332 template <DeviceType T> 333 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 334 { 335 const auto dcib = impls_cast_(dctxb); 336 const auto event = dcib->event; 337 338 PetscFunctionBegin; 339 PetscCall(check_current_device_(dctxa, dctxb)); 340 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 341 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 342 PetscFunctionReturn(PETSC_SUCCESS); 343 } 344 345 template <DeviceType T> 346 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 347 { 348 auto idle = PETSC_TRUE; 349 350 PetscFunctionBegin; 351 PetscCall(query(dctx, &idle)); 352 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 353 PetscFunctionReturn(PETSC_SUCCESS); 354 } 355 356 template <DeviceType T> 357 template <typename handle_t> 358 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 359 { 360 PetscFunctionBegin; 361 PetscCall(initialize_handle_(handle_t{}, dctx)); 362 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 363 PetscFunctionReturn(PETSC_SUCCESS); 364 } 365 366 template <DeviceType T> 367 template <typename handle_t> 368 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept 369 { 370 using handle_type = typename handle_t::type; 371 372 PetscFunctionBegin; 373 PetscCall(initialize_handle_(handle_t{}, dctx)); 374 *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{}))); 375 PetscFunctionReturn(PETSC_SUCCESS); 376 } 377 378 template <DeviceType T> 379 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 380 { 381 const auto dci = impls_cast_(dctx); 382 383 PetscFunctionBegin; 384 PetscCall(check_current_device_(dctx)); 385 #if PetscDefined(USE_DEBUG) 386 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 387 dci->timerInUse = PETSC_TRUE; 388 #endif 389 if (!dci->begin) { 390 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 391 PetscCallCUPM(cupmEventCreate(&dci->begin)); 392 PetscCallCUPM(cupmEventCreate(&dci->end)); 393 } 394 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 395 PetscFunctionReturn(PETSC_SUCCESS); 396 } 397 398 template <DeviceType T> 399 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 400 { 401 float gtime; 402 const auto dci = impls_cast_(dctx); 403 const auto end = dci->end; 404 405 PetscFunctionBegin; 406 PetscCall(check_current_device_(dctx)); 407 #if PetscDefined(USE_DEBUG) 408 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 409 dci->timerInUse = PETSC_FALSE; 410 #endif 411 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 412 PetscCallCUPM(cupmEventSynchronize(end)); 413 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 414 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 415 PetscFunctionReturn(PETSC_SUCCESS); 416 } 417 418 template <DeviceType T> 419 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 420 { 421 const auto &stream = impls_cast_(dctx)->stream; 422 423 PetscFunctionBegin; 424 PetscCall(check_current_device_(dctx)); 425 PetscCall(check_memtype_(mtype, "allocating")); 426 if (PetscMemTypeHost(mtype)) { 427 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 428 } else { 429 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 430 } 431 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 432 PetscFunctionReturn(PETSC_SUCCESS); 433 } 434 435 template <DeviceType T> 436 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 437 { 438 const auto &stream = impls_cast_(dctx)->stream; 439 440 PetscFunctionBegin; 441 PetscCall(check_current_device_(dctx)); 442 PetscCall(check_memtype_(mtype, "freeing")); 443 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 444 if (PetscMemTypeHost(mtype)) { 445 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 446 // if ptr exists still exists the pool didn't own it 447 if (*ptr) { 448 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 449 450 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 451 if (registered) { 452 PetscCallCUPM(cupmFreeHost(*ptr)); 453 } else if (managed) { 454 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 455 } 456 } 457 } else { 458 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 459 // if ptr still exists the pool didn't own it 460 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 461 } 462 PetscFunctionReturn(PETSC_SUCCESS); 463 } 464 465 template <DeviceType T> 466 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 467 { 468 const auto stream = impls_cast_(dctx)->stream.get_stream(); 469 470 PetscFunctionBegin; 471 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 472 if (mode == PETSC_DEVICE_COPY_HTOH) { 473 const auto cerr = cupmStreamQuery(stream); 474 475 // yes this is faster 476 if (cerr == cupmSuccess) { 477 PetscCall(PetscMemcpy(dest, src, n)); 478 PetscFunctionReturn(PETSC_SUCCESS); 479 } else if (cerr == cupmErrorNotReady) { 480 auto PETSC_UNUSED unused = cupmGetLastError(); 481 482 static_cast<void>(unused); 483 } else { 484 PetscCallCUPM(cerr); 485 } 486 } 487 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 488 PetscFunctionReturn(PETSC_SUCCESS); 489 } 490 491 template <DeviceType T> 492 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 493 { 494 PetscFunctionBegin; 495 PetscCall(check_current_device_(dctx)); 496 PetscCall(check_memtype_(mtype, "zeroing")); 497 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 498 PetscFunctionReturn(PETSC_SUCCESS); 499 } 500 501 template <DeviceType T> 502 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 503 { 504 PetscFunctionBegin; 505 PetscCallCXX(event->data = new event_type()); 506 event->destroy = [](PetscEvent event) { 507 PetscFunctionBegin; 508 delete event_cast_(event); 509 event->data = nullptr; 510 PetscFunctionReturn(PETSC_SUCCESS); 511 }; 512 PetscFunctionReturn(PETSC_SUCCESS); 513 } 514 515 template <DeviceType T> 516 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 517 { 518 PetscFunctionBegin; 519 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 520 PetscFunctionReturn(PETSC_SUCCESS); 521 } 522 523 template <DeviceType T> 524 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 525 { 526 PetscFunctionBegin; 527 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 528 PetscFunctionReturn(PETSC_SUCCESS); 529 } 530 531 // initialize the static member variables 532 template <DeviceType T> 533 bool DeviceContext<T>::initialized_ = false; 534 535 template <DeviceType T> 536 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 537 538 template <DeviceType T> 539 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 540 541 template <DeviceType T> 542 constexpr _DeviceContextOps DeviceContext<T>::ops; 543 544 } // namespace impl 545 546 // shorten this one up a bit (and instantiate the templates) 547 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 548 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 549 550 // shorthand for what is an EXTREMELY long name 551 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 552 553 } // namespace cupm 554 555 } // namespace device 556 557 } // namespace Petsc 558 559 #endif // __cplusplus 560 561 #endif // PETSCDEVICECONTEXTCUDA_HPP 562