1 #pragma once 2 3 #include <petsc/private/cupmblasinterface.hpp> 4 #include <petsc/private/petscadvancedmacros.h> 5 6 namespace Petsc 7 { 8 9 namespace device 10 { 11 12 namespace cupm 13 { 14 15 namespace impl 16 { 17 18 #define PetscCallCUPMSOLVER_(__abort_fn__, __comm__, ...) \ 19 do { \ 20 PetscStackUpdateLine; \ 21 const cupmSolverError_t cupmsolver_stat_p_ = __VA_ARGS__; \ 22 if (PetscUnlikely(cupmsolver_stat_p_ != CUPMSOLVER_STATUS_SUCCESS)) { \ 23 if (((cupmsolver_stat_p_ == CUPMSOLVER_STATUS_NOT_INITIALIZED) || (cupmsolver_stat_p_ == CUPMSOLVER_STATUS_ALLOC_FAILED) || (cupmsolver_stat_p_ == CUPMSOLVER_STATUS_INTERNAL_ERROR)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \ 24 __abort_fn__(__comm__, PETSC_ERR_GPU_RESOURCE, \ 25 "%s error %d (%s). " \ 26 "This indicates the GPU may have run out resources", \ 27 cupmSolverName(), static_cast<PetscErrorCode>(cupmsolver_stat_p_), cupmSolverGetErrorName(cupmsolver_stat_p_)); \ 28 } \ 29 __abort_fn__(__comm__, PETSC_ERR_GPU, "%s error %d (%s)", cupmSolverName(), static_cast<PetscErrorCode>(cupmsolver_stat_p_), cupmSolverGetErrorName(cupmsolver_stat_p_)); \ 30 } \ 31 } while (0) 32 33 #define PetscCallCUPMSOLVER(...) PetscCallCUPMSOLVER_(SETERRQ, PETSC_COMM_SELF, __VA_ARGS__) 34 #define PetscCallCUPMSOLVERAbort(comm_, ...) PetscCallCUPMSOLVER_(SETERRABORT, comm_, __VA_ARGS__) 35 36 #if !defined(PetscConcat3) 37 #define PetscConcat3(a, b, c) PetscConcat(PetscConcat(a, b), c) 38 #endif 39 40 #if PetscDefined(USE_COMPLEX) 41 #define PETSC_CUPMSOLVER_FP_TYPE_SPECIAL un 42 #else 43 #define PETSC_CUPMSOLVER_FP_TYPE_SPECIAL or 44 #endif // USE_COMPLEX 45 46 #define PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupm_name, their_prefix, fp_type, suffix) PETSC_CUPM_ALIAS_FUNCTION(cupm_name, PetscConcat3(their_prefix, fp_type, suffix)) 47 48 template <DeviceType> 49 struct SolverInterfaceImpl; 50 51 #if PetscDefined(HAVE_CUDA) 52 template <> 53 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL SolverInterfaceImpl<DeviceType::CUDA> : BlasInterface<DeviceType::CUDA> { 54 // typedefs 55 using cupmSolverHandle_t = cusolverDnHandle_t; 56 using cupmSolverError_t = cusolverStatus_t; 57 using cupmSolverFillMode_t = cublasFillMode_t; 58 using cupmSolverOperation_t = cublasOperation_t; 59 60 // error codes 61 static const auto CUPMSOLVER_STATUS_SUCCESS = CUSOLVER_STATUS_SUCCESS; 62 static const auto CUPMSOLVER_STATUS_NOT_INITIALIZED = CUSOLVER_STATUS_NOT_INITIALIZED; 63 static const auto CUPMSOLVER_STATUS_ALLOC_FAILED = CUSOLVER_STATUS_ALLOC_FAILED; 64 static const auto CUPMSOLVER_STATUS_INTERNAL_ERROR = CUSOLVER_STATUS_INTERNAL_ERROR; 65 66 // enums 67 // Why do these exist just to alias the CUBLAS versions? Because AMD -- in their boundless 68 // wisdom -- decided to do so for hipSOLVER... 69 // https://github.com/ROCmSoftwarePlatform/hipSOLVER/blob/develop/library/include/internal/hipsolver-types.h 70 static const auto CUPMSOLVER_OP_T = CUBLAS_OP_T; 71 static const auto CUPMSOLVER_OP_N = CUBLAS_OP_N; 72 static const auto CUPMSOLVER_OP_C = CUBLAS_OP_C; 73 static const auto CUPMSOLVER_FILL_MODE_LOWER = CUBLAS_FILL_MODE_LOWER; 74 static const auto CUPMSOLVER_FILL_MODE_UPPER = CUBLAS_FILL_MODE_UPPER; 75 static const auto CUPMSOLVER_SIDE_LEFT = CUBLAS_SIDE_LEFT; 76 static const auto CUPMSOLVER_SIDE_RIGHT = CUBLAS_SIDE_RIGHT; 77 78 // utility functions 79 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverCreate, cusolverDnCreate) 80 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverDestroy, cusolverDnDestroy) 81 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverSetStream, cusolverDnSetStream) 82 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverGetStream, cusolverDnGetStream) 83 84 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrf_bufferSize, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, potrf_bufferSize) 85 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrf, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, potrf) 86 87 using cupmBlasInt_t = typename BlasInterface<DeviceType::CUDA>::cupmBlasInt_t; 88 using cupmScalar_t = typename Interface<DeviceType::CUDA>::cupmScalar_t; 89 90 // to match hipSOLVER version (rocm 5.4.3, CUDA 12.0.1): 91 // 92 // hipsolverStatus_t hipsolverDpotrs_bufferSize( 93 // hipsolverHandle_t handle, hipsolverFillMode_t uplo, int n, int nrhs, double *A, int lda, 94 // double *B, int ldb, int *lwork 95 // ) 96 // 97 // hipsolverStatus_t hipsolverDpotrs( 98 // hipsolverHandle_t handle, hipsolverFillMode_t uplo, int n, int nrhs, double *A, int lda, 99 // double *B, int ldb, double *work, int lwork, int *devInfo 100 // ) cupmSolverXpotrs_bufferSizePetsc::device::cupm::impl::SolverInterfaceImpl101 PETSC_NODISCARD static cupmSolverError_t cupmSolverXpotrs_bufferSize(cupmSolverHandle_t /* handle */, cupmSolverFillMode_t /* uplo */, cupmBlasInt_t /* n */, cupmBlasInt_t /* nrhs */, cupmScalar_t * /* A */, cupmBlasInt_t /* lda */, cupmScalar_t * /* B */, cupmBlasInt_t /* ldb */, cupmBlasInt_t *lwork) noexcept 102 { 103 *lwork = 0; 104 return CUPMSOLVER_STATUS_SUCCESS; 105 } 106 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTIONPetsc::device::cupm::impl::SolverInterfaceImpl107 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrs_p, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, potrs) 108 109 PETSC_NODISCARD static cupmSolverError_t cupmSolverXpotrs(cupmSolverHandle_t handle, cupmSolverFillMode_t uplo, cupmBlasInt_t n, cupmBlasInt_t nrhs, const cupmScalar_t *A, cupmBlasInt_t lda, cupmScalar_t *B, cupmBlasInt_t ldb, cupmScalar_t * /* work */, cupmBlasInt_t /* lwork */, cupmBlasInt_t *dev_info) noexcept 110 { 111 return cupmSolverXpotrs_p(handle, uplo, n, nrhs, A, lda, B, ldb, dev_info); 112 } 113 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTIONPetsc::device::cupm::impl::SolverInterfaceImpl114 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotri_bufferSize, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, potri_bufferSize) 115 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotri, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, potri) 116 117 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXsytrf_bufferSize, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, sytrf_bufferSize) 118 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXsytrf, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, sytrf) 119 120 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrf_bufferSize, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, getrf_bufferSize) 121 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrf_p, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, getrf) 122 // to match hipSOLVER version (rocm 5.4.3, CUDA 12.0.1): 123 // 124 // hipsolverStatus_t hipsolverDgetrf( 125 // hipsolverHandle_t handle, int m, int n, double *A, int lda, double *work, int lwork, 126 // int *devIpiv, int *devInfo 127 // ) 128 PETSC_NODISCARD static cupmSolverError_t cupmSolverXgetrf(cupmSolverHandle_t handle, cupmBlasInt_t m, cupmBlasInt_t n, cupmScalar_t *A, cupmBlasInt_t lda, cupmScalar_t *work, cupmBlasInt_t /* lwork */, cupmBlasInt_t *dev_ipiv, cupmBlasInt_t *dev_info) noexcept 129 { 130 return cupmSolverXgetrf_p(handle, m, n, A, lda, work, dev_ipiv, dev_info); 131 } 132 133 // to match hipSOLVER version (rocm 5.4.3, CUDA 12.0.1): 134 // 135 // hipsolverStatus_t hipsolverDgetrs_bufferSize( 136 // hipsolverHandle_t handle, hipsolverOperation_t trans, int n, int nrhs, double *A, 137 // int lda, int *devIpiv, double *B, int ldb, int *lwork 138 // ) 139 // 140 // hipsolverStatus_t hipsolverDgetrs( 141 // hipsolverHandle_t handle, hipsolverOperation_t trans, int n, int nrhs, double *A, 142 // int lda, int *devIpiv, double *B, int ldb, double *work, int lwork, int *devInfo 143 // ) cupmSolverXgetrs_bufferSizePetsc::device::cupm::impl::SolverInterfaceImpl144 PETSC_NODISCARD static cupmSolverError_t cupmSolverXgetrs_bufferSize(cupmSolverHandle_t /* handle */, cupmSolverOperation_t /* op */, cupmBlasInt_t /* n */, cupmBlasInt_t /* nrhs */, cupmScalar_t * /* A */, cupmBlasInt_t /* lda */, cupmBlasInt_t * /* devIpiv */, cupmScalar_t * /* B */, cupmBlasInt_t /* ldb */, cupmBlasInt_t *lwork) noexcept 145 { 146 *lwork = 0; 147 return CUPMSOLVER_STATUS_SUCCESS; 148 } 149 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTIONPetsc::device::cupm::impl::SolverInterfaceImpl150 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrs_p, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, getrs) 151 152 PETSC_NODISCARD static cupmSolverError_t cupmSolverXgetrs(cupmSolverHandle_t handle, cupmSolverOperation_t op, cupmBlasInt_t n, cupmBlasInt_t nrhs, cupmScalar_t *A, cupmBlasInt_t lda, cupmBlasInt_t *dev_ipiv, cupmScalar_t *B, cupmBlasInt_t ldb, cupmScalar_t * /* work */, cupmBlasInt_t /* lwork */, cupmBlasInt_t *dev_info) noexcept 153 { 154 return cupmSolverXgetrs_p(handle, op, n, nrhs, A, lda, dev_ipiv, B, ldb, dev_info); 155 } 156 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTIONPetsc::device::cupm::impl::SolverInterfaceImpl157 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgeqrf_bufferSize, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, geqrf_bufferSize) 158 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgeqrf, cusolverDn, PETSC_CUPMBLAS_FP_TYPE_U, geqrf) 159 160 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXormqr_bufferSize, cusolverDn, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_U, PETSC_CUPMSOLVER_FP_TYPE_SPECIAL), mqr_bufferSize) 161 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXormqr, cusolverDn, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_U, PETSC_CUPMSOLVER_FP_TYPE_SPECIAL), mqr) 162 163 PETSC_NODISCARD static const char *cupmSolverGetErrorName(cupmSolverError_t status) noexcept { return PetscCUSolverGetErrorName(status); } 164 }; 165 #endif 166 167 #if PetscDefined(HAVE_HIP) 168 template <> 169 struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL SolverInterfaceImpl<DeviceType::HIP> : BlasInterface<DeviceType::HIP> { 170 // typedefs 171 using cupmSolverHandle_t = hipsolverHandle_t; 172 using cupmSolverError_t = hipsolverStatus_t; 173 using cupmSolverFillMode_t = hipsolverFillMode_t; 174 using cupmSolverOperation_t = hipsolverOperation_t; 175 176 // error codes 177 static const auto CUPMSOLVER_STATUS_SUCCESS = HIPSOLVER_STATUS_SUCCESS; 178 static const auto CUPMSOLVER_STATUS_NOT_INITIALIZED = HIPSOLVER_STATUS_NOT_INITIALIZED; 179 static const auto CUPMSOLVER_STATUS_ALLOC_FAILED = HIPSOLVER_STATUS_ALLOC_FAILED; 180 static const auto CUPMSOLVER_STATUS_INTERNAL_ERROR = HIPSOLVER_STATUS_INTERNAL_ERROR; 181 182 // enums 183 static const auto CUPMSOLVER_OP_T = HIPSOLVER_OP_T; 184 static const auto CUPMSOLVER_OP_N = HIPSOLVER_OP_N; 185 static const auto CUPMSOLVER_OP_C = HIPSOLVER_OP_C; 186 static const auto CUPMSOLVER_FILL_MODE_LOWER = HIPSOLVER_FILL_MODE_LOWER; 187 static const auto CUPMSOLVER_FILL_MODE_UPPER = HIPSOLVER_FILL_MODE_UPPER; 188 static const auto CUPMSOLVER_SIDE_LEFT = HIPSOLVER_SIDE_LEFT; 189 static const auto CUPMSOLVER_SIDE_RIGHT = HIPSOLVER_SIDE_RIGHT; 190 PETSC_CUPM_ALIAS_FUNCTIONPetsc::device::cupm::impl::SolverInterfaceImpl191 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverCreate, hipsolverCreate) 192 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverDestroy, hipsolverDestroy) 193 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverSetStream, hipsolverSetStream) 194 PETSC_CUPM_ALIAS_FUNCTION(cupmSolverGetStream, hipsolverGetStream) 195 196 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrf_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potrf_bufferSize) 197 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrf, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potrf) 198 199 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrs_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potrs_bufferSize) 200 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotrs, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potrs) 201 202 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotri_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potri_bufferSize) 203 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXpotri, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, potri) 204 205 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXsytrf_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, sytrf_bufferSize) 206 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXsytrf, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, sytrf) 207 208 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrf_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, getrf_bufferSize) 209 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrf, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, getrf) 210 211 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrs_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, getrs_bufferSize) 212 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgetrs, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, getrs) 213 214 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgeqrf_bufferSize, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, geqrf_bufferSize) 215 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXgeqrf, hipsolver, PETSC_CUPMBLAS_FP_TYPE_U, geqrf) 216 217 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXormqr_bufferSize, hipsolver, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_U, PETSC_CUPMSOLVER_FP_TYPE_SPECIAL), mqr_bufferSize) 218 PETSC_CUPMSOLVER_ALIAS_BLAS_FUNCTION(cupmSolverXormqr, hipsolver, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_U, PETSC_CUPMSOLVER_FP_TYPE_SPECIAL), mqr) 219 220 PETSC_NODISCARD static const char *cupmSolverGetErrorName(cupmSolverError_t status) noexcept { return PetscHIPSolverGetErrorName(status); } 221 }; 222 #endif 223 224 #define PETSC_CUPMSOLVER_IMPL_CLASS_HEADER(T) \ 225 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T); \ 226 /* introspection */ \ 227 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverGetErrorName; \ 228 /* types */ \ 229 using cupmSolverHandle_t = typename ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverHandle_t; \ 230 using cupmSolverError_t = typename ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverError_t; \ 231 using cupmSolverFillMode_t = typename ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverFillMode_t; \ 232 using cupmSolverOperation_t = typename ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverOperation_t; \ 233 /* error codes */ \ 234 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_STATUS_SUCCESS; \ 235 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_STATUS_NOT_INITIALIZED; \ 236 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_STATUS_ALLOC_FAILED; \ 237 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_STATUS_INTERNAL_ERROR; \ 238 /* values */ \ 239 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_OP_T; \ 240 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_OP_N; \ 241 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_OP_C; \ 242 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_FILL_MODE_LOWER; \ 243 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_FILL_MODE_UPPER; \ 244 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_SIDE_LEFT; \ 245 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::CUPMSOLVER_SIDE_RIGHT; \ 246 /* utility functions */ \ 247 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverCreate; \ 248 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverDestroy; \ 249 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverGetStream; \ 250 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverSetStream; \ 251 /* blas functions */ \ 252 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotrf_bufferSize; \ 253 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotrf; \ 254 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotrs_bufferSize; \ 255 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotrs; \ 256 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotri_bufferSize; \ 257 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXpotri; \ 258 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXsytrf_bufferSize; \ 259 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXsytrf; \ 260 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgetrf_bufferSize; \ 261 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgetrf; \ 262 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgetrs_bufferSize; \ 263 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgetrs; \ 264 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgeqrf_bufferSize; \ 265 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXgeqrf; \ 266 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXormqr_bufferSize; \ 267 using ::Petsc::device::cupm::impl::SolverInterfaceImpl<T>::cupmSolverXormqr 268 269 template <DeviceType T> 270 struct SolverInterface : SolverInterfaceImpl<T> { cupmSolverNamePetsc::device::cupm::impl::SolverInterface271 PETSC_NODISCARD static constexpr const char *cupmSolverName() noexcept { return T == DeviceType::CUDA ? "cusolverDn" : "hipsolver"; } 272 }; 273 274 #define PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T) \ 275 PETSC_CUPMSOLVER_IMPL_CLASS_HEADER(T); \ 276 using ::Petsc::device::cupm::impl::SolverInterface<T>::cupmSolverName 277 278 #if PetscDefined(HAVE_CUDA) 279 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL SolverInterface<DeviceType::CUDA>; 280 #endif 281 282 #if PetscDefined(HAVE_HIP) 283 extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL SolverInterface<DeviceType::HIP>; 284 #endif 285 286 } // namespace impl 287 288 } // namespace cupm 289 290 } // namespace device 291 292 } // namespace Petsc 293