1 #pragma once
2
3 #include <petsc/private/deviceimpl.h>
4 #include <petsc/private/cupmsolverinterface.hpp>
5
6 #include <cstring> // std::memset
7
PetscStrFreeAllocpy(const char target[],char ** dest)8 inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
9 {
10 PetscFunctionBegin;
11 PetscAssertPointer(dest, 2);
12 if (*dest) {
13 PetscAssertPointer(*dest, 2);
14 PetscCall(PetscFree(*dest));
15 }
16 PetscCall(PetscStrallocpy(target, dest));
17 PetscFunctionReturn(PETSC_SUCCESS);
18 }
19
20 namespace Petsc
21 {
22
23 namespace device
24 {
25
26 namespace cupm
27 {
28
29 namespace impl
30 {
31
32 namespace
33 {
34
35 // ==========================================================================================
36 // UseCUPMHostAllocGuard
37 //
38 // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
39 // regular versions would be an enormous pain to square with the templated types...
40 // ==========================================================================================
41 template <DeviceType T>
42 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL UseCUPMHostAllocGuard : Interface<T> {
43 public:
44 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
45
46 UseCUPMHostAllocGuard(bool) noexcept;
47 ~UseCUPMHostAllocGuard() noexcept;
48
49 PETSC_NODISCARD bool value() const noexcept;
50
51 private:
52 // would have loved to just do
53 //
54 // const auto oldmalloc = PetscTrMalloc;
55 //
56 // but in order to use auto the member needs to be static; in order to be static it must
57 // also be constexpr -- which in turn requires an initializer (also implicitly required by
58 // auto). But constexpr needs a constant expression initializer, so we can't initialize it
59 // with global (mutable) variables...
60 #define DECLTYPE_AUTO(left, right) decltype(right) left = right
61 const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
62 const DECLTYPE_AUTO(oldfree_, PetscTrFree);
63 const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
64 #undef DECLTYPE_AUTO
65 bool v_;
66 };
67
68 // ==========================================================================================
69 // UseCUPMHostAllocGuard -- Public API
70 // ==========================================================================================
71
72 template <DeviceType T>
UseCUPMHostAllocGuard(bool useit)73 inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
74 {
75 PetscFunctionBegin;
76 if (useit) {
77 // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
78 PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
79 PetscFunctionBegin;
80 PetscCallCUPM(cupmMallocHost(ptr, sz));
81 if (clear) std::memset(*ptr, 0, sz);
82 PetscFunctionReturn(PETSC_SUCCESS);
83 };
84 PetscTrFree = [](void *ptr, int, const char *, const char *) {
85 PetscFunctionBegin;
86 PetscCallCUPM(cupmFreeHost(ptr));
87 PetscFunctionReturn(PETSC_SUCCESS);
88 };
89 PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
90 // REVIEW ME: can be implemented by malloc->copy->free?
91 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
92 };
93 }
94 PetscFunctionReturnVoid();
95 }
96
97 template <DeviceType T>
~UseCUPMHostAllocGuard()98 inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
99 {
100 PetscFunctionBegin;
101 if (value()) {
102 PetscTrMalloc = oldmalloc_;
103 PetscTrFree = oldfree_;
104 PetscTrRealloc = oldrealloc_;
105 }
106 PetscFunctionReturnVoid();
107 }
108
109 template <DeviceType T>
value() const110 inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111 {
112 return v_;
113 }
114
115 } // anonymous namespace
116
117 template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL RestoreableArray : Interface<T> {
119 public:
120 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
121
122 static constexpr auto memory_type = MemoryType;
123 static constexpr auto access_type = AccessMode;
124
125 using value_type = PetscScalar;
126 using pointer_type = value_type *;
127 using cupm_pointer_type = cupmScalar_t *;
128
129 PETSC_NODISCARD pointer_type data() const noexcept;
130 PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;
131
132 operator pointer_type() const noexcept;
133 // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134 // we make a dummy template parameter to allow SFINAE to nix it for us
135 template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136 operator cupm_pointer_type() const noexcept;
137
138 protected:
139 constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;
140
141 value_type *ptr_ = nullptr;
142 PetscDeviceContext dctx_ = nullptr;
143 };
144
145 // ==========================================================================================
146 // RestoreableArray - Static Variables
147 // ==========================================================================================
148
149 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150 const PetscMemType RestoreableArray<T, MT, MA>::memory_type;
151
152 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153 const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;
154
155 // ==========================================================================================
156 // RestoreableArray - Public API
157 // ==========================================================================================
158
159 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
RestoreableArray(PetscDeviceContext dctx)160 constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161 {
162 }
163
164 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
data() const165 inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166 {
167 return ptr_;
168 }
169
170 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
cupmdata() const171 inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172 {
173 return cupmScalarPtrCast(data());
174 }
175
176 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
operator pointer_type() const177 inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178 {
179 return data();
180 }
181
182 // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183 // we make a dummy template parameter to allow SFINAE to nix it for us
184 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185 template <typename U, typename>
operator cupm_pointer_type() const186 inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187 {
188 return cupmdata();
189 }
190
191 template <DeviceType T>
192 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMObject : SolverInterface<T> {
193 protected:
194 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
195
196 private:
197 // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198 // compute handles and ensure the given PetscDeviceContext is of the right type
199 static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200 static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
201
202 protected:
203 PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;
204
205 // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206 // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207 // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208 // them check that the PetscDeviceContext is of the appropriate type
209 static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
210
211 // triple
212 static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213 static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
214
215 // double
216 static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217 static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;
218
219 // single
220 static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221 static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222 static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;
223
224 static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225 static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226 static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;
227
228 // disallow implicit conversion
229 template <typename U>
230 PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231 // utility for using cupmHostAlloc()
232 PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
233
234 // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
235 // actually correct. Errors on mismatch
236 static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
237 };
238
239 template <DeviceType T>
PETSCDEVICERAND()240 inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
241 {
242 // REVIEW ME: HIP default rng?
243 return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
244 }
245
246 template <DeviceType T>
GetFromHandleDispatch_(PetscDeviceContext dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream_handle)247 inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
248 {
249 PetscFunctionBegin;
250 PetscValidDeviceContext(dctx, 1);
251 if (blas_handle) {
252 PetscAssertPointer(blas_handle, 2);
253 *blas_handle = nullptr;
254 }
255 if (solver_handle) {
256 PetscAssertPointer(solver_handle, 3);
257 *solver_handle = nullptr;
258 }
259 if (stream_handle) {
260 PetscAssertPointer(stream_handle, 4);
261 *stream_handle = nullptr;
262 }
263 if (PetscDefined(USE_DEBUG)) {
264 PetscDeviceType dtype;
265
266 PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
267 PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
268 }
269 if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
270 if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
271 if (stream_handle) {
272 cupmStream_t *stream = nullptr;
273
274 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
275 *stream_handle = *stream;
276 }
277 PetscFunctionReturn(PETSC_SUCCESS);
278 }
279
280 template <DeviceType T>
GetHandleDispatch_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)281 inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
282 {
283 PetscDeviceContext dctx_loc = nullptr;
284
285 PetscFunctionBegin;
286 // silence uninitialized variable warnings
287 if (dctx) *dctx = nullptr;
288 PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
289 PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
290 if (dctx) *dctx = dctx_loc;
291 PetscFunctionReturn(PETSC_SUCCESS);
292 }
293
294 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)295 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
296 {
297 return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
298 }
299
300 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmStream_t * stream)301 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
302 {
303 return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
304 }
305
306 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)307 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
308 {
309 return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
310 }
311
312 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmStream_t * stream)313 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
314 {
315 return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
316 }
317
318 template <DeviceType T>
GetHandles_(cupmBlasHandle_t * handle)319 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
320 {
321 return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
322 }
323
324 template <DeviceType T>
GetHandles_(cupmSolverHandle_t * handle)325 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
326 {
327 return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
328 }
329
330 template <DeviceType T>
GetHandles_(cupmStream_t * stream)331 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
332 {
333 return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
334 }
335
336 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)337 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338 {
339 return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
340 }
341
342 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)343 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
344 {
345 return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
346 }
347
348 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmStream_t * stream)349 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
350 {
351 return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
352 }
353
354 template <DeviceType T>
UseCUPMHostAlloc(bool b)355 inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
356 {
357 return {b};
358 }
359
360 template <DeviceType T>
CheckPointerMatchesMemType_(const void * ptr,PetscMemType mtype)361 inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362 {
363 PetscFunctionBegin;
364 if (PetscDefined(USE_DEBUG) && ptr) {
365 PetscMemType ptr_mtype;
366
367 PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368 if (mtype == PETSC_MEMTYPE_HOST) {
369 PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370 } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371 // generic "device" memory should only care if the actual memtype is also generically
372 // "device"
373 PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374 } else {
375 PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376 }
377 }
378 PetscFunctionReturn(PETSC_SUCCESS);
379 }
380
381 #define PETSC_CUPMOBJECT_HEADER(T) \
382 PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383 using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384 using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385 using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386 using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387 using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_
388
389 } // namespace impl
390
391 } // namespace cupm
392
393 } // namespace device
394
395 } // namespace Petsc
396