1 /* Portions of this code are under: 2 Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. 3 */ 4 #ifndef PETSC_HIPSPARSEMATIMPL_H 5 #define PETSC_HIPSPARSEMATIMPL_H 6 7 #include <petscpkg_version.h> 8 #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp> /* for VecSeq_CUPM */ 9 #include <petsc/private/petsclegacycupmblas.h> 10 11 #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0) 12 #include <hipsparse/hipsparse.h> 13 #else /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */ 14 #include <hipsparse.h> 15 #endif /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */ 16 #include "hip/hip_runtime.h" 17 18 #include <algorithm> 19 #include <vector> 20 21 #include <thrust/device_vector.h> 22 #include <thrust/device_ptr.h> 23 #include <thrust/device_malloc_allocator.h> 24 #include <thrust/transform.h> 25 #include <thrust/functional.h> 26 #include <thrust/sequence.h> 27 #include <thrust/system/system_error.h> 28 29 #if defined(PETSC_USE_COMPLEX) 30 #if defined(PETSC_USE_REAL_SINGLE) 31 const hipComplex PETSC_HIPSPARSE_ONE = {1.0f, 0.0f}; 32 const hipComplex PETSC_HIPSPARSE_ZERO = {0.0f, 0.0f}; 33 #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsrilu02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i) 34 #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 35 #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 36 #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsric02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i) 37 #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 38 #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 39 #elif defined(PETSC_USE_REAL_DOUBLE) 40 const hipDoubleComplex PETSC_HIPSPARSE_ONE = {1.0, 0.0}; 41 const hipDoubleComplex PETSC_HIPSPARSE_ZERO = {0.0, 0.0}; 42 #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsrilu02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i) 43 #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 44 #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 45 #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsric02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i) 46 #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 47 #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 48 #endif /* Single or double */ 49 #else /* not complex */ 50 const PetscScalar PETSC_HIPSPARSE_ONE = 1.0; 51 const PetscScalar PETSC_HIPSPARSE_ZERO = 0.0; 52 #if defined(PETSC_USE_REAL_SINGLE) 53 #define hipsparseXcsrilu02_bufferSize hipsparseScsrilu02_bufferSize 54 #define hipsparseXcsrilu02_analysis hipsparseScsrilu02_analysis 55 #define hipsparseXcsrilu02 hipsparseScsrilu02 56 #define hipsparseXcsric02_bufferSize hipsparseScsric02_bufferSize 57 #define hipsparseXcsric02_analysis hipsparseScsric02_analysis 58 #define hipsparseXcsric02 hipsparseScsric02 59 #elif defined(PETSC_USE_REAL_DOUBLE) 60 #define hipsparseXcsrilu02_bufferSize hipsparseDcsrilu02_bufferSize 61 #define hipsparseXcsrilu02_analysis hipsparseDcsrilu02_analysis 62 #define hipsparseXcsrilu02 hipsparseDcsrilu02 63 #define hipsparseXcsric02_bufferSize hipsparseDcsric02_bufferSize 64 #define hipsparseXcsric02_analysis hipsparseDcsric02_analysis 65 #define hipsparseXcsric02 hipsparseDcsric02 66 #endif /* Single or double */ 67 #endif /* complex or not */ 68 69 #define csrsvInfo_t csrsv2Info_t 70 #define hipsparseCreateCsrsvInfo hipsparseCreateCsrsv2Info 71 #define hipsparseDestroyCsrsvInfo hipsparseDestroyCsrsv2Info 72 #if defined(PETSC_USE_COMPLEX) 73 #if defined(PETSC_USE_REAL_SINGLE) 74 #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrsv2_bufferSize(a, b, c, d, e, (hipComplex *)(f), g, h, i, j) 75 #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseCcsrsv2_analysis(a, b, c, d, e, (const hipComplex *)(f), g, h, i, j, k) 76 #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseCcsrsv2_solve(a, b, c, d, (const hipComplex *)(e), f, (const hipComplex *)(g), h, i, j, (const hipComplex *)(k), (hipComplex *)(l), m, n) 77 #elif defined(PETSC_USE_REAL_DOUBLE) 78 #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrsv2_bufferSize(a, b, c, d, e, (hipDoubleComplex *)(f), g, h, i, j) 79 #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseZcsrsv2_analysis(a, b, c, d, e, (const hipDoubleComplex *)(f), g, h, i, j, k) 80 #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseZcsrsv2_solve(a, b, c, d, (const hipDoubleComplex *)(e), f, (const hipDoubleComplex *)(g), h, i, j, (const hipDoubleComplex *)(k), (hipDoubleComplex *)(l), m, n) 81 #endif /* Single or double */ 82 #else /* not complex */ 83 #if defined(PETSC_USE_REAL_SINGLE) 84 #define hipsparseXcsrsv_buffsize hipsparseScsrsv2_bufferSize 85 #define hipsparseXcsrsv_analysis hipsparseScsrsv2_analysis 86 #define hipsparseXcsrsv_solve hipsparseScsrsv2_solve 87 #elif defined(PETSC_USE_REAL_DOUBLE) 88 #define hipsparseXcsrsv_buffsize hipsparseDcsrsv2_bufferSize 89 #define hipsparseXcsrsv_analysis hipsparseDcsrsv2_analysis 90 #define hipsparseXcsrsv_solve hipsparseDcsrsv2_solve 91 #endif /* Single or double */ 92 #endif /* not complex */ 93 94 #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 95 // #define cusparse_csr2csc cusparseCsr2cscEx2 96 #if defined(PETSC_USE_COMPLEX) 97 #if defined(PETSC_USE_REAL_SINGLE) 98 #define hipsparse_scalartype HIP_C_32F 99 #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgeam2(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t) 100 #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 101 hipsparseCcsrgeam2_bufferSizeExt(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t) 102 #elif defined(PETSC_USE_REAL_DOUBLE) 103 #define hipsparse_scalartype HIP_C_64F 104 #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 105 hipsparseZcsrgeam2(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t) 106 #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 107 hipsparseZcsrgeam2_bufferSizeExt(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t) 108 #endif /* Single or double */ 109 #else /* not complex */ 110 #if defined(PETSC_USE_REAL_SINGLE) 111 #define hipsparse_scalartype HIP_R_32F 112 #define hipsparse_csr_spgeam hipsparseScsrgeam2 113 #define hipsparse_csr_spgeam_bufferSize hipsparseScsrgeam2_bufferSizeExt 114 #elif defined(PETSC_USE_REAL_DOUBLE) 115 #define hipsparse_scalartype HIP_R_64F 116 #define hipsparse_csr_spgeam hipsparseDcsrgeam2 117 #define hipsparse_csr_spgeam_bufferSize hipsparseDcsrgeam2_bufferSizeExt 118 #endif /* Single or double */ 119 #endif /* complex or not */ 120 #endif /* PETSC_PKG_HIP_VERSION_GE(4, 5, 0) */ 121 122 #if defined(PETSC_USE_COMPLEX) 123 #if defined(PETSC_USE_REAL_SINGLE) 124 #define hipsparse_scalartype HIP_C_32F 125 #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) hipsparseCcsrmv((a), (b), (c), (d), (e), (hipComplex *)(f), (g), (hipComplex *)(h), (i), (j), (hipComplex *)(k), (hipComplex *)(l), (hipComplex *)(m)) 126 #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) hipsparseCcsrmm((a), (b), (c), (d), (e), (f), (hipComplex *)(g), (h), (hipComplex *)(i), (j), (k), (hipComplex *)(l), (m), (hipComplex *)(n), (hipComplex *)(o), (p)) 127 #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseCcsr2csc((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (hipComplex *)(h), (i), (j), (k), (l)) 128 #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseChybmv((a), (b), (hipComplex *)(c), (d), (e), (hipComplex *)(f), (hipComplex *)(g), (hipComplex *)(h)) 129 #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseCcsr2hyb((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (h), (i), (j)) 130 #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseChyb2csr((a), (b), (c), (hipComplex *)(d), (e), (f)) 131 #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgemm(a, b, c, d, e, f, g, h, (hipComplex *)i, j, k, l, m, (hipComplex *)n, o, p, q, (hipComplex *)r, s, t) 132 // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseCcsrgeam(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s) 133 #elif defined(PETSC_USE_REAL_DOUBLE) 134 #define hipsparse_scalartype HIP_C_64F 135 #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) hipsparseZcsrmv((a), (b), (c), (d), (e), (hipDoubleComplex *)(f), (g), (hipDoubleComplex *)(h), (i), (j), (hipDoubleComplex *)(k), (hipDoubleComplex *)(l), (hipDoubleComplex *)(m)) 136 #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) \ 137 hipsparseZcsrmm((a), (b), (c), (d), (e), (f), (hipDoubleComplex *)(g), (h), (hipDoubleComplex *)(i), (j), (k), (hipDoubleComplex *)(l), (m), (hipDoubleComplex *)(n), (hipDoubleComplex *)(o), (p)) 138 #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseZcsr2csc((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (hipDoubleComplex *)(h), (i), (j), (k), (l)) 139 #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseZhybmv((a), (b), (hipDoubleComplex *)(c), (d), (e), (hipDoubleComplex *)(f), (hipDoubleComplex *)(g), (hipDoubleComplex *)(h)) 140 #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseZcsr2hyb((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (h), (i), (j)) 141 #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseZhyb2csr((a), (b), (c), (hipDoubleComplex *)(d), (e), (f)) 142 #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseZcsrgemm(a, b, c, d, e, f, g, h, (hipDoubleComplex *)i, j, k, l, m, (hipDoubleComplex *)n, o, p, q, (hipDoubleComplex *)r, s, t) 143 // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseZcsrgeam(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s) 144 #endif /* Single or double */ 145 #else /* not complex */ 146 #if defined(PETSC_USE_REAL_SINGLE) 147 #define hipsparse_scalartype HIP_R_32F 148 #define hipsparse_csr_spmv hipsparseScsrmv 149 #define hipsparse_csr_spmm hipsparseScsrmm 150 #define hipsparse_csr2csc hipsparseScsr2csc 151 #define hipsparse_hyb_spmv hipsparseShybmv 152 #define hipsparse_csr2hyb hipsparseScsr2hyb 153 #define hipsparse_hyb2csr hipsparseShyb2csr 154 #define hipsparse_csr_spgemm hipsparseScsrgemm 155 // #define hipsparse_csr_spgeam hipsparseScsrgeam 156 #elif defined(PETSC_USE_REAL_DOUBLE) 157 #define hipsparse_scalartype HIP_R_64F 158 #define hipsparse_csr_spmv hipsparseDcsrmv 159 #define hipsparse_csr_spmm hipsparseDcsrmm 160 #define hipsparse_csr2csc hipsparseDcsr2csc 161 #define hipsparse_hyb_spmv hipsparseDhybmv 162 #define hipsparse_csr2hyb hipsparseDcsr2hyb 163 #define hipsparse_hyb2csr hipsparseDhyb2csr 164 #define hipsparse_csr_spgemm hipsparseDcsrgemm 165 // #define hipsparse_csr_spgeam hipsparseDcsrgeam 166 #endif /* Single or double */ 167 #endif /* complex or not */ 168 169 #define THRUSTINTARRAY32 thrust::device_vector<int> 170 #define THRUSTINTARRAY thrust::device_vector<PetscInt> 171 #define THRUSTARRAY thrust::device_vector<PetscScalar> 172 173 /* A CSR matrix structure */ 174 struct CsrMatrix { 175 PetscInt num_rows; 176 PetscInt num_cols; 177 PetscInt num_entries; 178 THRUSTINTARRAY32 *row_offsets; 179 THRUSTINTARRAY32 *column_indices; 180 THRUSTARRAY *values; 181 }; 182 183 /* This is struct holding the relevant data needed to a MatSolve */ 184 struct Mat_SeqAIJHIPSPARSETriFactorStruct { 185 /* Data needed for triangular solve */ 186 hipsparseMatDescr_t descr; 187 hipsparseOperation_t solveOp; 188 CsrMatrix *csrMat; 189 csrsvInfo_t solveInfo; 190 hipsparseSolvePolicy_t solvePolicy; /* whether level information is generated and used */ 191 int solveBufferSize; 192 void *solveBuffer; 193 size_t csr2cscBufferSize; /* to transpose the triangular factor (only used for CUDA >= 11.0) */ 194 void *csr2cscBuffer; 195 PetscScalar *AA_h; /* managed host buffer for moving values to the GPU */ 196 }; 197 198 /* This is a larger struct holding all the triangular factors for a solve, transpose solve, and any indices used in a reordering */ 199 struct Mat_SeqAIJHIPSPARSETriFactors { 200 Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtr; /* pointer for lower triangular (factored matrix) on GPU */ 201 Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtr; /* pointer for upper triangular (factored matrix) on GPU */ 202 Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtrTranspose; /* pointer for lower triangular (factored matrix) on GPU for the transpose (useful for BiCG) */ 203 Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtrTranspose; /* pointer for upper triangular (factored matrix) on GPU for the transpose (useful for BiCG)*/ 204 THRUSTINTARRAY *rpermIndices; /* indices used for any reordering */ 205 THRUSTINTARRAY *cpermIndices; /* indices used for any reordering */ 206 THRUSTARRAY *workVector; 207 hipsparseHandle_t handle; /* a handle to the hipsparse library */ 208 PetscInt nnz; /* number of nonzeros ... need this for accurate logging between ICC and ILU */ 209 PetscScalar *a_band_d; /* GPU data for banded CSR LU factorization matrix diag(L)=1 */ 210 int *i_band_d; /* this could be optimized away */ 211 hipDeviceProp_t dev_prop; 212 PetscBool init_dev_prop; 213 214 /* csrilu0/csric0 appeared in earlier versions of AMD ROCm^{TM}, but we use it along with hipsparseSpSV, 215 which first appeared in hipsparse with ROCm-4.5.0. 216 */ 217 PetscBool factorizeOnDevice; /* Do factorization on device or not */ 218 #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 219 PetscScalar *csrVal; 220 int *csrRowPtr, *csrColIdx; /* a,i,j of M. Using int since some hipsparse APIs only support 32-bit indices */ 221 222 /* Mixed mat descriptor types? yes, different hipsparse APIs use different types */ 223 hipsparseMatDescr_t matDescr_M; 224 hipsparseSpMatDescr_t spMatDescr_L, spMatDescr_U; 225 hipsparseSpSVDescr_t spsvDescr_L, spsvDescr_Lt, spsvDescr_U, spsvDescr_Ut; 226 227 hipsparseDnVecDescr_t dnVecDescr_X, dnVecDescr_Y; 228 PetscScalar *X, *Y; /* data array of dnVec X and Y */ 229 230 /* Mixed size types? yes */ 231 int factBufferSize_M; /* M ~= LU or LLt */ 232 size_t spsvBufferSize_L, spsvBufferSize_Lt, spsvBufferSize_U, spsvBufferSize_Ut; 233 /* hipsparse needs various buffers for factorization and solve of L, U, Lt, or Ut. 234 To save memory, we share the factorization buffer with one of spsvBuffer_L/U. 235 */ 236 void *factBuffer_M, *spsvBuffer_L, *spsvBuffer_U, *spsvBuffer_Lt, *spsvBuffer_Ut; 237 238 csrilu02Info_t ilu0Info_M; 239 csric02Info_t ic0Info_M; 240 int structural_zero, numerical_zero; 241 hipsparseSolvePolicy_t policy_M; 242 243 /* In MatSolveTranspose() for ILU0, we use the two flags to do on-demand solve */ 244 PetscBool createdTransposeSpSVDescr; /* Have we created SpSV descriptors for Lt, Ut? */ 245 PetscBool updatedTransposeSpSVAnalysis; /* Have we updated SpSV analysis with the latest L, U values? */ 246 247 PetscLogDouble numericFactFlops; /* Estimated FLOPs in ILU0/ICC0 numeric factorization */ 248 #endif 249 }; 250 251 struct Mat_HipsparseSpMV { 252 PetscBool initialized; /* Don't rely on spmvBuffer != NULL to test if the struct is initialized, */ 253 size_t spmvBufferSize; /* since I'm not sure if smvBuffer can be NULL even after hipsparseSpMV_bufferSize() */ 254 void *spmvBuffer; 255 hipsparseDnVecDescr_t vecXDescr, vecYDescr; /* descriptor for the dense vectors in y=op(A)x */ 256 }; 257 258 /* This is struct holding the relevant data needed to a MatMult */ 259 struct Mat_SeqAIJHIPSPARSEMultStruct { 260 void *mat; /* opaque pointer to a matrix. This could be either a hipsparseHybMat_t or a CsrMatrix */ 261 hipsparseMatDescr_t descr; /* Data needed to describe the matrix for a multiply */ 262 THRUSTINTARRAY *cprowIndices; /* compressed row indices used in the parallel SpMV */ 263 PetscScalar *alpha_one; /* pointer to a device "scalar" storing the alpha parameter in the SpMV */ 264 PetscScalar *beta_zero; /* pointer to a device "scalar" storing the beta parameter in the SpMV as zero*/ 265 PetscScalar *beta_one; /* pointer to a device "scalar" storing the beta parameter in the SpMV as one */ 266 hipsparseSpMatDescr_t matDescr; /* descriptor for the matrix, used by SpMV and SpMM */ 267 Mat_HipsparseSpMV hipSpMV[3]; /* different Mat_CusparseSpMV structs for non-transpose, transpose, conj-transpose */ 268 Mat_SeqAIJHIPSPARSEMultStruct() : matDescr(NULL) 269 { 270 for (int i = 0; i < 3; i++) hipSpMV[i].initialized = PETSC_FALSE; 271 } 272 }; 273 274 /* This is a larger struct holding all the matrices for a SpMV, and SpMV Transpose */ 275 struct Mat_SeqAIJHIPSPARSE { 276 Mat_SeqAIJHIPSPARSEMultStruct *mat; /* pointer to the matrix on the GPU */ 277 Mat_SeqAIJHIPSPARSEMultStruct *matTranspose; /* pointer to the matrix on the GPU (for the transpose ... useful for BiCG) */ 278 THRUSTARRAY *workVector; /* pointer to a workvector to which we can copy the relevant indices of a vector we want to multiply */ 279 THRUSTINTARRAY32 *rowoffsets_gpu; /* rowoffsets on GPU in non-compressed-row format. It is used to convert CSR to CSC */ 280 PetscInt nrows; /* number of rows of the matrix seen by GPU */ 281 MatHIPSPARSEStorageFormat format; /* the storage format for the matrix on the device */ 282 PetscBool use_cpu_solve; /* Use AIJ_Seq (I)LU solve */ 283 hipStream_t stream; /* a stream for the parallel SpMV ... this is not owned and should not be deleted */ 284 hipsparseHandle_t handle; /* a handle to the cusparse library ... this may not be owned (if we're working in parallel i.e. multiGPUs) */ 285 PetscObjectState nonzerostate; /* track nonzero state to possibly recreate the GPU matrix */ 286 size_t csr2cscBufferSize; /* stuff used to compute the matTranspose above */ 287 void *csr2cscBuffer; /* This is used as a C struct and is calloc'ed by PetscNewLog() */ 288 // hipsparseCsr2CscAlg_t csr2cscAlg; /* algorithms can be selected from command line options */ 289 hipsparseSpMVAlg_t spmvAlg; 290 hipsparseSpMMAlg_t spmmAlg; 291 THRUSTINTARRAY *csr2csc_i; 292 PetscSplitCSRDataStructure deviceMat; /* Matrix on device for, eg, assembly */ 293 THRUSTINTARRAY *cooPerm; /* permutation array that sorts the input coo entris by row and col */ 294 THRUSTINTARRAY *cooPerm_a; /* ordered array that indicate i-th nonzero (after sorting) is the j-th unique nonzero */ 295 296 /* Stuff for extended COO support */ 297 PetscBool use_extended_coo; /* Use extended COO format */ 298 PetscCount *jmap_d; /* perm[disp+jmap[i]..disp+jmap[i+1]) gives indices of entries in v[] associated with i-th nonzero of the matrix */ 299 PetscCount *perm_d; 300 301 Mat_SeqAIJHIPSPARSE() : use_extended_coo(PETSC_FALSE), jmap_d(NULL), perm_d(NULL) { } 302 }; 303 304 typedef struct Mat_SeqAIJHIPSPARSETriFactors *Mat_SeqAIJHIPSPARSETriFactors_p; 305 306 PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat); 307 PETSC_INTERN PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE_Basic(Mat, PetscCount, PetscInt[], PetscInt[]); 308 PETSC_INTERN PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE_Basic(Mat, const PetscScalar[], InsertMode); 309 PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat, Mat, MatReuse, Mat *); 310 PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *); 311 312 using VecSeq_HIP = Petsc::vec::cupm::impl::VecSeq_CUPM<Petsc::device::cupm::DeviceType::HIP>; 313 314 static inline bool isHipMem(const void *data) 315 { 316 using namespace Petsc::device::cupm; 317 auto mtype = PETSC_MEMTYPE_HOST; 318 319 PetscFunctionBegin; 320 PetscCallAbort(PETSC_COMM_SELF, impl::Interface<DeviceType::HIP>::PetscCUPMGetMemType(data, &mtype)); 321 PetscFunctionReturn(PetscMemTypeDevice(mtype)); 322 } 323 324 #endif // PETSC_HIPSPARSEIMPL_H 325