xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 5fa70555f2cfa5f8527759fb2fd8b5523acdf153)
1 #pragma once
2 
3 #include <petsc/private/deviceimpl.h>
4 #include <petsc/private/cupmsolverinterface.hpp>
5 #include <petsc/private/logimpl.h>
6 
7 #include <petsc/private/cpp/array.hpp>
8 
9 #include "../segmentedmempool.hpp"
10 #include "cupmallocator.hpp"
11 #include "cupmstream.hpp"
12 #include "cupmevent.hpp"
13 
14 namespace Petsc
15 {
16 
17 namespace device
18 {
19 
20 namespace cupm
21 {
22 
23 namespace impl
24 {
25 
26 template <DeviceType T>
27 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
28 public:
29   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
30 
31 private:
32   template <typename H, std::size_t>
33   struct HandleTag {
34     using type = H;
35   };
36 
37   using stream_tag = HandleTag<cupmStream_t, 0>;
38   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
39   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40 
41   using stream_type = CUPMStream<T>;
42   using event_type  = CUPMEvent<T>;
43 
44 public:
45   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46   // header, but since we are using the power of templates it must be declared part of
47   // this class to have easy access the same typedefs. Technically one can make a
48   // templated struct outside the class but it's more code for the same result.
49   struct PetscDeviceContext_IMPLS {
50     stream_type stream{};
51     cupmEvent_t event{};
52     cupmEvent_t begin{}; // timer-only
53     cupmEvent_t end{};   // timer-only
54 #if PetscDefined(USE_DEBUG)
55     PetscBool timerInUse{};
56     PetscBool EnergyMeterInUse{};
57 #endif
58     cupmBlasHandle_t   blas{};
59     cupmSolverHandle_t solver{};
60 #if PetscDefined(HAVE_CUDA)
61     nvmlDevice_t       nvmlHandle{};
62     unsigned long long energymeterbegin{};
63     unsigned long long energymeterend{};
64 #endif
65 
66     constexpr PetscDeviceContext_IMPLS() noexcept = default;
67 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS68     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
69 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS70     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
71 
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS72     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
73   };
74 
75 private:
76   static bool initialized_;
77 
78   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
79   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
80 
impls_cast_(PetscDeviceContext ptr)81   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
82 
event_cast_(PetscEvent event)83   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
84 
CUPMBLAS_HANDLE_CREATE()85   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
86 
CUPMSOLVER_HANDLE_CREATE()87   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
88 
89   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
90   // handles
initialize_handle_(stream_tag,PetscDeviceContext)91   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
92 
initialize_handle_(blas_tag,PetscDeviceContext dctx)93   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
94   {
95     const auto dci    = impls_cast_(dctx);
96     auto      &handle = blashandles_[dctx->device->deviceId];
97 
98     PetscFunctionBegin;
99     if (!handle) {
100       PetscCall(PetscLogEventsPause());
101       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
102       for (auto i = 0; i < 3; ++i) {
103         const auto cberr = cupmBlasCreate(handle.ptr_to());
104         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
105         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
106         if (i != 2) {
107           PetscCall(PetscSleep(3));
108           continue;
109         }
110         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111       }
112       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113       PetscCall(PetscLogEventsResume());
114     }
115     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
116     dci->blas = handle;
117     PetscFunctionReturn(PETSC_SUCCESS);
118   }
119 
initialize_handle_(solver_tag,PetscDeviceContext dctx)120   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121   {
122     const auto dci    = impls_cast_(dctx);
123     auto      &handle = solverhandles_[dctx->device->deviceId];
124 
125     PetscFunctionBegin;
126     if (!handle) {
127       PetscCall(PetscLogEventsPause());
128       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
129       for (auto i = 0; i < 3; ++i) {
130         const auto cerr = cupmSolverCreate(&handle);
131         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
132         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
133         if (i < 2) {
134           PetscCall(PetscSleep(3));
135           continue;
136         }
137         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
138       }
139       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140       PetscCall(PetscLogEventsResume());
141     }
142     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
143     dci->solver = handle;
144     PetscFunctionReturn(PETSC_SUCCESS);
145   }
146 
check_current_device_(PetscDeviceContext dctxl,PetscDeviceContext dctxr)147   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148   {
149     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
150 
151     PetscFunctionBegin;
152     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
153                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
154     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
155     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
156     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
157     PetscFunctionReturn(PETSC_SUCCESS);
158   }
159 
check_current_device_(PetscDeviceContext dctx)160   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
161 
finalize_()162   static PetscErrorCode finalize_() noexcept
163   {
164     PetscFunctionBegin;
165     for (auto &&handle : blashandles_) {
166       if (handle) {
167         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
168         handle = nullptr;
169       }
170     }
171     for (auto &&handle : solverhandles_) {
172       if (handle) {
173         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
174         handle = nullptr;
175       }
176     }
177     initialized_ = false;
178     PetscFunctionReturn(PETSC_SUCCESS);
179   }
180 
181   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
default_pool_()182   PETSC_NODISCARD static PoolType &default_pool_() noexcept
183   {
184     static PoolType pool;
185     return pool;
186   }
187 
check_memtype_(PetscMemType mtype,const char mess[])188   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189   {
190     PetscFunctionBegin;
191     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
192     PetscFunctionReturn(PETSC_SUCCESS);
193   }
194 
195 public:
196   // All of these functions MUST be static in order to be callable from C, otherwise they
197   // get the implicit 'this' pointer tacked on
198   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204   template <typename Handle_t>
205   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
206   template <typename Handle_t>
207   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210   static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211   static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212   static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
220 
221   // not a PetscDeviceContext method, this registers the class
222   static PetscErrorCode initialize(PetscDevice) noexcept;
223 
224   // clang-format off
225   static constexpr _DeviceContextOps ops = {
226     PetscDesignatedInitializer(destroy, destroy),
227     PetscDesignatedInitializer(changestreamtype, changeStreamType),
228     PetscDesignatedInitializer(setup, setUp),
229     PetscDesignatedInitializer(query, query),
230     PetscDesignatedInitializer(waitforcontext, waitForContext),
231     PetscDesignatedInitializer(synchronize, synchronize),
232     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
233     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
234     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
235     PetscDesignatedInitializer(begintimer, beginTimer),
236     PetscDesignatedInitializer(endtimer, endTimer),
237 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238     PetscDesignatedInitializer(getpower, getPower),
239 #else
240     PetscDesignatedInitializer(getpower, nullptr),
241 #endif
242 #if PetscDefined(HAVE_CUDA)
243     PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244     PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245 #else
246     PetscDesignatedInitializer(beginenergymeter, nullptr),
247     PetscDesignatedInitializer(endenergymeter, nullptr),
248 #endif
249     PetscDesignatedInitializer(memalloc, memAlloc),
250     PetscDesignatedInitializer(memfree, memFree),
251     PetscDesignatedInitializer(memcopy, memCopy),
252     PetscDesignatedInitializer(memset, memSet),
253     PetscDesignatedInitializer(createevent, createEvent),
254     PetscDesignatedInitializer(recordevent, recordEvent),
255     PetscDesignatedInitializer(waitforevent, waitForEvent)
256   };
257   // clang-format on
258 };
259 
260 // not a PetscDeviceContext method, this initializes the CLASS
261 template <DeviceType T>
initialize(PetscDevice device)262 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263 {
264   PetscFunctionBegin;
265   if (PetscUnlikely(!initialized_)) {
266     uint64_t      threshold = UINT64_MAX;
267     cupmMemPool_t mempool;
268 
269     initialized_ = true;
270     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
271     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
272     blashandles_.fill(nullptr);
273     solverhandles_.fill(nullptr);
274     PetscCall(PetscRegisterFinalize(finalize_));
275   }
276   PetscFunctionReturn(PETSC_SUCCESS);
277 }
278 
279 template <DeviceType T>
destroy(PetscDeviceContext dctx)280 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281 {
282   PetscFunctionBegin;
283   if (const auto dci = impls_cast_(dctx)) {
284     PetscCall(dci->stream.destroy());
285     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
286     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
287     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
288     delete dci;
289     dctx->data = nullptr;
290   }
291   PetscFunctionReturn(PETSC_SUCCESS);
292 }
293 
294 template <DeviceType T>
changeStreamType(PetscDeviceContext dctx,PETSC_UNUSED PetscStreamType stype)295 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296 {
297   const auto dci = impls_cast_(dctx);
298 
299   PetscFunctionBegin;
300   PetscCall(dci->stream.destroy());
301   // set these to null so they aren't usable until setup is called again
302   dci->blas   = nullptr;
303   dci->solver = nullptr;
304   PetscFunctionReturn(PETSC_SUCCESS);
305 }
306 
307 template <DeviceType T>
setUp(PetscDeviceContext dctx)308 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309 {
310   const auto dci   = impls_cast_(dctx);
311   auto      &event = dci->event;
312 
313   PetscFunctionBegin;
314   PetscCall(check_current_device_(dctx));
315   PetscCall(dci->stream.change_type(dctx->streamType));
316   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317 #if PetscDefined(USE_DEBUG)
318   dci->timerInUse = PETSC_FALSE;
319 #endif
320   PetscFunctionReturn(PETSC_SUCCESS);
321 }
322 
323 template <DeviceType T>
query(PetscDeviceContext dctx,PetscBool * idle)324 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325 {
326   PetscFunctionBegin;
327   PetscCall(check_current_device_(dctx));
328   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329   case cupmSuccess:
330     *idle = PETSC_TRUE;
331     break;
332   case cupmErrorNotReady:
333     *idle = PETSC_FALSE;
334     // reset the error
335     cerr = cupmGetLastError();
336     static_cast<void>(cerr);
337     break;
338   default:
339     PetscCallCUPM(cerr);
340     PetscUnreachable();
341   }
342   PetscFunctionReturn(PETSC_SUCCESS);
343 }
344 
345 template <DeviceType T>
waitForContext(PetscDeviceContext dctxa,PetscDeviceContext dctxb)346 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347 {
348   const auto dcib  = impls_cast_(dctxb);
349   const auto event = dcib->event;
350 
351   PetscFunctionBegin;
352   PetscCall(check_current_device_(dctxa, dctxb));
353   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
354   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 template <DeviceType T>
synchronize(PetscDeviceContext dctx)359 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360 {
361   auto idle = PETSC_TRUE;
362 
363   PetscFunctionBegin;
364   PetscCall(query(dctx, &idle));
365   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
366   PetscFunctionReturn(PETSC_SUCCESS);
367 }
368 
369 template <DeviceType T>
370 template <typename handle_t>
getHandle(PetscDeviceContext dctx,void * handle)371 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372 {
373   PetscFunctionBegin;
374   PetscCall(initialize_handle_(handle_t{}, dctx));
375   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
376   PetscFunctionReturn(PETSC_SUCCESS);
377 }
378 
379 template <DeviceType T>
380 template <typename handle_t>
getHandlePtr(PetscDeviceContext dctx,void ** handle)381 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
382 {
383   using handle_type = typename handle_t::type;
384 
385   PetscFunctionBegin;
386   PetscCall(initialize_handle_(handle_t{}, dctx));
387   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
388   PetscFunctionReturn(PETSC_SUCCESS);
389 }
390 
391 template <DeviceType T>
beginTimer(PetscDeviceContext dctx)392 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393 {
394   const auto dci = impls_cast_(dctx);
395 
396   PetscFunctionBegin;
397   PetscCall(check_current_device_(dctx));
398 #if PetscDefined(USE_DEBUG)
399   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400   dci->timerInUse = PETSC_TRUE;
401 #endif
402   if (!dci->begin) {
403     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
404     PetscCallCUPM(cupmEventCreate(&dci->begin));
405     PetscCallCUPM(cupmEventCreate(&dci->end));
406   }
407   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
408   PetscFunctionReturn(PETSC_SUCCESS);
409 }
410 
411 template <DeviceType T>
endTimer(PetscDeviceContext dctx,PetscLogDouble * elapsed)412 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413 {
414   float      gtime;
415   const auto dci = impls_cast_(dctx);
416   const auto end = dci->end;
417 
418   PetscFunctionBegin;
419   PetscCall(check_current_device_(dctx));
420 #if PetscDefined(USE_DEBUG)
421   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422   dci->timerInUse = PETSC_FALSE;
423 #endif
424   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
425   PetscCallCUPM(cupmEventSynchronize(end));
426   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
427   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
428   PetscFunctionReturn(PETSC_SUCCESS);
429 }
430 
431 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432 template <DeviceType T>
getPower(PetscDeviceContext dctx,PetscLogDouble * power)433 inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434 {
435   const auto       dci = impls_cast_(dctx);
436   nvmlFieldValue_t values[1];
437 
438   PetscFunctionBegin;
439   PetscCall(check_current_device_(dctx));
440   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441   values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443   PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444   *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445   PetscFunctionReturn(PETSC_SUCCESS);
446 }
447 #endif
448 
449 #if PetscDefined(HAVE_CUDA)
450 template <DeviceType T>
beginEnergyMeter(PetscDeviceContext dctx)451 inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452 {
453   const auto dci = impls_cast_(dctx);
454 
455   PetscFunctionBegin;
456   PetscCall(check_current_device_(dctx));
457   #if PetscDefined(USE_DEBUG)
458   PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459   dci->EnergyMeterInUse = PETSC_TRUE;
460   #endif
461   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463   PetscFunctionReturn(PETSC_SUCCESS);
464 }
465 
466 template <DeviceType T>
endEnergyMeter(PetscDeviceContext dctx,PetscLogDouble * energy)467 inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468 {
469   const auto dci = impls_cast_(dctx);
470 
471   PetscFunctionBegin;
472   PetscCall(check_current_device_(dctx));
473   #if PetscDefined(USE_DEBUG)
474   PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475   dci->EnergyMeterInUse = PETSC_FALSE;
476   #endif
477   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479   *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480   PetscFunctionReturn(PETSC_SUCCESS);
481 }
482 #endif
483 
484 template <DeviceType T>
memAlloc(PetscDeviceContext dctx,PetscBool clear,PetscMemType mtype,std::size_t n,std::size_t alignment,void ** dest)485 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486 {
487   const auto &stream = impls_cast_(dctx)->stream;
488 
489   PetscFunctionBegin;
490   PetscCall(check_current_device_(dctx));
491   PetscCall(check_memtype_(mtype, "allocating"));
492   if (PetscMemTypeHost(mtype)) {
493     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
494   } else {
495     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
496   }
497   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
498   PetscFunctionReturn(PETSC_SUCCESS);
499 }
500 
501 template <DeviceType T>
memFree(PetscDeviceContext dctx,PetscMemType mtype,void ** ptr)502 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503 {
504   const auto &stream = impls_cast_(dctx)->stream;
505 
506   PetscFunctionBegin;
507   PetscCall(check_current_device_(dctx));
508   PetscCall(check_memtype_(mtype, "freeing"));
509   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
510   if (PetscMemTypeHost(mtype)) {
511     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
512     // if ptr exists still exists the pool didn't own it
513     if (*ptr) {
514       auto registered = PETSC_FALSE, managed = PETSC_FALSE;
515 
516       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
517       if (registered) {
518         PetscCallCUPM(cupmFreeHost(*ptr));
519       } else if (managed) {
520         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
521       }
522     }
523   } else {
524     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
525     // if ptr still exists the pool didn't own it
526     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
527   }
528   PetscFunctionReturn(PETSC_SUCCESS);
529 }
530 
531 template <DeviceType T>
memCopy(PetscDeviceContext dctx,void * PETSC_RESTRICT dest,const void * PETSC_RESTRICT src,std::size_t n,PetscDeviceCopyMode mode)532 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533 {
534   const auto stream = impls_cast_(dctx)->stream.get_stream();
535 
536   PetscFunctionBegin;
537   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
538   if (mode == PETSC_DEVICE_COPY_HTOH) {
539     const auto cerr = cupmStreamQuery(stream);
540 
541     // yes this is faster
542     if (cerr == cupmSuccess) {
543       PetscCall(PetscMemcpy(dest, src, n));
544       PetscFunctionReturn(PETSC_SUCCESS);
545     } else if (cerr == cupmErrorNotReady) {
546       auto PETSC_UNUSED unused = cupmGetLastError();
547 
548       static_cast<void>(unused);
549     } else {
550       PetscCallCUPM(cerr);
551     }
552   }
553   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
554   PetscFunctionReturn(PETSC_SUCCESS);
555 }
556 
557 template <DeviceType T>
memSet(PetscDeviceContext dctx,PetscMemType mtype,void * ptr,PetscInt v,std::size_t n)558 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559 {
560   PetscFunctionBegin;
561   PetscCall(check_current_device_(dctx));
562   PetscCall(check_memtype_(mtype, "zeroing"));
563   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
564   PetscFunctionReturn(PETSC_SUCCESS);
565 }
566 
567 template <DeviceType T>
createEvent(PetscDeviceContext,PetscEvent event)568 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569 {
570   PetscFunctionBegin;
571   PetscCallCXX(event->data = new event_type{});
572   event->destroy = [](PetscEvent event) {
573     PetscFunctionBegin;
574     delete event_cast_(event);
575     event->data = nullptr;
576     PetscFunctionReturn(PETSC_SUCCESS);
577   };
578   PetscFunctionReturn(PETSC_SUCCESS);
579 }
580 
581 template <DeviceType T>
recordEvent(PetscDeviceContext dctx,PetscEvent event)582 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583 {
584   PetscFunctionBegin;
585   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
586   PetscFunctionReturn(PETSC_SUCCESS);
587 }
588 
589 template <DeviceType T>
waitForEvent(PetscDeviceContext dctx,PetscEvent event)590 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591 {
592   PetscFunctionBegin;
593   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
594   PetscFunctionReturn(PETSC_SUCCESS);
595 }
596 
597 // initialize the static member variables
598 template <DeviceType T>
599 bool DeviceContext<T>::initialized_ = false;
600 
601 template <DeviceType T>
602 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
603 
604 template <DeviceType T>
605 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
606 
607 template <DeviceType T>
608 constexpr _DeviceContextOps DeviceContext<T>::ops;
609 
610 } // namespace impl
611 
612 // shorten this one up a bit (and instantiate the templates)
613 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
614 using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;
615 
616 // shorthand for what is an EXTREMELY long name
617 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
618 
619 } // namespace cupm
620 
621 } // namespace device
622 
623 } // namespace Petsc
624