xref: /petsc/include/petsc/private/cupmblasinterface.hpp (revision 58bddbc0aeb8e2276be3739270a4176cb222ba3a)
1 #pragma once
2 
3 #include <petsc/private/cupminterface.hpp>
4 #include <petsc/private/petscadvancedmacros.h>
5 
6 #include <limits> // std::numeric_limits
7 
8 namespace Petsc
9 {
10 
11 namespace device
12 {
13 
14 namespace cupm
15 {
16 
17 namespace impl
18 {
19 
20 #define PetscCallCUPMBLAS_(__abort_fn__, __comm__, ...) \
21   do { \
22     PetscStackUpdateLine; \
23     const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
24     if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
25       if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
26         __abort_fn__(__comm__, PETSC_ERR_GPU_RESOURCE, \
27                      "%s error %d (%s). Reports not initialized or alloc failed; " \
28                      "this indicates the GPU may have run out resources", \
29                      cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
30       } \
31       __abort_fn__(__comm__, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
32     } \
33   } while (0)
34 
35 #define PetscCallCUPMBLAS(...)             PetscCallCUPMBLAS_(SETERRQ, PETSC_COMM_SELF, __VA_ARGS__)
36 #define PetscCallCUPMBLASAbort(comm_, ...) PetscCallCUPMBLAS_(SETERRABORT, comm_, __VA_ARGS__)
37 
38 // given cupmBlas<T>axpy() then
39 // T = PETSC_CUPBLAS_FP_TYPE
40 // given cupmBlas<T><u>nrm2() then
41 // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
42 // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
43 #if PetscDefined(USE_COMPLEX)
44   #if PetscDefined(USE_REAL_SINGLE)
45     #define PETSC_CUPMBLAS_FP_TYPE_U       C
46     #define PETSC_CUPMBLAS_FP_TYPE_L       c
47     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
48     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
49   #elif PetscDefined(USE_REAL_DOUBLE)
50     #define PETSC_CUPMBLAS_FP_TYPE_U       Z
51     #define PETSC_CUPMBLAS_FP_TYPE_L       z
52     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
53     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
54   #endif
55   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
56   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
57 #else
58   #if PetscDefined(USE_REAL_SINGLE)
59     #define PETSC_CUPMBLAS_FP_TYPE_U S
60     #define PETSC_CUPMBLAS_FP_TYPE_L s
61   #elif PetscDefined(USE_REAL_DOUBLE)
62     #define PETSC_CUPMBLAS_FP_TYPE_U D
63     #define PETSC_CUPMBLAS_FP_TYPE_L d
64   #endif
65   #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
66   #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
67   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
68   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
69 #endif // USE_COMPLEX
70 
71 #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
72   #error "Unsupported floating-point type for CUDA/HIP BLAS"
73 #endif
74 
75 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
76 // blas function whose return type does not match the input type
77 //
78 // input param:
79 // func - base suffix of the blas function, e.g. nrm2
80 //
81 // notes:
82 // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
83 // letter ("S" for real/complex single, "D" for real/complex double).
84 //
85 // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
86 // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
87 // single/double).
88 //
89 // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
90 // infuriatingly inconsistent...
91 //
92 // example usage:
93 // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
94 // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
95 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
96 //
97 // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
98 // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
99 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
100 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
101 
102 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
103 // because they are both extra special
104 //
105 // input param:
106 // func - base suffix of the blas function, either amax or amin
107 //
108 // notes:
109 // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
110 // that's what it does.
111 //
112 // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
113 // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
114 // real double).
115 //
116 // example usage:
117 // #define PETSC_CUPMBLAS_FP_TYPE_L s
118 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
119 //
120 // #define PETSC_CUPMBLAS_FP_TYPE_L z
121 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
122 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
123 
124 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
125 // blas function name
126 //
127 // input param:
128 // func - base suffix of the blas function, e.g. axpy, scal
129 //
130 // notes:
131 // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
132 // complex single, "Z" for complex double, "S" for real single, "D" for real double).
133 //
134 // example usage:
135 // #define PETSC_CUPMBLAS_FP_TYPE S
136 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
137 //
138 // #define PETSC_CUPMBLAS_FP_TYPE Z
139 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
140 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
141 
142 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
143 // one can provide both here
144 //
145 // input params:
146 // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
147 // IFPTYPE
148 // our_suffix   - the suffix of the alias function
149 // their_suffix - the suffix of the function being aliased
150 //
151 // notes:
152 // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
153 // prefix. requires any other specific definitions required by the specific builder macro to
154 // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
155 // function alias.
156 //
157 // example usage:
158 // #define PETSC_CUPMBLAS_PREFIX  cublas
159 // #define PETSC_CUPMBLAS_FP_TYPE C
160 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
161 // template <typename... T>
162 // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
163 // {
164 //   return cublasCdotc(std::forward<T>(args)...);
165 // }
166 #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
167   PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))
168 
169 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
170 //
171 // input params:
172 // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
173 // IFPTYPE
174 // suffix       - the common suffix between CUDA and HIP of the alias function
175 //
176 // notes:
177 // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
178 // "our_prefix" and "their_prefix"
179 #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
180 
181 // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
182 //
183 // input params:
184 // suffix - the common suffix between CUDA and HIP of the alias function
185 //
186 // notes:
187 // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
188 // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
189 //
190 // example usage:
191 // #define PETSC_CUPMBLAS_PREFIX hipblas
192 // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
193 // template <typename... T>
194 // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
195 // {
196 //   return hipblasCreate(std::forward<T>(args)...);
197 // }
198 #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))
199 
200 template <DeviceType>
201 struct BlasInterfaceImpl;
202 
203 // Exists because HIP (for whatever godforsaken reason) has elected to define both their
204 // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload
205 // resolution and hence need to wrap their types int this mess.
206 template <typename T, std::size_t I>
207 class cupmBlasHandleWrapper {
208 public:
209   constexpr cupmBlasHandleWrapper() noexcept = default;
cupmBlasHandleWrapper(T h)210   constexpr cupmBlasHandleWrapper(T h) noexcept : handle_{std::move(h)} { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); }
211 
operator =(std::nullptr_t)212   cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept
213   {
214     handle_ = nullptr;
215     return *this;
216   }
217 
operator T() const218   operator T() const { return handle_; }
219 
ptr_to() const220   const T *ptr_to() const { return &handle_; }
ptr_to()221   T       *ptr_to() { return &handle_; }
222 
223 private:
224   T handle_{};
225 };
226 
227 #if PetscDefined(HAVE_CUDA)
228   #define PETSC_CUPMBLAS_PREFIX         cublas
229   #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
230   #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
231   #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
232   #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
233 template <>
234 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> {
235   // typedefs
236   using cupmBlasHandle_t      = cupmBlasHandleWrapper<cublasHandle_t, 0>;
237   using cupmBlasError_t       = cublasStatus_t;
238   using cupmBlasInt_t         = int;
239   using cupmBlasPointerMode_t = cublasPointerMode_t;
240 
241   // values
242   static const auto CUPMBLAS_STATUS_SUCCESS         = CUBLAS_STATUS_SUCCESS;
243   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
244   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = CUBLAS_STATUS_ALLOC_FAILED;
245   static const auto CUPMBLAS_POINTER_MODE_HOST      = CUBLAS_POINTER_MODE_HOST;
246   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = CUBLAS_POINTER_MODE_DEVICE;
247   static const auto CUPMBLAS_OP_T                   = CUBLAS_OP_T;
248   static const auto CUPMBLAS_OP_N                   = CUBLAS_OP_N;
249   static const auto CUPMBLAS_OP_C                   = CUBLAS_OP_C;
250   static const auto CUPMBLAS_FILL_MODE_LOWER        = CUBLAS_FILL_MODE_LOWER;
251   static const auto CUPMBLAS_FILL_MODE_UPPER        = CUBLAS_FILL_MODE_UPPER;
252   static const auto CUPMBLAS_SIDE_LEFT              = CUBLAS_SIDE_LEFT;
253   static const auto CUPMBLAS_DIAG_NON_UNIT          = CUBLAS_DIAG_NON_UNIT;
254 
255   // utility functions
256   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
PETSC_CUPMBLAS_ALIAS_FUNCTIONPetsc::device::cupm::impl::BlasInterfaceImpl257   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
258   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
259   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
260   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
261   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
262 
263   // level 1 BLAS
264   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
265   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy)
266   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
267   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
268   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
269   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
270   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
271   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
272   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
273 
274   // level 2 BLAS
275   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
276   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv)
277   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv)
278   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv)
279   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv)
280   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv)
281   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv))
282 
283   // level 3 BLAS
284   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
285   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
286 
287   // BLAS extensions
288   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
289 
290   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); }
291 };
292   #undef PETSC_CUPMBLAS_PREFIX
293   #undef PETSC_CUPMBLAS_PREFIX_U
294   #undef PETSC_CUPMBLAS_FP_TYPE
295   #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
296   #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
297 #endif // PetscDefined(HAVE_CUDA)
298 
299 #if PetscDefined(HAVE_HIP)
300   #define PETSC_CUPMBLAS_PREFIX         hipblas
301   #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
302   #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
303   #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
304   #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
305 template <>
306 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> {
307   // typedefs
308   using cupmBlasHandle_t      = cupmBlasHandleWrapper<hipblasHandle_t, 0>;
309   using cupmBlasError_t       = hipblasStatus_t;
310   using cupmBlasInt_t         = int; // rocblas will have its own
311   using cupmBlasPointerMode_t = hipblasPointerMode_t;
312 
313   // values
314   static const auto CUPMBLAS_STATUS_SUCCESS         = HIPBLAS_STATUS_SUCCESS;
315   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
316   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = HIPBLAS_STATUS_ALLOC_FAILED;
317   static const auto CUPMBLAS_POINTER_MODE_HOST      = HIPBLAS_POINTER_MODE_HOST;
318   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = HIPBLAS_POINTER_MODE_DEVICE;
319   static const auto CUPMBLAS_OP_T                   = HIPBLAS_OP_T;
320   static const auto CUPMBLAS_OP_N                   = HIPBLAS_OP_N;
321   static const auto CUPMBLAS_OP_C                   = HIPBLAS_OP_C;
322   static const auto CUPMBLAS_FILL_MODE_LOWER        = HIPBLAS_FILL_MODE_LOWER;
323   static const auto CUPMBLAS_FILL_MODE_UPPER        = HIPBLAS_FILL_MODE_UPPER;
324   static const auto CUPMBLAS_SIDE_LEFT              = HIPBLAS_SIDE_LEFT;
325   static const auto CUPMBLAS_DIAG_NON_UNIT          = HIPBLAS_DIAG_NON_UNIT;
326 
327   // utility functions
328   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
PETSC_CUPMBLAS_ALIAS_FUNCTIONPetsc::device::cupm::impl::BlasInterfaceImpl329   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
330   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
331   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
332   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
333   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
334 
335   // level 1 BLAS
336   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
337   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy)
338   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
339   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
340   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
341   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
342   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
343   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
344   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
345 
346   // level 2 BLAS
347   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
348   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv)
349   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv)
350   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv)
351   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv)
352   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv)
353   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv))
354 
355   // level 3 BLAS
356   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
357   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
358 
359   // BLAS extensions
360   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
361 
362   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); }
363 };
364   #undef PETSC_CUPMBLAS_PREFIX
365   #undef PETSC_CUPMBLAS_PREFIX_U
366   #undef PETSC_CUPMBLAS_FP_TYPE
367   #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
368   #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
369 #endif // PetscDefined(HAVE_HIP)
370 
371 #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \
372   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
373   /* introspection */ \
374   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \
375   /* types */ \
376   using cupmBlasHandle_t      = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \
377   using cupmBlasError_t       = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \
378   using cupmBlasInt_t         = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \
379   using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \
380   /* values */ \
381   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \
382   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \
383   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \
384   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \
385   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \
386   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \
387   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \
388   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \
389   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \
390   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \
391   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \
392   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \
393   /* utility functions */ \
394   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \
395   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \
396   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \
397   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \
398   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \
399   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \
400   /* level 1 BLAS */ \
401   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \
402   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXcopy; \
403   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \
404   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \
405   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \
406   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \
407   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \
408   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \
409   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \
410   /* level 2 BLAS */ \
411   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \
412   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrmv; \
413   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsv; \
414   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgbmv; \
415   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbmv; \
416   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbsv; \
417   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXhemv; \
418   /* level 3 BLAS */ \
419   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \
420   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \
421   /* BLAS extensions */ \
422   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam
423 
424 // The actual interface class
425 template <DeviceType T>
426 struct BlasInterface : BlasInterfaceImpl<T> {
427   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T);
428 
cupmBlasNamePetsc::device::cupm::impl::BlasInterface429   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
430 
PetscCUPMBlasSetPointerModeFromPointerPetsc::device::cupm::impl::BlasInterface431   static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
432   {
433     auto mtype = PETSC_MEMTYPE_HOST;
434 
435     PetscFunctionBegin;
436     PetscCall(PetscCUPMGetMemType(ptr, &mtype));
437     PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
438     PetscFunctionReturn(PETSC_SUCCESS);
439   }
440 
checkCupmBlasIntCastPetsc::device::cupm::impl::BlasInterface441   static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
442   {
443     PetscFunctionBegin;
444     PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName());
445     PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName());
446     PetscFunctionReturn(PETSC_SUCCESS);
447   }
448 
PetscCUPMBlasIntCastPetsc::device::cupm::impl::BlasInterface449   static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept
450   {
451     PetscFunctionBegin;
452     *y = static_cast<cupmBlasInt_t>(x);
453     PetscCall(checkCupmBlasIntCast(x));
454     PetscFunctionReturn(PETSC_SUCCESS);
455   }
456 
457   class CUPMBlasPointerModeGuard {
458   public:
CUPMBlasPointerModeGuard(const cupmBlasHandle_t & handle,cupmBlasPointerMode_t mode)459     CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, cupmBlasPointerMode_t mode) noexcept : handle_{handle}
460     {
461       PetscFunctionBegin;
462       PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasGetPointerMode(handle, &this->old_));
463       if (this->old_ == mode) {
464         this->set_ = false;
465       } else {
466         this->set_ = true;
467         PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(handle, mode));
468       }
469       PetscFunctionReturnVoid();
470     }
471 
CUPMBlasPointerModeGuard(const cupmBlasHandle_t & handle,PetscMemType mtype)472     CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, PetscMemType mtype) noexcept : CUPMBlasPointerModeGuard{handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST} { }
473 
~CUPMBlasPointerModeGuard()474     ~CUPMBlasPointerModeGuard() noexcept
475     {
476       PetscFunctionBegin;
477       if (this->set_) PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(this->handle_, this->old_));
478       PetscFunctionReturnVoid();
479     }
480 
481   private:
482     cupmBlasHandle_t      handle_;
483     cupmBlasPointerMode_t old_;
484     bool                  set_;
485   };
486 };
487 
488 #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \
489   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \
490   using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \
491   using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \
492   using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \
493   using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast; \
494   using CUPMBlasPointerModeGuard = typename ::Petsc::device::cupm::impl::BlasInterface<T>::CUPMBlasPointerModeGuard
495 
496 #if PetscDefined(HAVE_CUDA)
497 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::CUDA>;
498 #endif
499 
500 #if PetscDefined(HAVE_HIP)
501 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::HIP>;
502 #endif
503 
504 } // namespace impl
505 
506 } // namespace cupm
507 
508 } // namespace device
509 
510 } // namespace Petsc
511