1 #pragma once 2 3 #include <petsc/private/cupminterface.hpp> 4 #include <petsc/private/petscadvancedmacros.h> 5 6 #include <limits> // std::numeric_limits 7 8 namespace Petsc 9 { 10 11 namespace device 12 { 13 14 namespace cupm 15 { 16 17 namespace impl 18 { 19 20 #define PetscCallCUPMBLAS_(__abort_fn__, __comm__, ...) \ 21 do { \ 22 PetscStackUpdateLine; \ 23 const cupmBlasError_t cberr_p_ = __VA_ARGS__; \ 24 if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \ 25 if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \ 26 __abort_fn__(__comm__, PETSC_ERR_GPU_RESOURCE, \ 27 "%s error %d (%s). Reports not initialized or alloc failed; " \ 28 "this indicates the GPU may have run out resources", \ 29 cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \ 30 } \ 31 __abort_fn__(__comm__, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \ 32 } \ 33 } while (0) 34 35 #define PetscCallCUPMBLAS(...) PetscCallCUPMBLAS_(SETERRQ, PETSC_COMM_SELF, __VA_ARGS__) 36 #define PetscCallCUPMBLASAbort(comm_, ...) PetscCallCUPMBLAS_(SETERRABORT, comm_, __VA_ARGS__) 37 38 // given cupmBlas<T>axpy() then 39 // T = PETSC_CUPBLAS_FP_TYPE 40 // given cupmBlas<T><u>nrm2() then 41 // T = PETSC_CUPMBLAS_FP_INPUT_TYPE 42 // u = PETSC_CUPMBLAS_FP_RETURN_TYPE 43 #if PetscDefined(USE_COMPLEX) 44 #if PetscDefined(USE_REAL_SINGLE) 45 #define PETSC_CUPMBLAS_FP_TYPE_U C 46 #define PETSC_CUPMBLAS_FP_TYPE_L c 47 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S 48 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s 49 #elif PetscDefined(USE_REAL_DOUBLE) 50 #define PETSC_CUPMBLAS_FP_TYPE_U Z 51 #define PETSC_CUPMBLAS_FP_TYPE_L z 52 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D 53 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d 54 #endif 55 #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U 56 #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L 57 #else 58 #if PetscDefined(USE_REAL_SINGLE) 59 #define PETSC_CUPMBLAS_FP_TYPE_U S 60 #define PETSC_CUPMBLAS_FP_TYPE_L s 61 #elif PetscDefined(USE_REAL_DOUBLE) 62 #define PETSC_CUPMBLAS_FP_TYPE_U D 63 #define PETSC_CUPMBLAS_FP_TYPE_L d 64 #endif 65 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U 66 #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L 67 #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U 68 #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L 69 #endif // USE_COMPLEX 70 71 #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128) 72 #error "Unsupported floating-point type for CUDA/HIP BLAS" 73 #endif 74 75 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified" 76 // blas function whose return type does not match the input type 77 // 78 // input param: 79 // func - base suffix of the blas function, e.g. nrm2 80 // 81 // notes: 82 // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type 83 // letter ("S" for real/complex single, "D" for real/complex double). 84 // 85 // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type 86 // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real 87 // single/double). 88 // 89 // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme 90 // infuriatingly inconsistent... 91 // 92 // example usage: 93 // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S 94 // #define PETSC_CUPMBLAS_FP_RETURN_TYPE 95 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2 96 // 97 // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D 98 // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z 99 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2 100 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func) 101 102 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin 103 // because they are both extra special 104 // 105 // input param: 106 // func - base suffix of the blas function, either amax or amin 107 // 108 // notes: 109 // The macro name literally stands for "I" ## "floating point type" because shockingly enough, 110 // that's what it does. 111 // 112 // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type 113 // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for 114 // real double). 115 // 116 // example usage: 117 // #define PETSC_CUPMBLAS_FP_TYPE_L s 118 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax 119 // 120 // #define PETSC_CUPMBLAS_FP_TYPE_L z 121 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin 122 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func)) 123 124 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard" 125 // blas function name 126 // 127 // input param: 128 // func - base suffix of the blas function, e.g. axpy, scal 129 // 130 // notes: 131 // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for 132 // complex single, "Z" for complex double, "S" for real single, "D" for real double). 133 // 134 // example usage: 135 // #define PETSC_CUPMBLAS_FP_TYPE S 136 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy 137 // 138 // #define PETSC_CUPMBLAS_FP_TYPE Z 139 // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy 140 #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func) 141 142 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix 143 // one can provide both here 144 // 145 // input params: 146 // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or 147 // IFPTYPE 148 // our_suffix - the suffix of the alias function 149 // their_suffix - the suffix of the function being aliased 150 // 151 // notes: 152 // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function 153 // prefix. requires any other specific definitions required by the specific builder macro to 154 // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the 155 // function alias. 156 // 157 // example usage: 158 // #define PETSC_CUPMBLAS_PREFIX cublas 159 // #define PETSC_CUPMBLAS_FP_TYPE C 160 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) -> 161 // template <typename... T> 162 // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection* 163 // { 164 // return cublasCdotc(std::forward<T>(args)...); 165 // } 166 #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \ 167 PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix))) 168 169 // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function 170 // 171 // input params: 172 // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or 173 // IFPTYPE 174 // suffix - the common suffix between CUDA and HIP of the alias function 175 // 176 // notes: 177 // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as 178 // "our_prefix" and "their_prefix" 179 #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix) 180 181 // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function 182 // 183 // input params: 184 // suffix - the common suffix between CUDA and HIP of the alias function 185 // 186 // notes: 187 // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library 188 // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro. 189 // 190 // example usage: 191 // #define PETSC_CUPMBLAS_PREFIX hipblas 192 // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) -> 193 // template <typename... T> 194 // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection* 195 // { 196 // return hipblasCreate(std::forward<T>(args)...); 197 // } 198 #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix)) 199 200 template <DeviceType> 201 struct BlasInterfaceImpl; 202 203 // Exists because HIP (for whatever godforsaken reason) has elected to define both their 204 // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload 205 // resolution and hence need to wrap their types int this mess. 206 template <typename T, std::size_t I> 207 class cupmBlasHandleWrapper { 208 public: 209 constexpr cupmBlasHandleWrapper() noexcept = default; cupmBlasHandleWrapper(T h)210 constexpr cupmBlasHandleWrapper(T h) noexcept : handle_{std::move(h)} { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); } 211 operator =(std::nullptr_t)212 cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept 213 { 214 handle_ = nullptr; 215 return *this; 216 } 217 operator T() const218 operator T() const { return handle_; } 219 ptr_to() const220 const T *ptr_to() const { return &handle_; } ptr_to()221 T *ptr_to() { return &handle_; } 222 223 private: 224 T handle_{}; 225 }; 226 227 #if PetscDefined(HAVE_CUDA) 228 #define PETSC_CUPMBLAS_PREFIX cublas 229 #define PETSC_CUPMBLAS_PREFIX_U CUBLAS 230 #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U 231 #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U 232 #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L 233 template <> 234 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> { 235 // typedefs 236 using cupmBlasHandle_t = cupmBlasHandleWrapper<cublasHandle_t, 0>; 237 using cupmBlasError_t = cublasStatus_t; 238 using cupmBlasInt_t = int; 239 using cupmBlasPointerMode_t = cublasPointerMode_t; 240 241 // values 242 static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS; 243 static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED; 244 static const auto CUPMBLAS_STATUS_ALLOC_FAILED = CUBLAS_STATUS_ALLOC_FAILED; 245 static const auto CUPMBLAS_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST; 246 static const auto CUPMBLAS_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE; 247 static const auto CUPMBLAS_OP_T = CUBLAS_OP_T; 248 static const auto CUPMBLAS_OP_N = CUBLAS_OP_N; 249 static const auto CUPMBLAS_OP_C = CUBLAS_OP_C; 250 static const auto CUPMBLAS_FILL_MODE_LOWER = CUBLAS_FILL_MODE_LOWER; 251 static const auto CUPMBLAS_FILL_MODE_UPPER = CUBLAS_FILL_MODE_UPPER; 252 static const auto CUPMBLAS_SIDE_LEFT = CUBLAS_SIDE_LEFT; 253 static const auto CUPMBLAS_DIAG_NON_UNIT = CUBLAS_DIAG_NON_UNIT; 254 255 // utility functions 256 PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) PETSC_CUPMBLAS_ALIAS_FUNCTIONPetsc::device::cupm::impl::BlasInterfaceImpl257 PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy) 258 PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream) 259 PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream) 260 PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode) 261 PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode) 262 263 // level 1 BLAS 264 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy) 265 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy) 266 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal) 267 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot)) 268 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot)) 269 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap) 270 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2) 271 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax) 272 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum) 273 274 // level 2 BLAS 275 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv) 276 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv) 277 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv) 278 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv) 279 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv) 280 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv) 281 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv)) 282 283 // level 3 BLAS 284 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm) 285 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm) 286 287 // BLAS extensions 288 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam) 289 290 PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); } 291 }; 292 #undef PETSC_CUPMBLAS_PREFIX 293 #undef PETSC_CUPMBLAS_PREFIX_U 294 #undef PETSC_CUPMBLAS_FP_TYPE 295 #undef PETSC_CUPMBLAS_FP_INPUT_TYPE 296 #undef PETSC_CUPMBLAS_FP_RETURN_TYPE 297 #endif // PetscDefined(HAVE_CUDA) 298 299 #if PetscDefined(HAVE_HIP) 300 #define PETSC_CUPMBLAS_PREFIX hipblas 301 #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS 302 #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U 303 #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U 304 #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L 305 template <> 306 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> { 307 // typedefs 308 using cupmBlasHandle_t = cupmBlasHandleWrapper<hipblasHandle_t, 0>; 309 using cupmBlasError_t = hipblasStatus_t; 310 using cupmBlasInt_t = int; // rocblas will have its own 311 using cupmBlasPointerMode_t = hipblasPointerMode_t; 312 313 // values 314 static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS; 315 static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED; 316 static const auto CUPMBLAS_STATUS_ALLOC_FAILED = HIPBLAS_STATUS_ALLOC_FAILED; 317 static const auto CUPMBLAS_POINTER_MODE_HOST = HIPBLAS_POINTER_MODE_HOST; 318 static const auto CUPMBLAS_POINTER_MODE_DEVICE = HIPBLAS_POINTER_MODE_DEVICE; 319 static const auto CUPMBLAS_OP_T = HIPBLAS_OP_T; 320 static const auto CUPMBLAS_OP_N = HIPBLAS_OP_N; 321 static const auto CUPMBLAS_OP_C = HIPBLAS_OP_C; 322 static const auto CUPMBLAS_FILL_MODE_LOWER = HIPBLAS_FILL_MODE_LOWER; 323 static const auto CUPMBLAS_FILL_MODE_UPPER = HIPBLAS_FILL_MODE_UPPER; 324 static const auto CUPMBLAS_SIDE_LEFT = HIPBLAS_SIDE_LEFT; 325 static const auto CUPMBLAS_DIAG_NON_UNIT = HIPBLAS_DIAG_NON_UNIT; 326 327 // utility functions 328 PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) PETSC_CUPMBLAS_ALIAS_FUNCTIONPetsc::device::cupm::impl::BlasInterfaceImpl329 PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy) 330 PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream) 331 PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream) 332 PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode) 333 PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode) 334 335 // level 1 BLAS 336 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy) 337 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy) 338 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal) 339 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot)) 340 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot)) 341 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap) 342 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2) 343 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax) 344 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum) 345 346 // level 2 BLAS 347 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv) 348 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv) 349 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv) 350 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv) 351 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv) 352 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv) 353 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv)) 354 355 // level 3 BLAS 356 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm) 357 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm) 358 359 // BLAS extensions 360 PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam) 361 362 PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); } 363 }; 364 #undef PETSC_CUPMBLAS_PREFIX 365 #undef PETSC_CUPMBLAS_PREFIX_U 366 #undef PETSC_CUPMBLAS_FP_TYPE 367 #undef PETSC_CUPMBLAS_FP_INPUT_TYPE 368 #undef PETSC_CUPMBLAS_FP_RETURN_TYPE 369 #endif // PetscDefined(HAVE_HIP) 370 371 #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \ 372 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \ 373 /* introspection */ \ 374 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \ 375 /* types */ \ 376 using cupmBlasHandle_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \ 377 using cupmBlasError_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \ 378 using cupmBlasInt_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \ 379 using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \ 380 /* values */ \ 381 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \ 382 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \ 383 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \ 384 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \ 385 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \ 386 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \ 387 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \ 388 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \ 389 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \ 390 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \ 391 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \ 392 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \ 393 /* utility functions */ \ 394 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \ 395 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \ 396 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \ 397 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \ 398 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \ 399 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \ 400 /* level 1 BLAS */ \ 401 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \ 402 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXcopy; \ 403 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \ 404 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \ 405 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \ 406 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \ 407 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \ 408 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \ 409 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \ 410 /* level 2 BLAS */ \ 411 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \ 412 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrmv; \ 413 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsv; \ 414 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgbmv; \ 415 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbmv; \ 416 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbsv; \ 417 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXhemv; \ 418 /* level 3 BLAS */ \ 419 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \ 420 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \ 421 /* BLAS extensions */ \ 422 using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam 423 424 // The actual interface class 425 template <DeviceType T> 426 struct BlasInterface : BlasInterfaceImpl<T> { 427 PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); 428 cupmBlasNamePetsc::device::cupm::impl::BlasInterface429 PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; } 430 PetscCUPMBlasSetPointerModeFromPointerPetsc::device::cupm::impl::BlasInterface431 static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept 432 { 433 auto mtype = PETSC_MEMTYPE_HOST; 434 435 PetscFunctionBegin; 436 PetscCall(PetscCUPMGetMemType(ptr, &mtype)); 437 PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST)); 438 PetscFunctionReturn(PETSC_SUCCESS); 439 } 440 checkCupmBlasIntCastPetsc::device::cupm::impl::BlasInterface441 static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept 442 { 443 PetscFunctionBegin; 444 PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName()); 445 PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName()); 446 PetscFunctionReturn(PETSC_SUCCESS); 447 } 448 PetscCUPMBlasIntCastPetsc::device::cupm::impl::BlasInterface449 static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept 450 { 451 PetscFunctionBegin; 452 *y = static_cast<cupmBlasInt_t>(x); 453 PetscCall(checkCupmBlasIntCast(x)); 454 PetscFunctionReturn(PETSC_SUCCESS); 455 } 456 457 class CUPMBlasPointerModeGuard { 458 public: CUPMBlasPointerModeGuard(const cupmBlasHandle_t & handle,cupmBlasPointerMode_t mode)459 CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, cupmBlasPointerMode_t mode) noexcept : handle_{handle} 460 { 461 PetscFunctionBegin; 462 PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasGetPointerMode(handle, &this->old_)); 463 if (this->old_ == mode) { 464 this->set_ = false; 465 } else { 466 this->set_ = true; 467 PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(handle, mode)); 468 } 469 PetscFunctionReturnVoid(); 470 } 471 CUPMBlasPointerModeGuard(const cupmBlasHandle_t & handle,PetscMemType mtype)472 CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, PetscMemType mtype) noexcept : CUPMBlasPointerModeGuard{handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST} { } 473 ~CUPMBlasPointerModeGuard()474 ~CUPMBlasPointerModeGuard() noexcept 475 { 476 PetscFunctionBegin; 477 if (this->set_) PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(this->handle_, this->old_)); 478 PetscFunctionReturnVoid(); 479 } 480 481 private: 482 cupmBlasHandle_t handle_; 483 cupmBlasPointerMode_t old_; 484 bool set_; 485 }; 486 }; 487 488 #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \ 489 PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \ 490 using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \ 491 using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \ 492 using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \ 493 using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast; \ 494 using CUPMBlasPointerModeGuard = typename ::Petsc::device::cupm::impl::BlasInterface<T>::CUPMBlasPointerModeGuard 495 496 #if PetscDefined(HAVE_CUDA) 497 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::CUDA>; 498 #endif 499 500 #if PetscDefined(HAVE_HIP) 501 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::HIP>; 502 #endif 503 504 } // namespace impl 505 506 } // namespace cupm 507 508 } // namespace device 509 510 } // namespace Petsc 511