147d993e7Ssuyashtn /* Portions of this code are under:
247d993e7Ssuyashtn Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
347d993e7Ssuyashtn */
4a4963045SJacob Faibussowitsch #pragma once
547d993e7Ssuyashtn
647d993e7Ssuyashtn #include <petscpkg_version.h>
76d54fb17SJacob Faibussowitsch #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp> /* for VecSeq_CUPM */
847f8145dSJacob Faibussowitsch #include <../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp>
96d54fb17SJacob Faibussowitsch #include <petsc/private/petsclegacycupmblas.h>
1047d993e7Ssuyashtn
1147d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
1247d993e7Ssuyashtn #include <hipsparse/hipsparse.h>
1347d993e7Ssuyashtn #else /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */
1447d993e7Ssuyashtn #include <hipsparse.h>
1547d993e7Ssuyashtn #endif /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */
1647d993e7Ssuyashtn #include "hip/hip_runtime.h"
1747d993e7Ssuyashtn
1847d993e7Ssuyashtn #include <algorithm>
1947d993e7Ssuyashtn #include <vector>
2047d993e7Ssuyashtn
2147d993e7Ssuyashtn #include <thrust/device_vector.h>
2247d993e7Ssuyashtn #include <thrust/device_ptr.h>
2347d993e7Ssuyashtn #include <thrust/device_malloc_allocator.h>
2447d993e7Ssuyashtn #include <thrust/transform.h>
2547d993e7Ssuyashtn #include <thrust/functional.h>
2647d993e7Ssuyashtn #include <thrust/sequence.h>
2747d993e7Ssuyashtn #include <thrust/system/system_error.h>
2847d993e7Ssuyashtn
2947d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX)
3047d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
3147d993e7Ssuyashtn const hipComplex PETSC_HIPSPARSE_ONE = {1.0f, 0.0f};
3247d993e7Ssuyashtn const hipComplex PETSC_HIPSPARSE_ZERO = {0.0f, 0.0f};
3347d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsrilu02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i)
3447d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j)
3547d993e7Ssuyashtn #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02(a, b, c, d, (hipComplex *)e, f, g, h, i, j)
3647d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsric02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i)
3747d993e7Ssuyashtn #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j)
3847d993e7Ssuyashtn #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02(a, b, c, d, (hipComplex *)e, f, g, h, i, j)
3947d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
4047d993e7Ssuyashtn const hipDoubleComplex PETSC_HIPSPARSE_ONE = {1.0, 0.0};
4147d993e7Ssuyashtn const hipDoubleComplex PETSC_HIPSPARSE_ZERO = {0.0, 0.0};
4247d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsrilu02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i)
4347d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j)
4447d993e7Ssuyashtn #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j)
4547d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsric02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i)
4647d993e7Ssuyashtn #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j)
4747d993e7Ssuyashtn #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j)
4847d993e7Ssuyashtn #endif /* Single or double */
4947d993e7Ssuyashtn #else /* not complex */
5047d993e7Ssuyashtn const PetscScalar PETSC_HIPSPARSE_ONE = 1.0;
5147d993e7Ssuyashtn const PetscScalar PETSC_HIPSPARSE_ZERO = 0.0;
5247d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
5347d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize hipsparseScsrilu02_bufferSize
5447d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis hipsparseScsrilu02_analysis
5547d993e7Ssuyashtn #define hipsparseXcsrilu02 hipsparseScsrilu02
5647d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize hipsparseScsric02_bufferSize
5747d993e7Ssuyashtn #define hipsparseXcsric02_analysis hipsparseScsric02_analysis
5847d993e7Ssuyashtn #define hipsparseXcsric02 hipsparseScsric02
5947d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
6047d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize hipsparseDcsrilu02_bufferSize
6147d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis hipsparseDcsrilu02_analysis
6247d993e7Ssuyashtn #define hipsparseXcsrilu02 hipsparseDcsrilu02
6347d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize hipsparseDcsric02_bufferSize
6447d993e7Ssuyashtn #define hipsparseXcsric02_analysis hipsparseDcsric02_analysis
6547d993e7Ssuyashtn #define hipsparseXcsric02 hipsparseDcsric02
6647d993e7Ssuyashtn #endif /* Single or double */
6747d993e7Ssuyashtn #endif /* complex or not */
6847d993e7Ssuyashtn
6947d993e7Ssuyashtn #define csrsvInfo_t csrsv2Info_t
7047d993e7Ssuyashtn #define hipsparseCreateCsrsvInfo hipsparseCreateCsrsv2Info
7147d993e7Ssuyashtn #define hipsparseDestroyCsrsvInfo hipsparseDestroyCsrsv2Info
7247d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX)
7347d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
7447d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrsv2_bufferSize(a, b, c, d, e, (hipComplex *)(f), g, h, i, j)
7547d993e7Ssuyashtn #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseCcsrsv2_analysis(a, b, c, d, e, (const hipComplex *)(f), g, h, i, j, k)
7647d993e7Ssuyashtn #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseCcsrsv2_solve(a, b, c, d, (const hipComplex *)(e), f, (const hipComplex *)(g), h, i, j, (const hipComplex *)(k), (hipComplex *)(l), m, n)
7747d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
7847d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrsv2_bufferSize(a, b, c, d, e, (hipDoubleComplex *)(f), g, h, i, j)
7947d993e7Ssuyashtn #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseZcsrsv2_analysis(a, b, c, d, e, (const hipDoubleComplex *)(f), g, h, i, j, k)
8047d993e7Ssuyashtn #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseZcsrsv2_solve(a, b, c, d, (const hipDoubleComplex *)(e), f, (const hipDoubleComplex *)(g), h, i, j, (const hipDoubleComplex *)(k), (hipDoubleComplex *)(l), m, n)
8147d993e7Ssuyashtn #endif /* Single or double */
8247d993e7Ssuyashtn #else /* not complex */
8347d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
8447d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize hipsparseScsrsv2_bufferSize
8547d993e7Ssuyashtn #define hipsparseXcsrsv_analysis hipsparseScsrsv2_analysis
8647d993e7Ssuyashtn #define hipsparseXcsrsv_solve hipsparseScsrsv2_solve
8747d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
8847d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize hipsparseDcsrsv2_bufferSize
8947d993e7Ssuyashtn #define hipsparseXcsrsv_analysis hipsparseDcsrsv2_analysis
9047d993e7Ssuyashtn #define hipsparseXcsrsv_solve hipsparseDcsrsv2_solve
9147d993e7Ssuyashtn #endif /* Single or double */
9247d993e7Ssuyashtn #endif /* not complex */
9347d993e7Ssuyashtn
9447d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
9547d993e7Ssuyashtn // #define cusparse_csr2csc cusparseCsr2cscEx2
9647d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX)
9747d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
9847d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_32F
9947d993e7Ssuyashtn #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgeam2(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t)
10047d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \
10147d993e7Ssuyashtn hipsparseCcsrgeam2_bufferSizeExt(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t)
10247d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
10347d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_64F
10447d993e7Ssuyashtn #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \
10547d993e7Ssuyashtn hipsparseZcsrgeam2(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t)
10647d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \
10747d993e7Ssuyashtn hipsparseZcsrgeam2_bufferSizeExt(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t)
10847d993e7Ssuyashtn #endif /* Single or double */
10947d993e7Ssuyashtn #else /* not complex */
11047d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
11147d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_32F
11247d993e7Ssuyashtn #define hipsparse_csr_spgeam hipsparseScsrgeam2
11347d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize hipsparseScsrgeam2_bufferSizeExt
11447d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
11547d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_64F
11647d993e7Ssuyashtn #define hipsparse_csr_spgeam hipsparseDcsrgeam2
11747d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize hipsparseDcsrgeam2_bufferSizeExt
11847d993e7Ssuyashtn #endif /* Single or double */
11947d993e7Ssuyashtn #endif /* complex or not */
12047d993e7Ssuyashtn #endif /* PETSC_PKG_HIP_VERSION_GE(4, 5, 0) */
12147d993e7Ssuyashtn
12247d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX)
12347d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
12447d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_32F
12547d993e7Ssuyashtn #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) hipsparseCcsrmv((a), (b), (c), (d), (e), (hipComplex *)(f), (g), (hipComplex *)(h), (i), (j), (hipComplex *)(k), (hipComplex *)(l), (hipComplex *)(m))
12647d993e7Ssuyashtn #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) hipsparseCcsrmm((a), (b), (c), (d), (e), (f), (hipComplex *)(g), (h), (hipComplex *)(i), (j), (k), (hipComplex *)(l), (m), (hipComplex *)(n), (hipComplex *)(o), (p))
12747d993e7Ssuyashtn #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseCcsr2csc((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (hipComplex *)(h), (i), (j), (k), (l))
12847d993e7Ssuyashtn #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseChybmv((a), (b), (hipComplex *)(c), (d), (e), (hipComplex *)(f), (hipComplex *)(g), (hipComplex *)(h))
12947d993e7Ssuyashtn #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseCcsr2hyb((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (h), (i), (j))
13047d993e7Ssuyashtn #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseChyb2csr((a), (b), (c), (hipComplex *)(d), (e), (f))
13147d993e7Ssuyashtn #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgemm(a, b, c, d, e, f, g, h, (hipComplex *)i, j, k, l, m, (hipComplex *)n, o, p, q, (hipComplex *)r, s, t)
13247d993e7Ssuyashtn // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseCcsrgeam(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s)
13347d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
13447d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_64F
13559194e33SJacob Faibussowitsch #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) hipsparseZcsrmv((a), (b), (c), (d), (e), (hipDoubleComplex *)(f), (g), (hipDoubleComplex *)(h), (i), (j), (hipDoubleComplex *)(k), (hipDoubleComplex *)(l), (hipDoubleComplex *)(m))
13647d993e7Ssuyashtn #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) \
13747d993e7Ssuyashtn hipsparseZcsrmm((a), (b), (c), (d), (e), (f), (hipDoubleComplex *)(g), (h), (hipDoubleComplex *)(i), (j), (k), (hipDoubleComplex *)(l), (m), (hipDoubleComplex *)(n), (hipDoubleComplex *)(o), (p))
13847d993e7Ssuyashtn #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseZcsr2csc((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (hipDoubleComplex *)(h), (i), (j), (k), (l))
13947d993e7Ssuyashtn #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseZhybmv((a), (b), (hipDoubleComplex *)(c), (d), (e), (hipDoubleComplex *)(f), (hipDoubleComplex *)(g), (hipDoubleComplex *)(h))
14047d993e7Ssuyashtn #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseZcsr2hyb((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (h), (i), (j))
14147d993e7Ssuyashtn #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseZhyb2csr((a), (b), (c), (hipDoubleComplex *)(d), (e), (f))
14247d993e7Ssuyashtn #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseZcsrgemm(a, b, c, d, e, f, g, h, (hipDoubleComplex *)i, j, k, l, m, (hipDoubleComplex *)n, o, p, q, (hipDoubleComplex *)r, s, t)
14347d993e7Ssuyashtn // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseZcsrgeam(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s)
14447d993e7Ssuyashtn #endif /* Single or double */
14547d993e7Ssuyashtn #else /* not complex */
14647d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE)
14747d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_32F
14847d993e7Ssuyashtn #define hipsparse_csr_spmv hipsparseScsrmv
14947d993e7Ssuyashtn #define hipsparse_csr_spmm hipsparseScsrmm
15047d993e7Ssuyashtn #define hipsparse_csr2csc hipsparseScsr2csc
15147d993e7Ssuyashtn #define hipsparse_hyb_spmv hipsparseShybmv
15247d993e7Ssuyashtn #define hipsparse_csr2hyb hipsparseScsr2hyb
15347d993e7Ssuyashtn #define hipsparse_hyb2csr hipsparseShyb2csr
15447d993e7Ssuyashtn #define hipsparse_csr_spgemm hipsparseScsrgemm
15547d993e7Ssuyashtn // #define hipsparse_csr_spgeam hipsparseScsrgeam
15647d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE)
15747d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_64F
15847d993e7Ssuyashtn #define hipsparse_csr_spmv hipsparseDcsrmv
15947d993e7Ssuyashtn #define hipsparse_csr_spmm hipsparseDcsrmm
16047d993e7Ssuyashtn #define hipsparse_csr2csc hipsparseDcsr2csc
16147d993e7Ssuyashtn #define hipsparse_hyb_spmv hipsparseDhybmv
16247d993e7Ssuyashtn #define hipsparse_csr2hyb hipsparseDcsr2hyb
16347d993e7Ssuyashtn #define hipsparse_hyb2csr hipsparseDhyb2csr
16447d993e7Ssuyashtn #define hipsparse_csr_spgemm hipsparseDcsrgemm
16547d993e7Ssuyashtn // #define hipsparse_csr_spgeam hipsparseDcsrgeam
16647d993e7Ssuyashtn #endif /* Single or double */
16747d993e7Ssuyashtn #endif /* complex or not */
16847d993e7Ssuyashtn
16947d993e7Ssuyashtn #define THRUSTINTARRAY32 thrust::device_vector<int>
17047d993e7Ssuyashtn #define THRUSTINTARRAY thrust::device_vector<PetscInt>
17147d993e7Ssuyashtn #define THRUSTARRAY thrust::device_vector<PetscScalar>
17247d993e7Ssuyashtn
173*0b4b7b1cSBarry Smith /* A CSR matrix nonzero structure */
17447d993e7Ssuyashtn struct CsrMatrix {
17547d993e7Ssuyashtn PetscInt num_rows;
17647d993e7Ssuyashtn PetscInt num_cols;
17747d993e7Ssuyashtn PetscInt num_entries;
17847d993e7Ssuyashtn THRUSTINTARRAY32 *row_offsets;
17947d993e7Ssuyashtn THRUSTINTARRAY32 *column_indices;
18047d993e7Ssuyashtn THRUSTARRAY *values;
18147d993e7Ssuyashtn };
18247d993e7Ssuyashtn
18347d993e7Ssuyashtn /* This is struct holding the relevant data needed to a MatSolve */
18447d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSETriFactorStruct {
18547d993e7Ssuyashtn /* Data needed for triangular solve */
18647d993e7Ssuyashtn hipsparseMatDescr_t descr;
18747d993e7Ssuyashtn hipsparseOperation_t solveOp;
18847d993e7Ssuyashtn CsrMatrix *csrMat;
18947d993e7Ssuyashtn csrsvInfo_t solveInfo;
19047d993e7Ssuyashtn hipsparseSolvePolicy_t solvePolicy; /* whether level information is generated and used */
19147d993e7Ssuyashtn int solveBufferSize;
19247d993e7Ssuyashtn void *solveBuffer;
19347d993e7Ssuyashtn size_t csr2cscBufferSize; /* to transpose the triangular factor (only used for CUDA >= 11.0) */
19447d993e7Ssuyashtn void *csr2cscBuffer;
19547d993e7Ssuyashtn PetscScalar *AA_h; /* managed host buffer for moving values to the GPU */
19647d993e7Ssuyashtn };
19747d993e7Ssuyashtn
19847d993e7Ssuyashtn /* This is a larger struct holding all the triangular factors for a solve, transpose solve, and any indices used in a reordering */
19947d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSETriFactors {
20047d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtr; /* pointer for lower triangular (factored matrix) on GPU */
20147d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtr; /* pointer for upper triangular (factored matrix) on GPU */
20247d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtrTranspose; /* pointer for lower triangular (factored matrix) on GPU for the transpose (useful for BiCG) */
20347d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtrTranspose; /* pointer for upper triangular (factored matrix) on GPU for the transpose (useful for BiCG)*/
20447d993e7Ssuyashtn THRUSTINTARRAY *rpermIndices; /* indices used for any reordering */
20547d993e7Ssuyashtn THRUSTINTARRAY *cpermIndices; /* indices used for any reordering */
20647d993e7Ssuyashtn THRUSTARRAY *workVector;
20747d993e7Ssuyashtn hipsparseHandle_t handle; /* a handle to the hipsparse library */
20847d993e7Ssuyashtn PetscInt nnz; /* number of nonzeros ... need this for accurate logging between ICC and ILU */
20947d993e7Ssuyashtn hipDeviceProp_t dev_prop;
21047d993e7Ssuyashtn PetscBool init_dev_prop;
21147d993e7Ssuyashtn
21247d993e7Ssuyashtn /* csrilu0/csric0 appeared in earlier versions of AMD ROCm^{TM}, but we use it along with hipsparseSpSV,
21347d993e7Ssuyashtn which first appeared in hipsparse with ROCm-4.5.0.
21447d993e7Ssuyashtn */
21547d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
21647d993e7Ssuyashtn PetscScalar *csrVal;
21747d993e7Ssuyashtn int *csrRowPtr, *csrColIdx; /* a,i,j of M. Using int since some hipsparse APIs only support 32-bit indices */
21847d993e7Ssuyashtn
21947d993e7Ssuyashtn /* Mixed mat descriptor types? yes, different hipsparse APIs use different types */
22047d993e7Ssuyashtn hipsparseMatDescr_t matDescr_M;
22147d993e7Ssuyashtn hipsparseSpMatDescr_t spMatDescr_L, spMatDescr_U;
22247d993e7Ssuyashtn hipsparseSpSVDescr_t spsvDescr_L, spsvDescr_Lt, spsvDescr_U, spsvDescr_Ut;
22347d993e7Ssuyashtn
22447d993e7Ssuyashtn hipsparseDnVecDescr_t dnVecDescr_X, dnVecDescr_Y;
22547d993e7Ssuyashtn PetscScalar *X, *Y; /* data array of dnVec X and Y */
22647d993e7Ssuyashtn
22747d993e7Ssuyashtn /* Mixed size types? yes */
22847d993e7Ssuyashtn int factBufferSize_M; /* M ~= LU or LLt */
22947d993e7Ssuyashtn size_t spsvBufferSize_L, spsvBufferSize_Lt, spsvBufferSize_U, spsvBufferSize_Ut;
23047d993e7Ssuyashtn /* hipsparse needs various buffers for factorization and solve of L, U, Lt, or Ut.
23147d993e7Ssuyashtn To save memory, we share the factorization buffer with one of spsvBuffer_L/U.
23247d993e7Ssuyashtn */
23347d993e7Ssuyashtn void *factBuffer_M, *spsvBuffer_L, *spsvBuffer_U, *spsvBuffer_Lt, *spsvBuffer_Ut;
23447d993e7Ssuyashtn
23547d993e7Ssuyashtn csrilu02Info_t ilu0Info_M;
23647d993e7Ssuyashtn csric02Info_t ic0Info_M;
23747d993e7Ssuyashtn int structural_zero, numerical_zero;
23847d993e7Ssuyashtn hipsparseSolvePolicy_t policy_M;
23947d993e7Ssuyashtn
24047d993e7Ssuyashtn /* In MatSolveTranspose() for ILU0, we use the two flags to do on-demand solve */
24147d993e7Ssuyashtn PetscBool createdTransposeSpSVDescr; /* Have we created SpSV descriptors for Lt, Ut? */
24247d993e7Ssuyashtn PetscBool updatedTransposeSpSVAnalysis; /* Have we updated SpSV analysis with the latest L, U values? */
24347d993e7Ssuyashtn
24447d993e7Ssuyashtn PetscLogDouble numericFactFlops; /* Estimated FLOPs in ILU0/ICC0 numeric factorization */
24547d993e7Ssuyashtn #endif
24647d993e7Ssuyashtn };
24747d993e7Ssuyashtn
24847d993e7Ssuyashtn struct Mat_HipsparseSpMV {
24947d993e7Ssuyashtn PetscBool initialized; /* Don't rely on spmvBuffer != NULL to test if the struct is initialized, */
25047d993e7Ssuyashtn size_t spmvBufferSize; /* since I'm not sure if smvBuffer can be NULL even after hipsparseSpMV_bufferSize() */
25147d993e7Ssuyashtn void *spmvBuffer;
25247d993e7Ssuyashtn hipsparseDnVecDescr_t vecXDescr, vecYDescr; /* descriptor for the dense vectors in y=op(A)x */
25347d993e7Ssuyashtn };
25447d993e7Ssuyashtn
25547d993e7Ssuyashtn /* This is struct holding the relevant data needed to a MatMult */
25647d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSEMultStruct {
25747d993e7Ssuyashtn void *mat; /* opaque pointer to a matrix. This could be either a hipsparseHybMat_t or a CsrMatrix */
25847d993e7Ssuyashtn hipsparseMatDescr_t descr; /* Data needed to describe the matrix for a multiply */
25947d993e7Ssuyashtn THRUSTINTARRAY *cprowIndices; /* compressed row indices used in the parallel SpMV */
26047d993e7Ssuyashtn PetscScalar *alpha_one; /* pointer to a device "scalar" storing the alpha parameter in the SpMV */
26147d993e7Ssuyashtn PetscScalar *beta_zero; /* pointer to a device "scalar" storing the beta parameter in the SpMV as zero*/
26247d993e7Ssuyashtn PetscScalar *beta_one; /* pointer to a device "scalar" storing the beta parameter in the SpMV as one */
26347d993e7Ssuyashtn hipsparseSpMatDescr_t matDescr; /* descriptor for the matrix, used by SpMV and SpMM */
26447d993e7Ssuyashtn Mat_HipsparseSpMV hipSpMV[3]; /* different Mat_CusparseSpMV structs for non-transpose, transpose, conj-transpose */
Mat_SeqAIJHIPSPARSEMultStructMat_SeqAIJHIPSPARSEMultStruct26547d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct() : matDescr(NULL)
26647d993e7Ssuyashtn {
26747d993e7Ssuyashtn for (int i = 0; i < 3; i++) hipSpMV[i].initialized = PETSC_FALSE;
26847d993e7Ssuyashtn }
26947d993e7Ssuyashtn };
27047d993e7Ssuyashtn
27147d993e7Ssuyashtn /* This is a larger struct holding all the matrices for a SpMV, and SpMV Transpose */
27247d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSE {
27347d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct *mat; /* pointer to the matrix on the GPU */
27447d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct *matTranspose; /* pointer to the matrix on the GPU (for the transpose ... useful for BiCG) */
27547d993e7Ssuyashtn THRUSTARRAY *workVector; /* pointer to a workvector to which we can copy the relevant indices of a vector we want to multiply */
27647d993e7Ssuyashtn THRUSTINTARRAY32 *rowoffsets_gpu; /* rowoffsets on GPU in non-compressed-row format. It is used to convert CSR to CSC */
27747d993e7Ssuyashtn PetscInt nrows; /* number of rows of the matrix seen by GPU */
27847d993e7Ssuyashtn MatHIPSPARSEStorageFormat format; /* the storage format for the matrix on the device */
27947d993e7Ssuyashtn PetscBool use_cpu_solve; /* Use AIJ_Seq (I)LU solve */
28047d993e7Ssuyashtn hipStream_t stream; /* a stream for the parallel SpMV ... this is not owned and should not be deleted */
28147d993e7Ssuyashtn hipsparseHandle_t handle; /* a handle to the cusparse library ... this may not be owned (if we're working in parallel i.e. multiGPUs) */
28247d993e7Ssuyashtn PetscObjectState nonzerostate; /* track nonzero state to possibly recreate the GPU matrix */
28347d993e7Ssuyashtn size_t csr2cscBufferSize; /* stuff used to compute the matTranspose above */
28447d993e7Ssuyashtn void *csr2cscBuffer; /* This is used as a C struct and is calloc'ed by PetscNewLog() */
28547d993e7Ssuyashtn // hipsparseCsr2CscAlg_t csr2cscAlg; /* algorithms can be selected from command line options */
28647d993e7Ssuyashtn hipsparseSpMVAlg_t spmvAlg;
28747d993e7Ssuyashtn hipsparseSpMMAlg_t spmmAlg;
28847d993e7Ssuyashtn THRUSTINTARRAY *csr2csc_i;
2892c4ab24aSJunchao Zhang THRUSTINTARRAY *coords; /* permutation array used in MatSeqAIJHIPSPARSEMergeMats */
29047d993e7Ssuyashtn };
29147d993e7Ssuyashtn
29247d993e7Ssuyashtn typedef struct Mat_SeqAIJHIPSPARSETriFactors *Mat_SeqAIJHIPSPARSETriFactors_p;
29347d993e7Ssuyashtn
29447d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat);
29547d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat, Mat, MatReuse, Mat *);
29647d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *);
29747d993e7Ssuyashtn
2986d54fb17SJacob Faibussowitsch using VecSeq_HIP = Petsc::vec::cupm::impl::VecSeq_CUPM<Petsc::device::cupm::DeviceType::HIP>;
2996d54fb17SJacob Faibussowitsch
isHipMem(const void * data)30047d993e7Ssuyashtn static inline bool isHipMem(const void *data)
30147d993e7Ssuyashtn {
3026d54fb17SJacob Faibussowitsch using namespace Petsc::device::cupm;
3036d54fb17SJacob Faibussowitsch auto mtype = PETSC_MEMTYPE_HOST;
3046d54fb17SJacob Faibussowitsch
3056d54fb17SJacob Faibussowitsch PetscFunctionBegin;
3066d54fb17SJacob Faibussowitsch PetscCallAbort(PETSC_COMM_SELF, impl::Interface<DeviceType::HIP>::PetscCUPMGetMemType(data, &mtype));
3076d54fb17SJacob Faibussowitsch PetscFunctionReturn(PetscMemTypeDevice(mtype));
30847d993e7Ssuyashtn }
309