1 #pragma once
2
3 #include <petsc/private/deviceimpl.h>
4 #include <petsc/private/cupmsolverinterface.hpp>
5 #include <petsc/private/logimpl.h>
6
7 #include <petsc/private/cpp/array.hpp>
8
9 #include "../segmentedmempool.hpp"
10 #include "cupmallocator.hpp"
11 #include "cupmstream.hpp"
12 #include "cupmevent.hpp"
13
14 namespace Petsc
15 {
16
17 namespace device
18 {
19
20 namespace cupm
21 {
22
23 namespace impl
24 {
25
26 template <DeviceType T>
27 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
28 public:
29 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
30
31 private:
32 template <typename H, std::size_t>
33 struct HandleTag {
34 using type = H;
35 };
36
37 using stream_tag = HandleTag<cupmStream_t, 0>;
38 using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
39 using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
40
41 using stream_type = CUPMStream<T>;
42 using event_type = CUPMEvent<T>;
43
44 public:
45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46 // header, but since we are using the power of templates it must be declared part of
47 // this class to have easy access the same typedefs. Technically one can make a
48 // templated struct outside the class but it's more code for the same result.
49 struct PetscDeviceContext_IMPLS {
50 stream_type stream{};
51 cupmEvent_t event{};
52 cupmEvent_t begin{}; // timer-only
53 cupmEvent_t end{}; // timer-only
54 #if PetscDefined(USE_DEBUG)
55 PetscBool timerInUse{};
56 PetscBool EnergyMeterInUse{};
57 #endif
58 cupmBlasHandle_t blas{};
59 cupmSolverHandle_t solver{};
60 #if PetscDefined(HAVE_CUDA)
61 nvmlDevice_t nvmlHandle{};
62 unsigned long long energymeterbegin{};
63 unsigned long long energymeterend{};
64 #endif
65
66 constexpr PetscDeviceContext_IMPLS() noexcept = default;
67
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS68 PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
69
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS70 PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
71
getPetsc::device::cupm::impl::DeviceContext::PetscDeviceContext_IMPLS72 PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
73 };
74
75 private:
76 static bool initialized_;
77
78 static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
79 static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
80
impls_cast_(PetscDeviceContext ptr)81 PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
82
event_cast_(PetscEvent event)83 PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
84
CUPMBLAS_HANDLE_CREATE()85 PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
86
CUPMSOLVER_HANDLE_CREATE()87 PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
88
89 // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
90 // handles
initialize_handle_(stream_tag,PetscDeviceContext)91 static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
92
initialize_handle_(blas_tag,PetscDeviceContext dctx)93 static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
94 {
95 const auto dci = impls_cast_(dctx);
96 auto &handle = blashandles_[dctx->device->deviceId];
97
98 PetscFunctionBegin;
99 if (!handle) {
100 PetscCall(PetscLogEventsPause());
101 PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
102 for (auto i = 0; i < 3; ++i) {
103 const auto cberr = cupmBlasCreate(handle.ptr_to());
104 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
105 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
106 if (i != 2) {
107 PetscCall(PetscSleep(3));
108 continue;
109 }
110 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111 }
112 PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113 PetscCall(PetscLogEventsResume());
114 }
115 PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
116 dci->blas = handle;
117 PetscFunctionReturn(PETSC_SUCCESS);
118 }
119
initialize_handle_(solver_tag,PetscDeviceContext dctx)120 static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121 {
122 const auto dci = impls_cast_(dctx);
123 auto &handle = solverhandles_[dctx->device->deviceId];
124
125 PetscFunctionBegin;
126 if (!handle) {
127 PetscCall(PetscLogEventsPause());
128 PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
129 for (auto i = 0; i < 3; ++i) {
130 const auto cerr = cupmSolverCreate(&handle);
131 if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
132 if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
133 if (i < 2) {
134 PetscCall(PetscSleep(3));
135 continue;
136 }
137 PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
138 }
139 PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140 PetscCall(PetscLogEventsResume());
141 }
142 PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
143 dci->solver = handle;
144 PetscFunctionReturn(PETSC_SUCCESS);
145 }
146
check_current_device_(PetscDeviceContext dctxl,PetscDeviceContext dctxr)147 static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148 {
149 const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
150
151 PetscFunctionBegin;
152 PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
153 PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
154 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
155 PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
156 PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
157 PetscFunctionReturn(PETSC_SUCCESS);
158 }
159
check_current_device_(PetscDeviceContext dctx)160 static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
161
finalize_()162 static PetscErrorCode finalize_() noexcept
163 {
164 PetscFunctionBegin;
165 for (auto &&handle : blashandles_) {
166 if (handle) {
167 PetscCallCUPMBLAS(cupmBlasDestroy(handle));
168 handle = nullptr;
169 }
170 }
171 for (auto &&handle : solverhandles_) {
172 if (handle) {
173 PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
174 handle = nullptr;
175 }
176 }
177 initialized_ = false;
178 PetscFunctionReturn(PETSC_SUCCESS);
179 }
180
181 template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
default_pool_()182 PETSC_NODISCARD static PoolType &default_pool_() noexcept
183 {
184 static PoolType pool;
185 return pool;
186 }
187
check_memtype_(PetscMemType mtype,const char mess[])188 static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189 {
190 PetscFunctionBegin;
191 PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
192 PetscFunctionReturn(PETSC_SUCCESS);
193 }
194
195 public:
196 // All of these functions MUST be static in order to be callable from C, otherwise they
197 // get the implicit 'this' pointer tacked on
198 static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199 static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200 static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201 static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202 static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203 static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204 template <typename Handle_t>
205 static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
206 template <typename Handle_t>
207 static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208 static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209 static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210 static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211 static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212 static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213 static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214 static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215 static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216 static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217 static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218 static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219 static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
220
221 // not a PetscDeviceContext method, this registers the class
222 static PetscErrorCode initialize(PetscDevice) noexcept;
223
224 // clang-format off
225 static constexpr _DeviceContextOps ops = {
226 PetscDesignatedInitializer(destroy, destroy),
227 PetscDesignatedInitializer(changestreamtype, changeStreamType),
228 PetscDesignatedInitializer(setup, setUp),
229 PetscDesignatedInitializer(query, query),
230 PetscDesignatedInitializer(waitforcontext, waitForContext),
231 PetscDesignatedInitializer(synchronize, synchronize),
232 PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
233 PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
234 PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
235 PetscDesignatedInitializer(begintimer, beginTimer),
236 PetscDesignatedInitializer(endtimer, endTimer),
237 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238 PetscDesignatedInitializer(getpower, getPower),
239 #else
240 PetscDesignatedInitializer(getpower, nullptr),
241 #endif
242 #if PetscDefined(HAVE_CUDA)
243 PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244 PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245 #else
246 PetscDesignatedInitializer(beginenergymeter, nullptr),
247 PetscDesignatedInitializer(endenergymeter, nullptr),
248 #endif
249 PetscDesignatedInitializer(memalloc, memAlloc),
250 PetscDesignatedInitializer(memfree, memFree),
251 PetscDesignatedInitializer(memcopy, memCopy),
252 PetscDesignatedInitializer(memset, memSet),
253 PetscDesignatedInitializer(createevent, createEvent),
254 PetscDesignatedInitializer(recordevent, recordEvent),
255 PetscDesignatedInitializer(waitforevent, waitForEvent)
256 };
257 // clang-format on
258 };
259
260 // not a PetscDeviceContext method, this initializes the CLASS
261 template <DeviceType T>
initialize(PetscDevice device)262 inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263 {
264 PetscFunctionBegin;
265 if (PetscUnlikely(!initialized_)) {
266 uint64_t threshold = UINT64_MAX;
267 cupmMemPool_t mempool;
268
269 initialized_ = true;
270 PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
271 PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
272 blashandles_.fill(nullptr);
273 solverhandles_.fill(nullptr);
274 PetscCall(PetscRegisterFinalize(finalize_));
275 }
276 PetscFunctionReturn(PETSC_SUCCESS);
277 }
278
279 template <DeviceType T>
destroy(PetscDeviceContext dctx)280 inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281 {
282 PetscFunctionBegin;
283 if (const auto dci = impls_cast_(dctx)) {
284 PetscCall(dci->stream.destroy());
285 if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
286 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
287 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
288 delete dci;
289 dctx->data = nullptr;
290 }
291 PetscFunctionReturn(PETSC_SUCCESS);
292 }
293
294 template <DeviceType T>
changeStreamType(PetscDeviceContext dctx,PETSC_UNUSED PetscStreamType stype)295 inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296 {
297 const auto dci = impls_cast_(dctx);
298
299 PetscFunctionBegin;
300 PetscCall(dci->stream.destroy());
301 // set these to null so they aren't usable until setup is called again
302 dci->blas = nullptr;
303 dci->solver = nullptr;
304 PetscFunctionReturn(PETSC_SUCCESS);
305 }
306
307 template <DeviceType T>
setUp(PetscDeviceContext dctx)308 inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309 {
310 const auto dci = impls_cast_(dctx);
311 auto &event = dci->event;
312
313 PetscFunctionBegin;
314 PetscCall(check_current_device_(dctx));
315 PetscCall(dci->stream.change_type(dctx->streamType));
316 if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317 #if PetscDefined(USE_DEBUG)
318 dci->timerInUse = PETSC_FALSE;
319 #endif
320 PetscFunctionReturn(PETSC_SUCCESS);
321 }
322
323 template <DeviceType T>
query(PetscDeviceContext dctx,PetscBool * idle)324 inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325 {
326 PetscFunctionBegin;
327 PetscCall(check_current_device_(dctx));
328 switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329 case cupmSuccess:
330 *idle = PETSC_TRUE;
331 break;
332 case cupmErrorNotReady:
333 *idle = PETSC_FALSE;
334 // reset the error
335 cerr = cupmGetLastError();
336 static_cast<void>(cerr);
337 break;
338 default:
339 PetscCallCUPM(cerr);
340 PetscUnreachable();
341 }
342 PetscFunctionReturn(PETSC_SUCCESS);
343 }
344
345 template <DeviceType T>
waitForContext(PetscDeviceContext dctxa,PetscDeviceContext dctxb)346 inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347 {
348 const auto dcib = impls_cast_(dctxb);
349 const auto event = dcib->event;
350
351 PetscFunctionBegin;
352 PetscCall(check_current_device_(dctxa, dctxb));
353 PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
354 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
355 PetscFunctionReturn(PETSC_SUCCESS);
356 }
357
358 template <DeviceType T>
synchronize(PetscDeviceContext dctx)359 inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360 {
361 auto idle = PETSC_TRUE;
362
363 PetscFunctionBegin;
364 PetscCall(query(dctx, &idle));
365 if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
366 PetscFunctionReturn(PETSC_SUCCESS);
367 }
368
369 template <DeviceType T>
370 template <typename handle_t>
getHandle(PetscDeviceContext dctx,void * handle)371 inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372 {
373 PetscFunctionBegin;
374 PetscCall(initialize_handle_(handle_t{}, dctx));
375 *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
376 PetscFunctionReturn(PETSC_SUCCESS);
377 }
378
379 template <DeviceType T>
380 template <typename handle_t>
getHandlePtr(PetscDeviceContext dctx,void ** handle)381 inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
382 {
383 using handle_type = typename handle_t::type;
384
385 PetscFunctionBegin;
386 PetscCall(initialize_handle_(handle_t{}, dctx));
387 *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
388 PetscFunctionReturn(PETSC_SUCCESS);
389 }
390
391 template <DeviceType T>
beginTimer(PetscDeviceContext dctx)392 inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393 {
394 const auto dci = impls_cast_(dctx);
395
396 PetscFunctionBegin;
397 PetscCall(check_current_device_(dctx));
398 #if PetscDefined(USE_DEBUG)
399 PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400 dci->timerInUse = PETSC_TRUE;
401 #endif
402 if (!dci->begin) {
403 PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
404 PetscCallCUPM(cupmEventCreate(&dci->begin));
405 PetscCallCUPM(cupmEventCreate(&dci->end));
406 }
407 PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
408 PetscFunctionReturn(PETSC_SUCCESS);
409 }
410
411 template <DeviceType T>
endTimer(PetscDeviceContext dctx,PetscLogDouble * elapsed)412 inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413 {
414 float gtime;
415 const auto dci = impls_cast_(dctx);
416 const auto end = dci->end;
417
418 PetscFunctionBegin;
419 PetscCall(check_current_device_(dctx));
420 #if PetscDefined(USE_DEBUG)
421 PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422 dci->timerInUse = PETSC_FALSE;
423 #endif
424 PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
425 PetscCallCUPM(cupmEventSynchronize(end));
426 PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end));
427 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
428 PetscFunctionReturn(PETSC_SUCCESS);
429 }
430
431 #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432 template <DeviceType T>
getPower(PetscDeviceContext dctx,PetscLogDouble * power)433 inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434 {
435 const auto dci = impls_cast_(dctx);
436 nvmlFieldValue_t values[1];
437
438 PetscFunctionBegin;
439 PetscCall(check_current_device_(dctx));
440 PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441 values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442 if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443 PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444 *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445 PetscFunctionReturn(PETSC_SUCCESS);
446 }
447 #endif
448
449 #if PetscDefined(HAVE_CUDA)
450 template <DeviceType T>
beginEnergyMeter(PetscDeviceContext dctx)451 inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452 {
453 const auto dci = impls_cast_(dctx);
454
455 PetscFunctionBegin;
456 PetscCall(check_current_device_(dctx));
457 #if PetscDefined(USE_DEBUG)
458 PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459 dci->EnergyMeterInUse = PETSC_TRUE;
460 #endif
461 if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462 PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463 PetscFunctionReturn(PETSC_SUCCESS);
464 }
465
466 template <DeviceType T>
endEnergyMeter(PetscDeviceContext dctx,PetscLogDouble * energy)467 inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468 {
469 const auto dci = impls_cast_(dctx);
470
471 PetscFunctionBegin;
472 PetscCall(check_current_device_(dctx));
473 #if PetscDefined(USE_DEBUG)
474 PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475 dci->EnergyMeterInUse = PETSC_FALSE;
476 #endif
477 PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478 PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479 *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480 PetscFunctionReturn(PETSC_SUCCESS);
481 }
482 #endif
483
484 template <DeviceType T>
memAlloc(PetscDeviceContext dctx,PetscBool clear,PetscMemType mtype,std::size_t n,std::size_t alignment,void ** dest)485 inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486 {
487 const auto &stream = impls_cast_(dctx)->stream;
488
489 PetscFunctionBegin;
490 PetscCall(check_current_device_(dctx));
491 PetscCall(check_memtype_(mtype, "allocating"));
492 if (PetscMemTypeHost(mtype)) {
493 PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
494 } else {
495 PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
496 }
497 if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
498 PetscFunctionReturn(PETSC_SUCCESS);
499 }
500
501 template <DeviceType T>
memFree(PetscDeviceContext dctx,PetscMemType mtype,void ** ptr)502 inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503 {
504 const auto &stream = impls_cast_(dctx)->stream;
505
506 PetscFunctionBegin;
507 PetscCall(check_current_device_(dctx));
508 PetscCall(check_memtype_(mtype, "freeing"));
509 if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
510 if (PetscMemTypeHost(mtype)) {
511 PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
512 // if ptr exists still exists the pool didn't own it
513 if (*ptr) {
514 auto registered = PETSC_FALSE, managed = PETSC_FALSE;
515
516 PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed));
517 if (registered) {
518 PetscCallCUPM(cupmFreeHost(*ptr));
519 } else if (managed) {
520 PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
521 }
522 }
523 } else {
524 PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
525 // if ptr still exists the pool didn't own it
526 if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
527 }
528 PetscFunctionReturn(PETSC_SUCCESS);
529 }
530
531 template <DeviceType T>
memCopy(PetscDeviceContext dctx,void * PETSC_RESTRICT dest,const void * PETSC_RESTRICT src,std::size_t n,PetscDeviceCopyMode mode)532 inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533 {
534 const auto stream = impls_cast_(dctx)->stream.get_stream();
535
536 PetscFunctionBegin;
537 // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
538 if (mode == PETSC_DEVICE_COPY_HTOH) {
539 const auto cerr = cupmStreamQuery(stream);
540
541 // yes this is faster
542 if (cerr == cupmSuccess) {
543 PetscCall(PetscMemcpy(dest, src, n));
544 PetscFunctionReturn(PETSC_SUCCESS);
545 } else if (cerr == cupmErrorNotReady) {
546 auto PETSC_UNUSED unused = cupmGetLastError();
547
548 static_cast<void>(unused);
549 } else {
550 PetscCallCUPM(cerr);
551 }
552 }
553 PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
554 PetscFunctionReturn(PETSC_SUCCESS);
555 }
556
557 template <DeviceType T>
memSet(PetscDeviceContext dctx,PetscMemType mtype,void * ptr,PetscInt v,std::size_t n)558 inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559 {
560 PetscFunctionBegin;
561 PetscCall(check_current_device_(dctx));
562 PetscCall(check_memtype_(mtype, "zeroing"));
563 PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
564 PetscFunctionReturn(PETSC_SUCCESS);
565 }
566
567 template <DeviceType T>
createEvent(PetscDeviceContext,PetscEvent event)568 inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569 {
570 PetscFunctionBegin;
571 PetscCallCXX(event->data = new event_type{});
572 event->destroy = [](PetscEvent event) {
573 PetscFunctionBegin;
574 delete event_cast_(event);
575 event->data = nullptr;
576 PetscFunctionReturn(PETSC_SUCCESS);
577 };
578 PetscFunctionReturn(PETSC_SUCCESS);
579 }
580
581 template <DeviceType T>
recordEvent(PetscDeviceContext dctx,PetscEvent event)582 inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583 {
584 PetscFunctionBegin;
585 PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
586 PetscFunctionReturn(PETSC_SUCCESS);
587 }
588
589 template <DeviceType T>
waitForEvent(PetscDeviceContext dctx,PetscEvent event)590 inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591 {
592 PetscFunctionBegin;
593 PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
594 PetscFunctionReturn(PETSC_SUCCESS);
595 }
596
597 // initialize the static member variables
598 template <DeviceType T>
599 bool DeviceContext<T>::initialized_ = false;
600
601 template <DeviceType T>
602 std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
603
604 template <DeviceType T>
605 std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
606
607 template <DeviceType T>
608 constexpr _DeviceContextOps DeviceContext<T>::ops;
609
610 } // namespace impl
611
612 // shorten this one up a bit (and instantiate the templates)
613 using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
614 using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
615
616 // shorthand for what is an EXTREMELY long name
617 #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
618
619 } // namespace cupm
620
621 } // namespace device
622
623 } // namespace Petsc
624