1 #pragma once 2 3 #include <petsc/private/deviceimpl.h> 4 #include <petsc/private/cupmsolverinterface.hpp> 5 #include <petsc/private/logimpl.h> 6 7 #include <petsc/private/cpp/array.hpp> 8 9 #include "../segmentedmempool.hpp" 10 #include "cupmallocator.hpp" 11 #include "cupmstream.hpp" 12 #include "cupmevent.hpp" 13 14 namespace Petsc 15 { 16 17 namespace device 18 { 19 20 namespace cupm 21 { 22 23 namespace impl 24 { 25 26 template <DeviceType T> 27 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> { 28 public: 29 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); 30 31 private: 32 template <typename H, std::size_t> 33 struct HandleTag { 34 using type = H; 35 }; 36 37 using stream_tag = HandleTag<cupmStream_t, 0>; 38 using blas_tag = HandleTag<cupmBlasHandle_t, 1>; 39 using solver_tag = HandleTag<cupmSolverHandle_t, 2>; 40 41 using stream_type = CUPMStream<T>; 42 using event_type = CUPMEvent<T>; 43 44 public: 45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46 // header, but since we are using the power of templates it must be declared part of 47 // this class to have easy access the same typedefs. Technically one can make a 48 // templated struct outside the class but it's more code for the same result. 49 struct PetscDeviceContext_IMPLS { 50 stream_type stream{}; 51 cupmEvent_t event{}; 52 cupmEvent_t begin{}; // timer-only 53 cupmEvent_t end{}; // timer-only 54 #if PetscDefined(USE_DEBUG) 55 PetscBool timerInUse{}; 56 PetscBool EnergyMeterInUse{}; 57 #endif 58 cupmBlasHandle_t blas{}; 59 cupmSolverHandle_t solver{}; 60 #if PetscDefined(HAVE_CUDA) 61 nvmlDevice_t nvmlHandle{}; 62 unsigned long long energymeterbegin{}; 63 unsigned long long energymeterend{}; 64 #endif 65 66 constexpr PetscDeviceContext_IMPLS() noexcept = default; 67 68 PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); } 69 70 PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; } 71 72 PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; } 73 }; 74 75 private: 76 static bool initialized_; 77 78 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_; 79 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_; 80 81 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); } 82 83 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); } 84 85 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; } 86 87 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; } 88 89 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other 90 // handles 91 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; } 92 93 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept 94 { 95 const auto dci = impls_cast_(dctx); 96 auto &handle = blashandles_[dctx->device->deviceId]; 97 98 PetscFunctionBegin; 99 if (!handle) { 100 PetscCall(PetscLogEventsPause()); 101 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 102 for (auto i = 0; i < 3; ++i) { 103 const auto cberr = cupmBlasCreate(handle.ptr_to()); 104 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 105 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 106 if (i != 2) { 107 PetscCall(PetscSleep(3)); 108 continue; 109 } 110 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName()); 111 } 112 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0)); 113 PetscCall(PetscLogEventsResume()); 114 } 115 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream())); 116 dci->blas = handle; 117 PetscFunctionReturn(PETSC_SUCCESS); 118 } 119 120 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept 121 { 122 const auto dci = impls_cast_(dctx); 123 auto &handle = solverhandles_[dctx->device->deviceId]; 124 125 PetscFunctionBegin; 126 if (!handle) { 127 PetscCall(PetscLogEventsPause()); 128 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 129 for (auto i = 0; i < 3; ++i) { 130 const auto cerr = cupmSolverCreate(&handle); 131 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break; 132 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr); 133 if (i < 2) { 134 PetscCall(PetscSleep(3)); 135 continue; 136 } 137 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName()); 138 } 139 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0)); 140 PetscCall(PetscLogEventsResume()); 141 } 142 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream())); 143 dci->solver = handle; 144 PetscFunctionReturn(PETSC_SUCCESS); 145 } 146 147 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept 148 { 149 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId; 150 151 PetscFunctionBegin; 152 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")", 153 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr); 154 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl)); 155 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr)); 156 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl))); 157 PetscFunctionReturn(PETSC_SUCCESS); 158 } 159 160 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); } 161 162 static PetscErrorCode finalize_() noexcept 163 { 164 PetscFunctionBegin; 165 for (auto &&handle : blashandles_) { 166 if (handle) { 167 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 168 handle = nullptr; 169 } 170 } 171 for (auto &&handle : solverhandles_) { 172 if (handle) { 173 PetscCallCUPMSOLVER(cupmSolverDestroy(handle)); 174 handle = nullptr; 175 } 176 } 177 initialized_ = false; 178 PetscFunctionReturn(PETSC_SUCCESS); 179 } 180 181 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>> 182 PETSC_NODISCARD static PoolType &default_pool_() noexcept 183 { 184 static PoolType pool; 185 return pool; 186 } 187 188 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept 189 { 190 PetscFunctionBegin; 191 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess); 192 PetscFunctionReturn(PETSC_SUCCESS); 193 } 194 195 public: 196 // All of these functions MUST be static in order to be callable from C, otherwise they 197 // get the implicit 'this' pointer tacked on 198 static PetscErrorCode destroy(PetscDeviceContext) noexcept; 199 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept; 200 static PetscErrorCode setUp(PetscDeviceContext) noexcept; 201 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept; 202 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept; 203 static PetscErrorCode synchronize(PetscDeviceContext) noexcept; 204 template <typename Handle_t> 205 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept; 206 template <typename Handle_t> 207 static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept; 208 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept; 209 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept; 210 static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept; 211 static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept; 212 static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept; 213 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept; 214 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept; 215 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept; 216 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept; 217 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept; 218 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept; 219 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept; 220 221 // not a PetscDeviceContext method, this registers the class 222 static PetscErrorCode initialize(PetscDevice) noexcept; 223 224 // clang-format off 225 static constexpr _DeviceContextOps ops = { 226 PetscDesignatedInitializer(destroy, destroy), 227 PetscDesignatedInitializer(changestreamtype, changeStreamType), 228 PetscDesignatedInitializer(setup, setUp), 229 PetscDesignatedInitializer(query, query), 230 PetscDesignatedInitializer(waitforcontext, waitForContext), 231 PetscDesignatedInitializer(synchronize, synchronize), 232 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>), 233 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>), 234 PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>), 235 PetscDesignatedInitializer(begintimer, beginTimer), 236 PetscDesignatedInitializer(endtimer, endTimer), 237 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS) 238 PetscDesignatedInitializer(getpower, getPower), 239 #else 240 PetscDesignatedInitializer(getpower, nullptr), 241 #endif 242 #if PetscDefined(HAVE_CUDA) 243 PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter), 244 PetscDesignatedInitializer(endenergymeter, endEnergyMeter), 245 #else 246 PetscDesignatedInitializer(beginenergymeter, nullptr), 247 PetscDesignatedInitializer(endenergymeter, nullptr), 248 #endif 249 PetscDesignatedInitializer(memalloc, memAlloc), 250 PetscDesignatedInitializer(memfree, memFree), 251 PetscDesignatedInitializer(memcopy, memCopy), 252 PetscDesignatedInitializer(memset, memSet), 253 PetscDesignatedInitializer(createevent, createEvent), 254 PetscDesignatedInitializer(recordevent, recordEvent), 255 PetscDesignatedInitializer(waitforevent, waitForEvent) 256 }; 257 // clang-format on 258 }; 259 260 // not a PetscDeviceContext method, this initializes the CLASS 261 template <DeviceType T> 262 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept 263 { 264 PetscFunctionBegin; 265 if (PetscUnlikely(!initialized_)) { 266 uint64_t threshold = UINT64_MAX; 267 cupmMemPool_t mempool; 268 269 initialized_ = true; 270 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId))); 271 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold)); 272 blashandles_.fill(nullptr); 273 solverhandles_.fill(nullptr); 274 PetscCall(PetscRegisterFinalize(finalize_)); 275 } 276 PetscFunctionReturn(PETSC_SUCCESS); 277 } 278 279 template <DeviceType T> 280 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept 281 { 282 PetscFunctionBegin; 283 if (const auto dci = impls_cast_(dctx)) { 284 PetscCall(dci->stream.destroy()); 285 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event)); 286 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 287 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 288 delete dci; 289 dctx->data = nullptr; 290 } 291 PetscFunctionReturn(PETSC_SUCCESS); 292 } 293 294 template <DeviceType T> 295 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept 296 { 297 const auto dci = impls_cast_(dctx); 298 299 PetscFunctionBegin; 300 PetscCall(dci->stream.destroy()); 301 // set these to null so they aren't usable until setup is called again 302 dci->blas = nullptr; 303 dci->solver = nullptr; 304 PetscFunctionReturn(PETSC_SUCCESS); 305 } 306 307 template <DeviceType T> 308 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept 309 { 310 const auto dci = impls_cast_(dctx); 311 auto &event = dci->event; 312 313 PetscFunctionBegin; 314 PetscCall(check_current_device_(dctx)); 315 PetscCall(dci->stream.change_type(dctx->streamType)); 316 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event)); 317 #if PetscDefined(USE_DEBUG) 318 dci->timerInUse = PETSC_FALSE; 319 #endif 320 PetscFunctionReturn(PETSC_SUCCESS); 321 } 322 323 template <DeviceType T> 324 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept 325 { 326 PetscFunctionBegin; 327 PetscCall(check_current_device_(dctx)); 328 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) { 329 case cupmSuccess: 330 *idle = PETSC_TRUE; 331 break; 332 case cupmErrorNotReady: 333 *idle = PETSC_FALSE; 334 // reset the error 335 cerr = cupmGetLastError(); 336 static_cast<void>(cerr); 337 break; 338 default: 339 PetscCallCUPM(cerr); 340 PetscUnreachable(); 341 } 342 PetscFunctionReturn(PETSC_SUCCESS); 343 } 344 345 template <DeviceType T> 346 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept 347 { 348 const auto dcib = impls_cast_(dctxb); 349 const auto event = dcib->event; 350 351 PetscFunctionBegin; 352 PetscCall(check_current_device_(dctxa, dctxb)); 353 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream())); 354 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0)); 355 PetscFunctionReturn(PETSC_SUCCESS); 356 } 357 358 template <DeviceType T> 359 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept 360 { 361 auto idle = PETSC_TRUE; 362 363 PetscFunctionBegin; 364 PetscCall(query(dctx, &idle)); 365 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream())); 366 PetscFunctionReturn(PETSC_SUCCESS); 367 } 368 369 template <DeviceType T> 370 template <typename handle_t> 371 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept 372 { 373 PetscFunctionBegin; 374 PetscCall(initialize_handle_(handle_t{}, dctx)); 375 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{}); 376 PetscFunctionReturn(PETSC_SUCCESS); 377 } 378 379 template <DeviceType T> 380 template <typename handle_t> 381 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept 382 { 383 using handle_type = typename handle_t::type; 384 385 PetscFunctionBegin; 386 PetscCall(initialize_handle_(handle_t{}, dctx)); 387 *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{}))); 388 PetscFunctionReturn(PETSC_SUCCESS); 389 } 390 391 template <DeviceType T> 392 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept 393 { 394 const auto dci = impls_cast_(dctx); 395 396 PetscFunctionBegin; 397 PetscCall(check_current_device_(dctx)); 398 #if PetscDefined(USE_DEBUG) 399 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?"); 400 dci->timerInUse = PETSC_TRUE; 401 #endif 402 if (!dci->begin) { 403 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event"); 404 PetscCallCUPM(cupmEventCreate(&dci->begin)); 405 PetscCallCUPM(cupmEventCreate(&dci->end)); 406 } 407 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream())); 408 PetscFunctionReturn(PETSC_SUCCESS); 409 } 410 411 template <DeviceType T> 412 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept 413 { 414 float gtime; 415 const auto dci = impls_cast_(dctx); 416 const auto end = dci->end; 417 418 PetscFunctionBegin; 419 PetscCall(check_current_device_(dctx)); 420 #if PetscDefined(USE_DEBUG) 421 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?"); 422 dci->timerInUse = PETSC_FALSE; 423 #endif 424 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream())); 425 PetscCallCUPM(cupmEventSynchronize(end)); 426 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end)); 427 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 428 PetscFunctionReturn(PETSC_SUCCESS); 429 } 430 431 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS) 432 template <DeviceType T> 433 inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept 434 { 435 const auto dci = impls_cast_(dctx); 436 nvmlFieldValue_t values[1]; 437 438 PetscFunctionBegin; 439 PetscCall(check_current_device_(dctx)); 440 PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream())); 441 values[0].fieldId = NVML_FI_DEV_POWER_INSTANT; 442 if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle)); 443 PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values)); 444 *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal); 445 PetscFunctionReturn(PETSC_SUCCESS); 446 } 447 #endif 448 449 #if PetscDefined(HAVE_CUDA) 450 template <DeviceType T> 451 inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept 452 { 453 const auto dci = impls_cast_(dctx); 454 455 PetscFunctionBegin; 456 PetscCall(check_current_device_(dctx)); 457 #if PetscDefined(USE_DEBUG) 458 PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?"); 459 dci->EnergyMeterInUse = PETSC_TRUE; 460 #endif 461 if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle)); 462 PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin)); 463 PetscFunctionReturn(PETSC_SUCCESS); 464 } 465 466 template <DeviceType T> 467 inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept 468 { 469 const auto dci = impls_cast_(dctx); 470 471 PetscFunctionBegin; 472 PetscCall(check_current_device_(dctx)); 473 #if PetscDefined(USE_DEBUG) 474 PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?"); 475 dci->EnergyMeterInUse = PETSC_FALSE; 476 #endif 477 PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream())); 478 PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend)); 479 *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule 480 PetscFunctionReturn(PETSC_SUCCESS); 481 } 482 #endif 483 484 template <DeviceType T> 485 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept 486 { 487 const auto &stream = impls_cast_(dctx)->stream; 488 489 PetscFunctionBegin; 490 PetscCall(check_current_device_(dctx)); 491 PetscCall(check_memtype_(mtype, "allocating")); 492 if (PetscMemTypeHost(mtype)) { 493 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 494 } else { 495 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment)); 496 } 497 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream())); 498 PetscFunctionReturn(PETSC_SUCCESS); 499 } 500 501 template <DeviceType T> 502 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept 503 { 504 const auto &stream = impls_cast_(dctx)->stream; 505 506 PetscFunctionBegin; 507 PetscCall(check_current_device_(dctx)); 508 PetscCall(check_memtype_(mtype, "freeing")); 509 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS); 510 if (PetscMemTypeHost(mtype)) { 511 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 512 // if ptr exists still exists the pool didn't own it 513 if (*ptr) { 514 auto registered = PETSC_FALSE, managed = PETSC_FALSE; 515 516 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed)); 517 if (registered) { 518 PetscCallCUPM(cupmFreeHost(*ptr)); 519 } else if (managed) { 520 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 521 } 522 } 523 } else { 524 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream)); 525 // if ptr still exists the pool didn't own it 526 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream())); 527 } 528 PetscFunctionReturn(PETSC_SUCCESS); 529 } 530 531 template <DeviceType T> 532 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept 533 { 534 const auto stream = impls_cast_(dctx)->stream.get_stream(); 535 536 PetscFunctionBegin; 537 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)... 538 if (mode == PETSC_DEVICE_COPY_HTOH) { 539 const auto cerr = cupmStreamQuery(stream); 540 541 // yes this is faster 542 if (cerr == cupmSuccess) { 543 PetscCall(PetscMemcpy(dest, src, n)); 544 PetscFunctionReturn(PETSC_SUCCESS); 545 } else if (cerr == cupmErrorNotReady) { 546 auto PETSC_UNUSED unused = cupmGetLastError(); 547 548 static_cast<void>(unused); 549 } else { 550 PetscCallCUPM(cerr); 551 } 552 } 553 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream)); 554 PetscFunctionReturn(PETSC_SUCCESS); 555 } 556 557 template <DeviceType T> 558 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept 559 { 560 PetscFunctionBegin; 561 PetscCall(check_current_device_(dctx)); 562 PetscCall(check_memtype_(mtype, "zeroing")); 563 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream())); 564 PetscFunctionReturn(PETSC_SUCCESS); 565 } 566 567 template <DeviceType T> 568 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept 569 { 570 PetscFunctionBegin; 571 PetscCallCXX(event->data = new event_type{}); 572 event->destroy = [](PetscEvent event) { 573 PetscFunctionBegin; 574 delete event_cast_(event); 575 event->data = nullptr; 576 PetscFunctionReturn(PETSC_SUCCESS); 577 }; 578 PetscFunctionReturn(PETSC_SUCCESS); 579 } 580 581 template <DeviceType T> 582 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 583 { 584 PetscFunctionBegin; 585 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event))); 586 PetscFunctionReturn(PETSC_SUCCESS); 587 } 588 589 template <DeviceType T> 590 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept 591 { 592 PetscFunctionBegin; 593 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event))); 594 PetscFunctionReturn(PETSC_SUCCESS); 595 } 596 597 // initialize the static member variables 598 template <DeviceType T> 599 bool DeviceContext<T>::initialized_ = false; 600 601 template <DeviceType T> 602 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 603 604 template <DeviceType T> 605 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 606 607 template <DeviceType T> 608 constexpr _DeviceContextOps DeviceContext<T>::ops; 609 610 } // namespace impl 611 612 // shorten this one up a bit (and instantiate the templates) 613 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>; 614 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>; 615 616 // shorthand for what is an EXTREMELY long name 617 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 618 619 } // namespace cupm 620 621 } // namespace device 622 623 } // namespace Petsc 624