xref: /petsc/include/petsc/private/cupmobject.hpp (revision 51b144c619aff302b570817d6f78637b8418d403)
1 #pragma once
2 
3 #include <petsc/private/deviceimpl.h>
4 #include <petsc/private/cupmsolverinterface.hpp>
5 
6 #include <cstring> // std::memset
7 
PetscStrFreeAllocpy(const char target[],char ** dest)8 inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
9 {
10   PetscFunctionBegin;
11   PetscAssertPointer(dest, 2);
12   if (*dest) {
13     PetscAssertPointer(*dest, 2);
14     PetscCall(PetscFree(*dest));
15   }
16   PetscCall(PetscStrallocpy(target, dest));
17   PetscFunctionReturn(PETSC_SUCCESS);
18 }
19 
20 namespace Petsc
21 {
22 
23 namespace device
24 {
25 
26 namespace cupm
27 {
28 
29 namespace impl
30 {
31 
32 namespace
33 {
34 
35 // ==========================================================================================
36 // UseCUPMHostAllocGuard
37 //
38 // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
39 // regular versions would be an enormous pain to square with the templated types...
40 // ==========================================================================================
41 template <DeviceType T>
42 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL UseCUPMHostAllocGuard : Interface<T> {
43 public:
44   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
45 
46   UseCUPMHostAllocGuard(bool) noexcept;
47   ~UseCUPMHostAllocGuard() noexcept;
48 
49   PETSC_NODISCARD bool value() const noexcept;
50 
51 private:
52   // would have loved to just do
53   //
54   // const auto oldmalloc = PetscTrMalloc;
55   //
56   // but in order to use auto the member needs to be static; in order to be static it must
57   // also be constexpr -- which in turn requires an initializer (also implicitly required by
58   // auto). But constexpr needs a constant expression initializer, so we can't initialize it
59   // with global (mutable) variables...
60 #define DECLTYPE_AUTO(left, right) decltype(right) left = right
61   const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
62   const DECLTYPE_AUTO(oldfree_, PetscTrFree);
63   const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
64 #undef DECLTYPE_AUTO
65   bool v_;
66 };
67 
68 // ==========================================================================================
69 // UseCUPMHostAllocGuard -- Public API
70 // ==========================================================================================
71 
72 template <DeviceType T>
UseCUPMHostAllocGuard(bool useit)73 inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
74 {
75   PetscFunctionBegin;
76   if (useit) {
77     // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
78     PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
79       PetscFunctionBegin;
80       PetscCallCUPM(cupmMallocHost(ptr, sz));
81       if (clear) std::memset(*ptr, 0, sz);
82       PetscFunctionReturn(PETSC_SUCCESS);
83     };
84     PetscTrFree = [](void *ptr, int, const char *, const char *) {
85       PetscFunctionBegin;
86       PetscCallCUPM(cupmFreeHost(ptr));
87       PetscFunctionReturn(PETSC_SUCCESS);
88     };
89     PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
90       // REVIEW ME: can be implemented by malloc->copy->free?
91       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
92     };
93   }
94   PetscFunctionReturnVoid();
95 }
96 
97 template <DeviceType T>
~UseCUPMHostAllocGuard()98 inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
99 {
100   PetscFunctionBegin;
101   if (value()) {
102     PetscTrMalloc  = oldmalloc_;
103     PetscTrFree    = oldfree_;
104     PetscTrRealloc = oldrealloc_;
105   }
106   PetscFunctionReturnVoid();
107 }
108 
109 template <DeviceType T>
value() const110 inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111 {
112   return v_;
113 }
114 
115 } // anonymous namespace
116 
117 template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL RestoreableArray : Interface<T> {
119 public:
120   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
121 
122   static constexpr auto memory_type = MemoryType;
123   static constexpr auto access_type = AccessMode;
124 
125   using value_type        = PetscScalar;
126   using pointer_type      = value_type *;
127   using cupm_pointer_type = cupmScalar_t *;
128 
129   PETSC_NODISCARD pointer_type      data() const noexcept;
130   PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;
131 
132   operator pointer_type() const noexcept;
133   // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134   // we make a dummy template parameter to allow SFINAE to nix it for us
135   template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136   operator cupm_pointer_type() const noexcept;
137 
138 protected:
139   constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;
140 
141   value_type        *ptr_  = nullptr;
142   PetscDeviceContext dctx_ = nullptr;
143 };
144 
145 // ==========================================================================================
146 // RestoreableArray - Static Variables
147 // ==========================================================================================
148 
149 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150 const PetscMemType RestoreableArray<T, MT, MA>::memory_type;
151 
152 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153 const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;
154 
155 // ==========================================================================================
156 // RestoreableArray - Public API
157 // ==========================================================================================
158 
159 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
RestoreableArray(PetscDeviceContext dctx)160 constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161 {
162 }
163 
164 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
data() const165 inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166 {
167   return ptr_;
168 }
169 
170 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
cupmdata() const171 inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172 {
173   return cupmScalarPtrCast(data());
174 }
175 
176 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
operator pointer_type() const177 inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178 {
179   return data();
180 }
181 
182 // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183 // we make a dummy template parameter to allow SFINAE to nix it for us
184 template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185 template <typename U, typename>
operator cupm_pointer_type() const186 inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187 {
188   return cupmdata();
189 }
190 
191 template <DeviceType T>
192 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMObject : SolverInterface<T> {
193 protected:
194   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
195 
196 private:
197   // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198   // compute handles and ensure the given PetscDeviceContext is of the right type
199   static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200   static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
201 
202 protected:
203   PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;
204 
205   // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206   // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207   // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208   // them check that the PetscDeviceContext is of the appropriate type
209   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
210 
211   // triple
212   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
214 
215   // double
216   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;
218 
219   // single
220   static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221   static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222   static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;
223 
224   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;
227 
228   // disallow implicit conversion
229   template <typename U>
230   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231   // utility for using cupmHostAlloc()
232   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
233 
234   // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
235   // actually correct. Errors on mismatch
236   static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
237 };
238 
239 template <DeviceType T>
PETSCDEVICERAND()240 inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
241 {
242   // REVIEW ME: HIP default rng?
243   return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
244 }
245 
246 template <DeviceType T>
GetFromHandleDispatch_(PetscDeviceContext dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream_handle)247 inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
248 {
249   PetscFunctionBegin;
250   PetscValidDeviceContext(dctx, 1);
251   if (blas_handle) {
252     PetscAssertPointer(blas_handle, 2);
253     *blas_handle = nullptr;
254   }
255   if (solver_handle) {
256     PetscAssertPointer(solver_handle, 3);
257     *solver_handle = nullptr;
258   }
259   if (stream_handle) {
260     PetscAssertPointer(stream_handle, 4);
261     *stream_handle = nullptr;
262   }
263   if (PetscDefined(USE_DEBUG)) {
264     PetscDeviceType dtype;
265 
266     PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
267     PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
268   }
269   if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
270   if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
271   if (stream_handle) {
272     cupmStream_t *stream = nullptr;
273 
274     PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
275     *stream_handle = *stream;
276   }
277   PetscFunctionReturn(PETSC_SUCCESS);
278 }
279 
280 template <DeviceType T>
GetHandleDispatch_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)281 inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
282 {
283   PetscDeviceContext dctx_loc = nullptr;
284 
285   PetscFunctionBegin;
286   // silence uninitialized variable warnings
287   if (dctx) *dctx = nullptr;
288   PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
289   PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
290   if (dctx) *dctx = dctx_loc;
291   PetscFunctionReturn(PETSC_SUCCESS);
292 }
293 
294 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)295 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
296 {
297   return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
298 }
299 
300 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmBlasHandle_t * blas_handle,cupmStream_t * stream)301 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
302 {
303   return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
304 }
305 
306 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)307 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
308 {
309   return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
310 }
311 
312 template <DeviceType T>
GetHandles_(PetscDeviceContext * dctx,cupmStream_t * stream)313 inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
314 {
315   return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
316 }
317 
318 template <DeviceType T>
GetHandles_(cupmBlasHandle_t * handle)319 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
320 {
321   return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
322 }
323 
324 template <DeviceType T>
GetHandles_(cupmSolverHandle_t * handle)325 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
326 {
327   return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
328 }
329 
330 template <DeviceType T>
GetHandles_(cupmStream_t * stream)331 inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
332 {
333   return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
334 }
335 
336 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmBlasHandle_t * blas_handle,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)337 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338 {
339   return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
340 }
341 
342 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmSolverHandle_t * solver_handle,cupmStream_t * stream)343 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
344 {
345   return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
346 }
347 
348 template <DeviceType T>
GetHandlesFrom_(PetscDeviceContext dctx,cupmStream_t * stream)349 inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
350 {
351   return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
352 }
353 
354 template <DeviceType T>
UseCUPMHostAlloc(bool b)355 inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
356 {
357   return {b};
358 }
359 
360 template <DeviceType T>
CheckPointerMatchesMemType_(const void * ptr,PetscMemType mtype)361 inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362 {
363   PetscFunctionBegin;
364   if (PetscDefined(USE_DEBUG) && ptr) {
365     PetscMemType ptr_mtype;
366 
367     PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368     if (mtype == PETSC_MEMTYPE_HOST) {
369       PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370     } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371       // generic "device" memory should only care if the actual memtype is also generically
372       // "device"
373       PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374     } else {
375       PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376     }
377   }
378   PetscFunctionReturn(PETSC_SUCCESS);
379 }
380 
381 #define PETSC_CUPMOBJECT_HEADER(T) \
382   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383   using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386   using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387   using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_
388 
389 } // namespace impl
390 
391 } // namespace cupm
392 
393 } // namespace device
394 
395 } // namespace Petsc
396