1d52a580bSJunchao Zhang /*
2d52a580bSJunchao Zhang Defines the basic matrix operations for the AIJ (compressed row)
3d52a580bSJunchao Zhang matrix storage format using the HIPSPARSE library,
4d52a580bSJunchao Zhang Portions of this code are under:
5d52a580bSJunchao Zhang Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
6d52a580bSJunchao Zhang */
7d52a580bSJunchao Zhang #include <petscconf.h>
8d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
9d52a580bSJunchao Zhang #include <../src/mat/impls/sbaij/seq/sbaij.h>
10d52a580bSJunchao Zhang #include <../src/mat/impls/dense/seq/dense.h> // MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal()
11d52a580bSJunchao Zhang #include <../src/vec/vec/impls/dvecimpl.h>
12d52a580bSJunchao Zhang #include <petsc/private/vecimpl.h>
13d52a580bSJunchao Zhang #undef VecType
14d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/seqhipsparse/hipsparsematimpl.h>
15d52a580bSJunchao Zhang #include <thrust/adjacent_difference.h>
16d52a580bSJunchao Zhang #include <thrust/iterator/transform_iterator.h>
17d52a580bSJunchao Zhang #if PETSC_CPP_VERSION >= 14
18d52a580bSJunchao Zhang #define PETSC_HAVE_THRUST_ASYNC 1
19d52a580bSJunchao Zhang #include <thrust/async/for_each.h>
20d52a580bSJunchao Zhang #endif
21d52a580bSJunchao Zhang #include <thrust/iterator/constant_iterator.h>
22d52a580bSJunchao Zhang #include <thrust/iterator/discard_iterator.h>
23d52a580bSJunchao Zhang #include <thrust/binary_search.h>
24d52a580bSJunchao Zhang #include <thrust/remove.h>
25d52a580bSJunchao Zhang #include <thrust/sort.h>
26d52a580bSJunchao Zhang #include <thrust/unique.h>
27d52a580bSJunchao Zhang
28d52a580bSJunchao Zhang const char *const MatHIPSPARSEStorageFormats[] = {"CSR", "ELL", "HYB", "MatHIPSPARSEStorageFormat", "MAT_HIPSPARSE_", 0};
29d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMVAlgorithms[] = {"MV_ALG_DEFAULT", "COOMV_ALG", "CSRMV_ALG1", "CSRMV_ALG2", "SPMV_ALG_DEFAULT", "SPMV_COO_ALG1", "SPMV_COO_ALG2", "SPMV_CSR_ALG1", "SPMV_CSR_ALG2", "hipsparseSpMVAlg_t", "HIPSPARSE_", 0};
30d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMMAlgorithms[] = {"ALG_DEFAULT", "COO_ALG1", "COO_ALG2", "COO_ALG3", "CSR_ALG1", "COO_ALG4", "CSR_ALG2", "hipsparseSpMMAlg_t", "HIPSPARSE_SPMM_", 0};
31d52a580bSJunchao Zhang //const char *const MatHIPSPARSECsr2CscAlgorithms[] = {"INVALID"/*HIPSPARSE does not have enum 0! We created one*/, "ALG1", "ALG2", "hipsparseCsr2CscAlg_t", "HIPSPARSE_CSR2CSC_", 0};
32d52a580bSJunchao Zhang
33d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
34d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
35d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
36d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
37d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
38d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
39d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat, Vec, Vec);
40d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
41d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
42d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
43d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat, PetscOptionItems PetscOptionsObject);
44d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat, PetscScalar, Mat, MatStructure);
45d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat, PetscScalar);
46d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat, Vec, Vec);
47d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
48d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
49d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
50d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
51d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
52d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec, PetscBool, PetscBool);
53d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **);
54d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **);
55d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **, MatHIPSPARSEStorageFormat);
56d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **);
57d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat);
58d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat);
59d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat);
60d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat, PetscBool);
61d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat, PetscInt, const PetscInt[], PetscScalar[]);
62d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat, PetscBool);
63d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat, PetscCount, PetscInt[], PetscInt[]);
64d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat, const PetscScalar[], InsertMode);
65d52a580bSJunchao Zhang
66d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_SeqAIJ_SeqDense(Mat);
67d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat, MatType, MatReuse, Mat *);
68d52a580bSJunchao Zhang
69d52a580bSJunchao Zhang /*
70d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetStream(Mat A, const hipStream_t stream)
71d52a580bSJunchao Zhang {
72d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
73d52a580bSJunchao Zhang
74d52a580bSJunchao Zhang PetscFunctionBegin;
75d52a580bSJunchao Zhang PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
76d52a580bSJunchao Zhang hipsparsestruct->stream = stream;
77d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(hipsparsestruct->handle, hipsparsestruct->stream));
78d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
79d52a580bSJunchao Zhang }
80d52a580bSJunchao Zhang
81d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetHandle(Mat A, const hipsparseHandle_t handle)
82d52a580bSJunchao Zhang {
83d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
84d52a580bSJunchao Zhang
85d52a580bSJunchao Zhang PetscFunctionBegin;
86d52a580bSJunchao Zhang PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
87d52a580bSJunchao Zhang if (hipsparsestruct->handle != handle) {
88d52a580bSJunchao Zhang if (hipsparsestruct->handle) PetscCallHIPSPARSE(hipsparseDestroy(hipsparsestruct->handle));
89d52a580bSJunchao Zhang hipsparsestruct->handle = handle;
90d52a580bSJunchao Zhang }
91d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
92d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
93d52a580bSJunchao Zhang }
94d52a580bSJunchao Zhang
95d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSEClearHandle(Mat A)
96d52a580bSJunchao Zhang {
97d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
98d52a580bSJunchao Zhang PetscBool flg;
99d52a580bSJunchao Zhang
100d52a580bSJunchao Zhang PetscFunctionBegin;
101d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
102d52a580bSJunchao Zhang if (!flg || !hipsparsestruct) PetscFunctionReturn(PETSC_SUCCESS);
103d52a580bSJunchao Zhang if (hipsparsestruct->handle) hipsparsestruct->handle = 0;
104d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
105d52a580bSJunchao Zhang }
106d52a580bSJunchao Zhang */
107d52a580bSJunchao Zhang
MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A,MatHIPSPARSEFormatOperation op,MatHIPSPARSEStorageFormat format)108d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
109d52a580bSJunchao Zhang {
110d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
111d52a580bSJunchao Zhang
112d52a580bSJunchao Zhang PetscFunctionBegin;
113d52a580bSJunchao Zhang switch (op) {
114d52a580bSJunchao Zhang case MAT_HIPSPARSE_MULT:
115d52a580bSJunchao Zhang hipsparsestruct->format = format;
116d52a580bSJunchao Zhang break;
117d52a580bSJunchao Zhang case MAT_HIPSPARSE_ALL:
118d52a580bSJunchao Zhang hipsparsestruct->format = format;
119d52a580bSJunchao Zhang break;
120d52a580bSJunchao Zhang default:
121d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "unsupported operation %d for MatHIPSPARSEFormatOperation. MAT_HIPSPARSE_MULT and MAT_HIPSPARSE_ALL are currently supported.", op);
122d52a580bSJunchao Zhang }
123d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
124d52a580bSJunchao Zhang }
125d52a580bSJunchao Zhang
126d52a580bSJunchao Zhang /*@
127d52a580bSJunchao Zhang MatHIPSPARSESetFormat - Sets the storage format of `MATSEQHIPSPARSE` matrices for a particular
128d52a580bSJunchao Zhang operation. Only the `MatMult()` operation can use different GPU storage formats
129d52a580bSJunchao Zhang
130d52a580bSJunchao Zhang Not Collective
131d52a580bSJunchao Zhang
132d52a580bSJunchao Zhang Input Parameters:
133d52a580bSJunchao Zhang + A - Matrix of type `MATSEQAIJHIPSPARSE`
134d52a580bSJunchao Zhang . op - `MatHIPSPARSEFormatOperation`. `MATSEQAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT` and `MAT_HIPSPARSE_ALL`.
135d52a580bSJunchao Zhang `MATMPIAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT_DIAG`, `MAT_HIPSPARSE_MULT_OFFDIAG`, and `MAT_HIPSPARSE_ALL`.
136d52a580bSJunchao Zhang - format - `MatHIPSPARSEStorageFormat` (one of `MAT_HIPSPARSE_CSR`, `MAT_HIPSPARSE_ELL`, `MAT_HIPSPARSE_HYB`.)
137d52a580bSJunchao Zhang
138d52a580bSJunchao Zhang Level: intermediate
139d52a580bSJunchao Zhang
140d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
141d52a580bSJunchao Zhang @*/
MatHIPSPARSESetFormat(Mat A,MatHIPSPARSEFormatOperation op,MatHIPSPARSEStorageFormat format)142d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetFormat(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
143d52a580bSJunchao Zhang {
144d52a580bSJunchao Zhang PetscFunctionBegin;
145d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
146d52a580bSJunchao Zhang PetscTryMethod(A, "MatHIPSPARSESetFormat_C", (Mat, MatHIPSPARSEFormatOperation, MatHIPSPARSEStorageFormat), (A, op, format));
147d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
148d52a580bSJunchao Zhang }
149d52a580bSJunchao Zhang
MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A,PetscBool use_cpu)150d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A, PetscBool use_cpu)
151d52a580bSJunchao Zhang {
152d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
153d52a580bSJunchao Zhang
154d52a580bSJunchao Zhang PetscFunctionBegin;
155d52a580bSJunchao Zhang hipsparsestruct->use_cpu_solve = use_cpu;
156d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
157d52a580bSJunchao Zhang }
158d52a580bSJunchao Zhang
159d52a580bSJunchao Zhang /*@
160d52a580bSJunchao Zhang MatHIPSPARSESetUseCPUSolve - Sets use CPU `MatSolve()`.
161d52a580bSJunchao Zhang
162d52a580bSJunchao Zhang Input Parameters:
163d52a580bSJunchao Zhang + A - Matrix of type `MATSEQAIJHIPSPARSE`
164d52a580bSJunchao Zhang - use_cpu - set flag for using the built-in CPU `MatSolve()`
165d52a580bSJunchao Zhang
166d52a580bSJunchao Zhang Level: intermediate
167d52a580bSJunchao Zhang
168d52a580bSJunchao Zhang Notes:
169d52a580bSJunchao Zhang The hipSparse LU solver currently computes the factors with the built-in CPU method
170d52a580bSJunchao Zhang and moves the factors to the GPU for the solve. We have observed better performance keeping the data on the CPU and computing the solve there.
171d52a580bSJunchao Zhang This method to specifies if the solve is done on the CPU or GPU (GPU is the default).
172d52a580bSJunchao Zhang
173d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSolve()`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
174d52a580bSJunchao Zhang @*/
MatHIPSPARSESetUseCPUSolve(Mat A,PetscBool use_cpu)175d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetUseCPUSolve(Mat A, PetscBool use_cpu)
176d52a580bSJunchao Zhang {
177d52a580bSJunchao Zhang PetscFunctionBegin;
178d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
179d52a580bSJunchao Zhang PetscTryMethod(A, "MatHIPSPARSESetUseCPUSolve_C", (Mat, PetscBool), (A, use_cpu));
180d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
181d52a580bSJunchao Zhang }
182d52a580bSJunchao Zhang
MatSetOption_SeqAIJHIPSPARSE(Mat A,MatOption op,PetscBool flg)183d52a580bSJunchao Zhang static PetscErrorCode MatSetOption_SeqAIJHIPSPARSE(Mat A, MatOption op, PetscBool flg)
184d52a580bSJunchao Zhang {
185d52a580bSJunchao Zhang PetscFunctionBegin;
186d52a580bSJunchao Zhang switch (op) {
187d52a580bSJunchao Zhang case MAT_FORM_EXPLICIT_TRANSPOSE:
188d52a580bSJunchao Zhang /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
189d52a580bSJunchao Zhang if (A->form_explicit_transpose && !flg) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
190d52a580bSJunchao Zhang A->form_explicit_transpose = flg;
191d52a580bSJunchao Zhang break;
192d52a580bSJunchao Zhang default:
193d52a580bSJunchao Zhang PetscCall(MatSetOption_SeqAIJ(A, op, flg));
194d52a580bSJunchao Zhang break;
195d52a580bSJunchao Zhang }
196d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
197d52a580bSJunchao Zhang }
198d52a580bSJunchao Zhang
MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B,Mat A,const MatFactorInfo * info)199d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
200d52a580bSJunchao Zhang {
201d52a580bSJunchao Zhang PetscBool row_identity, col_identity;
202d52a580bSJunchao Zhang Mat_SeqAIJ *b = (Mat_SeqAIJ *)B->data;
203d52a580bSJunchao Zhang IS isrow = b->row, iscol = b->col;
204d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)B->spptr;
205d52a580bSJunchao Zhang
206d52a580bSJunchao Zhang PetscFunctionBegin;
207d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
208d52a580bSJunchao Zhang PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
209d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_CPU;
210d52a580bSJunchao Zhang /* determine which version of MatSolve needs to be used. */
211d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity));
212d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity));
213d52a580bSJunchao Zhang if (!hipsparsestruct->use_cpu_solve) {
214d52a580bSJunchao Zhang if (row_identity && col_identity) {
215d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
216d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
217d52a580bSJunchao Zhang } else {
218d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE;
219d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE;
220d52a580bSJunchao Zhang }
221d52a580bSJunchao Zhang }
222d52a580bSJunchao Zhang B->ops->matsolve = NULL;
223d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL;
224d52a580bSJunchao Zhang
225d52a580bSJunchao Zhang /* get the triangular factors */
226d52a580bSJunchao Zhang if (!hipsparsestruct->use_cpu_solve) PetscCall(MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(B));
227d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
228d52a580bSJunchao Zhang }
229d52a580bSJunchao Zhang
MatSetFromOptions_SeqAIJHIPSPARSE(Mat A,PetscOptionItems PetscOptionsObject)230d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat A, PetscOptionItems PetscOptionsObject)
231d52a580bSJunchao Zhang {
232d52a580bSJunchao Zhang MatHIPSPARSEStorageFormat format;
233d52a580bSJunchao Zhang PetscBool flg;
234d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
235d52a580bSJunchao Zhang
236d52a580bSJunchao Zhang PetscFunctionBegin;
237d52a580bSJunchao Zhang PetscOptionsHeadBegin(PetscOptionsObject, "SeqAIJHIPSPARSE options");
238d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) {
239d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
240d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT, format));
241d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV and TriSolve", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
242d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_ALL, format));
243d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_hipsparse_use_cpu_solve", "Use CPU (I)LU solve", "MatHIPSPARSESetUseCPUSolve", hipsparsestruct->use_cpu_solve, &hipsparsestruct->use_cpu_solve, &flg));
244d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetUseCPUSolve(A, hipsparsestruct->use_cpu_solve));
245d52a580bSJunchao Zhang PetscCall(
246d52a580bSJunchao Zhang PetscOptionsEnum("-mat_hipsparse_spmv_alg", "sets hipSPARSE algorithm used in sparse-mat dense-vector multiplication (SpMV)", "hipsparseSpMVAlg_t", MatHIPSPARSESpMVAlgorithms, (PetscEnum)hipsparsestruct->spmvAlg, (PetscEnum *)&hipsparsestruct->spmvAlg, &flg));
247d52a580bSJunchao Zhang /* If user did use this option, check its consistency with hipSPARSE, since PetscOptionsEnum() sets enum values based on their position in MatHIPSPARSESpMVAlgorithms[] */
248d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_CSRMV_ALG1 == 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMVAlg_t has been changed but PETSc has not been updated accordingly");
249d52a580bSJunchao Zhang PetscCall(
250d52a580bSJunchao Zhang PetscOptionsEnum("-mat_hipsparse_spmm_alg", "sets hipSPARSE algorithm used in sparse-mat dense-mat multiplication (SpMM)", "hipsparseSpMMAlg_t", MatHIPSPARSESpMMAlgorithms, (PetscEnum)hipsparsestruct->spmmAlg, (PetscEnum *)&hipsparsestruct->spmmAlg, &flg));
251d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_SPMM_CSR_ALG1 == 4, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMMAlg_t has been changed but PETSc has not been updated accordingly");
252d52a580bSJunchao Zhang /*
253d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_csr2csc_alg", "sets hipSPARSE algorithm used in converting CSR matrices to CSC matrices", "hipsparseCsr2CscAlg_t", MatHIPSPARSECsr2CscAlgorithms, (PetscEnum)hipsparsestruct->csr2cscAlg, (PetscEnum*)&hipsparsestruct->csr2cscAlg, &flg));
254d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_CSR2CSC_ALG1 == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseCsr2CscAlg_t has been changed but PETSc has not been updated accordingly");
255d52a580bSJunchao Zhang */
256d52a580bSJunchao Zhang }
257d52a580bSJunchao Zhang PetscOptionsHeadEnd();
258d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
259d52a580bSJunchao Zhang }
260d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A)261d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A)
262d52a580bSJunchao Zhang {
263d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
264d52a580bSJunchao Zhang PetscInt n = A->rmap->n;
265d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
266d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
267d52a580bSJunchao Zhang const PetscInt *ai = a->i, *aj = a->j, *vi;
268d52a580bSJunchao Zhang const MatScalar *aa = a->a, *v;
269d52a580bSJunchao Zhang PetscInt *AiLo, *AjLo;
270d52a580bSJunchao Zhang PetscInt i, nz, nzLower, offset, rowOffset;
271d52a580bSJunchao Zhang
272d52a580bSJunchao Zhang PetscFunctionBegin;
273d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS);
274d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
275d52a580bSJunchao Zhang try {
276d52a580bSJunchao Zhang /* first figure out the number of nonzeros in the lower triangular matrix including 1's on the diagonal. */
277d52a580bSJunchao Zhang nzLower = n + ai[n] - ai[1];
278d52a580bSJunchao Zhang if (!loTriFactor) {
279d52a580bSJunchao Zhang PetscScalar *AALo;
280d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AALo, nzLower * sizeof(PetscScalar)));
281d52a580bSJunchao Zhang
282d52a580bSJunchao Zhang /* Allocate Space for the lower triangular matrix */
283d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiLo, (n + 1) * sizeof(PetscInt)));
284d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjLo, nzLower * sizeof(PetscInt)));
285d52a580bSJunchao Zhang
286d52a580bSJunchao Zhang /* Fill the lower triangular matrix */
287d52a580bSJunchao Zhang AiLo[0] = (PetscInt)0;
288d52a580bSJunchao Zhang AiLo[n] = nzLower;
289d52a580bSJunchao Zhang AjLo[0] = (PetscInt)0;
290d52a580bSJunchao Zhang AALo[0] = (MatScalar)1.0;
291d52a580bSJunchao Zhang v = aa;
292d52a580bSJunchao Zhang vi = aj;
293d52a580bSJunchao Zhang offset = 1;
294d52a580bSJunchao Zhang rowOffset = 1;
295d52a580bSJunchao Zhang for (i = 1; i < n; i++) {
296d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i];
297d52a580bSJunchao Zhang /* additional 1 for the term on the diagonal */
298d52a580bSJunchao Zhang AiLo[i] = rowOffset;
299d52a580bSJunchao Zhang rowOffset += nz + 1;
300d52a580bSJunchao Zhang
301d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjLo[offset], vi, nz));
302d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AALo[offset], v, nz));
303d52a580bSJunchao Zhang offset += nz;
304d52a580bSJunchao Zhang AjLo[offset] = (PetscInt)i;
305d52a580bSJunchao Zhang AALo[offset] = (MatScalar)1.0;
306d52a580bSJunchao Zhang offset += 1;
307d52a580bSJunchao Zhang v += nz;
308d52a580bSJunchao Zhang vi += nz;
309d52a580bSJunchao Zhang }
310d52a580bSJunchao Zhang
311d52a580bSJunchao Zhang /* allocate space for the triangular factor information */
312d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactor));
313d52a580bSJunchao Zhang loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
314d52a580bSJunchao Zhang /* Create the matrix description */
315d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
316d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
317d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
318d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_LOWER));
319d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
320d52a580bSJunchao Zhang
321d52a580bSJunchao Zhang /* set the operation */
322d52a580bSJunchao Zhang loTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
323d52a580bSJunchao Zhang
324d52a580bSJunchao Zhang /* set the matrix */
325d52a580bSJunchao Zhang loTriFactor->csrMat = new CsrMatrix;
326d52a580bSJunchao Zhang loTriFactor->csrMat->num_rows = n;
327d52a580bSJunchao Zhang loTriFactor->csrMat->num_cols = n;
328d52a580bSJunchao Zhang loTriFactor->csrMat->num_entries = nzLower;
329d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(n + 1);
330d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzLower);
331d52a580bSJunchao Zhang loTriFactor->csrMat->values = new THRUSTARRAY(nzLower);
332d52a580bSJunchao Zhang
333d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->assign(AiLo, AiLo + n + 1);
334d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->assign(AjLo, AjLo + nzLower);
335d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + nzLower);
336d52a580bSJunchao Zhang
337d52a580bSJunchao Zhang /* Create the solve analysis information */
338d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
339d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
340d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
341d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
342d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
343d52a580bSJunchao Zhang
344d52a580bSJunchao Zhang /* perform the solve analysis */
345d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
346d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
347d52a580bSJunchao Zhang
348d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
349d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
350d52a580bSJunchao Zhang
351d52a580bSJunchao Zhang /* assign the pointer */
352d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
353d52a580bSJunchao Zhang loTriFactor->AA_h = AALo;
354d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiLo));
355d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjLo));
356d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((n + 1 + nzLower) * sizeof(int) + nzLower * sizeof(PetscScalar)));
357d52a580bSJunchao Zhang } else { /* update values only */
358d52a580bSJunchao Zhang if (!loTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&loTriFactor->AA_h, nzLower * sizeof(PetscScalar)));
359d52a580bSJunchao Zhang /* Fill the lower triangular matrix */
360d52a580bSJunchao Zhang loTriFactor->AA_h[0] = 1.0;
361d52a580bSJunchao Zhang v = aa;
362d52a580bSJunchao Zhang vi = aj;
363d52a580bSJunchao Zhang offset = 1;
364d52a580bSJunchao Zhang for (i = 1; i < n; i++) {
365d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i];
366d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&loTriFactor->AA_h[offset], v, nz));
367d52a580bSJunchao Zhang offset += nz;
368d52a580bSJunchao Zhang loTriFactor->AA_h[offset] = 1.0;
369d52a580bSJunchao Zhang offset += 1;
370d52a580bSJunchao Zhang v += nz;
371d52a580bSJunchao Zhang }
372d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(loTriFactor->AA_h, loTriFactor->AA_h + nzLower);
373d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(nzLower * sizeof(PetscScalar)));
374d52a580bSJunchao Zhang }
375d52a580bSJunchao Zhang } catch (char *ex) {
376d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
377d52a580bSJunchao Zhang }
378d52a580bSJunchao Zhang }
379d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
380d52a580bSJunchao Zhang }
381d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A)382d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A)
383d52a580bSJunchao Zhang {
384d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
385d52a580bSJunchao Zhang PetscInt n = A->rmap->n;
386d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
387d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
388d52a580bSJunchao Zhang const PetscInt *aj = a->j, *adiag, *vi;
389d52a580bSJunchao Zhang const MatScalar *aa = a->a, *v;
390d52a580bSJunchao Zhang PetscInt *AiUp, *AjUp;
391d52a580bSJunchao Zhang PetscInt i, nz, nzUpper, offset;
392d52a580bSJunchao Zhang
393d52a580bSJunchao Zhang PetscFunctionBegin;
394d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS);
395d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &adiag, NULL));
396d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
397d52a580bSJunchao Zhang try {
398d52a580bSJunchao Zhang /* next, figure out the number of nonzeros in the upper triangular matrix. */
399d52a580bSJunchao Zhang nzUpper = adiag[0] - adiag[n];
400d52a580bSJunchao Zhang if (!upTriFactor) {
401d52a580bSJunchao Zhang PetscScalar *AAUp;
402d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
403d52a580bSJunchao Zhang
404d52a580bSJunchao Zhang /* Allocate Space for the upper triangular matrix */
405d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
406d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
407d52a580bSJunchao Zhang
408d52a580bSJunchao Zhang /* Fill the upper triangular matrix */
409d52a580bSJunchao Zhang AiUp[0] = (PetscInt)0;
410d52a580bSJunchao Zhang AiUp[n] = nzUpper;
411d52a580bSJunchao Zhang offset = nzUpper;
412d52a580bSJunchao Zhang for (i = n - 1; i >= 0; i--) {
413d52a580bSJunchao Zhang v = aa + adiag[i + 1] + 1;
414d52a580bSJunchao Zhang vi = aj + adiag[i + 1] + 1;
415d52a580bSJunchao Zhang nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
416d52a580bSJunchao Zhang offset -= (nz + 1); /* decrement the offset */
417d52a580bSJunchao Zhang
418d52a580bSJunchao Zhang /* first, set the diagonal elements */
419d52a580bSJunchao Zhang AjUp[offset] = (PetscInt)i;
420d52a580bSJunchao Zhang AAUp[offset] = (MatScalar)1. / v[nz];
421d52a580bSJunchao Zhang AiUp[i] = AiUp[i + 1] - (nz + 1);
422d52a580bSJunchao Zhang
423d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjUp[offset + 1], vi, nz));
424d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset + 1], v, nz));
425d52a580bSJunchao Zhang }
426d52a580bSJunchao Zhang
427d52a580bSJunchao Zhang /* allocate space for the triangular factor information */
428d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactor));
429d52a580bSJunchao Zhang upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
430d52a580bSJunchao Zhang
431d52a580bSJunchao Zhang /* Create the matrix description */
432d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
433d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
434d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
435d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
436d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
437d52a580bSJunchao Zhang
438d52a580bSJunchao Zhang /* set the operation */
439d52a580bSJunchao Zhang upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
440d52a580bSJunchao Zhang
441d52a580bSJunchao Zhang /* set the matrix */
442d52a580bSJunchao Zhang upTriFactor->csrMat = new CsrMatrix;
443d52a580bSJunchao Zhang upTriFactor->csrMat->num_rows = n;
444d52a580bSJunchao Zhang upTriFactor->csrMat->num_cols = n;
445d52a580bSJunchao Zhang upTriFactor->csrMat->num_entries = nzUpper;
446d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(n + 1);
447d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzUpper);
448d52a580bSJunchao Zhang upTriFactor->csrMat->values = new THRUSTARRAY(nzUpper);
449d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + n + 1);
450d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + nzUpper);
451d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + nzUpper);
452d52a580bSJunchao Zhang
453d52a580bSJunchao Zhang /* Create the solve analysis information */
454d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
455d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
456d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
457d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
458d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
459d52a580bSJunchao Zhang
460d52a580bSJunchao Zhang /* perform the solve analysis */
461d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
462d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
463d52a580bSJunchao Zhang
464d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
465d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
466d52a580bSJunchao Zhang
467d52a580bSJunchao Zhang /* assign the pointer */
468d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
469d52a580bSJunchao Zhang upTriFactor->AA_h = AAUp;
470d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiUp));
471d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjUp));
472d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((n + 1 + nzUpper) * sizeof(int) + nzUpper * sizeof(PetscScalar)));
473d52a580bSJunchao Zhang } else {
474d52a580bSJunchao Zhang if (!upTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&upTriFactor->AA_h, nzUpper * sizeof(PetscScalar)));
475d52a580bSJunchao Zhang /* Fill the upper triangular matrix */
476d52a580bSJunchao Zhang offset = nzUpper;
477d52a580bSJunchao Zhang for (i = n - 1; i >= 0; i--) {
478d52a580bSJunchao Zhang v = aa + adiag[i + 1] + 1;
479d52a580bSJunchao Zhang nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
480d52a580bSJunchao Zhang offset -= (nz + 1); /* decrement the offset */
481d52a580bSJunchao Zhang
482d52a580bSJunchao Zhang /* first, set the diagonal elements */
483d52a580bSJunchao Zhang upTriFactor->AA_h[offset] = 1. / v[nz];
484d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&upTriFactor->AA_h[offset + 1], v, nz));
485d52a580bSJunchao Zhang }
486d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(upTriFactor->AA_h, upTriFactor->AA_h + nzUpper);
487d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(nzUpper * sizeof(PetscScalar)));
488d52a580bSJunchao Zhang }
489d52a580bSJunchao Zhang } catch (char *ex) {
490d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
491d52a580bSJunchao Zhang }
492d52a580bSJunchao Zhang }
493d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
494d52a580bSJunchao Zhang }
495d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A)496d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A)
497d52a580bSJunchao Zhang {
498d52a580bSJunchao Zhang PetscBool row_identity, col_identity;
499d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
500d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
501d52a580bSJunchao Zhang IS isrow = a->row, iscol = a->icol;
502d52a580bSJunchao Zhang PetscInt n = A->rmap->n;
503d52a580bSJunchao Zhang
504d52a580bSJunchao Zhang PetscFunctionBegin;
505d52a580bSJunchao Zhang PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
506d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(A));
507d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(A));
508d52a580bSJunchao Zhang
509d52a580bSJunchao Zhang if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
510d52a580bSJunchao Zhang hipsparseTriFactors->nnz = a->nz;
511d52a580bSJunchao Zhang
512d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH;
513d52a580bSJunchao Zhang /* lower triangular indices */
514d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity));
515d52a580bSJunchao Zhang if (!row_identity && !hipsparseTriFactors->rpermIndices) {
516d52a580bSJunchao Zhang const PetscInt *r;
517d52a580bSJunchao Zhang
518d52a580bSJunchao Zhang PetscCall(ISGetIndices(isrow, &r));
519d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
520d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices->assign(r, r + n);
521d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(isrow, &r));
522d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
523d52a580bSJunchao Zhang }
524d52a580bSJunchao Zhang /* upper triangular indices */
525d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity));
526d52a580bSJunchao Zhang if (!col_identity && !hipsparseTriFactors->cpermIndices) {
527d52a580bSJunchao Zhang const PetscInt *c;
528d52a580bSJunchao Zhang
529d52a580bSJunchao Zhang PetscCall(ISGetIndices(iscol, &c));
530d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
531d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices->assign(c, c + n);
532d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(iscol, &c));
533d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
534d52a580bSJunchao Zhang }
535d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
536d52a580bSJunchao Zhang }
537d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A)538d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A)
539d52a580bSJunchao Zhang {
540d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
541d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
542d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
543d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
544d52a580bSJunchao Zhang PetscInt *AiUp, *AjUp;
545d52a580bSJunchao Zhang PetscScalar *AAUp;
546d52a580bSJunchao Zhang PetscScalar *AALo;
547d52a580bSJunchao Zhang PetscInt nzUpper = a->nz, n = A->rmap->n, i, offset, nz, j;
548d52a580bSJunchao Zhang Mat_SeqSBAIJ *b = (Mat_SeqSBAIJ *)A->data;
549d52a580bSJunchao Zhang const PetscInt *ai = b->i, *aj = b->j, *vj;
550d52a580bSJunchao Zhang const MatScalar *aa = b->a, *v;
551d52a580bSJunchao Zhang
552d52a580bSJunchao Zhang PetscFunctionBegin;
553d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS);
554d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
555d52a580bSJunchao Zhang try {
556d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
557d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AALo, nzUpper * sizeof(PetscScalar)));
558d52a580bSJunchao Zhang if (!upTriFactor && !loTriFactor) {
559d52a580bSJunchao Zhang /* Allocate Space for the upper triangular matrix */
560d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
561d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
562d52a580bSJunchao Zhang
563d52a580bSJunchao Zhang /* Fill the upper triangular matrix */
564d52a580bSJunchao Zhang AiUp[0] = (PetscInt)0;
565d52a580bSJunchao Zhang AiUp[n] = nzUpper;
566d52a580bSJunchao Zhang offset = 0;
567d52a580bSJunchao Zhang for (i = 0; i < n; i++) {
568d52a580bSJunchao Zhang /* set the pointers */
569d52a580bSJunchao Zhang v = aa + ai[i];
570d52a580bSJunchao Zhang vj = aj + ai[i];
571d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
572d52a580bSJunchao Zhang
573d52a580bSJunchao Zhang /* first, set the diagonal elements */
574d52a580bSJunchao Zhang AjUp[offset] = (PetscInt)i;
575d52a580bSJunchao Zhang AAUp[offset] = (MatScalar)1.0 / v[nz];
576d52a580bSJunchao Zhang AiUp[i] = offset;
577d52a580bSJunchao Zhang AALo[offset] = (MatScalar)1.0 / v[nz];
578d52a580bSJunchao Zhang
579d52a580bSJunchao Zhang offset += 1;
580d52a580bSJunchao Zhang if (nz > 0) {
581d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjUp[offset], vj, nz));
582d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
583d52a580bSJunchao Zhang for (j = offset; j < offset + nz; j++) {
584d52a580bSJunchao Zhang AAUp[j] = -AAUp[j];
585d52a580bSJunchao Zhang AALo[j] = AAUp[j] / v[nz];
586d52a580bSJunchao Zhang }
587d52a580bSJunchao Zhang offset += nz;
588d52a580bSJunchao Zhang }
589d52a580bSJunchao Zhang }
590d52a580bSJunchao Zhang
591d52a580bSJunchao Zhang /* allocate space for the triangular factor information */
592d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactor));
593d52a580bSJunchao Zhang upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
594d52a580bSJunchao Zhang
595d52a580bSJunchao Zhang /* Create the matrix description */
596d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
597d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
598d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
599d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
600d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
601d52a580bSJunchao Zhang
602d52a580bSJunchao Zhang /* set the matrix */
603d52a580bSJunchao Zhang upTriFactor->csrMat = new CsrMatrix;
604d52a580bSJunchao Zhang upTriFactor->csrMat->num_rows = A->rmap->n;
605d52a580bSJunchao Zhang upTriFactor->csrMat->num_cols = A->cmap->n;
606d52a580bSJunchao Zhang upTriFactor->csrMat->num_entries = a->nz;
607d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1);
608d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
609d52a580bSJunchao Zhang upTriFactor->csrMat->values = new THRUSTARRAY(a->nz);
610d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
611d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
612d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
613d52a580bSJunchao Zhang
614d52a580bSJunchao Zhang /* set the operation */
615d52a580bSJunchao Zhang upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
616d52a580bSJunchao Zhang
617d52a580bSJunchao Zhang /* Create the solve analysis information */
618d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
619d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
620d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
621d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
622d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
623d52a580bSJunchao Zhang
624d52a580bSJunchao Zhang /* perform the solve analysis */
625d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
626d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
627d52a580bSJunchao Zhang
628d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
629d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
630d52a580bSJunchao Zhang
631d52a580bSJunchao Zhang /* assign the pointer */
632d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
633d52a580bSJunchao Zhang
634d52a580bSJunchao Zhang /* allocate space for the triangular factor information */
635d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactor));
636d52a580bSJunchao Zhang loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
637d52a580bSJunchao Zhang
638d52a580bSJunchao Zhang /* Create the matrix description */
639d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
640d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
641d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
642d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
643d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
644d52a580bSJunchao Zhang
645d52a580bSJunchao Zhang /* set the operation */
646d52a580bSJunchao Zhang loTriFactor->solveOp = HIPSPARSE_OPERATION_TRANSPOSE;
647d52a580bSJunchao Zhang
648d52a580bSJunchao Zhang /* set the matrix */
649d52a580bSJunchao Zhang loTriFactor->csrMat = new CsrMatrix;
650d52a580bSJunchao Zhang loTriFactor->csrMat->num_rows = A->rmap->n;
651d52a580bSJunchao Zhang loTriFactor->csrMat->num_cols = A->cmap->n;
652d52a580bSJunchao Zhang loTriFactor->csrMat->num_entries = a->nz;
653d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1);
654d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
655d52a580bSJunchao Zhang loTriFactor->csrMat->values = new THRUSTARRAY(a->nz);
656d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
657d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
658d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
659d52a580bSJunchao Zhang
660d52a580bSJunchao Zhang /* Create the solve analysis information */
661d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
662d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
663d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
664d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
665d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
666d52a580bSJunchao Zhang
667d52a580bSJunchao Zhang /* perform the solve analysis */
668d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
669d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
670d52a580bSJunchao Zhang
671d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
672d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
673d52a580bSJunchao Zhang
674d52a580bSJunchao Zhang /* assign the pointer */
675d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
676d52a580bSJunchao Zhang
677d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2 * (((A->rmap->n + 1) + (a->nz)) * sizeof(int) + (a->nz) * sizeof(PetscScalar))));
678d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiUp));
679d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjUp));
680d52a580bSJunchao Zhang } else {
681d52a580bSJunchao Zhang /* Fill the upper triangular matrix */
682d52a580bSJunchao Zhang offset = 0;
683d52a580bSJunchao Zhang for (i = 0; i < n; i++) {
684d52a580bSJunchao Zhang /* set the pointers */
685d52a580bSJunchao Zhang v = aa + ai[i];
686d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
687d52a580bSJunchao Zhang
688d52a580bSJunchao Zhang /* first, set the diagonal elements */
689d52a580bSJunchao Zhang AAUp[offset] = 1.0 / v[nz];
690d52a580bSJunchao Zhang AALo[offset] = 1.0 / v[nz];
691d52a580bSJunchao Zhang
692d52a580bSJunchao Zhang offset += 1;
693d52a580bSJunchao Zhang if (nz > 0) {
694d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
695d52a580bSJunchao Zhang for (j = offset; j < offset + nz; j++) {
696d52a580bSJunchao Zhang AAUp[j] = -AAUp[j];
697d52a580bSJunchao Zhang AALo[j] = AAUp[j] / v[nz];
698d52a580bSJunchao Zhang }
699d52a580bSJunchao Zhang offset += nz;
700d52a580bSJunchao Zhang }
701d52a580bSJunchao Zhang }
702d52a580bSJunchao Zhang PetscCheck(upTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
703d52a580bSJunchao Zhang PetscCheck(loTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
704d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
705d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
706d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2 * (a->nz) * sizeof(PetscScalar)));
707d52a580bSJunchao Zhang }
708d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AAUp));
709d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AALo));
710d52a580bSJunchao Zhang } catch (char *ex) {
711d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
712d52a580bSJunchao Zhang }
713d52a580bSJunchao Zhang }
714d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
715d52a580bSJunchao Zhang }
716d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A)717d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A)
718d52a580bSJunchao Zhang {
719d52a580bSJunchao Zhang PetscBool perm_identity;
720d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
721d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
722d52a580bSJunchao Zhang IS ip = a->row;
723d52a580bSJunchao Zhang PetscInt n = A->rmap->n;
724d52a580bSJunchao Zhang
725d52a580bSJunchao Zhang PetscFunctionBegin;
726d52a580bSJunchao Zhang PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
727d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildICCTriMatrices(A));
728d52a580bSJunchao Zhang if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
729d52a580bSJunchao Zhang hipsparseTriFactors->nnz = (a->nz - n) * 2 + n;
730d52a580bSJunchao Zhang
731d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH;
732d52a580bSJunchao Zhang /* lower triangular indices */
733d52a580bSJunchao Zhang PetscCall(ISIdentity(ip, &perm_identity));
734d52a580bSJunchao Zhang if (!perm_identity) {
735d52a580bSJunchao Zhang IS iip;
736d52a580bSJunchao Zhang const PetscInt *irip, *rip;
737d52a580bSJunchao Zhang
738d52a580bSJunchao Zhang PetscCall(ISInvertPermutation(ip, PETSC_DECIDE, &iip));
739d52a580bSJunchao Zhang PetscCall(ISGetIndices(iip, &irip));
740d52a580bSJunchao Zhang PetscCall(ISGetIndices(ip, &rip));
741d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
742d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
743d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices->assign(rip, rip + n);
744d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices->assign(irip, irip + n);
745d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(iip, &irip));
746d52a580bSJunchao Zhang PetscCall(ISDestroy(&iip));
747d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(ip, &rip));
748d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2. * n * sizeof(PetscInt)));
749d52a580bSJunchao Zhang }
750d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
751d52a580bSJunchao Zhang }
752d52a580bSJunchao Zhang
MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B,Mat A,const MatFactorInfo * info)753d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
754d52a580bSJunchao Zhang {
755d52a580bSJunchao Zhang PetscBool perm_identity;
756d52a580bSJunchao Zhang Mat_SeqAIJ *b = (Mat_SeqAIJ *)B->data;
757d52a580bSJunchao Zhang IS ip = b->row;
758d52a580bSJunchao Zhang
759d52a580bSJunchao Zhang PetscFunctionBegin;
760d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
761d52a580bSJunchao Zhang PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
762d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_CPU;
763d52a580bSJunchao Zhang /* determine which version of MatSolve needs to be used. */
764d52a580bSJunchao Zhang PetscCall(ISIdentity(ip, &perm_identity));
765d52a580bSJunchao Zhang if (perm_identity) {
766d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
767d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
768d52a580bSJunchao Zhang B->ops->matsolve = NULL;
769d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL;
770d52a580bSJunchao Zhang } else {
771d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE;
772d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE;
773d52a580bSJunchao Zhang B->ops->matsolve = NULL;
774d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL;
775d52a580bSJunchao Zhang }
776d52a580bSJunchao Zhang
777d52a580bSJunchao Zhang /* get the triangular factors */
778d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(B));
779d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
780d52a580bSJunchao Zhang }
781d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A)782d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A)
783d52a580bSJunchao Zhang {
784d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
785d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
786d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
787d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT;
788d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT;
789d52a580bSJunchao Zhang hipsparseIndexBase_t indexBase;
790d52a580bSJunchao Zhang hipsparseMatrixType_t matrixType;
791d52a580bSJunchao Zhang hipsparseFillMode_t fillMode;
792d52a580bSJunchao Zhang hipsparseDiagType_t diagType;
793d52a580bSJunchao Zhang
794d52a580bSJunchao Zhang PetscFunctionBegin;
795d52a580bSJunchao Zhang /* allocate space for the transpose of the lower triangular factor */
796d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactorT));
797d52a580bSJunchao Zhang loTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
798d52a580bSJunchao Zhang
799d52a580bSJunchao Zhang /* set the matrix descriptors of the lower triangular factor */
800d52a580bSJunchao Zhang matrixType = hipsparseGetMatType(loTriFactor->descr);
801d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(loTriFactor->descr);
802d52a580bSJunchao Zhang fillMode = hipsparseGetMatFillMode(loTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
803d52a580bSJunchao Zhang diagType = hipsparseGetMatDiagType(loTriFactor->descr);
804d52a580bSJunchao Zhang
805d52a580bSJunchao Zhang /* Create the matrix description */
806d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactorT->descr));
807d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactorT->descr, indexBase));
808d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactorT->descr, matrixType));
809d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactorT->descr, fillMode));
810d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactorT->descr, diagType));
811d52a580bSJunchao Zhang
812d52a580bSJunchao Zhang /* set the operation */
813d52a580bSJunchao Zhang loTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
814d52a580bSJunchao Zhang
815d52a580bSJunchao Zhang /* allocate GPU space for the CSC of the lower triangular factor*/
816d52a580bSJunchao Zhang loTriFactorT->csrMat = new CsrMatrix;
817d52a580bSJunchao Zhang loTriFactorT->csrMat->num_rows = loTriFactor->csrMat->num_cols;
818d52a580bSJunchao Zhang loTriFactorT->csrMat->num_cols = loTriFactor->csrMat->num_rows;
819d52a580bSJunchao Zhang loTriFactorT->csrMat->num_entries = loTriFactor->csrMat->num_entries;
820d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_rows + 1);
821d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_entries);
822d52a580bSJunchao Zhang loTriFactorT->csrMat->values = new THRUSTARRAY(loTriFactorT->csrMat->num_entries);
823d52a580bSJunchao Zhang
824d52a580bSJunchao Zhang /* compute the transpose of the lower triangular factor, i.e. the CSC */
825d52a580bSJunchao Zhang /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
826d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
827d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(),
828d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(), loTriFactorT->csrMat->row_offsets->data().get(),
829d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &loTriFactor->csr2cscBufferSize));
830d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->csr2cscBuffer, loTriFactor->csr2cscBufferSize));
831d52a580bSJunchao Zhang #endif
832d52a580bSJunchao Zhang */
833d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
834d52a580bSJunchao Zhang
835d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(), loTriFactor->csrMat->row_offsets->data().get(),
836d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(),
837d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
838d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(),
839d52a580bSJunchao Zhang hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, loTriFactor->csr2cscBuffer));
840d52a580bSJunchao Zhang #else
841d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
842d52a580bSJunchao Zhang #endif
843d52a580bSJunchao Zhang
844d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
845d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
846d52a580bSJunchao Zhang
847d52a580bSJunchao Zhang /* Create the solve analysis information */
848d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
849d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactorT->solveInfo));
850d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
851d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, &loTriFactorT->solveBufferSize));
852d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactorT->solveBuffer, loTriFactorT->solveBufferSize));
853d52a580bSJunchao Zhang
854d52a580bSJunchao Zhang /* perform the solve analysis */
855d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
856d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
857d52a580bSJunchao Zhang
858d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
859d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
860d52a580bSJunchao Zhang
861d52a580bSJunchao Zhang /* assign the pointer */
862d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtrTranspose = loTriFactorT;
863d52a580bSJunchao Zhang
864d52a580bSJunchao Zhang /*********************************************/
865d52a580bSJunchao Zhang /* Now the Transpose of the Upper Tri Factor */
866d52a580bSJunchao Zhang /*********************************************/
867d52a580bSJunchao Zhang
868d52a580bSJunchao Zhang /* allocate space for the transpose of the upper triangular factor */
869d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactorT));
870d52a580bSJunchao Zhang upTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
871d52a580bSJunchao Zhang
872d52a580bSJunchao Zhang /* set the matrix descriptors of the upper triangular factor */
873d52a580bSJunchao Zhang matrixType = hipsparseGetMatType(upTriFactor->descr);
874d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(upTriFactor->descr);
875d52a580bSJunchao Zhang fillMode = hipsparseGetMatFillMode(upTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
876d52a580bSJunchao Zhang diagType = hipsparseGetMatDiagType(upTriFactor->descr);
877d52a580bSJunchao Zhang
878d52a580bSJunchao Zhang /* Create the matrix description */
879d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactorT->descr));
880d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactorT->descr, indexBase));
881d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactorT->descr, matrixType));
882d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactorT->descr, fillMode));
883d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactorT->descr, diagType));
884d52a580bSJunchao Zhang
885d52a580bSJunchao Zhang /* set the operation */
886d52a580bSJunchao Zhang upTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
887d52a580bSJunchao Zhang
888d52a580bSJunchao Zhang /* allocate GPU space for the CSC of the upper triangular factor*/
889d52a580bSJunchao Zhang upTriFactorT->csrMat = new CsrMatrix;
890d52a580bSJunchao Zhang upTriFactorT->csrMat->num_rows = upTriFactor->csrMat->num_cols;
891d52a580bSJunchao Zhang upTriFactorT->csrMat->num_cols = upTriFactor->csrMat->num_rows;
892d52a580bSJunchao Zhang upTriFactorT->csrMat->num_entries = upTriFactor->csrMat->num_entries;
893d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_rows + 1);
894d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_entries);
895d52a580bSJunchao Zhang upTriFactorT->csrMat->values = new THRUSTARRAY(upTriFactorT->csrMat->num_entries);
896d52a580bSJunchao Zhang
897d52a580bSJunchao Zhang /* compute the transpose of the upper triangular factor, i.e. the CSC */
898d52a580bSJunchao Zhang /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
899d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
900d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(),
901d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(), upTriFactorT->csrMat->row_offsets->data().get(),
902d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &upTriFactor->csr2cscBufferSize));
903d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->csr2cscBuffer, upTriFactor->csr2cscBufferSize));
904d52a580bSJunchao Zhang #endif
905d52a580bSJunchao Zhang */
906d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
907d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(), upTriFactor->csrMat->row_offsets->data().get(),
908d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(),
909d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
910d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(),
911d52a580bSJunchao Zhang hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, upTriFactor->csr2cscBuffer));
912d52a580bSJunchao Zhang #else
913d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
914d52a580bSJunchao Zhang #endif
915d52a580bSJunchao Zhang
916d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
917d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
918d52a580bSJunchao Zhang
919d52a580bSJunchao Zhang /* Create the solve analysis information */
920d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
921d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactorT->solveInfo));
922d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
923d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, &upTriFactorT->solveBufferSize));
924d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactorT->solveBuffer, upTriFactorT->solveBufferSize));
925d52a580bSJunchao Zhang
926d52a580bSJunchao Zhang /* perform the solve analysis */
927d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
928d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
929d52a580bSJunchao Zhang
930d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
931d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
932d52a580bSJunchao Zhang
933d52a580bSJunchao Zhang /* assign the pointer */
934d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtrTranspose = upTriFactorT;
935d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
936d52a580bSJunchao Zhang }
937d52a580bSJunchao Zhang
938d52a580bSJunchao Zhang struct PetscScalarToPetscInt {
operator ()PetscScalarToPetscInt939d52a580bSJunchao Zhang __host__ __device__ PetscInt operator()(PetscScalar s) { return (PetscInt)PetscRealPart(s); }
940d52a580bSJunchao Zhang };
941d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A)942d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A)
943d52a580bSJunchao Zhang {
944d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
945d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct, *matstructT;
946d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
947d52a580bSJunchao Zhang hipsparseIndexBase_t indexBase;
948d52a580bSJunchao Zhang
949d52a580bSJunchao Zhang PetscFunctionBegin;
950d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
951d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
952d52a580bSJunchao Zhang PetscCheck(matstruct, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing mat struct");
953d52a580bSJunchao Zhang matstructT = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
954d52a580bSJunchao Zhang PetscCheck(!A->transupdated || matstructT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing matTranspose struct");
955d52a580bSJunchao Zhang if (A->transupdated) PetscFunctionReturn(PETSC_SUCCESS);
956d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
957d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
958d52a580bSJunchao Zhang if (hipsparsestruct->format != MAT_HIPSPARSE_CSR) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
959d52a580bSJunchao Zhang if (!hipsparsestruct->matTranspose) { /* create hipsparse matrix */
960d52a580bSJunchao Zhang matstructT = new Mat_SeqAIJHIPSPARSEMultStruct;
961d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstructT->descr));
962d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(matstruct->descr);
963d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstructT->descr, indexBase));
964d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(matstructT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
965d52a580bSJunchao Zhang
966d52a580bSJunchao Zhang /* set alpha and beta */
967d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->alpha_one, sizeof(PetscScalar)));
968d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->beta_zero, sizeof(PetscScalar)));
969d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->beta_one, sizeof(PetscScalar)));
970d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
971d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
972d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
973d52a580bSJunchao Zhang
974d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
975d52a580bSJunchao Zhang CsrMatrix *matrixT = new CsrMatrix;
976d52a580bSJunchao Zhang matstructT->mat = matrixT;
977d52a580bSJunchao Zhang matrixT->num_rows = A->cmap->n;
978d52a580bSJunchao Zhang matrixT->num_cols = A->rmap->n;
979d52a580bSJunchao Zhang matrixT->num_entries = a->nz;
980d52a580bSJunchao Zhang matrixT->row_offsets = new THRUSTINTARRAY32(matrixT->num_rows + 1);
981d52a580bSJunchao Zhang matrixT->column_indices = new THRUSTINTARRAY32(a->nz);
982d52a580bSJunchao Zhang matrixT->values = new THRUSTARRAY(a->nz);
983d52a580bSJunchao Zhang
984d52a580bSJunchao Zhang if (!hipsparsestruct->rowoffsets_gpu) hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
985d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
986d52a580bSJunchao Zhang
987d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&matstructT->matDescr, matrixT->num_rows, matrixT->num_cols, matrixT->num_entries, matrixT->row_offsets->data().get(), matrixT->column_indices->data().get(), matrixT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx type due to THRUSTINTARRAY32 */
988d52a580bSJunchao Zhang indexBase, hipsparse_scalartype));
989d52a580bSJunchao Zhang } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
990d52a580bSJunchao Zhang CsrMatrix *temp = new CsrMatrix;
991d52a580bSJunchao Zhang CsrMatrix *tempT = new CsrMatrix;
992d52a580bSJunchao Zhang /* First convert HYB to CSR */
993d52a580bSJunchao Zhang temp->num_rows = A->rmap->n;
994d52a580bSJunchao Zhang temp->num_cols = A->cmap->n;
995d52a580bSJunchao Zhang temp->num_entries = a->nz;
996d52a580bSJunchao Zhang temp->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1);
997d52a580bSJunchao Zhang temp->column_indices = new THRUSTINTARRAY32(a->nz);
998d52a580bSJunchao Zhang temp->values = new THRUSTARRAY(a->nz);
999d52a580bSJunchao Zhang
1000d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_hyb2csr(hipsparsestruct->handle, matstruct->descr, (hipsparseHybMat_t)matstruct->mat, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get()));
1001d52a580bSJunchao Zhang
1002d52a580bSJunchao Zhang /* Next, convert CSR to CSC (i.e. the matrix transpose) */
1003d52a580bSJunchao Zhang tempT->num_rows = A->rmap->n;
1004d52a580bSJunchao Zhang tempT->num_cols = A->cmap->n;
1005d52a580bSJunchao Zhang tempT->num_entries = a->nz;
1006d52a580bSJunchao Zhang tempT->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1);
1007d52a580bSJunchao Zhang tempT->column_indices = new THRUSTINTARRAY32(a->nz);
1008d52a580bSJunchao Zhang tempT->values = new THRUSTARRAY(a->nz);
1009d52a580bSJunchao Zhang
1010d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, temp->num_rows, temp->num_cols, temp->num_entries, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get(), tempT->values->data().get(),
1011d52a580bSJunchao Zhang tempT->column_indices->data().get(), tempT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1012d52a580bSJunchao Zhang
1013d52a580bSJunchao Zhang /* Last, convert CSC to HYB */
1014d52a580bSJunchao Zhang hipsparseHybMat_t hybMat;
1015d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
1016d52a580bSJunchao Zhang hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
1017d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matstructT->descr, tempT->values->data().get(), tempT->row_offsets->data().get(), tempT->column_indices->data().get(), hybMat, 0, partition));
1018d52a580bSJunchao Zhang
1019d52a580bSJunchao Zhang /* assign the pointer */
1020d52a580bSJunchao Zhang matstructT->mat = hybMat;
1021d52a580bSJunchao Zhang A->transupdated = PETSC_TRUE;
1022d52a580bSJunchao Zhang /* delete temporaries */
1023d52a580bSJunchao Zhang if (tempT) {
1024d52a580bSJunchao Zhang if (tempT->values) delete (THRUSTARRAY *)tempT->values;
1025d52a580bSJunchao Zhang if (tempT->column_indices) delete (THRUSTINTARRAY32 *)tempT->column_indices;
1026d52a580bSJunchao Zhang if (tempT->row_offsets) delete (THRUSTINTARRAY32 *)tempT->row_offsets;
1027d52a580bSJunchao Zhang delete (CsrMatrix *)tempT;
1028d52a580bSJunchao Zhang }
1029d52a580bSJunchao Zhang if (temp) {
1030d52a580bSJunchao Zhang if (temp->values) delete (THRUSTARRAY *)temp->values;
1031d52a580bSJunchao Zhang if (temp->column_indices) delete (THRUSTINTARRAY32 *)temp->column_indices;
1032d52a580bSJunchao Zhang if (temp->row_offsets) delete (THRUSTINTARRAY32 *)temp->row_offsets;
1033d52a580bSJunchao Zhang delete (CsrMatrix *)temp;
1034d52a580bSJunchao Zhang }
1035d52a580bSJunchao Zhang }
1036d52a580bSJunchao Zhang }
1037d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* transpose mat struct may be already present, update data */
1038d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)matstruct->mat;
1039d52a580bSJunchao Zhang CsrMatrix *matrixT = (CsrMatrix *)matstructT->mat;
1040d52a580bSJunchao Zhang PetscCheck(matrix, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix");
1041d52a580bSJunchao Zhang PetscCheck(matrix->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix rows");
1042d52a580bSJunchao Zhang PetscCheck(matrix->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix cols");
1043d52a580bSJunchao Zhang PetscCheck(matrix->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix values");
1044d52a580bSJunchao Zhang PetscCheck(matrixT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT");
1045d52a580bSJunchao Zhang PetscCheck(matrixT->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT rows");
1046d52a580bSJunchao Zhang PetscCheck(matrixT->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT cols");
1047d52a580bSJunchao Zhang PetscCheck(matrixT->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT values");
1048d52a580bSJunchao Zhang if (!hipsparsestruct->rowoffsets_gpu) { /* this may be absent when we did not construct the transpose with csr2csc */
1049d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
1050d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
1051d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
1052d52a580bSJunchao Zhang }
1053d52a580bSJunchao Zhang if (!hipsparsestruct->csr2csc_i) {
1054d52a580bSJunchao Zhang THRUSTARRAY csr2csc_a(matrix->num_entries);
1055d52a580bSJunchao Zhang PetscCallThrust(thrust::sequence(thrust::device, csr2csc_a.begin(), csr2csc_a.end(), 0.0));
1056d52a580bSJunchao Zhang
1057d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(matstruct->descr);
1058d52a580bSJunchao Zhang if (matrix->num_entries) {
1059d52a580bSJunchao Zhang /* This routine is known to give errors with CUDA-11, but works fine with CUDA-10
1060d52a580bSJunchao Zhang Need to verify this for ROCm.
1061d52a580bSJunchao Zhang */
1062d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matrix->num_entries, csr2csc_a.data().get(), hipsparsestruct->rowoffsets_gpu->data().get(), matrix->column_indices->data().get(), matrixT->values->data().get(),
1063d52a580bSJunchao Zhang matrixT->column_indices->data().get(), matrixT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1064d52a580bSJunchao Zhang } else {
1065d52a580bSJunchao Zhang matrixT->row_offsets->assign(matrixT->row_offsets->size(), indexBase);
1066d52a580bSJunchao Zhang }
1067d52a580bSJunchao Zhang
1068d52a580bSJunchao Zhang hipsparsestruct->csr2csc_i = new THRUSTINTARRAY(matrix->num_entries);
1069d52a580bSJunchao Zhang PetscCallThrust(thrust::transform(thrust::device, matrixT->values->begin(), matrixT->values->end(), hipsparsestruct->csr2csc_i->begin(), PetscScalarToPetscInt()));
1070d52a580bSJunchao Zhang }
1071d52a580bSJunchao Zhang PetscCallThrust(
1072d52a580bSJunchao Zhang thrust::copy(thrust::device, thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->begin()), thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->end()), matrixT->values->begin()));
1073d52a580bSJunchao Zhang }
1074d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1075d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
1076d52a580bSJunchao Zhang /* the compressed row indices is not used for matTranspose */
1077d52a580bSJunchao Zhang matstructT->cprowIndices = NULL;
1078d52a580bSJunchao Zhang /* assign the pointer */
1079d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSE *)A->spptr)->matTranspose = matstructT;
1080d52a580bSJunchao Zhang A->transupdated = PETSC_TRUE;
1081d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1082d52a580bSJunchao Zhang }
1083d52a580bSJunchao Zhang
1084d52a580bSJunchao Zhang /* Why do we need to analyze the transposed matrix again? Can't we just use op(A) = HIPSPARSE_OPERATION_TRANSPOSE in MatSolve_SeqAIJHIPSPARSE? */
MatSolveTranspose_SeqAIJHIPSPARSE(Mat A,Vec bb,Vec xx)1085d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1086d52a580bSJunchao Zhang {
1087d52a580bSJunchao Zhang PetscInt n = xx->map->n;
1088d52a580bSJunchao Zhang const PetscScalar *barray;
1089d52a580bSJunchao Zhang PetscScalar *xarray;
1090d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> bGPU;
1091d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xGPU;
1092d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1093d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1094d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1095d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1096d52a580bSJunchao Zhang
1097d52a580bSJunchao Zhang PetscFunctionBegin;
1098d52a580bSJunchao Zhang /* Analyze the matrix and create the transpose ... on the fly */
1099d52a580bSJunchao Zhang if (!loTriFactorT && !upTriFactorT) {
1100d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1101d52a580bSJunchao Zhang loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1102d52a580bSJunchao Zhang upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1103d52a580bSJunchao Zhang }
1104d52a580bSJunchao Zhang
1105d52a580bSJunchao Zhang /* Get the GPU pointers */
1106d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1107d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray));
1108d52a580bSJunchao Zhang xGPU = thrust::device_pointer_cast(xarray);
1109d52a580bSJunchao Zhang bGPU = thrust::device_pointer_cast(barray);
1110d52a580bSJunchao Zhang
1111d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1112d52a580bSJunchao Zhang /* First, reorder with the row permutation */
1113d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU + n, hipsparseTriFactors->rpermIndices->end()), xGPU);
1114d52a580bSJunchao Zhang
1115d52a580bSJunchao Zhang /* First, solve U */
1116d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1117d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, xarray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1118d52a580bSJunchao Zhang
1119d52a580bSJunchao Zhang /* Then, solve L */
1120d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1121d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1122d52a580bSJunchao Zhang
1123d52a580bSJunchao Zhang /* Last, copy the solution, xGPU, into a temporary with the column permutation ... can't be done in place. */
1124d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(xGPU, hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(xGPU + n, hipsparseTriFactors->cpermIndices->end()), tempGPU->begin());
1125d52a580bSJunchao Zhang
1126d52a580bSJunchao Zhang /* Copy the temporary to the full solution. */
1127d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), tempGPU->begin(), tempGPU->end(), xGPU);
1128d52a580bSJunchao Zhang
1129d52a580bSJunchao Zhang /* restore */
1130d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1131d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1132d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1133d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1134d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1135d52a580bSJunchao Zhang }
1136d52a580bSJunchao Zhang
MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A,Vec bb,Vec xx)1137d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1138d52a580bSJunchao Zhang {
1139d52a580bSJunchao Zhang const PetscScalar *barray;
1140d52a580bSJunchao Zhang PetscScalar *xarray;
1141d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1142d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1143d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1144d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1145d52a580bSJunchao Zhang
1146d52a580bSJunchao Zhang PetscFunctionBegin;
1147d52a580bSJunchao Zhang /* Analyze the matrix and create the transpose ... on the fly */
1148d52a580bSJunchao Zhang if (!loTriFactorT && !upTriFactorT) {
1149d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1150d52a580bSJunchao Zhang loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1151d52a580bSJunchao Zhang upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1152d52a580bSJunchao Zhang }
1153d52a580bSJunchao Zhang
1154d52a580bSJunchao Zhang /* Get the GPU pointers */
1155d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1156d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray));
1157d52a580bSJunchao Zhang
1158d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1159d52a580bSJunchao Zhang /* First, solve U */
1160d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1161d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, barray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1162d52a580bSJunchao Zhang
1163d52a580bSJunchao Zhang /* Then, solve L */
1164d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1165d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1166d52a580bSJunchao Zhang
1167d52a580bSJunchao Zhang /* restore */
1168d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1169d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1170d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1171d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1172d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1173d52a580bSJunchao Zhang }
1174d52a580bSJunchao Zhang
MatSolve_SeqAIJHIPSPARSE(Mat A,Vec bb,Vec xx)1175d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1176d52a580bSJunchao Zhang {
1177d52a580bSJunchao Zhang const PetscScalar *barray;
1178d52a580bSJunchao Zhang PetscScalar *xarray;
1179d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> bGPU;
1180d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xGPU;
1181d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1182d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1183d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1184d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1185d52a580bSJunchao Zhang
1186d52a580bSJunchao Zhang PetscFunctionBegin;
1187d52a580bSJunchao Zhang /* Get the GPU pointers */
1188d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1189d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray));
1190d52a580bSJunchao Zhang xGPU = thrust::device_pointer_cast(xarray);
1191d52a580bSJunchao Zhang bGPU = thrust::device_pointer_cast(barray);
1192d52a580bSJunchao Zhang
1193d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1194d52a580bSJunchao Zhang /* First, reorder with the row permutation */
1195d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->end()), tempGPU->begin());
1196d52a580bSJunchao Zhang
1197d52a580bSJunchao Zhang /* Next, solve L */
1198d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1199d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, tempGPU->data().get(), xarray, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1200d52a580bSJunchao Zhang
1201d52a580bSJunchao Zhang /* Then, solve U */
1202d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1203d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, xarray, tempGPU->data().get(), upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1204d52a580bSJunchao Zhang
1205d52a580bSJunchao Zhang /* Last, reorder with the column permutation */
1206d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->end()), xGPU);
1207d52a580bSJunchao Zhang
1208d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1209d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1210d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1211d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1212d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1213d52a580bSJunchao Zhang }
1214d52a580bSJunchao Zhang
MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A,Vec bb,Vec xx)1215d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1216d52a580bSJunchao Zhang {
1217d52a580bSJunchao Zhang const PetscScalar *barray;
1218d52a580bSJunchao Zhang PetscScalar *xarray;
1219d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1220d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1221d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1222d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1223d52a580bSJunchao Zhang
1224d52a580bSJunchao Zhang PetscFunctionBegin;
1225d52a580bSJunchao Zhang /* Get the GPU pointers */
1226d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1227d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray));
1228d52a580bSJunchao Zhang
1229d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1230d52a580bSJunchao Zhang /* First, solve L */
1231d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1232d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, barray, tempGPU->data().get(), loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1233d52a580bSJunchao Zhang
1234d52a580bSJunchao Zhang /* Next, solve U */
1235d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1236d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, tempGPU->data().get(), xarray, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1237d52a580bSJunchao Zhang
1238d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1239d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1240d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1241d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1242d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1243d52a580bSJunchao Zhang }
1244d52a580bSJunchao Zhang
1245d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1246d52a580bSJunchao Zhang /* hipsparseSpSV_solve() and related functions first appeared in ROCm-4.5.0*/
MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact,Vec b,Vec x)1247d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1248d52a580bSJunchao Zhang {
1249d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1250d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1251d52a580bSJunchao Zhang const PetscScalar *barray;
1252d52a580bSJunchao Zhang PetscScalar *xarray;
1253d52a580bSJunchao Zhang
1254d52a580bSJunchao Zhang PetscFunctionBegin;
1255d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray));
1256d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray));
1257d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1258d52a580bSJunchao Zhang
1259d52a580bSJunchao Zhang /* Solve L*y = b */
1260d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1261d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1262d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1263d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1264d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1265d52a580bSJunchao Zhang #else
1266d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1267d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1268d52a580bSJunchao Zhang #endif
1269d52a580bSJunchao Zhang /* Solve U*x = y */
1270d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1271d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1272d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1273d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U));
1274d52a580bSJunchao Zhang #else
1275d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1276d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1277d52a580bSJunchao Zhang #endif
1278d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray));
1279d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1280d52a580bSJunchao Zhang
1281d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1282d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1283d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1284d52a580bSJunchao Zhang }
1285d52a580bSJunchao Zhang
MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact,Vec b,Vec x)1286d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1287d52a580bSJunchao Zhang {
1288d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1289d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1290d52a580bSJunchao Zhang const PetscScalar *barray;
1291d52a580bSJunchao Zhang PetscScalar *xarray;
1292d52a580bSJunchao Zhang
1293d52a580bSJunchao Zhang PetscFunctionBegin;
1294d52a580bSJunchao Zhang if (!fs->createdTransposeSpSVDescr) { /* Call MatSolveTranspose() for the first time */
1295d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1296d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* The matrix is still L. We only do transpose solve with it */
1297d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1298d52a580bSJunchao Zhang
1299d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Ut));
1300d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, &fs->spsvBufferSize_Ut));
1301d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1302d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Ut, fs->spsvBufferSize_Ut));
1303d52a580bSJunchao Zhang fs->createdTransposeSpSVDescr = PETSC_TRUE;
1304d52a580bSJunchao Zhang }
1305d52a580bSJunchao Zhang
1306d52a580bSJunchao Zhang if (!fs->updatedTransposeSpSVAnalysis) {
1307d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1308d52a580bSJunchao Zhang
1309d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1310d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_TRUE;
1311d52a580bSJunchao Zhang }
1312d52a580bSJunchao Zhang
1313d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray));
1314d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray));
1315d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1316d52a580bSJunchao Zhang
1317d52a580bSJunchao Zhang /* Solve Ut*y = b */
1318d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1319d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1320d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1321d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1322d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut));
1323d52a580bSJunchao Zhang #else
1324d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1325d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1326d52a580bSJunchao Zhang #endif
1327d52a580bSJunchao Zhang /* Solve Lt*x = y */
1328d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1329d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1330d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1331d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1332d52a580bSJunchao Zhang #else
1333d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1334d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1335d52a580bSJunchao Zhang #endif
1336d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray));
1337d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1338d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1339d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1340d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1341d52a580bSJunchao Zhang }
1342d52a580bSJunchao Zhang
MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact,Mat A,const MatFactorInfo * info)1343d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, const MatFactorInfo *info)
1344d52a580bSJunchao Zhang {
1345d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1346d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1347d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1348d52a580bSJunchao Zhang CsrMatrix *Acsr;
1349d52a580bSJunchao Zhang PetscInt m, nz;
1350d52a580bSJunchao Zhang PetscBool flg;
1351d52a580bSJunchao Zhang
1352d52a580bSJunchao Zhang PetscFunctionBegin;
1353d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1354d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1355d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1356d52a580bSJunchao Zhang }
1357d52a580bSJunchao Zhang
1358d52a580bSJunchao Zhang /* Copy A's value to fact */
1359d52a580bSJunchao Zhang m = fact->rmap->n;
1360d52a580bSJunchao Zhang nz = aij->nz;
1361d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1362d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat;
1363d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1364d52a580bSJunchao Zhang
1365d52a580bSJunchao Zhang /* Factorize fact inplace */
1366d52a580bSJunchao Zhang if (m)
1367d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1368d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1369d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1370d52a580bSJunchao Zhang int numerical_zero;
1371d52a580bSJunchao Zhang hipsparseStatus_t status;
1372d52a580bSJunchao Zhang status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &numerical_zero);
1373d52a580bSJunchao Zhang PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csrilu02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1374d52a580bSJunchao Zhang }
1375d52a580bSJunchao Zhang
1376d52a580bSJunchao Zhang /* hipsparseSpSV_analysis() is numeric, i.e., it requires valid matrix values, therefore, we do it after hipsparseXcsrilu02() */
1377d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1378d52a580bSJunchao Zhang
1379d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1380d52a580bSJunchao Zhang
1381d52a580bSJunchao Zhang /* L, U values have changed, reset the flag to indicate we need to redo hipsparseSpSV_analysis() for transpose solve */
1382d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
1383d52a580bSJunchao Zhang
1384d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_GPU;
1385d52a580bSJunchao Zhang fact->ops->solve = MatSolve_SeqAIJHIPSPARSE_ILU0;
1386d52a580bSJunchao Zhang fact->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_ILU0;
1387d52a580bSJunchao Zhang fact->ops->matsolve = NULL;
1388d52a580bSJunchao Zhang fact->ops->matsolvetranspose = NULL;
1389d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1390d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1391d52a580bSJunchao Zhang }
1392d52a580bSJunchao Zhang
MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1393d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1394d52a580bSJunchao Zhang {
1395d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1396d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1397d52a580bSJunchao Zhang PetscInt m, nz;
1398d52a580bSJunchao Zhang
1399d52a580bSJunchao Zhang PetscFunctionBegin;
1400d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1401d52a580bSJunchao Zhang PetscBool flg, diagDense;
1402d52a580bSJunchao Zhang
1403d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1404d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1405d52a580bSJunchao Zhang PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1406d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1407d52a580bSJunchao Zhang PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1408d52a580bSJunchao Zhang }
1409d52a580bSJunchao Zhang
1410d52a580bSJunchao Zhang /* Free the old stale stuff */
1411d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1412d52a580bSJunchao Zhang
1413d52a580bSJunchao Zhang /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1414d52a580bSJunchao Zhang but they will not be used. Allocate them just for easy debugging.
1415d52a580bSJunchao Zhang */
1416d52a580bSJunchao Zhang PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1417d52a580bSJunchao Zhang
1418d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_BOTH;
1419d52a580bSJunchao Zhang fact->factortype = MAT_FACTOR_ILU;
1420d52a580bSJunchao Zhang fact->info.factor_mallocs = 0;
1421d52a580bSJunchao Zhang fact->info.fill_ratio_given = info->fill;
1422d52a580bSJunchao Zhang fact->info.fill_ratio_needed = 1.0;
1423d52a580bSJunchao Zhang
1424d52a580bSJunchao Zhang aij->row = NULL;
1425d52a580bSJunchao Zhang aij->col = NULL;
1426d52a580bSJunchao Zhang
1427d52a580bSJunchao Zhang /* ====================================================================== */
1428d52a580bSJunchao Zhang /* Copy A's i, j to fact and also allocate the value array of fact. */
1429d52a580bSJunchao Zhang /* We'll do in-place factorization on fact */
1430d52a580bSJunchao Zhang /* ====================================================================== */
1431d52a580bSJunchao Zhang const int *Ai, *Aj;
1432d52a580bSJunchao Zhang
1433d52a580bSJunchao Zhang m = fact->rmap->n;
1434d52a580bSJunchao Zhang nz = aij->nz;
1435d52a580bSJunchao Zhang
1436d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1437d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1438d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1439d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1440d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1441d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1442d52a580bSJunchao Zhang
1443d52a580bSJunchao Zhang /* ====================================================================== */
1444d52a580bSJunchao Zhang /* Create descriptors for M, L, U */
1445d52a580bSJunchao Zhang /* ====================================================================== */
1446d52a580bSJunchao Zhang hipsparseFillMode_t fillMode;
1447d52a580bSJunchao Zhang hipsparseDiagType_t diagType;
1448d52a580bSJunchao Zhang
1449d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1450d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1451d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1452d52a580bSJunchao Zhang
1453d52a580bSJunchao Zhang /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1454d52a580bSJunchao Zhang hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1455d52a580bSJunchao Zhang assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1456d52a580bSJunchao Zhang all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1457d52a580bSJunchao Zhang assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1458d52a580bSJunchao Zhang */
1459d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_LOWER;
1460d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_UNIT;
1461d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1462d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1463d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1464d52a580bSJunchao Zhang
1465d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_UPPER;
1466d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1467d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_U, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1468d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1469d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1470d52a580bSJunchao Zhang
1471d52a580bSJunchao Zhang /* ========================================================================= */
1472d52a580bSJunchao Zhang /* Query buffer sizes for csrilu0, SpSV and allocate buffers */
1473d52a580bSJunchao Zhang /* ========================================================================= */
1474d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrilu02Info(&fs->ilu0Info_M));
1475d52a580bSJunchao Zhang if (m)
1476d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02_bufferSize(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1477d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, &fs->factBufferSize_M));
1478d52a580bSJunchao Zhang
1479d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1480d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1481d52a580bSJunchao Zhang
1482d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1483d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1484d52a580bSJunchao Zhang
1485d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1486d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1487d52a580bSJunchao Zhang
1488d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_U));
1489d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, &fs->spsvBufferSize_U));
1490d52a580bSJunchao Zhang
1491d52a580bSJunchao Zhang /* It appears spsvBuffer_L/U can not be shared (i.e., the same) for our case, but factBuffer_M can share with either of spsvBuffer_L/U.
1492d52a580bSJunchao Zhang To save memory, we make factBuffer_M share with the bigger of spsvBuffer_L/U.
1493d52a580bSJunchao Zhang */
1494d52a580bSJunchao Zhang if (fs->spsvBufferSize_L > fs->spsvBufferSize_U) {
1495d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1496d52a580bSJunchao Zhang fs->spsvBuffer_L = fs->factBuffer_M;
1497d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_U, fs->spsvBufferSize_U));
1498d52a580bSJunchao Zhang } else {
1499d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_U, (size_t)fs->factBufferSize_M)));
1500d52a580bSJunchao Zhang fs->spsvBuffer_U = fs->factBuffer_M;
1501d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1502d52a580bSJunchao Zhang }
1503d52a580bSJunchao Zhang
1504d52a580bSJunchao Zhang /* ========================================================================== */
1505d52a580bSJunchao Zhang /* Perform analysis of ilu0 on M, SpSv on L and U */
1506d52a580bSJunchao Zhang /* The lower(upper) triangular part of M has the same sparsity pattern as L(U)*/
1507d52a580bSJunchao Zhang /* ========================================================================== */
1508d52a580bSJunchao Zhang int structural_zero;
1509d52a580bSJunchao Zhang
1510d52a580bSJunchao Zhang fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1511d52a580bSJunchao Zhang if (m)
1512d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02_analysis(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1513d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1514d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1515d52a580bSJunchao Zhang /* Function hipsparseXcsrilu02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1516d52a580bSJunchao Zhang hipsparseStatus_t status;
1517d52a580bSJunchao Zhang status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &structural_zero);
1518d52a580bSJunchao Zhang PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csrilu02: A(%d,%d) is missing", structural_zero, structural_zero);
1519d52a580bSJunchao Zhang }
1520d52a580bSJunchao Zhang
1521d52a580bSJunchao Zhang /* Estimate FLOPs of the numeric factorization */
1522d52a580bSJunchao Zhang {
1523d52a580bSJunchao Zhang Mat_SeqAIJ *Aseq = (Mat_SeqAIJ *)A->data;
1524d52a580bSJunchao Zhang PetscInt *Ai, nzRow, nzLeft;
1525d52a580bSJunchao Zhang PetscLogDouble flops = 0.0;
1526d52a580bSJunchao Zhang const PetscInt *Adiag;
1527d52a580bSJunchao Zhang
1528d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &Adiag, NULL));
1529d52a580bSJunchao Zhang Ai = Aseq->i;
1530d52a580bSJunchao Zhang for (PetscInt i = 0; i < m; i++) {
1531d52a580bSJunchao Zhang if (Ai[i] < Adiag[i] && Adiag[i] < Ai[i + 1]) { /* There are nonzeros left to the diagonal of row i */
1532d52a580bSJunchao Zhang nzRow = Ai[i + 1] - Ai[i];
1533d52a580bSJunchao Zhang nzLeft = Adiag[i] - Ai[i];
1534d52a580bSJunchao Zhang /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1535d52a580bSJunchao Zhang and include the eliminated one will be updated, which incurs a multiplication and an addition.
1536d52a580bSJunchao Zhang */
1537d52a580bSJunchao Zhang nzLeft = (nzRow - 1) / 2;
1538d52a580bSJunchao Zhang flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1539d52a580bSJunchao Zhang }
1540d52a580bSJunchao Zhang }
1541d52a580bSJunchao Zhang fs->numericFactFlops = flops;
1542d52a580bSJunchao Zhang }
1543d52a580bSJunchao Zhang fact->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0;
1544d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1545d52a580bSJunchao Zhang }
1546d52a580bSJunchao Zhang
MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact,Vec b,Vec x)1547d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x)
1548d52a580bSJunchao Zhang {
1549d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1550d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1551d52a580bSJunchao Zhang const PetscScalar *barray;
1552d52a580bSJunchao Zhang PetscScalar *xarray;
1553d52a580bSJunchao Zhang
1554d52a580bSJunchao Zhang PetscFunctionBegin;
1555d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray));
1556d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray));
1557d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
1558d52a580bSJunchao Zhang
1559d52a580bSJunchao Zhang /* Solve L*y = b */
1560d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1561d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1562d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1563d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1564d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L));
1565d52a580bSJunchao Zhang #else
1566d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1567d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1568d52a580bSJunchao Zhang #endif
1569d52a580bSJunchao Zhang /* Solve Lt*x = y */
1570d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1571d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1572d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1573d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1574d52a580bSJunchao Zhang #else
1575d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1576d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1577d52a580bSJunchao Zhang #endif
1578d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray));
1579d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1580d52a580bSJunchao Zhang
1581d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
1582d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1583d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1584d52a580bSJunchao Zhang }
1585d52a580bSJunchao Zhang
MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact,Mat A,const MatFactorInfo * info)1586d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, const MatFactorInfo *info)
1587d52a580bSJunchao Zhang {
1588d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1589d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1590d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1591d52a580bSJunchao Zhang CsrMatrix *Acsr;
1592d52a580bSJunchao Zhang PetscInt m, nz;
1593d52a580bSJunchao Zhang PetscBool flg;
1594d52a580bSJunchao Zhang
1595d52a580bSJunchao Zhang PetscFunctionBegin;
1596d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1597d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1598d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1599d52a580bSJunchao Zhang }
1600d52a580bSJunchao Zhang
1601d52a580bSJunchao Zhang /* Copy A's value to fact */
1602d52a580bSJunchao Zhang m = fact->rmap->n;
1603d52a580bSJunchao Zhang nz = aij->nz;
1604d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1605d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat;
1606d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1607d52a580bSJunchao Zhang
1608d52a580bSJunchao Zhang /* Factorize fact inplace */
1609d52a580bSJunchao Zhang /* Function csric02() only takes the lower triangular part of matrix A to perform factorization.
1610d52a580bSJunchao Zhang The matrix type must be HIPSPARSE_MATRIX_TYPE_GENERAL, the fill mode and diagonal type are ignored,
1611d52a580bSJunchao Zhang and the strictly upper triangular part is ignored and never touched. It does not matter if A is Hermitian or not.
1612d52a580bSJunchao Zhang In other words, from the point of view of csric02() A is Hermitian and only the lower triangular part is provided.
1613d52a580bSJunchao Zhang */
1614d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1615d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1616d52a580bSJunchao Zhang int numerical_zero;
1617d52a580bSJunchao Zhang hipsparseStatus_t status;
1618d52a580bSJunchao Zhang status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &numerical_zero);
1619d52a580bSJunchao Zhang PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csric02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1620d52a580bSJunchao Zhang }
1621d52a580bSJunchao Zhang
1622d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1623d52a580bSJunchao Zhang
1624d52a580bSJunchao Zhang /* Note that hipsparse reports this error if we use double and HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
1625d52a580bSJunchao Zhang ** On entry to hipsparseSpSV_analysis(): conjugate transpose (opA) is not supported for matA data type, current -> CUDA_R_64F
1626d52a580bSJunchao Zhang */
1627d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1628d52a580bSJunchao Zhang
1629d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_GPU;
1630d52a580bSJunchao Zhang fact->ops->solve = MatSolve_SeqAIJHIPSPARSE_ICC0;
1631d52a580bSJunchao Zhang fact->ops->solvetranspose = MatSolve_SeqAIJHIPSPARSE_ICC0;
1632d52a580bSJunchao Zhang fact->ops->matsolve = NULL;
1633d52a580bSJunchao Zhang fact->ops->matsolvetranspose = NULL;
1634d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1635d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1636d52a580bSJunchao Zhang }
1637d52a580bSJunchao Zhang
MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact,Mat A,IS perm,const MatFactorInfo * info)1638d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, IS perm, const MatFactorInfo *info)
1639d52a580bSJunchao Zhang {
1640d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1641d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data;
1642d52a580bSJunchao Zhang PetscInt m, nz;
1643d52a580bSJunchao Zhang
1644d52a580bSJunchao Zhang PetscFunctionBegin;
1645d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1646d52a580bSJunchao Zhang PetscBool flg, diagDense;
1647d52a580bSJunchao Zhang
1648d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1649d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1650d52a580bSJunchao Zhang PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1651d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1652d52a580bSJunchao Zhang PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1653d52a580bSJunchao Zhang }
1654d52a580bSJunchao Zhang
1655d52a580bSJunchao Zhang /* Free the old stale stuff */
1656d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1657d52a580bSJunchao Zhang
1658d52a580bSJunchao Zhang /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1659d52a580bSJunchao Zhang but they will not be used. Allocate them just for easy debugging.
1660d52a580bSJunchao Zhang */
1661d52a580bSJunchao Zhang PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1662d52a580bSJunchao Zhang
1663d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_BOTH;
1664d52a580bSJunchao Zhang fact->factortype = MAT_FACTOR_ICC;
1665d52a580bSJunchao Zhang fact->info.factor_mallocs = 0;
1666d52a580bSJunchao Zhang fact->info.fill_ratio_given = info->fill;
1667d52a580bSJunchao Zhang fact->info.fill_ratio_needed = 1.0;
1668d52a580bSJunchao Zhang
1669d52a580bSJunchao Zhang aij->row = NULL;
1670d52a580bSJunchao Zhang aij->col = NULL;
1671d52a580bSJunchao Zhang
1672d52a580bSJunchao Zhang /* ====================================================================== */
1673d52a580bSJunchao Zhang /* Copy A's i, j to fact and also allocate the value array of fact. */
1674d52a580bSJunchao Zhang /* We'll do in-place factorization on fact */
1675d52a580bSJunchao Zhang /* ====================================================================== */
1676d52a580bSJunchao Zhang const int *Ai, *Aj;
1677d52a580bSJunchao Zhang
1678d52a580bSJunchao Zhang m = fact->rmap->n;
1679d52a580bSJunchao Zhang nz = aij->nz;
1680d52a580bSJunchao Zhang
1681d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1682d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1683d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1684d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1685d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1686d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1687d52a580bSJunchao Zhang
1688d52a580bSJunchao Zhang /* ====================================================================== */
1689d52a580bSJunchao Zhang /* Create mat descriptors for M, L */
1690d52a580bSJunchao Zhang /* ====================================================================== */
1691d52a580bSJunchao Zhang hipsparseFillMode_t fillMode;
1692d52a580bSJunchao Zhang hipsparseDiagType_t diagType;
1693d52a580bSJunchao Zhang
1694d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1695d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1696d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1697d52a580bSJunchao Zhang
1698d52a580bSJunchao Zhang /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1699d52a580bSJunchao Zhang hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1700d52a580bSJunchao Zhang assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1701d52a580bSJunchao Zhang all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1702d52a580bSJunchao Zhang assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1703d52a580bSJunchao Zhang */
1704d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_LOWER;
1705d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1706d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1707d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1708d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1709d52a580bSJunchao Zhang
1710d52a580bSJunchao Zhang /* ========================================================================= */
1711d52a580bSJunchao Zhang /* Query buffer sizes for csric0, SpSV of L and Lt, and allocate buffers */
1712d52a580bSJunchao Zhang /* ========================================================================= */
1713d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsric02Info(&fs->ic0Info_M));
1714d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02_bufferSize(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, &fs->factBufferSize_M));
1715d52a580bSJunchao Zhang
1716d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1717d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1718d52a580bSJunchao Zhang
1719d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1720d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1721d52a580bSJunchao Zhang
1722d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1723d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1724d52a580bSJunchao Zhang
1725d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1726d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1727d52a580bSJunchao Zhang
1728d52a580bSJunchao Zhang /* To save device memory, we make the factorization buffer share with one of the solver buffer.
1729d52a580bSJunchao Zhang See also comments in `MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0()`.
1730d52a580bSJunchao Zhang */
1731d52a580bSJunchao Zhang if (fs->spsvBufferSize_L > fs->spsvBufferSize_Lt) {
1732d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1733d52a580bSJunchao Zhang fs->spsvBuffer_L = fs->factBuffer_M;
1734d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1735d52a580bSJunchao Zhang } else {
1736d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_Lt, (size_t)fs->factBufferSize_M)));
1737d52a580bSJunchao Zhang fs->spsvBuffer_Lt = fs->factBuffer_M;
1738d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1739d52a580bSJunchao Zhang }
1740d52a580bSJunchao Zhang
1741d52a580bSJunchao Zhang /* ========================================================================== */
1742d52a580bSJunchao Zhang /* Perform analysis of ic0 on M */
1743d52a580bSJunchao Zhang /* The lower triangular part of M has the same sparsity pattern as L */
1744d52a580bSJunchao Zhang /* ========================================================================== */
1745d52a580bSJunchao Zhang int structural_zero;
1746d52a580bSJunchao Zhang
1747d52a580bSJunchao Zhang fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1748d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02_analysis(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1749d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) {
1750d52a580bSJunchao Zhang hipsparseStatus_t status;
1751d52a580bSJunchao Zhang /* Function hipsparseXcsric02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1752d52a580bSJunchao Zhang status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &structural_zero);
1753d52a580bSJunchao Zhang PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csric02: A(%d,%d) is missing", structural_zero, structural_zero);
1754d52a580bSJunchao Zhang }
1755d52a580bSJunchao Zhang
1756d52a580bSJunchao Zhang /* Estimate FLOPs of the numeric factorization */
1757d52a580bSJunchao Zhang {
1758d52a580bSJunchao Zhang Mat_SeqAIJ *Aseq = (Mat_SeqAIJ *)A->data;
1759d52a580bSJunchao Zhang PetscInt *Ai, nzRow, nzLeft;
1760d52a580bSJunchao Zhang PetscLogDouble flops = 0.0;
1761d52a580bSJunchao Zhang
1762d52a580bSJunchao Zhang Ai = Aseq->i;
1763d52a580bSJunchao Zhang for (PetscInt i = 0; i < m; i++) {
1764d52a580bSJunchao Zhang nzRow = Ai[i + 1] - Ai[i];
1765d52a580bSJunchao Zhang if (nzRow > 1) {
1766d52a580bSJunchao Zhang /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1767d52a580bSJunchao Zhang and include the eliminated one will be updated, which incurs a multiplication and an addition.
1768d52a580bSJunchao Zhang */
1769d52a580bSJunchao Zhang nzLeft = (nzRow - 1) / 2;
1770d52a580bSJunchao Zhang flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1771d52a580bSJunchao Zhang }
1772d52a580bSJunchao Zhang }
1773d52a580bSJunchao Zhang fs->numericFactFlops = flops;
1774d52a580bSJunchao Zhang }
1775d52a580bSJunchao Zhang fact->ops->choleskyfactornumeric = MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0;
1776d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1777d52a580bSJunchao Zhang }
1778d52a580bSJunchao Zhang #endif
1779d52a580bSJunchao Zhang
MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1780d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1781d52a580bSJunchao Zhang {
1782d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1783d52a580bSJunchao Zhang
1784d52a580bSJunchao Zhang PetscFunctionBegin;
1785d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1786d52a580bSJunchao Zhang PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
1787d52a580bSJunchao Zhang if (!info->factoronhost) {
1788d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity));
1789d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity));
1790d52a580bSJunchao Zhang }
1791d52a580bSJunchao Zhang if (!info->levels && row_identity && col_identity) PetscCall(MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(B, A, isrow, iscol, info));
1792d52a580bSJunchao Zhang else
1793d52a580bSJunchao Zhang #endif
1794d52a580bSJunchao Zhang {
1795d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1796d52a580bSJunchao Zhang PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1797d52a580bSJunchao Zhang B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1798d52a580bSJunchao Zhang }
1799d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1800d52a580bSJunchao Zhang }
1801d52a580bSJunchao Zhang
MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1802d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1803d52a580bSJunchao Zhang {
1804d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1805d52a580bSJunchao Zhang
1806d52a580bSJunchao Zhang PetscFunctionBegin;
1807d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1808d52a580bSJunchao Zhang PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1809d52a580bSJunchao Zhang B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1810d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1811d52a580bSJunchao Zhang }
1812d52a580bSJunchao Zhang
MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS perm,const MatFactorInfo * info)1813d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1814d52a580bSJunchao Zhang {
1815d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1816d52a580bSJunchao Zhang
1817d52a580bSJunchao Zhang PetscFunctionBegin;
1818d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1819d52a580bSJunchao Zhang PetscBool perm_identity = PETSC_FALSE;
1820d52a580bSJunchao Zhang if (!info->factoronhost) PetscCall(ISIdentity(perm, &perm_identity));
1821d52a580bSJunchao Zhang if (!info->levels && perm_identity) PetscCall(MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(B, A, perm, info));
1822d52a580bSJunchao Zhang else
1823d52a580bSJunchao Zhang #endif
1824d52a580bSJunchao Zhang {
1825d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1826d52a580bSJunchao Zhang PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
1827d52a580bSJunchao Zhang B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1828d52a580bSJunchao Zhang }
1829d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1830d52a580bSJunchao Zhang }
1831d52a580bSJunchao Zhang
MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS perm,const MatFactorInfo * info)1832d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1833d52a580bSJunchao Zhang {
1834d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1835d52a580bSJunchao Zhang
1836d52a580bSJunchao Zhang PetscFunctionBegin;
1837d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1838d52a580bSJunchao Zhang PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info));
1839d52a580bSJunchao Zhang B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1840d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1841d52a580bSJunchao Zhang }
1842d52a580bSJunchao Zhang
MatFactorGetSolverType_seqaij_hipsparse(Mat A,MatSolverType * type)1843d52a580bSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_seqaij_hipsparse(Mat A, MatSolverType *type)
1844d52a580bSJunchao Zhang {
1845d52a580bSJunchao Zhang PetscFunctionBegin;
1846d52a580bSJunchao Zhang *type = MATSOLVERHIPSPARSE;
1847d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1848d52a580bSJunchao Zhang }
1849d52a580bSJunchao Zhang
1850d52a580bSJunchao Zhang /*MC
1851d52a580bSJunchao Zhang MATSOLVERHIPSPARSE = "hipsparse" - A matrix type providing triangular solvers for sequential matrices
1852d52a580bSJunchao Zhang on a single GPU of type, `MATSEQAIJHIPSPARSE`. Currently supported
1853d52a580bSJunchao Zhang algorithms are ILU(k) and ICC(k). Typically, deeper factorizations (larger k) results in poorer
1854d52a580bSJunchao Zhang performance in the triangular solves. Full LU, and Cholesky decompositions can be solved through the
1855d52a580bSJunchao Zhang HipSPARSE triangular solve algorithm. However, the performance can be quite poor and thus these
1856d52a580bSJunchao Zhang algorithms are not recommended. This class does NOT support direct solver operations.
1857d52a580bSJunchao Zhang
1858d52a580bSJunchao Zhang Level: beginner
1859d52a580bSJunchao Zhang
1860d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
1861d52a580bSJunchao Zhang M*/
1862d52a580bSJunchao Zhang
MatGetFactor_seqaijhipsparse_hipsparse(Mat A,MatFactorType ftype,Mat * B)1863d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijhipsparse_hipsparse(Mat A, MatFactorType ftype, Mat *B)
1864d52a580bSJunchao Zhang {
1865d52a580bSJunchao Zhang PetscInt n = A->rmap->n;
1866d52a580bSJunchao Zhang
1867d52a580bSJunchao Zhang PetscFunctionBegin;
1868d52a580bSJunchao Zhang PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1869d52a580bSJunchao Zhang PetscCall(MatSetSizes(*B, n, n, n, n));
1870d52a580bSJunchao Zhang (*B)->factortype = ftype;
1871d52a580bSJunchao Zhang PetscCall(MatSetType(*B, MATSEQAIJHIPSPARSE));
1872d52a580bSJunchao Zhang
1873d52a580bSJunchao Zhang if (A->boundtocpu && A->bindingpropagates) PetscCall(MatBindToCPU(*B, PETSC_TRUE));
1874d52a580bSJunchao Zhang if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
1875d52a580bSJunchao Zhang PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1876d52a580bSJunchao Zhang if (!A->boundtocpu) {
1877d52a580bSJunchao Zhang (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJHIPSPARSE;
1878d52a580bSJunchao Zhang (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJHIPSPARSE;
1879d52a580bSJunchao Zhang } else {
1880d52a580bSJunchao Zhang (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJ;
1881d52a580bSJunchao Zhang (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJ;
1882d52a580bSJunchao Zhang }
1883d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1884d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
1885d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
1886d52a580bSJunchao Zhang } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
1887d52a580bSJunchao Zhang if (!A->boundtocpu) {
1888d52a580bSJunchao Zhang (*B)->ops->iccfactorsymbolic = MatICCFactorSymbolic_SeqAIJHIPSPARSE;
1889d52a580bSJunchao Zhang (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE;
1890d52a580bSJunchao Zhang } else {
1891d52a580bSJunchao Zhang (*B)->ops->iccfactorsymbolic = MatICCFactorSymbolic_SeqAIJ;
1892d52a580bSJunchao Zhang (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJ;
1893d52a580bSJunchao Zhang }
1894d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
1895d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
1896d52a580bSJunchao Zhang } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for HIPSPARSE Matrix Types");
1897d52a580bSJunchao Zhang
1898d52a580bSJunchao Zhang PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1899d52a580bSJunchao Zhang (*B)->canuseordering = PETSC_TRUE;
1900d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_hipsparse));
1901d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1902d52a580bSJunchao Zhang }
1903d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSECopyFromGPU(Mat A)1904d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat A)
1905d52a580bSJunchao Zhang {
1906d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1907d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1908d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1909d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1910d52a580bSJunchao Zhang #endif
1911d52a580bSJunchao Zhang
1912d52a580bSJunchao Zhang PetscFunctionBegin;
1913d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_GPU) {
1914d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1915d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) {
1916d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)cusp->mat->mat;
1917d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(a->a, matrix->values->data().get(), a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1918d52a580bSJunchao Zhang }
1919d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1920d52a580bSJunchao Zhang else if (fs->csrVal) {
1921d52a580bSJunchao Zhang /* We have a factorized matrix on device and are able to copy it to host */
1922d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(a->a, fs->csrVal, a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1923d52a580bSJunchao Zhang }
1924d52a580bSJunchao Zhang #endif
1925d52a580bSJunchao Zhang else
1926d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for copying this type of factorized matrix from device to host");
1927d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu(a->nz * sizeof(PetscScalar)));
1928d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1929d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH;
1930d52a580bSJunchao Zhang }
1931d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1932d52a580bSJunchao Zhang }
1933d52a580bSJunchao Zhang
MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1934d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1935d52a580bSJunchao Zhang {
1936d52a580bSJunchao Zhang PetscFunctionBegin;
1937d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1938d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a;
1939d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1940d52a580bSJunchao Zhang }
1941d52a580bSJunchao Zhang
MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1942d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1943d52a580bSJunchao Zhang {
1944d52a580bSJunchao Zhang PetscFunctionBegin;
1945d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_CPU;
1946d52a580bSJunchao Zhang *array = NULL;
1947d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1948d52a580bSJunchao Zhang }
1949d52a580bSJunchao Zhang
MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A,const PetscScalar * array[])1950d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1951d52a580bSJunchao Zhang {
1952d52a580bSJunchao Zhang PetscFunctionBegin;
1953d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1954d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a;
1955d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1956d52a580bSJunchao Zhang }
1957d52a580bSJunchao Zhang
MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A,const PetscScalar * array[])1958d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1959d52a580bSJunchao Zhang {
1960d52a580bSJunchao Zhang PetscFunctionBegin;
1961d52a580bSJunchao Zhang *array = NULL;
1962d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1963d52a580bSJunchao Zhang }
1964d52a580bSJunchao Zhang
MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1965d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1966d52a580bSJunchao Zhang {
1967d52a580bSJunchao Zhang PetscFunctionBegin;
1968d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a;
1969d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1970d52a580bSJunchao Zhang }
1971d52a580bSJunchao Zhang
MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1972d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1973d52a580bSJunchao Zhang {
1974d52a580bSJunchao Zhang PetscFunctionBegin;
1975d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_CPU;
1976d52a580bSJunchao Zhang *array = NULL;
1977d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
1978d52a580bSJunchao Zhang }
1979d52a580bSJunchao Zhang
MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A,const PetscInt ** i,const PetscInt ** j,PetscScalar ** a,PetscMemType * mtype)1980d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
1981d52a580bSJunchao Zhang {
1982d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp;
1983d52a580bSJunchao Zhang CsrMatrix *matrix;
1984d52a580bSJunchao Zhang
1985d52a580bSJunchao Zhang PetscFunctionBegin;
1986d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1987d52a580bSJunchao Zhang PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Not for factored matrix");
1988d52a580bSJunchao Zhang cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(A->spptr);
1989d52a580bSJunchao Zhang PetscCheck(cusp != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "cusp is NULL");
1990d52a580bSJunchao Zhang matrix = (CsrMatrix *)cusp->mat->mat;
1991d52a580bSJunchao Zhang
1992d52a580bSJunchao Zhang if (i) {
1993d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
1994d52a580bSJunchao Zhang *i = matrix->row_offsets->data().get();
1995d52a580bSJunchao Zhang #else
1996d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
1997d52a580bSJunchao Zhang #endif
1998d52a580bSJunchao Zhang }
1999d52a580bSJunchao Zhang if (j) {
2000d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
2001d52a580bSJunchao Zhang *j = matrix->column_indices->data().get();
2002d52a580bSJunchao Zhang #else
2003d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
2004d52a580bSJunchao Zhang #endif
2005d52a580bSJunchao Zhang }
2006d52a580bSJunchao Zhang if (a) *a = matrix->values->data().get();
2007d52a580bSJunchao Zhang if (mtype) *mtype = PETSC_MEMTYPE_HIP;
2008d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2009d52a580bSJunchao Zhang }
2010d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSECopyToGPU(Mat A)2011d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat A)
2012d52a580bSJunchao Zhang {
2013d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2014d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct = hipsparsestruct->mat;
2015d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
2016d52a580bSJunchao Zhang PetscBool both = PETSC_TRUE;
2017d52a580bSJunchao Zhang PetscInt m = A->rmap->n, *ii, *ridx, tmp;
2018d52a580bSJunchao Zhang
2019d52a580bSJunchao Zhang PetscFunctionBegin;
2020d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PETSC_COMM_SELF, PETSC_ERR_GPU, "Cannot copy to GPU");
2021d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
2022d52a580bSJunchao Zhang if (A->nonzerostate == hipsparsestruct->nonzerostate && hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* Copy values only */
2023d52a580bSJunchao Zhang CsrMatrix *matrix;
2024d52a580bSJunchao Zhang matrix = (CsrMatrix *)hipsparsestruct->mat->mat;
2025d52a580bSJunchao Zhang
2026d52a580bSJunchao Zhang PetscCheck(!a->nz || a->a, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR values");
2027d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2028d52a580bSJunchao Zhang matrix->values->assign(a->a, a->a + a->nz);
2029d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
2030d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(a->nz * sizeof(PetscScalar)));
2031d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2032d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
2033d52a580bSJunchao Zhang } else {
2034d52a580bSJunchao Zhang PetscInt nnz;
2035d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2036d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&hipsparsestruct->mat, hipsparsestruct->format));
2037d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
2038d52a580bSJunchao Zhang delete hipsparsestruct->workVector;
2039d52a580bSJunchao Zhang delete hipsparsestruct->rowoffsets_gpu;
2040d52a580bSJunchao Zhang hipsparsestruct->workVector = NULL;
2041d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu = NULL;
2042d52a580bSJunchao Zhang try {
2043d52a580bSJunchao Zhang if (a->compressedrow.use) {
2044d52a580bSJunchao Zhang m = a->compressedrow.nrows;
2045d52a580bSJunchao Zhang ii = a->compressedrow.i;
2046d52a580bSJunchao Zhang ridx = a->compressedrow.rindex;
2047d52a580bSJunchao Zhang } else {
2048d52a580bSJunchao Zhang m = A->rmap->n;
2049d52a580bSJunchao Zhang ii = a->i;
2050d52a580bSJunchao Zhang ridx = NULL;
2051d52a580bSJunchao Zhang }
2052d52a580bSJunchao Zhang PetscCheck(ii, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR row data");
2053d52a580bSJunchao Zhang if (!a->a) {
2054d52a580bSJunchao Zhang nnz = ii[m];
2055d52a580bSJunchao Zhang both = PETSC_FALSE;
2056d52a580bSJunchao Zhang } else nnz = a->nz;
2057d52a580bSJunchao Zhang PetscCheck(!nnz || a->j, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR column data");
2058d52a580bSJunchao Zhang
2059d52a580bSJunchao Zhang /* create hipsparse matrix */
2060d52a580bSJunchao Zhang hipsparsestruct->nrows = m;
2061d52a580bSJunchao Zhang matstruct = new Mat_SeqAIJHIPSPARSEMultStruct;
2062d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstruct->descr));
2063d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstruct->descr, HIPSPARSE_INDEX_BASE_ZERO));
2064d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(matstruct->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2065d52a580bSJunchao Zhang
2066d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->alpha_one, sizeof(PetscScalar)));
2067d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->beta_zero, sizeof(PetscScalar)));
2068d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->beta_one, sizeof(PetscScalar)));
2069d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2070d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2071d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2072d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2073d52a580bSJunchao Zhang
2074d52a580bSJunchao Zhang /* Build a hybrid/ellpack matrix if this option is chosen for the storage */
2075d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
2076d52a580bSJunchao Zhang /* set the matrix */
2077d52a580bSJunchao Zhang CsrMatrix *mat = new CsrMatrix;
2078d52a580bSJunchao Zhang mat->num_rows = m;
2079d52a580bSJunchao Zhang mat->num_cols = A->cmap->n;
2080d52a580bSJunchao Zhang mat->num_entries = nnz;
2081d52a580bSJunchao Zhang mat->row_offsets = new THRUSTINTARRAY32(m + 1);
2082d52a580bSJunchao Zhang mat->column_indices = new THRUSTINTARRAY32(nnz);
2083d52a580bSJunchao Zhang mat->values = new THRUSTARRAY(nnz);
2084d52a580bSJunchao Zhang mat->row_offsets->assign(ii, ii + m + 1);
2085d52a580bSJunchao Zhang mat->column_indices->assign(a->j, a->j + nnz);
2086d52a580bSJunchao Zhang if (a->a) mat->values->assign(a->a, a->a + nnz);
2087d52a580bSJunchao Zhang
2088d52a580bSJunchao Zhang /* assign the pointer */
2089d52a580bSJunchao Zhang matstruct->mat = mat;
2090d52a580bSJunchao Zhang if (mat->num_rows) { /* hipsparse errors on empty matrices! */
2091d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&matstruct->matDescr, mat->num_rows, mat->num_cols, mat->num_entries, mat->row_offsets->data().get(), mat->column_indices->data().get(), mat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2092d52a580bSJunchao Zhang HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2093d52a580bSJunchao Zhang }
2094d52a580bSJunchao Zhang } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
2095d52a580bSJunchao Zhang CsrMatrix *mat = new CsrMatrix;
2096d52a580bSJunchao Zhang mat->num_rows = m;
2097d52a580bSJunchao Zhang mat->num_cols = A->cmap->n;
2098d52a580bSJunchao Zhang mat->num_entries = nnz;
2099d52a580bSJunchao Zhang mat->row_offsets = new THRUSTINTARRAY32(m + 1);
2100d52a580bSJunchao Zhang mat->column_indices = new THRUSTINTARRAY32(nnz);
2101d52a580bSJunchao Zhang mat->values = new THRUSTARRAY(nnz);
2102d52a580bSJunchao Zhang mat->row_offsets->assign(ii, ii + m + 1);
2103d52a580bSJunchao Zhang mat->column_indices->assign(a->j, a->j + nnz);
2104d52a580bSJunchao Zhang if (a->a) mat->values->assign(a->a, a->a + nnz);
2105d52a580bSJunchao Zhang
2106d52a580bSJunchao Zhang hipsparseHybMat_t hybMat;
2107d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
2108d52a580bSJunchao Zhang hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
2109d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, mat->num_rows, mat->num_cols, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), hybMat, 0, partition));
2110d52a580bSJunchao Zhang /* assign the pointer */
2111d52a580bSJunchao Zhang matstruct->mat = hybMat;
2112d52a580bSJunchao Zhang
2113d52a580bSJunchao Zhang if (mat) {
2114d52a580bSJunchao Zhang if (mat->values) delete (THRUSTARRAY *)mat->values;
2115d52a580bSJunchao Zhang if (mat->column_indices) delete (THRUSTINTARRAY32 *)mat->column_indices;
2116d52a580bSJunchao Zhang if (mat->row_offsets) delete (THRUSTINTARRAY32 *)mat->row_offsets;
2117d52a580bSJunchao Zhang delete (CsrMatrix *)mat;
2118d52a580bSJunchao Zhang }
2119d52a580bSJunchao Zhang }
2120d52a580bSJunchao Zhang
2121d52a580bSJunchao Zhang /* assign the compressed row indices */
2122d52a580bSJunchao Zhang if (a->compressedrow.use) {
2123d52a580bSJunchao Zhang hipsparsestruct->workVector = new THRUSTARRAY(m);
2124d52a580bSJunchao Zhang matstruct->cprowIndices = new THRUSTINTARRAY(m);
2125d52a580bSJunchao Zhang matstruct->cprowIndices->assign(ridx, ridx + m);
2126d52a580bSJunchao Zhang tmp = m;
2127d52a580bSJunchao Zhang } else {
2128d52a580bSJunchao Zhang hipsparsestruct->workVector = NULL;
2129d52a580bSJunchao Zhang matstruct->cprowIndices = NULL;
2130d52a580bSJunchao Zhang tmp = 0;
2131d52a580bSJunchao Zhang }
2132d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(((m + 1) + (a->nz)) * sizeof(int) + tmp * sizeof(PetscInt) + (3 + (a->nz)) * sizeof(PetscScalar)));
2133d52a580bSJunchao Zhang
2134d52a580bSJunchao Zhang /* assign the pointer */
2135d52a580bSJunchao Zhang hipsparsestruct->mat = matstruct;
2136d52a580bSJunchao Zhang } catch (char *ex) {
2137d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
2138d52a580bSJunchao Zhang }
2139d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
2140d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2141d52a580bSJunchao Zhang hipsparsestruct->nonzerostate = A->nonzerostate;
2142d52a580bSJunchao Zhang }
2143d52a580bSJunchao Zhang if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
2144d52a580bSJunchao Zhang }
2145d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2146d52a580bSJunchao Zhang }
2147d52a580bSJunchao Zhang
2148d52a580bSJunchao Zhang struct VecHIPPlusEquals {
2149d52a580bSJunchao Zhang template <typename Tuple>
operator ()VecHIPPlusEquals2150d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t)
2151d52a580bSJunchao Zhang {
2152d52a580bSJunchao Zhang thrust::get<1>(t) = thrust::get<1>(t) + thrust::get<0>(t);
2153d52a580bSJunchao Zhang }
2154d52a580bSJunchao Zhang };
2155d52a580bSJunchao Zhang
2156d52a580bSJunchao Zhang struct VecHIPEquals {
2157d52a580bSJunchao Zhang template <typename Tuple>
operator ()VecHIPEquals2158d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t)
2159d52a580bSJunchao Zhang {
2160d52a580bSJunchao Zhang thrust::get<1>(t) = thrust::get<0>(t);
2161d52a580bSJunchao Zhang }
2162d52a580bSJunchao Zhang };
2163d52a580bSJunchao Zhang
2164d52a580bSJunchao Zhang struct VecHIPEqualsReverse {
2165d52a580bSJunchao Zhang template <typename Tuple>
operator ()VecHIPEqualsReverse2166d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t)
2167d52a580bSJunchao Zhang {
2168d52a580bSJunchao Zhang thrust::get<0>(t) = thrust::get<1>(t);
2169d52a580bSJunchao Zhang }
2170d52a580bSJunchao Zhang };
2171d52a580bSJunchao Zhang
2172d52a580bSJunchao Zhang struct MatProductCtx_MatMatHipsparse {
2173d52a580bSJunchao Zhang PetscBool cisdense;
2174d52a580bSJunchao Zhang PetscScalar *Bt;
2175d52a580bSJunchao Zhang Mat X;
2176d52a580bSJunchao Zhang PetscBool reusesym; /* Hipsparse does not have split symbolic and numeric phases for sparse matmat operations */
2177d52a580bSJunchao Zhang PetscLogDouble flops;
2178d52a580bSJunchao Zhang CsrMatrix *Bcsr;
2179d52a580bSJunchao Zhang hipsparseSpMatDescr_t matSpBDescr;
2180d52a580bSJunchao Zhang PetscBool initialized; /* C = alpha op(A) op(B) + beta C */
2181d52a580bSJunchao Zhang hipsparseDnMatDescr_t matBDescr;
2182d52a580bSJunchao Zhang hipsparseDnMatDescr_t matCDescr;
2183d52a580bSJunchao Zhang PetscInt Blda, Clda; /* Record leading dimensions of B and C here to detect changes*/
2184d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2185d52a580bSJunchao Zhang void *dBuffer4, *dBuffer5;
2186d52a580bSJunchao Zhang #endif
2187d52a580bSJunchao Zhang size_t mmBufferSize;
2188d52a580bSJunchao Zhang void *mmBuffer, *mmBuffer2; /* SpGEMM WorkEstimation buffer */
2189d52a580bSJunchao Zhang hipsparseSpGEMMDescr_t spgemmDesc;
2190d52a580bSJunchao Zhang };
2191d52a580bSJunchao Zhang
MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data)2192d52a580bSJunchao Zhang static PetscErrorCode MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data)
2193d52a580bSJunchao Zhang {
2194d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata = *(MatProductCtx_MatMatHipsparse **)data;
2195d52a580bSJunchao Zhang
2196d52a580bSJunchao Zhang PetscFunctionBegin;
2197d52a580bSJunchao Zhang PetscCallHIP(hipFree(mmdata->Bt));
2198d52a580bSJunchao Zhang delete mmdata->Bcsr;
2199d52a580bSJunchao Zhang if (mmdata->matSpBDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mmdata->matSpBDescr));
2200d52a580bSJunchao Zhang if (mmdata->matBDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2201d52a580bSJunchao Zhang if (mmdata->matCDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2202d52a580bSJunchao Zhang if (mmdata->spgemmDesc) PetscCallHIPSPARSE(hipsparseSpGEMM_destroyDescr(mmdata->spgemmDesc));
2203d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2204d52a580bSJunchao Zhang if (mmdata->dBuffer4) PetscCallHIP(hipFree(mmdata->dBuffer4));
2205d52a580bSJunchao Zhang if (mmdata->dBuffer5) PetscCallHIP(hipFree(mmdata->dBuffer5));
2206d52a580bSJunchao Zhang #endif
2207d52a580bSJunchao Zhang if (mmdata->mmBuffer) PetscCallHIP(hipFree(mmdata->mmBuffer));
2208d52a580bSJunchao Zhang if (mmdata->mmBuffer2) PetscCallHIP(hipFree(mmdata->mmBuffer2));
2209d52a580bSJunchao Zhang PetscCall(MatDestroy(&mmdata->X));
2210d52a580bSJunchao Zhang PetscCall(PetscFree(*(void **)data));
2211d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2212d52a580bSJunchao Zhang }
2213d52a580bSJunchao Zhang
MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)2214d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2215d52a580bSJunchao Zhang {
2216d52a580bSJunchao Zhang Mat_Product *product = C->product;
2217d52a580bSJunchao Zhang Mat A, B;
2218d52a580bSJunchao Zhang PetscInt m, n, blda, clda;
2219d52a580bSJunchao Zhang PetscBool flg, biship;
2220d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp;
2221d52a580bSJunchao Zhang hipsparseOperation_t opA;
2222d52a580bSJunchao Zhang const PetscScalar *barray;
2223d52a580bSJunchao Zhang PetscScalar *carray;
2224d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata;
2225d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *mat;
2226d52a580bSJunchao Zhang CsrMatrix *csrmat;
2227d52a580bSJunchao Zhang
2228d52a580bSJunchao Zhang PetscFunctionBegin;
2229d52a580bSJunchao Zhang MatCheckProduct(C, 1);
2230d52a580bSJunchao Zhang PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2231d52a580bSJunchao Zhang mmdata = (MatProductCtx_MatMatHipsparse *)product->data;
2232d52a580bSJunchao Zhang A = product->A;
2233d52a580bSJunchao Zhang B = product->B;
2234d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2235d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2236d52a580bSJunchao Zhang /* currently CopyToGpu does not copy if the matrix is bound to CPU
2237d52a580bSJunchao Zhang Instead of silently accepting the wrong answer, I prefer to raise the error */
2238d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2239d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2240d52a580bSJunchao Zhang cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2241d52a580bSJunchao Zhang switch (product->type) {
2242d52a580bSJunchao Zhang case MATPRODUCT_AB:
2243d52a580bSJunchao Zhang case MATPRODUCT_PtAP:
2244d52a580bSJunchao Zhang mat = cusp->mat;
2245d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2246d52a580bSJunchao Zhang m = A->rmap->n;
2247d52a580bSJunchao Zhang n = B->cmap->n;
2248d52a580bSJunchao Zhang break;
2249d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2250d52a580bSJunchao Zhang if (!A->form_explicit_transpose) {
2251d52a580bSJunchao Zhang mat = cusp->mat;
2252d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_TRANSPOSE;
2253d52a580bSJunchao Zhang } else {
2254d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2255d52a580bSJunchao Zhang mat = cusp->matTranspose;
2256d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2257d52a580bSJunchao Zhang }
2258d52a580bSJunchao Zhang m = A->cmap->n;
2259d52a580bSJunchao Zhang n = B->cmap->n;
2260d52a580bSJunchao Zhang break;
2261d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2262d52a580bSJunchao Zhang case MATPRODUCT_RARt:
2263d52a580bSJunchao Zhang mat = cusp->mat;
2264d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2265d52a580bSJunchao Zhang m = A->rmap->n;
2266d52a580bSJunchao Zhang n = B->rmap->n;
2267d52a580bSJunchao Zhang break;
2268d52a580bSJunchao Zhang default:
2269d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2270d52a580bSJunchao Zhang }
2271d52a580bSJunchao Zhang PetscCheck(mat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
2272d52a580bSJunchao Zhang csrmat = (CsrMatrix *)mat->mat;
2273d52a580bSJunchao Zhang /* if the user passed a CPU matrix, copy the data to the GPU */
2274d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQDENSEHIP, &biship));
2275d52a580bSJunchao Zhang if (!biship) PetscCall(MatConvert(B, MATSEQDENSEHIP, MAT_INPLACE_MATRIX, &B));
2276d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayReadAndMemType(B, &barray, nullptr));
2277d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(B, &blda));
2278d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2279d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayWriteAndMemType(mmdata->X, &carray, nullptr));
2280d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(mmdata->X, &clda));
2281d52a580bSJunchao Zhang } else {
2282d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayWriteAndMemType(C, &carray, nullptr));
2283d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(C, &clda));
2284d52a580bSJunchao Zhang }
2285d52a580bSJunchao Zhang
2286d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
2287d52a580bSJunchao Zhang hipsparseOperation_t opB = (product->type == MATPRODUCT_ABt || product->type == MATPRODUCT_RARt) ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE;
2288d52a580bSJunchao Zhang /* (re)allocate mmBuffer if not initialized or LDAs are different */
2289d52a580bSJunchao Zhang if (!mmdata->initialized || mmdata->Blda != blda || mmdata->Clda != clda) {
2290d52a580bSJunchao Zhang size_t mmBufferSize;
2291d52a580bSJunchao Zhang if (mmdata->initialized && mmdata->Blda != blda) {
2292d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2293d52a580bSJunchao Zhang mmdata->matBDescr = NULL;
2294d52a580bSJunchao Zhang }
2295d52a580bSJunchao Zhang if (!mmdata->matBDescr) {
2296d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matBDescr, B->rmap->n, B->cmap->n, blda, (void *)barray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2297d52a580bSJunchao Zhang mmdata->Blda = blda;
2298d52a580bSJunchao Zhang }
2299d52a580bSJunchao Zhang if (mmdata->initialized && mmdata->Clda != clda) {
2300d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2301d52a580bSJunchao Zhang mmdata->matCDescr = NULL;
2302d52a580bSJunchao Zhang }
2303d52a580bSJunchao Zhang if (!mmdata->matCDescr) { /* matCDescr is for C or mmdata->X */
2304d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matCDescr, m, n, clda, (void *)carray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2305d52a580bSJunchao Zhang mmdata->Clda = clda;
2306d52a580bSJunchao Zhang }
2307d52a580bSJunchao Zhang if (!mat->matDescr) {
2308d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&mat->matDescr, csrmat->num_rows, csrmat->num_cols, csrmat->num_entries, csrmat->row_offsets->data().get(), csrmat->column_indices->data().get(), csrmat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2309d52a580bSJunchao Zhang HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2310d52a580bSJunchao Zhang }
2311d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMM_bufferSize(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, &mmBufferSize));
2312d52a580bSJunchao Zhang if ((mmdata->mmBuffer && mmdata->mmBufferSize < mmBufferSize) || !mmdata->mmBuffer) {
2313d52a580bSJunchao Zhang PetscCallHIP(hipFree(mmdata->mmBuffer));
2314d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&mmdata->mmBuffer, mmBufferSize));
2315d52a580bSJunchao Zhang mmdata->mmBufferSize = mmBufferSize;
2316d52a580bSJunchao Zhang }
2317d52a580bSJunchao Zhang mmdata->initialized = PETSC_TRUE;
2318d52a580bSJunchao Zhang } else {
2319d52a580bSJunchao Zhang /* to be safe, always update pointers of the mats */
2320d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetValues(mat->matDescr, csrmat->values->data().get()));
2321d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matBDescr, (void *)barray));
2322d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matCDescr, (void *)carray));
2323d52a580bSJunchao Zhang }
2324d52a580bSJunchao Zhang
2325d52a580bSJunchao Zhang /* do hipsparseSpMM, which supports transpose on B */
2326d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMM(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, mmdata->mmBuffer));
2327d52a580bSJunchao Zhang
2328d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
2329d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(n * 2.0 * csrmat->num_entries));
2330d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayReadAndMemType(B, &barray));
2331d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt) {
2332d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2333d52a580bSJunchao Zhang PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_FALSE, PETSC_FALSE));
2334d52a580bSJunchao Zhang } else if (product->type == MATPRODUCT_PtAP) {
2335d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2336d52a580bSJunchao Zhang PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_TRUE, PETSC_FALSE));
2337d52a580bSJunchao Zhang } else PetscCall(MatDenseRestoreArrayWriteAndMemType(C, &carray));
2338d52a580bSJunchao Zhang if (mmdata->cisdense) PetscCall(MatConvert(C, MATSEQDENSE, MAT_INPLACE_MATRIX, &C));
2339d52a580bSJunchao Zhang if (!biship) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B));
2340d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2341d52a580bSJunchao Zhang }
2342d52a580bSJunchao Zhang
MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)2343d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2344d52a580bSJunchao Zhang {
2345d52a580bSJunchao Zhang Mat_Product *product = C->product;
2346d52a580bSJunchao Zhang Mat A, B;
2347d52a580bSJunchao Zhang PetscInt m, n;
2348d52a580bSJunchao Zhang PetscBool cisdense, flg;
2349d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata;
2350d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp;
2351d52a580bSJunchao Zhang
2352d52a580bSJunchao Zhang PetscFunctionBegin;
2353d52a580bSJunchao Zhang MatCheckProduct(C, 1);
2354d52a580bSJunchao Zhang PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2355d52a580bSJunchao Zhang A = product->A;
2356d52a580bSJunchao Zhang B = product->B;
2357d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2358d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2359d52a580bSJunchao Zhang cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2360d52a580bSJunchao Zhang PetscCheck(cusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2361d52a580bSJunchao Zhang switch (product->type) {
2362d52a580bSJunchao Zhang case MATPRODUCT_AB:
2363d52a580bSJunchao Zhang m = A->rmap->n;
2364d52a580bSJunchao Zhang n = B->cmap->n;
2365d52a580bSJunchao Zhang break;
2366d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2367d52a580bSJunchao Zhang m = A->cmap->n;
2368d52a580bSJunchao Zhang n = B->cmap->n;
2369d52a580bSJunchao Zhang break;
2370d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2371d52a580bSJunchao Zhang m = A->rmap->n;
2372d52a580bSJunchao Zhang n = B->rmap->n;
2373d52a580bSJunchao Zhang break;
2374d52a580bSJunchao Zhang case MATPRODUCT_PtAP:
2375d52a580bSJunchao Zhang m = B->cmap->n;
2376d52a580bSJunchao Zhang n = B->cmap->n;
2377d52a580bSJunchao Zhang break;
2378d52a580bSJunchao Zhang case MATPRODUCT_RARt:
2379d52a580bSJunchao Zhang m = B->rmap->n;
2380d52a580bSJunchao Zhang n = B->rmap->n;
2381d52a580bSJunchao Zhang break;
2382d52a580bSJunchao Zhang default:
2383d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2384d52a580bSJunchao Zhang }
2385d52a580bSJunchao Zhang PetscCall(MatSetSizes(C, m, n, m, n));
2386d52a580bSJunchao Zhang /* if C is of type MATSEQDENSE (CPU), perform the operation on the GPU and then copy on the CPU */
2387d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQDENSE, &cisdense));
2388d52a580bSJunchao Zhang PetscCall(MatSetType(C, MATSEQDENSEHIP));
2389d52a580bSJunchao Zhang
2390d52a580bSJunchao Zhang /* product data */
2391d52a580bSJunchao Zhang PetscCall(PetscNew(&mmdata));
2392d52a580bSJunchao Zhang mmdata->cisdense = cisdense;
2393d52a580bSJunchao Zhang /* for these products we need intermediate storage */
2394d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2395d52a580bSJunchao Zhang PetscCall(MatCreate(PetscObjectComm((PetscObject)C), &mmdata->X));
2396d52a580bSJunchao Zhang PetscCall(MatSetType(mmdata->X, MATSEQDENSEHIP));
2397d52a580bSJunchao Zhang /* do not preallocate, since the first call to MatDenseHIPGetArray will preallocate on the GPU for us */
2398d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt) PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->rmap->n, A->rmap->n, B->rmap->n));
2399d52a580bSJunchao Zhang else PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->cmap->n, A->rmap->n, B->cmap->n));
2400d52a580bSJunchao Zhang }
2401d52a580bSJunchao Zhang C->product->data = mmdata;
2402d52a580bSJunchao Zhang C->product->destroy = MatProductCtxDestroy_MatMatHipsparse;
2403d52a580bSJunchao Zhang C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP;
2404d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2405d52a580bSJunchao Zhang }
2406d52a580bSJunchao Zhang
MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)2407d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2408d52a580bSJunchao Zhang {
2409d52a580bSJunchao Zhang Mat_Product *product = C->product;
2410d52a580bSJunchao Zhang Mat A, B;
2411d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp, *Bcusp, *Ccusp;
2412d52a580bSJunchao Zhang Mat_SeqAIJ *c = (Mat_SeqAIJ *)C->data;
2413d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2414d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr;
2415d52a580bSJunchao Zhang PetscBool flg;
2416d52a580bSJunchao Zhang MatProductType ptype;
2417d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata;
2418d52a580bSJunchao Zhang hipsparseSpMatDescr_t BmatSpDescr;
2419d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2420d52a580bSJunchao Zhang
2421d52a580bSJunchao Zhang PetscFunctionBegin;
2422d52a580bSJunchao Zhang MatCheckProduct(C, 1);
2423d52a580bSJunchao Zhang PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2424d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQAIJHIPSPARSE, &flg));
2425d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for C of type %s", ((PetscObject)C)->type_name);
2426d52a580bSJunchao Zhang mmdata = (MatProductCtx_MatMatHipsparse *)C->product->data;
2427d52a580bSJunchao Zhang A = product->A;
2428d52a580bSJunchao Zhang B = product->B;
2429d52a580bSJunchao Zhang if (mmdata->reusesym) { /* this happens when api_user is true, meaning that the matrix values have been already computed in the MatProductSymbolic phase */
2430d52a580bSJunchao Zhang mmdata->reusesym = PETSC_FALSE;
2431d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2432d52a580bSJunchao Zhang PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2433d52a580bSJunchao Zhang Cmat = Ccusp->mat;
2434d52a580bSJunchao Zhang PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[C->product->type]);
2435d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Cmat->mat;
2436d52a580bSJunchao Zhang PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2437d52a580bSJunchao Zhang goto finalize;
2438d52a580bSJunchao Zhang }
2439d52a580bSJunchao Zhang if (!c->nz) goto finalize;
2440d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2441d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2442d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2443d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2444d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2445d52a580bSJunchao Zhang PetscCheck(!B->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2446d52a580bSJunchao Zhang Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2447d52a580bSJunchao Zhang Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2448d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2449d52a580bSJunchao Zhang PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2450d52a580bSJunchao Zhang PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2451d52a580bSJunchao Zhang PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2452d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2453d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2454d52a580bSJunchao Zhang
2455d52a580bSJunchao Zhang ptype = product->type;
2456d52a580bSJunchao Zhang if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2457d52a580bSJunchao Zhang ptype = MATPRODUCT_AB;
2458d52a580bSJunchao Zhang PetscCheck(product->symbolic_used_the_fact_A_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that A is symmetric");
2459d52a580bSJunchao Zhang }
2460d52a580bSJunchao Zhang if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2461d52a580bSJunchao Zhang ptype = MATPRODUCT_AB;
2462d52a580bSJunchao Zhang PetscCheck(product->symbolic_used_the_fact_B_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that B is symmetric");
2463d52a580bSJunchao Zhang }
2464d52a580bSJunchao Zhang switch (ptype) {
2465d52a580bSJunchao Zhang case MATPRODUCT_AB:
2466d52a580bSJunchao Zhang Amat = Acusp->mat;
2467d52a580bSJunchao Zhang Bmat = Bcusp->mat;
2468d52a580bSJunchao Zhang break;
2469d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2470d52a580bSJunchao Zhang Amat = Acusp->matTranspose;
2471d52a580bSJunchao Zhang Bmat = Bcusp->mat;
2472d52a580bSJunchao Zhang break;
2473d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2474d52a580bSJunchao Zhang Amat = Acusp->mat;
2475d52a580bSJunchao Zhang Bmat = Bcusp->matTranspose;
2476d52a580bSJunchao Zhang break;
2477d52a580bSJunchao Zhang default:
2478d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2479d52a580bSJunchao Zhang }
2480d52a580bSJunchao Zhang Cmat = Ccusp->mat;
2481d52a580bSJunchao Zhang PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2482d52a580bSJunchao Zhang PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2483d52a580bSJunchao Zhang PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[ptype]);
2484d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Amat->mat;
2485d52a580bSJunchao Zhang Bcsr = mmdata->Bcsr ? mmdata->Bcsr : (CsrMatrix *)Bmat->mat; /* B may be in compressed row storage */
2486d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Cmat->mat;
2487d52a580bSJunchao Zhang PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2488d52a580bSJunchao Zhang PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2489d52a580bSJunchao Zhang PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2490d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
2491d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2492d52a580bSJunchao Zhang BmatSpDescr = mmdata->Bcsr ? mmdata->matSpBDescr : Bmat->matDescr; /* B may be in compressed row storage */
2493d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2494d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2495d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2496d52a580bSJunchao Zhang #else
2497d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2498d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2499d52a580bSJunchao Zhang #endif
2500d52a580bSJunchao Zhang #else
2501d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2502d52a580bSJunchao Zhang Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2503d52a580bSJunchao Zhang Ccsr->column_indices->data().get()));
2504d52a580bSJunchao Zhang #endif
2505d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(mmdata->flops));
2506d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP());
2507d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
2508d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_GPU;
2509d52a580bSJunchao Zhang finalize:
2510d52a580bSJunchao Zhang /* shorter version of MatAssemblyEnd_SeqAIJ */
2511d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded, %" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
2512d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
2513d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
2514d52a580bSJunchao Zhang c->reallocs = 0;
2515d52a580bSJunchao Zhang C->info.mallocs += 0;
2516d52a580bSJunchao Zhang C->info.nz_unneeded = 0;
2517d52a580bSJunchao Zhang C->assembled = C->was_assembled = PETSC_TRUE;
2518d52a580bSJunchao Zhang C->num_ass++;
2519d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2520d52a580bSJunchao Zhang }
2521d52a580bSJunchao Zhang
MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)2522d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2523d52a580bSJunchao Zhang {
2524d52a580bSJunchao Zhang Mat_Product *product = C->product;
2525d52a580bSJunchao Zhang Mat A, B;
2526d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp, *Bcusp, *Ccusp;
2527d52a580bSJunchao Zhang Mat_SeqAIJ *a, *b, *c;
2528d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2529d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr;
2530d52a580bSJunchao Zhang PetscInt i, j, m, n, k;
2531d52a580bSJunchao Zhang PetscBool flg;
2532d52a580bSJunchao Zhang MatProductType ptype;
2533d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata;
2534d52a580bSJunchao Zhang PetscLogDouble flops;
2535d52a580bSJunchao Zhang PetscBool biscompressed, ciscompressed;
2536d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2537d52a580bSJunchao Zhang int64_t C_num_rows1, C_num_cols1, C_nnz1;
2538d52a580bSJunchao Zhang hipsparseSpMatDescr_t BmatSpDescr;
2539d52a580bSJunchao Zhang #else
2540d52a580bSJunchao Zhang int cnz;
2541d52a580bSJunchao Zhang #endif
2542d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2543d52a580bSJunchao Zhang
2544d52a580bSJunchao Zhang PetscFunctionBegin;
2545d52a580bSJunchao Zhang MatCheckProduct(C, 1);
2546d52a580bSJunchao Zhang PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2547d52a580bSJunchao Zhang A = product->A;
2548d52a580bSJunchao Zhang B = product->B;
2549d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2550d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2551d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2552d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2553d52a580bSJunchao Zhang a = (Mat_SeqAIJ *)A->data;
2554d52a580bSJunchao Zhang b = (Mat_SeqAIJ *)B->data;
2555d52a580bSJunchao Zhang /* product data */
2556d52a580bSJunchao Zhang PetscCall(PetscNew(&mmdata));
2557d52a580bSJunchao Zhang C->product->data = mmdata;
2558d52a580bSJunchao Zhang C->product->destroy = MatProductCtxDestroy_MatMatHipsparse;
2559d52a580bSJunchao Zhang
2560d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2561d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2562d52a580bSJunchao Zhang Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; /* Access spptr after MatSeqAIJHIPSPARSECopyToGPU, not before */
2563d52a580bSJunchao Zhang Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2564d52a580bSJunchao Zhang PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2565d52a580bSJunchao Zhang PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2566d52a580bSJunchao Zhang
2567d52a580bSJunchao Zhang ptype = product->type;
2568d52a580bSJunchao Zhang if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2569d52a580bSJunchao Zhang ptype = MATPRODUCT_AB;
2570d52a580bSJunchao Zhang product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
2571d52a580bSJunchao Zhang }
2572d52a580bSJunchao Zhang if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2573d52a580bSJunchao Zhang ptype = MATPRODUCT_AB;
2574d52a580bSJunchao Zhang product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
2575d52a580bSJunchao Zhang }
2576d52a580bSJunchao Zhang biscompressed = PETSC_FALSE;
2577d52a580bSJunchao Zhang ciscompressed = PETSC_FALSE;
2578d52a580bSJunchao Zhang switch (ptype) {
2579d52a580bSJunchao Zhang case MATPRODUCT_AB:
2580d52a580bSJunchao Zhang m = A->rmap->n;
2581d52a580bSJunchao Zhang n = B->cmap->n;
2582d52a580bSJunchao Zhang k = A->cmap->n;
2583d52a580bSJunchao Zhang Amat = Acusp->mat;
2584d52a580bSJunchao Zhang Bmat = Bcusp->mat;
2585d52a580bSJunchao Zhang if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2586d52a580bSJunchao Zhang if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2587d52a580bSJunchao Zhang break;
2588d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2589d52a580bSJunchao Zhang m = A->cmap->n;
2590d52a580bSJunchao Zhang n = B->cmap->n;
2591d52a580bSJunchao Zhang k = A->rmap->n;
2592d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2593d52a580bSJunchao Zhang Amat = Acusp->matTranspose;
2594d52a580bSJunchao Zhang Bmat = Bcusp->mat;
2595d52a580bSJunchao Zhang if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2596d52a580bSJunchao Zhang break;
2597d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2598d52a580bSJunchao Zhang m = A->rmap->n;
2599d52a580bSJunchao Zhang n = B->rmap->n;
2600d52a580bSJunchao Zhang k = A->cmap->n;
2601d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
2602d52a580bSJunchao Zhang Amat = Acusp->mat;
2603d52a580bSJunchao Zhang Bmat = Bcusp->matTranspose;
2604d52a580bSJunchao Zhang if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2605d52a580bSJunchao Zhang break;
2606d52a580bSJunchao Zhang default:
2607d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2608d52a580bSJunchao Zhang }
2609d52a580bSJunchao Zhang
2610d52a580bSJunchao Zhang /* create hipsparse matrix */
2611d52a580bSJunchao Zhang PetscCall(MatSetSizes(C, m, n, m, n));
2612d52a580bSJunchao Zhang PetscCall(MatSetType(C, MATSEQAIJHIPSPARSE));
2613d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)C->data;
2614d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2615d52a580bSJunchao Zhang Cmat = new Mat_SeqAIJHIPSPARSEMultStruct;
2616d52a580bSJunchao Zhang Ccsr = new CsrMatrix;
2617d52a580bSJunchao Zhang
2618d52a580bSJunchao Zhang c->compressedrow.use = ciscompressed;
2619d52a580bSJunchao Zhang if (c->compressedrow.use) { /* if a is in compressed row, than c will be in compressed row format */
2620d52a580bSJunchao Zhang c->compressedrow.nrows = a->compressedrow.nrows;
2621d52a580bSJunchao Zhang PetscCall(PetscMalloc2(c->compressedrow.nrows + 1, &c->compressedrow.i, c->compressedrow.nrows, &c->compressedrow.rindex));
2622d52a580bSJunchao Zhang PetscCall(PetscArraycpy(c->compressedrow.rindex, a->compressedrow.rindex, c->compressedrow.nrows));
2623d52a580bSJunchao Zhang Ccusp->workVector = new THRUSTARRAY(c->compressedrow.nrows);
2624d52a580bSJunchao Zhang Cmat->cprowIndices = new THRUSTINTARRAY(c->compressedrow.nrows);
2625d52a580bSJunchao Zhang Cmat->cprowIndices->assign(c->compressedrow.rindex, c->compressedrow.rindex + c->compressedrow.nrows);
2626d52a580bSJunchao Zhang } else {
2627d52a580bSJunchao Zhang c->compressedrow.nrows = 0;
2628d52a580bSJunchao Zhang c->compressedrow.i = NULL;
2629d52a580bSJunchao Zhang c->compressedrow.rindex = NULL;
2630d52a580bSJunchao Zhang Ccusp->workVector = NULL;
2631d52a580bSJunchao Zhang Cmat->cprowIndices = NULL;
2632d52a580bSJunchao Zhang }
2633d52a580bSJunchao Zhang Ccusp->nrows = ciscompressed ? c->compressedrow.nrows : m;
2634d52a580bSJunchao Zhang Ccusp->mat = Cmat;
2635d52a580bSJunchao Zhang Ccusp->mat->mat = Ccsr;
2636d52a580bSJunchao Zhang Ccsr->num_rows = Ccusp->nrows;
2637d52a580bSJunchao Zhang Ccsr->num_cols = n;
2638d52a580bSJunchao Zhang Ccsr->row_offsets = new THRUSTINTARRAY32(Ccusp->nrows + 1);
2639d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
2640d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
2641d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2642d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
2643d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
2644d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
2645d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2646d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2647d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2648d52a580bSJunchao Zhang if (!Ccsr->num_rows || !Ccsr->num_cols || !a->nz || !b->nz) { /* hipsparse raise errors in different calls when matrices have zero rows/columns! */
2649d52a580bSJunchao Zhang thrust::fill(thrust::device, Ccsr->row_offsets->begin(), Ccsr->row_offsets->end(), 0);
2650d52a580bSJunchao Zhang c->nz = 0;
2651d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2652d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz);
2653d52a580bSJunchao Zhang goto finalizesym;
2654d52a580bSJunchao Zhang }
2655d52a580bSJunchao Zhang
2656d52a580bSJunchao Zhang PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2657d52a580bSJunchao Zhang PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2658d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Amat->mat;
2659d52a580bSJunchao Zhang if (!biscompressed) {
2660d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bmat->mat;
2661d52a580bSJunchao Zhang BmatSpDescr = Bmat->matDescr;
2662d52a580bSJunchao Zhang } else { /* we need to use row offsets for the full matrix */
2663d52a580bSJunchao Zhang CsrMatrix *cBcsr = (CsrMatrix *)Bmat->mat;
2664d52a580bSJunchao Zhang Bcsr = new CsrMatrix;
2665d52a580bSJunchao Zhang Bcsr->num_rows = B->rmap->n;
2666d52a580bSJunchao Zhang Bcsr->num_cols = cBcsr->num_cols;
2667d52a580bSJunchao Zhang Bcsr->num_entries = cBcsr->num_entries;
2668d52a580bSJunchao Zhang Bcsr->column_indices = cBcsr->column_indices;
2669d52a580bSJunchao Zhang Bcsr->values = cBcsr->values;
2670d52a580bSJunchao Zhang if (!Bcusp->rowoffsets_gpu) {
2671d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
2672d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
2673d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
2674d52a580bSJunchao Zhang }
2675d52a580bSJunchao Zhang Bcsr->row_offsets = Bcusp->rowoffsets_gpu;
2676d52a580bSJunchao Zhang mmdata->Bcsr = Bcsr;
2677d52a580bSJunchao Zhang if (Bcsr->num_rows && Bcsr->num_cols) {
2678d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&mmdata->matSpBDescr, Bcsr->num_rows, Bcsr->num_cols, Bcsr->num_entries, Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Bcsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2679d52a580bSJunchao Zhang }
2680d52a580bSJunchao Zhang BmatSpDescr = mmdata->matSpBDescr;
2681d52a580bSJunchao Zhang }
2682d52a580bSJunchao Zhang PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2683d52a580bSJunchao Zhang PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2684d52a580bSJunchao Zhang /* precompute flops count */
2685d52a580bSJunchao Zhang if (ptype == MATPRODUCT_AB) {
2686d52a580bSJunchao Zhang for (i = 0, flops = 0; i < A->rmap->n; i++) {
2687d52a580bSJunchao Zhang const PetscInt st = a->i[i];
2688d52a580bSJunchao Zhang const PetscInt en = a->i[i + 1];
2689d52a580bSJunchao Zhang for (j = st; j < en; j++) {
2690d52a580bSJunchao Zhang const PetscInt brow = a->j[j];
2691d52a580bSJunchao Zhang flops += 2. * (b->i[brow + 1] - b->i[brow]);
2692d52a580bSJunchao Zhang }
2693d52a580bSJunchao Zhang }
2694d52a580bSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) {
2695d52a580bSJunchao Zhang for (i = 0, flops = 0; i < A->rmap->n; i++) {
2696d52a580bSJunchao Zhang const PetscInt anzi = a->i[i + 1] - a->i[i];
2697d52a580bSJunchao Zhang const PetscInt bnzi = b->i[i + 1] - b->i[i];
2698d52a580bSJunchao Zhang flops += (2. * anzi) * bnzi;
2699d52a580bSJunchao Zhang }
2700d52a580bSJunchao Zhang } else flops = 0.; /* TODO */
2701d52a580bSJunchao Zhang
2702d52a580bSJunchao Zhang mmdata->flops = flops;
2703d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
2704d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2705d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2706d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, 0, Ccsr->row_offsets->data().get(), NULL, NULL, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2707d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_createDescr(&mmdata->spgemmDesc));
2708d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2709d52a580bSJunchao Zhang {
2710d52a580bSJunchao Zhang /* hipsparseSpGEMMreuse has more reasonable APIs than hipsparseSpGEMM, so we prefer to use it.
2711d52a580bSJunchao Zhang We follow the sample code at https://github.com/ROCmSoftwarePlatform/hipSPARSE/blob/develop/clients/include/testing_spgemmreuse_csr.hpp
2712d52a580bSJunchao Zhang */
2713d52a580bSJunchao Zhang void *dBuffer1 = NULL;
2714d52a580bSJunchao Zhang void *dBuffer2 = NULL;
2715d52a580bSJunchao Zhang void *dBuffer3 = NULL;
2716d52a580bSJunchao Zhang /* dBuffer4, dBuffer5 are needed by hipsparseSpGEMMreuse_compute, and therefore are stored in mmdata */
2717d52a580bSJunchao Zhang size_t bufferSize1 = 0;
2718d52a580bSJunchao Zhang size_t bufferSize2 = 0;
2719d52a580bSJunchao Zhang size_t bufferSize3 = 0;
2720d52a580bSJunchao Zhang size_t bufferSize4 = 0;
2721d52a580bSJunchao Zhang size_t bufferSize5 = 0;
2722d52a580bSJunchao Zhang
2723d52a580bSJunchao Zhang /* ask bufferSize1 bytes for external memory */
2724d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, NULL));
2725d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer1, bufferSize1));
2726d52a580bSJunchao Zhang /* inspect the matrices A and B to understand the memory requirement for the next step */
2727d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, dBuffer1));
2728d52a580bSJunchao Zhang
2729d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, NULL, &bufferSize3, NULL, &bufferSize4, NULL));
2730d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer2, bufferSize2));
2731d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer3, bufferSize3));
2732d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer4, bufferSize4));
2733d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, dBuffer2, &bufferSize3, dBuffer3, &bufferSize4, mmdata->dBuffer4));
2734d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer1));
2735d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer2));
2736d52a580bSJunchao Zhang
2737d52a580bSJunchao Zhang /* get matrix C non-zero entries C_nnz1 */
2738d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2739d52a580bSJunchao Zhang c->nz = (PetscInt)C_nnz1;
2740d52a580bSJunchao Zhang /* allocate matrix C */
2741d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2742d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2743d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz);
2744d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2745d52a580bSJunchao Zhang /* update matC with the new pointers */
2746d52a580bSJunchao Zhang if (c->nz) { /* 5.5.1 has a bug with nz = 0, exposed by mat_tests_ex123_2_hypre */
2747d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2748d52a580bSJunchao Zhang
2749d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, NULL));
2750d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer5, bufferSize5));
2751d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, mmdata->dBuffer5));
2752d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer3));
2753d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2754d52a580bSJunchao Zhang }
2755d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufferSize4 / 1024, bufferSize5 / 1024));
2756d52a580bSJunchao Zhang }
2757d52a580bSJunchao Zhang #else
2758d52a580bSJunchao Zhang size_t bufSize2;
2759d52a580bSJunchao Zhang /* ask bufferSize bytes for external memory */
2760d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, NULL));
2761d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer2, bufSize2));
2762d52a580bSJunchao Zhang /* inspect the matrices A and B to understand the memory requirement for the next step */
2763d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, mmdata->mmBuffer2));
2764d52a580bSJunchao Zhang /* ask bufferSize again bytes for external memory */
2765d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, NULL));
2766d52a580bSJunchao Zhang /* Similar to CUSPARSE, we need both buffers to perform the operations properly!
2767d52a580bSJunchao Zhang mmdata->mmBuffer2 does not appear anywhere in the compute/copy API
2768d52a580bSJunchao Zhang it only appears for the workEstimation stuff, but it seems it is needed in compute, so probably the address
2769d52a580bSJunchao Zhang is stored in the descriptor! What a messy API... */
2770d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer, mmdata->mmBufferSize));
2771d52a580bSJunchao Zhang /* compute the intermediate product of A * B */
2772d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2773d52a580bSJunchao Zhang /* get matrix C non-zero entries C_nnz1 */
2774d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2775d52a580bSJunchao Zhang c->nz = (PetscInt)C_nnz1;
2776d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufSize2 / 1024,
2777d52a580bSJunchao Zhang mmdata->mmBufferSize / 1024));
2778d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2779d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2780d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz);
2781d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2782d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2783d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2784d52a580bSJunchao Zhang #endif
2785d52a580bSJunchao Zhang #else
2786d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_HOST));
2787d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrgemmNnz(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, Bcsr->num_entries,
2788d52a580bSJunchao Zhang Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->row_offsets->data().get(), &cnz));
2789d52a580bSJunchao Zhang c->nz = cnz;
2790d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2791d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2792d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz);
2793d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2794d52a580bSJunchao Zhang
2795d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2796d52a580bSJunchao Zhang /* with the old gemm interface (removed from 11.0 on) we cannot compute the symbolic factorization only.
2797d52a580bSJunchao Zhang I have tried using the gemm2 interface (alpha * A * B + beta * D), which allows to do symbolic by passing NULL for values, but it seems quite buggy when
2798d52a580bSJunchao Zhang D is NULL, despite the fact that CUSPARSE documentation claims it is supported! */
2799d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2800d52a580bSJunchao Zhang Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2801d52a580bSJunchao Zhang Ccsr->column_indices->data().get()));
2802d52a580bSJunchao Zhang #endif
2803d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(mmdata->flops));
2804d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
2805d52a580bSJunchao Zhang finalizesym:
2806d52a580bSJunchao Zhang c->free_a = PETSC_TRUE;
2807d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
2808d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
2809d52a580bSJunchao Zhang c->free_ij = PETSC_TRUE;
2810d52a580bSJunchao Zhang if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
2811d52a580bSJunchao Zhang PetscInt *d_i = c->i;
2812d52a580bSJunchao Zhang THRUSTINTARRAY ii(Ccsr->row_offsets->size());
2813d52a580bSJunchao Zhang THRUSTINTARRAY jj(Ccsr->column_indices->size());
2814d52a580bSJunchao Zhang ii = *Ccsr->row_offsets;
2815d52a580bSJunchao Zhang jj = *Ccsr->column_indices;
2816d52a580bSJunchao Zhang if (ciscompressed) d_i = c->compressedrow.i;
2817d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(d_i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2818d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2819d52a580bSJunchao Zhang } else {
2820d52a580bSJunchao Zhang PetscInt *d_i = c->i;
2821d52a580bSJunchao Zhang if (ciscompressed) d_i = c->compressedrow.i;
2822d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(d_i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2823d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2824d52a580bSJunchao Zhang }
2825d52a580bSJunchao Zhang if (ciscompressed) { /* need to expand host row offsets */
2826d52a580bSJunchao Zhang PetscInt r = 0;
2827d52a580bSJunchao Zhang c->i[0] = 0;
2828d52a580bSJunchao Zhang for (k = 0; k < c->compressedrow.nrows; k++) {
2829d52a580bSJunchao Zhang const PetscInt next = c->compressedrow.rindex[k];
2830d52a580bSJunchao Zhang const PetscInt old = c->compressedrow.i[k];
2831d52a580bSJunchao Zhang for (; r < next; r++) c->i[r + 1] = old;
2832d52a580bSJunchao Zhang }
2833d52a580bSJunchao Zhang for (; r < m; r++) c->i[r + 1] = c->compressedrow.i[c->compressedrow.nrows];
2834d52a580bSJunchao Zhang }
2835d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
2836d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->ilen));
2837d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->imax));
2838d52a580bSJunchao Zhang c->maxnz = c->nz;
2839d52a580bSJunchao Zhang c->nonzerorowcnt = 0;
2840d52a580bSJunchao Zhang c->rmax = 0;
2841d52a580bSJunchao Zhang for (k = 0; k < m; k++) {
2842d52a580bSJunchao Zhang const PetscInt nn = c->i[k + 1] - c->i[k];
2843d52a580bSJunchao Zhang c->ilen[k] = c->imax[k] = nn;
2844d52a580bSJunchao Zhang c->nonzerorowcnt += (PetscInt)!!nn;
2845d52a580bSJunchao Zhang c->rmax = PetscMax(c->rmax, nn);
2846d52a580bSJunchao Zhang }
2847d52a580bSJunchao Zhang PetscCall(PetscMalloc1(c->nz, &c->a));
2848d52a580bSJunchao Zhang Ccsr->num_entries = c->nz;
2849d52a580bSJunchao Zhang
2850d52a580bSJunchao Zhang C->nonzerostate++;
2851d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp(C->rmap));
2852d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp(C->cmap));
2853d52a580bSJunchao Zhang Ccusp->nonzerostate = C->nonzerostate;
2854d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
2855d52a580bSJunchao Zhang C->preallocated = PETSC_TRUE;
2856d52a580bSJunchao Zhang C->assembled = PETSC_FALSE;
2857d52a580bSJunchao Zhang C->was_assembled = PETSC_FALSE;
2858d52a580bSJunchao Zhang if (product->api_user && A->offloadmask == PETSC_OFFLOAD_BOTH && B->offloadmask == PETSC_OFFLOAD_BOTH) { /* flag the matrix C values as computed, so that the numeric phase will only call MatAssembly */
2859d52a580bSJunchao Zhang mmdata->reusesym = PETSC_TRUE;
2860d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_GPU;
2861d52a580bSJunchao Zhang }
2862d52a580bSJunchao Zhang C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2863d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2864d52a580bSJunchao Zhang }
2865d52a580bSJunchao Zhang
2866d52a580bSJunchao Zhang /* handles sparse or dense B */
MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat)2867d52a580bSJunchao Zhang static PetscErrorCode MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat)
2868d52a580bSJunchao Zhang {
2869d52a580bSJunchao Zhang Mat_Product *product = mat->product;
2870d52a580bSJunchao Zhang PetscBool isdense = PETSC_FALSE, Biscusp = PETSC_FALSE, Ciscusp = PETSC_TRUE;
2871d52a580bSJunchao Zhang
2872d52a580bSJunchao Zhang PetscFunctionBegin;
2873d52a580bSJunchao Zhang MatCheckProduct(mat, 1);
2874d52a580bSJunchao Zhang PetscCall(PetscObjectBaseTypeCompare((PetscObject)product->B, MATSEQDENSE, &isdense));
2875d52a580bSJunchao Zhang if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJHIPSPARSE, &Biscusp));
2876d52a580bSJunchao Zhang if (product->type == MATPRODUCT_ABC) {
2877d52a580bSJunchao Zhang Ciscusp = PETSC_FALSE;
2878d52a580bSJunchao Zhang if (!product->C->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJHIPSPARSE, &Ciscusp));
2879d52a580bSJunchao Zhang }
2880d52a580bSJunchao Zhang if (Biscusp && Ciscusp) { /* we can always select the CPU backend */
2881d52a580bSJunchao Zhang PetscBool usecpu = PETSC_FALSE;
2882d52a580bSJunchao Zhang switch (product->type) {
2883d52a580bSJunchao Zhang case MATPRODUCT_AB:
2884d52a580bSJunchao Zhang if (product->api_user) {
2885d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
2886d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2887d52a580bSJunchao Zhang PetscOptionsEnd();
2888d52a580bSJunchao Zhang } else {
2889d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
2890d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2891d52a580bSJunchao Zhang PetscOptionsEnd();
2892d52a580bSJunchao Zhang }
2893d52a580bSJunchao Zhang break;
2894d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2895d52a580bSJunchao Zhang if (product->api_user) {
2896d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
2897d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2898d52a580bSJunchao Zhang PetscOptionsEnd();
2899d52a580bSJunchao Zhang } else {
2900d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
2901d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2902d52a580bSJunchao Zhang PetscOptionsEnd();
2903d52a580bSJunchao Zhang }
2904d52a580bSJunchao Zhang break;
2905d52a580bSJunchao Zhang case MATPRODUCT_PtAP:
2906d52a580bSJunchao Zhang if (product->api_user) {
2907d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
2908d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2909d52a580bSJunchao Zhang PetscOptionsEnd();
2910d52a580bSJunchao Zhang } else {
2911d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
2912d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2913d52a580bSJunchao Zhang PetscOptionsEnd();
2914d52a580bSJunchao Zhang }
2915d52a580bSJunchao Zhang break;
2916d52a580bSJunchao Zhang case MATPRODUCT_RARt:
2917d52a580bSJunchao Zhang if (product->api_user) {
2918d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatRARt", "Mat");
2919d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matrart_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2920d52a580bSJunchao Zhang PetscOptionsEnd();
2921d52a580bSJunchao Zhang } else {
2922d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_RARt", "Mat");
2923d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2924d52a580bSJunchao Zhang PetscOptionsEnd();
2925d52a580bSJunchao Zhang }
2926d52a580bSJunchao Zhang break;
2927d52a580bSJunchao Zhang case MATPRODUCT_ABC:
2928d52a580bSJunchao Zhang if (product->api_user) {
2929d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMatMult", "Mat");
2930d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matmatmatmult_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2931d52a580bSJunchao Zhang PetscOptionsEnd();
2932d52a580bSJunchao Zhang } else {
2933d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_ABC", "Mat");
2934d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2935d52a580bSJunchao Zhang PetscOptionsEnd();
2936d52a580bSJunchao Zhang }
2937d52a580bSJunchao Zhang break;
2938d52a580bSJunchao Zhang default:
2939d52a580bSJunchao Zhang break;
2940d52a580bSJunchao Zhang }
2941d52a580bSJunchao Zhang if (usecpu) Biscusp = Ciscusp = PETSC_FALSE;
2942d52a580bSJunchao Zhang }
2943d52a580bSJunchao Zhang /* dispatch */
2944d52a580bSJunchao Zhang if (isdense) {
2945d52a580bSJunchao Zhang switch (product->type) {
2946d52a580bSJunchao Zhang case MATPRODUCT_AB:
2947d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2948d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2949d52a580bSJunchao Zhang case MATPRODUCT_PtAP:
2950d52a580bSJunchao Zhang case MATPRODUCT_RARt:
2951d52a580bSJunchao Zhang if (product->A->boundtocpu) PetscCall(MatProductSetFromOptions_SeqAIJ_SeqDense(mat));
2952d52a580bSJunchao Zhang else mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP;
2953d52a580bSJunchao Zhang break;
2954d52a580bSJunchao Zhang case MATPRODUCT_ABC:
2955d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2956d52a580bSJunchao Zhang break;
2957d52a580bSJunchao Zhang default:
2958d52a580bSJunchao Zhang break;
2959d52a580bSJunchao Zhang }
2960d52a580bSJunchao Zhang } else if (Biscusp && Ciscusp) {
2961d52a580bSJunchao Zhang switch (product->type) {
2962d52a580bSJunchao Zhang case MATPRODUCT_AB:
2963d52a580bSJunchao Zhang case MATPRODUCT_AtB:
2964d52a580bSJunchao Zhang case MATPRODUCT_ABt:
2965d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2966d52a580bSJunchao Zhang break;
2967d52a580bSJunchao Zhang case MATPRODUCT_PtAP:
2968d52a580bSJunchao Zhang case MATPRODUCT_RARt:
2969d52a580bSJunchao Zhang case MATPRODUCT_ABC:
2970d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2971d52a580bSJunchao Zhang break;
2972d52a580bSJunchao Zhang default:
2973d52a580bSJunchao Zhang break;
2974d52a580bSJunchao Zhang }
2975d52a580bSJunchao Zhang } else PetscCall(MatProductSetFromOptions_SeqAIJ(mat)); /* fallback for AIJ */
2976d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2977d52a580bSJunchao Zhang }
2978d52a580bSJunchao Zhang
MatMult_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)2979d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2980d52a580bSJunchao Zhang {
2981d52a580bSJunchao Zhang PetscFunctionBegin;
2982d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_FALSE, PETSC_FALSE));
2983d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2984d52a580bSJunchao Zhang }
2985d52a580bSJunchao Zhang
MatMultAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)2986d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
2987d52a580bSJunchao Zhang {
2988d52a580bSJunchao Zhang PetscFunctionBegin;
2989d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_FALSE, PETSC_FALSE));
2990d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2991d52a580bSJunchao Zhang }
2992d52a580bSJunchao Zhang
MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)2993d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2994d52a580bSJunchao Zhang {
2995d52a580bSJunchao Zhang PetscFunctionBegin;
2996d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_TRUE));
2997d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
2998d52a580bSJunchao Zhang }
2999d52a580bSJunchao Zhang
MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)3000d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3001d52a580bSJunchao Zhang {
3002d52a580bSJunchao Zhang PetscFunctionBegin;
3003d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_TRUE));
3004d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3005d52a580bSJunchao Zhang }
3006d52a580bSJunchao Zhang
MatMultTranspose_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)3007d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
3008d52a580bSJunchao Zhang {
3009d52a580bSJunchao Zhang PetscFunctionBegin;
3010d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_FALSE));
3011d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3012d52a580bSJunchao Zhang }
3013d52a580bSJunchao Zhang
ScatterAdd(PetscInt n,PetscInt * idx,const PetscScalar * x,PetscScalar * y)3014d52a580bSJunchao Zhang __global__ static void ScatterAdd(PetscInt n, PetscInt *idx, const PetscScalar *x, PetscScalar *y)
3015d52a580bSJunchao Zhang {
3016d52a580bSJunchao Zhang int i = blockIdx.x * blockDim.x + threadIdx.x;
3017d52a580bSJunchao Zhang if (i < n) y[idx[i]] += x[i];
3018d52a580bSJunchao Zhang }
3019d52a580bSJunchao Zhang
3020d52a580bSJunchao Zhang /* z = op(A) x + y. If trans & !herm, op = ^T; if trans & herm, op = ^H; if !trans, op = no-op */
MatMultAddKernel_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz,PetscBool trans,PetscBool herm)3021d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz, PetscBool trans, PetscBool herm)
3022d52a580bSJunchao Zhang {
3023d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3024d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3025d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct;
3026d52a580bSJunchao Zhang PetscScalar *xarray, *zarray, *dptr, *beta, *xptr;
3027d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
3028d52a580bSJunchao Zhang PetscBool compressed;
3029d52a580bSJunchao Zhang PetscInt nx, ny;
3030d52a580bSJunchao Zhang
3031d52a580bSJunchao Zhang PetscFunctionBegin;
3032d52a580bSJunchao Zhang PetscCheck(!herm || trans, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Hermitian and not transpose not supported");
3033d52a580bSJunchao Zhang if (!a->nz) {
3034d52a580bSJunchao Zhang if (yy) PetscCall(VecSeq_HIP::Copy(yy, zz));
3035d52a580bSJunchao Zhang else PetscCall(VecSeq_HIP::Set(zz, 0));
3036d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3037d52a580bSJunchao Zhang }
3038d52a580bSJunchao Zhang /* The line below is necessary due to the operations that modify the matrix on the CPU (axpy, scale, etc) */
3039d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3040d52a580bSJunchao Zhang if (!trans) {
3041d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3042d52a580bSJunchao Zhang PetscCheck(matstruct, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "SeqAIJHIPSPARSE does not have a 'mat' (need to fix)");
3043d52a580bSJunchao Zhang } else {
3044d52a580bSJunchao Zhang if (herm || !A->form_explicit_transpose) {
3045d52a580bSJunchao Zhang opA = herm ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE : HIPSPARSE_OPERATION_TRANSPOSE;
3046d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3047d52a580bSJunchao Zhang } else {
3048d52a580bSJunchao Zhang if (!hipsparsestruct->matTranspose) PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
3049d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
3050d52a580bSJunchao Zhang }
3051d52a580bSJunchao Zhang }
3052d52a580bSJunchao Zhang /* Does the matrix use compressed rows (i.e., drop zero rows)? */
3053d52a580bSJunchao Zhang compressed = matstruct->cprowIndices ? PETSC_TRUE : PETSC_FALSE;
3054d52a580bSJunchao Zhang try {
3055d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(xx, (const PetscScalar **)&xarray));
3056d52a580bSJunchao Zhang if (yy == zz) PetscCall(VecHIPGetArray(zz, &zarray)); /* read & write zz, so need to get up-to-date zarray on GPU */
3057d52a580bSJunchao Zhang else PetscCall(VecHIPGetArrayWrite(zz, &zarray)); /* write zz, so no need to init zarray on GPU */
3058d52a580bSJunchao Zhang
3059d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3060d52a580bSJunchao Zhang if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3061d52a580bSJunchao Zhang /* z = A x + beta y.
3062d52a580bSJunchao Zhang If A is compressed (with less rows), then Ax is shorter than the full z, so we need a work vector to store Ax.
3063d52a580bSJunchao Zhang When A is non-compressed, and z = y, we can set beta=1 to compute y = Ax + y in one call.
3064d52a580bSJunchao Zhang */
3065d52a580bSJunchao Zhang xptr = xarray;
3066d52a580bSJunchao Zhang dptr = compressed ? hipsparsestruct->workVector->data().get() : zarray;
3067d52a580bSJunchao Zhang beta = (yy == zz && !compressed) ? matstruct->beta_one : matstruct->beta_zero;
3068d52a580bSJunchao Zhang /* Get length of x, y for y=Ax. ny might be shorter than the work vector's allocated length, since the work vector is
3069d52a580bSJunchao Zhang allocated to accommodate different uses. So we get the length info directly from mat.
3070d52a580bSJunchao Zhang */
3071d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3072d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3073d52a580bSJunchao Zhang nx = mat->num_cols;
3074d52a580bSJunchao Zhang ny = mat->num_rows;
3075d52a580bSJunchao Zhang }
3076d52a580bSJunchao Zhang } else {
3077d52a580bSJunchao Zhang /* z = A^T x + beta y
3078d52a580bSJunchao Zhang If A is compressed, then we need a work vector as the shorter version of x to compute A^T x.
3079d52a580bSJunchao Zhang Note A^Tx is of full length, so we set beta to 1.0 if y exists.
3080d52a580bSJunchao Zhang */
3081d52a580bSJunchao Zhang xptr = compressed ? hipsparsestruct->workVector->data().get() : xarray;
3082d52a580bSJunchao Zhang dptr = zarray;
3083d52a580bSJunchao Zhang beta = yy ? matstruct->beta_one : matstruct->beta_zero;
3084d52a580bSJunchao Zhang if (compressed) { /* Scatter x to work vector */
3085d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xarr = thrust::device_pointer_cast(xarray);
3086d52a580bSJunchao Zhang thrust::for_each(
3087d52a580bSJunchao Zhang #if PetscDefined(HAVE_THRUST_ASYNC)
3088d52a580bSJunchao Zhang thrust::hip::par.on(PetscDefaultHipStream),
3089d52a580bSJunchao Zhang #endif
3090d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))),
3091d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(), VecHIPEqualsReverse());
3092d52a580bSJunchao Zhang }
3093d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3094d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3095d52a580bSJunchao Zhang nx = mat->num_rows;
3096d52a580bSJunchao Zhang ny = mat->num_cols;
3097d52a580bSJunchao Zhang }
3098d52a580bSJunchao Zhang }
3099d52a580bSJunchao Zhang /* csr_spmv does y = alpha op(A) x + beta y */
3100d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3101*5a884c48SSatish Balay #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) && !PETSC_PKG_HIP_VERSION_EQ(7, 2, 0)
3102d52a580bSJunchao Zhang PetscCheck(opA >= 0 && opA <= 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE API on hipsparseOperation_t has changed and PETSc has not been updated accordingly");
3103d52a580bSJunchao Zhang if (!matstruct->hipSpMV[opA].initialized) { /* built on demand */
3104d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecXDescr, nx, xptr, hipsparse_scalartype));
3105d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecYDescr, ny, dptr, hipsparse_scalartype));
3106d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMV_bufferSize(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg,
3107d52a580bSJunchao Zhang &matstruct->hipSpMV[opA].spmvBufferSize));
3108d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&matstruct->hipSpMV[opA].spmvBuffer, matstruct->hipSpMV[opA].spmvBufferSize));
3109d52a580bSJunchao Zhang matstruct->hipSpMV[opA].initialized = PETSC_TRUE;
3110d52a580bSJunchao Zhang } else {
3111d52a580bSJunchao Zhang /* x, y's value pointers might change between calls, but their shape is kept, so we just update pointers */
3112d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecXDescr, xptr));
3113d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecYDescr, dptr));
3114d52a580bSJunchao Zhang }
3115d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMV(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, /* built in MatSeqAIJHIPSPARSECopyToGPU() or MatSeqAIJHIPSPARSEFormExplicitTranspose() */
3116d52a580bSJunchao Zhang matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg, matstruct->hipSpMV[opA].spmvBuffer));
3117d52a580bSJunchao Zhang #else
3118d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3119*5a884c48SSatish Balay nx = mat->num_rows; /* nx,ny are set before the #if block, set them again to avoid set-but-not-used warning */
3120*5a884c48SSatish Balay ny = mat->num_cols;
3121*5a884c48SSatish Balay PetscCallHIPSPARSE(hipsparse_csr_spmv(hipsparsestruct->handle, opA, nx, ny, mat->num_entries, matstruct->alpha_one, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), xptr, beta, dptr));
3122d52a580bSJunchao Zhang #endif
3123d52a580bSJunchao Zhang } else {
3124d52a580bSJunchao Zhang if (hipsparsestruct->nrows) {
3125d52a580bSJunchao Zhang hipsparseHybMat_t hybMat = (hipsparseHybMat_t)matstruct->mat;
3126d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_hyb_spmv(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->descr, hybMat, xptr, beta, dptr));
3127d52a580bSJunchao Zhang }
3128d52a580bSJunchao Zhang }
3129d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3130d52a580bSJunchao Zhang
3131d52a580bSJunchao Zhang if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3132d52a580bSJunchao Zhang if (yy) { /* MatMultAdd: zz = A*xx + yy */
3133d52a580bSJunchao Zhang if (compressed) { /* A is compressed. We first copy yy to zz, then ScatterAdd the work vector to zz */
3134d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::Copy(yy, zz)); /* zz = yy */
3135d52a580bSJunchao Zhang } else if (zz != yy) { /* A is not compressed. zz already contains A*xx, and we just need to add yy */
3136d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3137d52a580bSJunchao Zhang }
3138d52a580bSJunchao Zhang } else if (compressed) { /* MatMult: zz = A*xx. A is compressed, so we zero zz first, then ScatterAdd the work vector to zz */
3139d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::Set(zz, 0));
3140d52a580bSJunchao Zhang }
3141d52a580bSJunchao Zhang
3142d52a580bSJunchao Zhang /* ScatterAdd the result from work vector into the full vector when A is compressed */
3143d52a580bSJunchao Zhang if (compressed) {
3144d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3145d52a580bSJunchao Zhang /* I wanted to make this for_each asynchronous but failed. thrust::async::for_each() returns an event (internally registered)
3146d52a580bSJunchao Zhang and in the destructor of the scope, it will call hipStreamSynchronize() on this stream. One has to store all events to
3147d52a580bSJunchao Zhang prevent that. So I just add a ScatterAdd kernel.
3148d52a580bSJunchao Zhang */
3149d52a580bSJunchao Zhang #if 0
3150d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> zptr = thrust::device_pointer_cast(zarray);
3151d52a580bSJunchao Zhang thrust::async::for_each(thrust::hip::par.on(hipsparsestruct->stream),
3152d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))),
3153d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(),
3154d52a580bSJunchao Zhang VecHIPPlusEquals());
3155d52a580bSJunchao Zhang #else
3156d52a580bSJunchao Zhang PetscInt n = matstruct->cprowIndices->size();
3157d52a580bSJunchao Zhang hipLaunchKernelGGL(ScatterAdd, dim3((n + 255) / 256), dim3(256), 0, PetscDefaultHipStream, n, matstruct->cprowIndices->data().get(), hipsparsestruct->workVector->data().get(), zarray);
3158d52a580bSJunchao Zhang #endif
3159d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3160d52a580bSJunchao Zhang }
3161d52a580bSJunchao Zhang } else {
3162d52a580bSJunchao Zhang if (yy && yy != zz) PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3163d52a580bSJunchao Zhang }
3164d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(xx, (const PetscScalar **)&xarray));
3165d52a580bSJunchao Zhang if (yy == zz) PetscCall(VecHIPRestoreArray(zz, &zarray));
3166d52a580bSJunchao Zhang else PetscCall(VecHIPRestoreArrayWrite(zz, &zarray));
3167d52a580bSJunchao Zhang } catch (char *ex) {
3168d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
3169d52a580bSJunchao Zhang }
3170d52a580bSJunchao Zhang if (yy) PetscCall(PetscLogGpuFlops(2.0 * a->nz));
3171d52a580bSJunchao Zhang else PetscCall(PetscLogGpuFlops(2.0 * a->nz - a->nonzerorowcnt));
3172d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3173d52a580bSJunchao Zhang }
3174d52a580bSJunchao Zhang
MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)3175d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3176d52a580bSJunchao Zhang {
3177d52a580bSJunchao Zhang PetscFunctionBegin;
3178d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_FALSE));
3179d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3180d52a580bSJunchao Zhang }
3181d52a580bSJunchao Zhang
MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A,MatAssemblyType mode)3182d52a580bSJunchao Zhang static PetscErrorCode MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A, MatAssemblyType mode)
3183d52a580bSJunchao Zhang {
3184d52a580bSJunchao Zhang PetscFunctionBegin;
3185d52a580bSJunchao Zhang PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
3186d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3187d52a580bSJunchao Zhang }
3188d52a580bSJunchao Zhang
3189d52a580bSJunchao Zhang /*@
3190d52a580bSJunchao Zhang MatCreateSeqAIJHIPSPARSE - Creates a sparse matrix in `MATAIJHIPSPARSE` (compressed row) format.
3191d52a580bSJunchao Zhang This matrix will ultimately pushed down to AMD GPUs and use the HIPSPARSE library for calculations.
3192d52a580bSJunchao Zhang
3193d52a580bSJunchao Zhang Collective
3194d52a580bSJunchao Zhang
3195d52a580bSJunchao Zhang Input Parameters:
3196d52a580bSJunchao Zhang + comm - MPI communicator, set to `PETSC_COMM_SELF`
3197d52a580bSJunchao Zhang . m - number of rows
3198d52a580bSJunchao Zhang . n - number of columns
3199d52a580bSJunchao Zhang . nz - number of nonzeros per row (same for all rows), ignored if `nnz` is set
3200d52a580bSJunchao Zhang - nnz - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
3201d52a580bSJunchao Zhang
3202d52a580bSJunchao Zhang Output Parameter:
3203d52a580bSJunchao Zhang . A - the matrix
3204d52a580bSJunchao Zhang
3205d52a580bSJunchao Zhang Level: intermediate
3206d52a580bSJunchao Zhang
3207d52a580bSJunchao Zhang Notes:
3208d52a580bSJunchao Zhang It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
3209d52a580bSJunchao Zhang `MatXXXXSetPreallocation()` paradgm instead of this routine directly.
3210d52a580bSJunchao Zhang [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation`]
3211d52a580bSJunchao Zhang
3212d52a580bSJunchao Zhang The AIJ format (compressed row storage), is fully compatible with standard Fortran
3213d52a580bSJunchao Zhang storage. That is, the stored row and column indices can begin at
3214d52a580bSJunchao Zhang either one (as in Fortran) or zero.
3215d52a580bSJunchao Zhang
3216d52a580bSJunchao Zhang Specify the preallocated storage with either `nz` or `nnz` (not both).
3217d52a580bSJunchao Zhang Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
3218d52a580bSJunchao Zhang allocation.
3219d52a580bSJunchao Zhang
3220d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATSEQAIJHIPSPARSE`, `MATAIJHIPSPARSE`
3221d52a580bSJunchao Zhang @*/
MatCreateSeqAIJHIPSPARSE(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt nz,const PetscInt nnz[],Mat * A)3222d52a580bSJunchao Zhang PetscErrorCode MatCreateSeqAIJHIPSPARSE(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
3223d52a580bSJunchao Zhang {
3224d52a580bSJunchao Zhang PetscFunctionBegin;
3225d52a580bSJunchao Zhang PetscCall(MatCreate(comm, A));
3226d52a580bSJunchao Zhang PetscCall(MatSetSizes(*A, m, n, m, n));
3227d52a580bSJunchao Zhang PetscCall(MatSetType(*A, MATSEQAIJHIPSPARSE));
3228d52a580bSJunchao Zhang PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
3229d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3230d52a580bSJunchao Zhang }
3231d52a580bSJunchao Zhang
MatDestroy_SeqAIJHIPSPARSE(Mat A)3232d52a580bSJunchao Zhang static PetscErrorCode MatDestroy_SeqAIJHIPSPARSE(Mat A)
3233d52a580bSJunchao Zhang {
3234d52a580bSJunchao Zhang PetscFunctionBegin;
3235d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) PetscCall(MatSeqAIJHIPSPARSE_Destroy(A));
3236d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSETriFactors_Destroy((Mat_SeqAIJHIPSPARSETriFactors **)&A->spptr));
3237d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3238d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", NULL));
3239d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetUseCPUSolve_C", NULL));
3240d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3241d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3242d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3243d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
3244d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3245d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3246d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijhipsparse_hypre_C", NULL));
3247d52a580bSJunchao Zhang PetscCall(MatDestroy_SeqAIJ(A));
3248d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3249d52a580bSJunchao Zhang }
3250d52a580bSJunchao Zhang
MatDuplicate_SeqAIJHIPSPARSE(Mat A,MatDuplicateOption cpvalues,Mat * B)3251d52a580bSJunchao Zhang static PetscErrorCode MatDuplicate_SeqAIJHIPSPARSE(Mat A, MatDuplicateOption cpvalues, Mat *B)
3252d52a580bSJunchao Zhang {
3253d52a580bSJunchao Zhang PetscFunctionBegin;
3254d52a580bSJunchao Zhang PetscCall(MatDuplicate_SeqAIJ(A, cpvalues, B));
3255d52a580bSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(*B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, B));
3256d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3257d52a580bSJunchao Zhang }
3258d52a580bSJunchao Zhang
MatAXPY_SeqAIJHIPSPARSE(Mat Y,PetscScalar a,Mat X,MatStructure str)3259d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat Y, PetscScalar a, Mat X, MatStructure str)
3260d52a580bSJunchao Zhang {
3261d52a580bSJunchao Zhang Mat_SeqAIJ *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
3262d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cy;
3263d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cx;
3264d52a580bSJunchao Zhang PetscScalar *ay;
3265d52a580bSJunchao Zhang const PetscScalar *ax;
3266d52a580bSJunchao Zhang CsrMatrix *csry, *csrx;
3267d52a580bSJunchao Zhang
3268d52a580bSJunchao Zhang PetscFunctionBegin;
3269d52a580bSJunchao Zhang cy = (Mat_SeqAIJHIPSPARSE *)Y->spptr;
3270d52a580bSJunchao Zhang cx = (Mat_SeqAIJHIPSPARSE *)X->spptr;
3271d52a580bSJunchao Zhang if (X->ops->axpy != Y->ops->axpy) {
3272d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3273d52a580bSJunchao Zhang PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3274d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3275d52a580bSJunchao Zhang }
3276d52a580bSJunchao Zhang /* if we are here, it means both matrices are bound to GPU */
3277d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(Y));
3278d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(X));
3279d52a580bSJunchao Zhang PetscCheck(cy->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)Y), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3280d52a580bSJunchao Zhang PetscCheck(cx->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)X), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3281d52a580bSJunchao Zhang csry = (CsrMatrix *)cy->mat->mat;
3282d52a580bSJunchao Zhang csrx = (CsrMatrix *)cx->mat->mat;
3283d52a580bSJunchao Zhang /* see if we can turn this into a hipblas axpy */
3284d52a580bSJunchao Zhang if (str != SAME_NONZERO_PATTERN && x->nz == y->nz && !x->compressedrow.use && !y->compressedrow.use) {
3285d52a580bSJunchao Zhang bool eq = thrust::equal(thrust::device, csry->row_offsets->begin(), csry->row_offsets->end(), csrx->row_offsets->begin());
3286d52a580bSJunchao Zhang if (eq) eq = thrust::equal(thrust::device, csry->column_indices->begin(), csry->column_indices->end(), csrx->column_indices->begin());
3287d52a580bSJunchao Zhang if (eq) str = SAME_NONZERO_PATTERN;
3288d52a580bSJunchao Zhang }
3289d52a580bSJunchao Zhang /* spgeam is buggy with one column */
3290d52a580bSJunchao Zhang if (Y->cmap->n == 1 && str != SAME_NONZERO_PATTERN) str = DIFFERENT_NONZERO_PATTERN;
3291d52a580bSJunchao Zhang if (str == SUBSET_NONZERO_PATTERN) {
3292d52a580bSJunchao Zhang PetscScalar b = 1.0;
3293d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3294d52a580bSJunchao Zhang size_t bufferSize;
3295d52a580bSJunchao Zhang void *buffer;
3296d52a580bSJunchao Zhang #endif
3297d52a580bSJunchao Zhang
3298d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3299d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3300d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_HOST));
3301d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3302d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam_bufferSize(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3303d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), &bufferSize));
3304d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&buffer, bufferSize));
3305d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3306d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3307d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), buffer));
3308d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3309d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3310d52a580bSJunchao Zhang PetscCallHIP(hipFree(buffer));
3311d52a580bSJunchao Zhang #else
3312d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3313d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3314d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get()));
3315d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3316d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3317d52a580bSJunchao Zhang #endif
3318d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_DEVICE));
3319d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3320d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3321d52a580bSJunchao Zhang } else if (str == SAME_NONZERO_PATTERN) {
3322d52a580bSJunchao Zhang hipblasHandle_t hipblasv2handle;
3323d52a580bSJunchao Zhang PetscBLASInt one = 1, bnz = 1;
3324d52a580bSJunchao Zhang
3325d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3326d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3327d52a580bSJunchao Zhang PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3328d52a580bSJunchao Zhang PetscCall(PetscBLASIntCast(x->nz, &bnz));
3329d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3330d52a580bSJunchao Zhang PetscCallHIPBLAS(hipblasXaxpy(hipblasv2handle, bnz, &a, ax, one, ay, one));
3331d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * bnz));
3332d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3333d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3334d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3335d52a580bSJunchao Zhang } else {
3336d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3337d52a580bSJunchao Zhang PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3338d52a580bSJunchao Zhang }
3339d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3340d52a580bSJunchao Zhang }
3341d52a580bSJunchao Zhang
MatScale_SeqAIJHIPSPARSE(Mat Y,PetscScalar a)3342d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat Y, PetscScalar a)
3343d52a580bSJunchao Zhang {
3344d52a580bSJunchao Zhang Mat_SeqAIJ *y = (Mat_SeqAIJ *)Y->data;
3345d52a580bSJunchao Zhang PetscScalar *ay;
3346d52a580bSJunchao Zhang hipblasHandle_t hipblasv2handle;
3347d52a580bSJunchao Zhang PetscBLASInt one = 1, bnz = 1;
3348d52a580bSJunchao Zhang
3349d52a580bSJunchao Zhang PetscFunctionBegin;
3350d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3351d52a580bSJunchao Zhang PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3352d52a580bSJunchao Zhang PetscCall(PetscBLASIntCast(y->nz, &bnz));
3353d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3354d52a580bSJunchao Zhang PetscCallHIPBLAS(hipblasXscal(hipblasv2handle, bnz, &a, ay, one));
3355d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(bnz));
3356d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3357d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3358d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3359d52a580bSJunchao Zhang }
3360d52a580bSJunchao Zhang
MatZeroEntries_SeqAIJHIPSPARSE(Mat A)3361d52a580bSJunchao Zhang static PetscErrorCode MatZeroEntries_SeqAIJHIPSPARSE(Mat A)
3362d52a580bSJunchao Zhang {
3363d52a580bSJunchao Zhang PetscBool both = PETSC_FALSE;
3364d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3365d52a580bSJunchao Zhang
3366d52a580bSJunchao Zhang PetscFunctionBegin;
3367d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) {
3368d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *spptr = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3369d52a580bSJunchao Zhang if (spptr->mat) {
3370d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)spptr->mat->mat;
3371d52a580bSJunchao Zhang if (matrix->values) {
3372d52a580bSJunchao Zhang both = PETSC_TRUE;
3373d52a580bSJunchao Zhang thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3374d52a580bSJunchao Zhang }
3375d52a580bSJunchao Zhang }
3376d52a580bSJunchao Zhang if (spptr->matTranspose) {
3377d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)spptr->matTranspose->mat;
3378d52a580bSJunchao Zhang if (matrix->values) thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3379d52a580bSJunchao Zhang }
3380d52a580bSJunchao Zhang }
3381d52a580bSJunchao Zhang //PetscCall(MatZeroEntries_SeqAIJ(A));
3382d52a580bSJunchao Zhang PetscCall(PetscArrayzero(a->a, a->i[A->rmap->n]));
3383d52a580bSJunchao Zhang if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
3384d52a580bSJunchao Zhang else A->offloadmask = PETSC_OFFLOAD_CPU;
3385d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3386d52a580bSJunchao Zhang }
3387d52a580bSJunchao Zhang
MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A,PetscMemType * m)3388d52a580bSJunchao Zhang static PetscErrorCode MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A, PetscMemType *m)
3389d52a580bSJunchao Zhang {
3390d52a580bSJunchao Zhang PetscFunctionBegin;
3391d52a580bSJunchao Zhang *m = PETSC_MEMTYPE_HIP;
3392d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3393d52a580bSJunchao Zhang }
3394d52a580bSJunchao Zhang
MatBindToCPU_SeqAIJHIPSPARSE(Mat A,PetscBool flg)3395d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat A, PetscBool flg)
3396d52a580bSJunchao Zhang {
3397d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3398d52a580bSJunchao Zhang
3399d52a580bSJunchao Zhang PetscFunctionBegin;
3400d52a580bSJunchao Zhang if (A->factortype != MAT_FACTOR_NONE) {
3401d52a580bSJunchao Zhang A->boundtocpu = flg;
3402d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3403d52a580bSJunchao Zhang }
3404d52a580bSJunchao Zhang if (flg) {
3405d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
3406d52a580bSJunchao Zhang
3407d52a580bSJunchao Zhang A->ops->scale = MatScale_SeqAIJ;
3408d52a580bSJunchao Zhang A->ops->axpy = MatAXPY_SeqAIJ;
3409d52a580bSJunchao Zhang A->ops->zeroentries = MatZeroEntries_SeqAIJ;
3410d52a580bSJunchao Zhang A->ops->mult = MatMult_SeqAIJ;
3411d52a580bSJunchao Zhang A->ops->multadd = MatMultAdd_SeqAIJ;
3412d52a580bSJunchao Zhang A->ops->multtranspose = MatMultTranspose_SeqAIJ;
3413d52a580bSJunchao Zhang A->ops->multtransposeadd = MatMultTransposeAdd_SeqAIJ;
3414d52a580bSJunchao Zhang A->ops->multhermitiantranspose = NULL;
3415d52a580bSJunchao Zhang A->ops->multhermitiantransposeadd = NULL;
3416d52a580bSJunchao Zhang A->ops->productsetfromoptions = MatProductSetFromOptions_SeqAIJ;
3417d52a580bSJunchao Zhang A->ops->getcurrentmemtype = NULL;
3418d52a580bSJunchao Zhang PetscCall(PetscMemzero(a->ops, sizeof(Mat_SeqAIJOps)));
3419d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3420d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3421d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3422d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3423d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3424d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3425d52a580bSJunchao Zhang } else {
3426d52a580bSJunchao Zhang A->ops->scale = MatScale_SeqAIJHIPSPARSE;
3427d52a580bSJunchao Zhang A->ops->axpy = MatAXPY_SeqAIJHIPSPARSE;
3428d52a580bSJunchao Zhang A->ops->zeroentries = MatZeroEntries_SeqAIJHIPSPARSE;
3429d52a580bSJunchao Zhang A->ops->mult = MatMult_SeqAIJHIPSPARSE;
3430d52a580bSJunchao Zhang A->ops->multadd = MatMultAdd_SeqAIJHIPSPARSE;
3431d52a580bSJunchao Zhang A->ops->multtranspose = MatMultTranspose_SeqAIJHIPSPARSE;
3432d52a580bSJunchao Zhang A->ops->multtransposeadd = MatMultTransposeAdd_SeqAIJHIPSPARSE;
3433d52a580bSJunchao Zhang A->ops->multhermitiantranspose = MatMultHermitianTranspose_SeqAIJHIPSPARSE;
3434d52a580bSJunchao Zhang A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE;
3435d52a580bSJunchao Zhang A->ops->productsetfromoptions = MatProductSetFromOptions_SeqAIJHIPSPARSE;
3436d52a580bSJunchao Zhang A->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3437d52a580bSJunchao Zhang a->ops->getarray = MatSeqAIJGetArray_SeqAIJHIPSPARSE;
3438d52a580bSJunchao Zhang a->ops->restorearray = MatSeqAIJRestoreArray_SeqAIJHIPSPARSE;
3439d52a580bSJunchao Zhang a->ops->getarrayread = MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE;
3440d52a580bSJunchao Zhang a->ops->restorearrayread = MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE;
3441d52a580bSJunchao Zhang a->ops->getarraywrite = MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE;
3442d52a580bSJunchao Zhang a->ops->restorearraywrite = MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE;
3443d52a580bSJunchao Zhang a->ops->getcsrandmemtype = MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE;
3444d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", MatSeqAIJCopySubArray_SeqAIJHIPSPARSE));
3445d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3446d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3447d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJHIPSPARSE));
3448d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJHIPSPARSE));
3449d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3450d52a580bSJunchao Zhang }
3451d52a580bSJunchao Zhang A->boundtocpu = flg;
3452d52a580bSJunchao Zhang if (flg && a->inode.size_csr) a->inode.use = PETSC_TRUE;
3453d52a580bSJunchao Zhang else a->inode.use = PETSC_FALSE;
3454d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3455d52a580bSJunchao Zhang }
3456d52a580bSJunchao Zhang
MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A,MatType mtype,MatReuse reuse,Mat * newmat)3457d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
3458d52a580bSJunchao Zhang {
3459d52a580bSJunchao Zhang Mat B;
3460d52a580bSJunchao Zhang
3461d52a580bSJunchao Zhang PetscFunctionBegin;
3462d52a580bSJunchao Zhang PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP)); /* first use of HIPSPARSE may be via MatConvert */
3463d52a580bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) {
3464d52a580bSJunchao Zhang PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
3465d52a580bSJunchao Zhang } else if (reuse == MAT_REUSE_MATRIX) {
3466d52a580bSJunchao Zhang PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
3467d52a580bSJunchao Zhang }
3468d52a580bSJunchao Zhang B = *newmat;
3469d52a580bSJunchao Zhang PetscCall(PetscFree(B->defaultvectype));
3470d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(VECHIP, &B->defaultvectype));
3471d52a580bSJunchao Zhang if (reuse != MAT_REUSE_MATRIX && !B->spptr) {
3472d52a580bSJunchao Zhang if (B->factortype == MAT_FACTOR_NONE) {
3473d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *spptr;
3474d52a580bSJunchao Zhang PetscCall(PetscNew(&spptr));
3475d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3476d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3477d52a580bSJunchao Zhang spptr->format = MAT_HIPSPARSE_CSR;
3478d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3479d52a580bSJunchao Zhang spptr->spmvAlg = HIPSPARSE_SPMV_CSR_ALG1;
3480d52a580bSJunchao Zhang #else
3481d52a580bSJunchao Zhang spptr->spmvAlg = HIPSPARSE_CSRMV_ALG1; /* default, since we only support csr */
3482d52a580bSJunchao Zhang #endif
3483d52a580bSJunchao Zhang spptr->spmmAlg = HIPSPARSE_SPMM_CSR_ALG1; /* default, only support column-major dense matrix B */
3484d52a580bSJunchao Zhang //spptr->csr2cscAlg = HIPSPARSE_CSR2CSC_ALG1;
3485d52a580bSJunchao Zhang
3486d52a580bSJunchao Zhang B->spptr = spptr;
3487d52a580bSJunchao Zhang } else {
3488d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *spptr;
3489d52a580bSJunchao Zhang
3490d52a580bSJunchao Zhang PetscCall(PetscNew(&spptr));
3491d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3492d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3493d52a580bSJunchao Zhang B->spptr = spptr;
3494d52a580bSJunchao Zhang }
3495d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
3496d52a580bSJunchao Zhang }
3497d52a580bSJunchao Zhang B->ops->assemblyend = MatAssemblyEnd_SeqAIJHIPSPARSE;
3498d52a580bSJunchao Zhang B->ops->destroy = MatDestroy_SeqAIJHIPSPARSE;
3499d52a580bSJunchao Zhang B->ops->setoption = MatSetOption_SeqAIJHIPSPARSE;
3500d52a580bSJunchao Zhang B->ops->setfromoptions = MatSetFromOptions_SeqAIJHIPSPARSE;
3501d52a580bSJunchao Zhang B->ops->bindtocpu = MatBindToCPU_SeqAIJHIPSPARSE;
3502d52a580bSJunchao Zhang B->ops->duplicate = MatDuplicate_SeqAIJHIPSPARSE;
3503d52a580bSJunchao Zhang B->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3504d52a580bSJunchao Zhang
3505d52a580bSJunchao Zhang PetscCall(MatBindToCPU_SeqAIJHIPSPARSE(B, PETSC_FALSE));
3506d52a580bSJunchao Zhang PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATSEQAIJHIPSPARSE));
3507d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetFormat_C", MatHIPSPARSESetFormat_SeqAIJHIPSPARSE));
3508d52a580bSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
3509d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_seqaijhipsparse_hypre_C", MatConvert_AIJ_HYPRE));
3510d52a580bSJunchao Zhang #endif
3511d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetUseCPUSolve_C", MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE));
3512d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3513d52a580bSJunchao Zhang }
3514d52a580bSJunchao Zhang
MatCreate_SeqAIJHIPSPARSE(Mat B)3515d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJHIPSPARSE(Mat B)
3516d52a580bSJunchao Zhang {
3517d52a580bSJunchao Zhang PetscFunctionBegin;
3518d52a580bSJunchao Zhang PetscCall(MatCreate_SeqAIJ(B));
3519d52a580bSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, &B));
3520d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3521d52a580bSJunchao Zhang }
3522d52a580bSJunchao Zhang
3523d52a580bSJunchao Zhang /*MC
3524d52a580bSJunchao Zhang MATSEQAIJHIPSPARSE - MATAIJHIPSPARSE = "(seq)aijhipsparse" - A matrix type to be used for sparse matrices on AMD GPUs
3525d52a580bSJunchao Zhang
3526d52a580bSJunchao Zhang A matrix type whose data resides on AMD GPUs. These matrices can be in either
3527d52a580bSJunchao Zhang CSR, ELL, or Hybrid format.
3528d52a580bSJunchao Zhang All matrix calculations are performed on AMD/NVIDIA GPUs using the HIPSPARSE library.
3529d52a580bSJunchao Zhang
3530d52a580bSJunchao Zhang Options Database Keys:
3531d52a580bSJunchao Zhang + -mat_type aijhipsparse - sets the matrix type to `MATSEQAIJHIPSPARSE`
3532d52a580bSJunchao Zhang . -mat_hipsparse_storage_format csr - sets the storage format of matrices (for `MatMult()` and factors in `MatSolve()`).
3533d52a580bSJunchao Zhang Other options include ell (ellpack) or hyb (hybrid).
3534d52a580bSJunchao Zhang . -mat_hipsparse_mult_storage_format csr - sets the storage format of matrices (for `MatMult()`). Other options include ell (ellpack) or hyb (hybrid).
3535d52a580bSJunchao Zhang - -mat_hipsparse_use_cpu_solve - Do `MatSolve()` on the CPU
3536d52a580bSJunchao Zhang
3537d52a580bSJunchao Zhang Level: beginner
3538d52a580bSJunchao Zhang
3539d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
3540d52a580bSJunchao Zhang M*/
3541d52a580bSJunchao Zhang
MatSolverTypeRegister_HIPSPARSE(void)3542d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_HIPSPARSE(void)
3543d52a580bSJunchao Zhang {
3544d52a580bSJunchao Zhang PetscFunctionBegin;
3545d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_LU, MatGetFactor_seqaijhipsparse_hipsparse));
3546d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_CHOLESKY, MatGetFactor_seqaijhipsparse_hipsparse));
3547d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ILU, MatGetFactor_seqaijhipsparse_hipsparse));
3548d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ICC, MatGetFactor_seqaijhipsparse_hipsparse));
3549d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3550d52a580bSJunchao Zhang }
3551d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSE_Destroy(Mat mat)3552d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat mat)
3553d52a580bSJunchao Zhang {
3554d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(mat->spptr);
3555d52a580bSJunchao Zhang
3556d52a580bSJunchao Zhang PetscFunctionBegin;
3557d52a580bSJunchao Zhang if (cusp) {
3558d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->mat, cusp->format));
3559d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3560d52a580bSJunchao Zhang delete cusp->workVector;
3561d52a580bSJunchao Zhang delete cusp->rowoffsets_gpu;
3562d52a580bSJunchao Zhang delete cusp->csr2csc_i;
3563d52a580bSJunchao Zhang delete cusp->coords;
3564d52a580bSJunchao Zhang if (cusp->handle) PetscCallHIPSPARSE(hipsparseDestroy(cusp->handle));
3565d52a580bSJunchao Zhang PetscCall(PetscFree(mat->spptr));
3566d52a580bSJunchao Zhang }
3567d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3568d52a580bSJunchao Zhang }
3569d52a580bSJunchao Zhang
CsrMatrix_Destroy(CsrMatrix ** mat)3570d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **mat)
3571d52a580bSJunchao Zhang {
3572d52a580bSJunchao Zhang PetscFunctionBegin;
3573d52a580bSJunchao Zhang if (*mat) {
3574d52a580bSJunchao Zhang delete (*mat)->values;
3575d52a580bSJunchao Zhang delete (*mat)->column_indices;
3576d52a580bSJunchao Zhang delete (*mat)->row_offsets;
3577d52a580bSJunchao Zhang delete *mat;
3578d52a580bSJunchao Zhang *mat = 0;
3579d52a580bSJunchao Zhang }
3580d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3581d52a580bSJunchao Zhang }
3582d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct ** trifactor)3583d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **trifactor)
3584d52a580bSJunchao Zhang {
3585d52a580bSJunchao Zhang PetscFunctionBegin;
3586d52a580bSJunchao Zhang if (*trifactor) {
3587d52a580bSJunchao Zhang if ((*trifactor)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*trifactor)->descr));
3588d52a580bSJunchao Zhang if ((*trifactor)->solveInfo) PetscCallHIPSPARSE(hipsparseDestroyCsrsvInfo((*trifactor)->solveInfo));
3589d52a580bSJunchao Zhang PetscCall(CsrMatrix_Destroy(&(*trifactor)->csrMat));
3590d52a580bSJunchao Zhang if ((*trifactor)->solveBuffer) PetscCallHIP(hipFree((*trifactor)->solveBuffer));
3591d52a580bSJunchao Zhang if ((*trifactor)->AA_h) PetscCallHIP(hipHostFree((*trifactor)->AA_h));
3592d52a580bSJunchao Zhang if ((*trifactor)->csr2cscBuffer) PetscCallHIP(hipFree((*trifactor)->csr2cscBuffer));
3593d52a580bSJunchao Zhang PetscCall(PetscFree(*trifactor));
3594d52a580bSJunchao Zhang }
3595d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3596d52a580bSJunchao Zhang }
3597d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct ** matstruct,MatHIPSPARSEStorageFormat format)3598d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **matstruct, MatHIPSPARSEStorageFormat format)
3599d52a580bSJunchao Zhang {
3600d52a580bSJunchao Zhang CsrMatrix *mat;
3601d52a580bSJunchao Zhang
3602d52a580bSJunchao Zhang PetscFunctionBegin;
3603d52a580bSJunchao Zhang if (*matstruct) {
3604d52a580bSJunchao Zhang if ((*matstruct)->mat) {
3605d52a580bSJunchao Zhang if (format == MAT_HIPSPARSE_ELL || format == MAT_HIPSPARSE_HYB) {
3606d52a580bSJunchao Zhang hipsparseHybMat_t hybMat = (hipsparseHybMat_t)(*matstruct)->mat;
3607d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyHybMat(hybMat));
3608d52a580bSJunchao Zhang } else {
3609d52a580bSJunchao Zhang mat = (CsrMatrix *)(*matstruct)->mat;
3610d52a580bSJunchao Zhang PetscCall(CsrMatrix_Destroy(&mat));
3611d52a580bSJunchao Zhang }
3612d52a580bSJunchao Zhang }
3613d52a580bSJunchao Zhang if ((*matstruct)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*matstruct)->descr));
3614d52a580bSJunchao Zhang delete (*matstruct)->cprowIndices;
3615d52a580bSJunchao Zhang if ((*matstruct)->alpha_one) PetscCallHIP(hipFree((*matstruct)->alpha_one));
3616d52a580bSJunchao Zhang if ((*matstruct)->beta_zero) PetscCallHIP(hipFree((*matstruct)->beta_zero));
3617d52a580bSJunchao Zhang if ((*matstruct)->beta_one) PetscCallHIP(hipFree((*matstruct)->beta_one));
3618d52a580bSJunchao Zhang
3619d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *mdata = *matstruct;
3620d52a580bSJunchao Zhang if (mdata->matDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mdata->matDescr));
3621d52a580bSJunchao Zhang for (int i = 0; i < 3; i++) {
3622d52a580bSJunchao Zhang if (mdata->hipSpMV[i].initialized) {
3623d52a580bSJunchao Zhang PetscCallHIP(hipFree(mdata->hipSpMV[i].spmvBuffer));
3624d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecXDescr));
3625d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecYDescr));
3626d52a580bSJunchao Zhang }
3627d52a580bSJunchao Zhang }
3628d52a580bSJunchao Zhang delete *matstruct;
3629d52a580bSJunchao Zhang *matstruct = NULL;
3630d52a580bSJunchao Zhang }
3631d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3632d52a580bSJunchao Zhang }
3633d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p * trifactors)3634d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *trifactors)
3635d52a580bSJunchao Zhang {
3636d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = *trifactors;
3637d52a580bSJunchao Zhang
3638d52a580bSJunchao Zhang PetscFunctionBegin;
3639d52a580bSJunchao Zhang if (fs) {
3640d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtr));
3641d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtr));
3642d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtrTranspose));
3643d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtrTranspose));
3644d52a580bSJunchao Zhang delete fs->rpermIndices;
3645d52a580bSJunchao Zhang delete fs->cpermIndices;
3646d52a580bSJunchao Zhang delete fs->workVector;
3647d52a580bSJunchao Zhang fs->rpermIndices = NULL;
3648d52a580bSJunchao Zhang fs->cpermIndices = NULL;
3649d52a580bSJunchao Zhang fs->workVector = NULL;
3650d52a580bSJunchao Zhang fs->init_dev_prop = PETSC_FALSE;
3651d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3652d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrRowPtr));
3653d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrColIdx));
3654d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrVal));
3655d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->X));
3656d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->Y));
3657d52a580bSJunchao Zhang // PetscCallHIP(hipFree(fs->factBuffer_M)); /* No needed since factBuffer_M shares with one of spsvBuffer_L/U */
3658d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_L));
3659d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_U));
3660d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_Lt));
3661d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_Ut));
3662d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyMatDescr(fs->matDescr_M));
3663d52a580bSJunchao Zhang if (fs->spMatDescr_L) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_L));
3664d52a580bSJunchao Zhang if (fs->spMatDescr_U) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_U));
3665d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_L));
3666d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Lt));
3667d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_U));
3668d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Ut));
3669d52a580bSJunchao Zhang if (fs->dnVecDescr_X) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_X));
3670d52a580bSJunchao Zhang if (fs->dnVecDescr_Y) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_Y));
3671d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyCsrilu02Info(fs->ilu0Info_M));
3672d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyCsric02Info(fs->ic0Info_M));
3673d52a580bSJunchao Zhang
3674d52a580bSJunchao Zhang fs->createdTransposeSpSVDescr = PETSC_FALSE;
3675d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
3676d52a580bSJunchao Zhang #endif
3677d52a580bSJunchao Zhang }
3678d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3679d52a580bSJunchao Zhang }
3680d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors ** trifactors)3681d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **trifactors)
3682d52a580bSJunchao Zhang {
3683d52a580bSJunchao Zhang hipsparseHandle_t handle;
3684d52a580bSJunchao Zhang
3685d52a580bSJunchao Zhang PetscFunctionBegin;
3686d52a580bSJunchao Zhang if (*trifactors) {
3687d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(trifactors));
3688d52a580bSJunchao Zhang if ((handle = (*trifactors)->handle)) PetscCallHIPSPARSE(hipsparseDestroy(handle));
3689d52a580bSJunchao Zhang PetscCall(PetscFree(*trifactors));
3690d52a580bSJunchao Zhang }
3691d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3692d52a580bSJunchao Zhang }
3693d52a580bSJunchao Zhang
3694d52a580bSJunchao Zhang struct IJCompare {
operator ()IJCompare3695d52a580bSJunchao Zhang __host__ __device__ inline bool operator()(const thrust::tuple<PetscInt, PetscInt> &t1, const thrust::tuple<PetscInt, PetscInt> &t2)
3696d52a580bSJunchao Zhang {
3697d52a580bSJunchao Zhang if (t1.get<0>() < t2.get<0>()) return true;
3698d52a580bSJunchao Zhang if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
3699d52a580bSJunchao Zhang return false;
3700d52a580bSJunchao Zhang }
3701d52a580bSJunchao Zhang };
3702d52a580bSJunchao Zhang
MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A,PetscBool destroy)3703d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A, PetscBool destroy)
3704d52a580bSJunchao Zhang {
3705d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3706d52a580bSJunchao Zhang
3707d52a580bSJunchao Zhang PetscFunctionBegin;
3708d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3709d52a580bSJunchao Zhang if (!cusp) PetscFunctionReturn(PETSC_SUCCESS);
3710d52a580bSJunchao Zhang if (destroy) {
3711d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3712d52a580bSJunchao Zhang delete cusp->csr2csc_i;
3713d52a580bSJunchao Zhang cusp->csr2csc_i = NULL;
3714d52a580bSJunchao Zhang }
3715d52a580bSJunchao Zhang A->transupdated = PETSC_FALSE;
3716d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3717d52a580bSJunchao Zhang }
3718d52a580bSJunchao Zhang
MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data)3719d52a580bSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data)
3720d52a580bSJunchao Zhang {
3721d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo = *(MatCOOStruct_SeqAIJ **)data;
3722d52a580bSJunchao Zhang
3723d52a580bSJunchao Zhang PetscFunctionBegin;
3724d52a580bSJunchao Zhang PetscCallHIP(hipFree(coo->perm));
3725d52a580bSJunchao Zhang PetscCallHIP(hipFree(coo->jmap));
3726d52a580bSJunchao Zhang PetscCall(PetscFree(coo));
3727d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3728d52a580bSJunchao Zhang }
3729d52a580bSJunchao Zhang
MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat,PetscCount coo_n,PetscInt coo_i[],PetscInt coo_j[])3730d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
3731d52a580bSJunchao Zhang {
3732d52a580bSJunchao Zhang PetscBool dev_ij = PETSC_FALSE;
3733d52a580bSJunchao Zhang PetscMemType mtype = PETSC_MEMTYPE_HOST;
3734d52a580bSJunchao Zhang PetscInt *i, *j;
3735d52a580bSJunchao Zhang PetscContainer container_h;
3736d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo_h, *coo_d;
3737d52a580bSJunchao Zhang
3738d52a580bSJunchao Zhang PetscFunctionBegin;
3739d52a580bSJunchao Zhang PetscCall(PetscGetMemType(coo_i, &mtype));
3740d52a580bSJunchao Zhang if (PetscMemTypeDevice(mtype)) {
3741d52a580bSJunchao Zhang dev_ij = PETSC_TRUE;
3742d52a580bSJunchao Zhang PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
3743d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(i, coo_i, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3744d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(j, coo_j, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3745d52a580bSJunchao Zhang } else {
3746d52a580bSJunchao Zhang i = coo_i;
3747d52a580bSJunchao Zhang j = coo_j;
3748d52a580bSJunchao Zhang }
3749d52a580bSJunchao Zhang PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, i, j));
3750d52a580bSJunchao Zhang if (dev_ij) PetscCall(PetscFree2(i, j));
3751d52a580bSJunchao Zhang mat->offloadmask = PETSC_OFFLOAD_CPU;
3752d52a580bSJunchao Zhang // Create the GPU memory
3753d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mat));
3754d52a580bSJunchao Zhang
3755d52a580bSJunchao Zhang // Copy the COO struct to device
3756d52a580bSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
3757d52a580bSJunchao Zhang PetscCall(PetscContainerGetPointer(container_h, &coo_h));
3758d52a580bSJunchao Zhang PetscCall(PetscMalloc1(1, &coo_d));
3759d52a580bSJunchao Zhang *coo_d = *coo_h; // do a shallow copy and then amend some fields that need to be different
3760d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&coo_d->jmap, (coo_h->nz + 1) * sizeof(PetscCount)));
3761d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(coo_d->jmap, coo_h->jmap, (coo_h->nz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
3762d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&coo_d->perm, coo_h->Atot * sizeof(PetscCount)));
3763d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(coo_d->perm, coo_h->perm, coo_h->Atot * sizeof(PetscCount), hipMemcpyHostToDevice));
3764d52a580bSJunchao Zhang
3765d52a580bSJunchao Zhang // Put the COO struct in a container and then attach that to the matrix
3766d52a580bSJunchao Zhang PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJHIPSPARSE));
3767d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3768d52a580bSJunchao Zhang }
3769d52a580bSJunchao Zhang
MatAddCOOValues(const PetscScalar kv[],PetscCount nnz,const PetscCount jmap[],const PetscCount perm[],InsertMode imode,PetscScalar a[])3770d52a580bSJunchao Zhang __global__ static void MatAddCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount jmap[], const PetscCount perm[], InsertMode imode, PetscScalar a[])
3771d52a580bSJunchao Zhang {
3772d52a580bSJunchao Zhang PetscCount i = blockIdx.x * blockDim.x + threadIdx.x;
3773d52a580bSJunchao Zhang const PetscCount grid_size = gridDim.x * blockDim.x;
3774d52a580bSJunchao Zhang for (; i < nnz; i += grid_size) {
3775d52a580bSJunchao Zhang PetscScalar sum = 0.0;
3776d52a580bSJunchao Zhang for (PetscCount k = jmap[i]; k < jmap[i + 1]; k++) sum += kv[perm[k]];
3777d52a580bSJunchao Zhang a[i] = (imode == INSERT_VALUES ? 0.0 : a[i]) + sum;
3778d52a580bSJunchao Zhang }
3779d52a580bSJunchao Zhang }
3780d52a580bSJunchao Zhang
MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A,const PetscScalar v[],InsertMode imode)3781d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A, const PetscScalar v[], InsertMode imode)
3782d52a580bSJunchao Zhang {
3783d52a580bSJunchao Zhang Mat_SeqAIJ *seq = (Mat_SeqAIJ *)A->data;
3784d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *dev = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3785d52a580bSJunchao Zhang PetscCount Annz = seq->nz;
3786d52a580bSJunchao Zhang PetscMemType memtype;
3787d52a580bSJunchao Zhang const PetscScalar *v1 = v;
3788d52a580bSJunchao Zhang PetscScalar *Aa;
3789d52a580bSJunchao Zhang PetscContainer container;
3790d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo;
3791d52a580bSJunchao Zhang
3792d52a580bSJunchao Zhang PetscFunctionBegin;
3793d52a580bSJunchao Zhang if (!dev->mat) PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3794d52a580bSJunchao Zhang
3795d52a580bSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
3796d52a580bSJunchao Zhang PetscCall(PetscContainerGetPointer(container, &coo));
3797d52a580bSJunchao Zhang
3798d52a580bSJunchao Zhang PetscCall(PetscGetMemType(v, &memtype));
3799d52a580bSJunchao Zhang if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
3800d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
3801d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), hipMemcpyHostToDevice));
3802d52a580bSJunchao Zhang }
3803d52a580bSJunchao Zhang
3804d52a580bSJunchao Zhang if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(A, &Aa));
3805d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSEGetArray(A, &Aa));
3806d52a580bSJunchao Zhang
3807d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
3808d52a580bSJunchao Zhang if (Annz) {
3809d52a580bSJunchao Zhang hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddCOOValues), dim3((Annz + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, Annz, coo->jmap, coo->perm, imode, Aa);
3810d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError());
3811d52a580bSJunchao Zhang }
3812d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
3813d52a580bSJunchao Zhang
3814d52a580bSJunchao Zhang if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(A, &Aa));
3815d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSERestoreArray(A, &Aa));
3816d52a580bSJunchao Zhang
3817d52a580bSJunchao Zhang if (PetscMemTypeHost(memtype)) PetscCallHIP(hipFree((void *)v1));
3818d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3819d52a580bSJunchao Zhang }
3820d52a580bSJunchao Zhang
3821d52a580bSJunchao Zhang /*@C
3822d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetIJ - returns the device row storage `i` and `j` indices for `MATSEQAIJHIPSPARSE` matrices.
3823d52a580bSJunchao Zhang
3824d52a580bSJunchao Zhang Not Collective
3825d52a580bSJunchao Zhang
3826d52a580bSJunchao Zhang Input Parameters:
3827d52a580bSJunchao Zhang + A - the matrix
3828d52a580bSJunchao Zhang - compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3829d52a580bSJunchao Zhang
3830d52a580bSJunchao Zhang Output Parameters:
3831d52a580bSJunchao Zhang + i - the CSR row pointers
3832d52a580bSJunchao Zhang - j - the CSR column indices
3833d52a580bSJunchao Zhang
3834d52a580bSJunchao Zhang Level: developer
3835d52a580bSJunchao Zhang
3836d52a580bSJunchao Zhang Note:
3837d52a580bSJunchao Zhang When compressed is true, the CSR structure does not contain empty rows
3838d52a580bSJunchao Zhang
3839d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSERestoreIJ()`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3840d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetIJ(Mat A,PetscBool compressed,const int * i[],const int * j[])3841d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3842d52a580bSJunchao Zhang {
3843d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3844d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3845d52a580bSJunchao Zhang CsrMatrix *csr;
3846d52a580bSJunchao Zhang
3847d52a580bSJunchao Zhang PetscFunctionBegin;
3848d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3849d52a580bSJunchao Zhang if (!i || !j) PetscFunctionReturn(PETSC_SUCCESS);
3850d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3851d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3852d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3853d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3854d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat;
3855d52a580bSJunchao Zhang if (i) {
3856d52a580bSJunchao Zhang if (!compressed && a->compressedrow.use) { /* need full row offset */
3857d52a580bSJunchao Zhang if (!cusp->rowoffsets_gpu) {
3858d52a580bSJunchao Zhang cusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
3859d52a580bSJunchao Zhang cusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
3860d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
3861d52a580bSJunchao Zhang }
3862d52a580bSJunchao Zhang *i = cusp->rowoffsets_gpu->data().get();
3863d52a580bSJunchao Zhang } else *i = csr->row_offsets->data().get();
3864d52a580bSJunchao Zhang }
3865d52a580bSJunchao Zhang if (j) *j = csr->column_indices->data().get();
3866d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3867d52a580bSJunchao Zhang }
3868d52a580bSJunchao Zhang
3869d52a580bSJunchao Zhang /*@C
3870d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreIJ - restore the device row storage `i` and `j` indices obtained with `MatSeqAIJHIPSPARSEGetIJ()`
3871d52a580bSJunchao Zhang
3872d52a580bSJunchao Zhang Not Collective
3873d52a580bSJunchao Zhang
3874d52a580bSJunchao Zhang Input Parameters:
3875d52a580bSJunchao Zhang + A - the matrix
3876d52a580bSJunchao Zhang . compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3877d52a580bSJunchao Zhang . i - the CSR row pointers
3878d52a580bSJunchao Zhang - j - the CSR column indices
3879d52a580bSJunchao Zhang
3880d52a580bSJunchao Zhang Level: developer
3881d52a580bSJunchao Zhang
3882d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetIJ()`
3883d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreIJ(Mat A,PetscBool compressed,const int * i[],const int * j[])3884d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3885d52a580bSJunchao Zhang {
3886d52a580bSJunchao Zhang PetscFunctionBegin;
3887d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3888d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3889d52a580bSJunchao Zhang if (i) *i = NULL;
3890d52a580bSJunchao Zhang if (j) *j = NULL;
3891d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3892d52a580bSJunchao Zhang }
3893d52a580bSJunchao Zhang
3894d52a580bSJunchao Zhang /*@C
3895d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArrayRead - gives read-only access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3896d52a580bSJunchao Zhang
3897d52a580bSJunchao Zhang Not Collective
3898d52a580bSJunchao Zhang
3899d52a580bSJunchao Zhang Input Parameter:
3900d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3901d52a580bSJunchao Zhang
3902d52a580bSJunchao Zhang Output Parameter:
3903d52a580bSJunchao Zhang . a - pointer to the device data
3904d52a580bSJunchao Zhang
3905d52a580bSJunchao Zhang Level: developer
3906d52a580bSJunchao Zhang
3907d52a580bSJunchao Zhang Note:
3908d52a580bSJunchao Zhang May trigger host-device copies if the up-to-date matrix data is on host
3909d52a580bSJunchao Zhang
3910d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArrayRead()`
3911d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArrayRead(Mat A,const PetscScalar * a[])3912d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayRead(Mat A, const PetscScalar *a[])
3913d52a580bSJunchao Zhang {
3914d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3915d52a580bSJunchao Zhang CsrMatrix *csr;
3916d52a580bSJunchao Zhang
3917d52a580bSJunchao Zhang PetscFunctionBegin;
3918d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3919d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
3920d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3921d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3922d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3923d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3924d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat;
3925d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3926d52a580bSJunchao Zhang *a = csr->values->data().get();
3927d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3928d52a580bSJunchao Zhang }
3929d52a580bSJunchao Zhang
3930d52a580bSJunchao Zhang /*@C
3931d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArrayRead - restore the read-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayRead()`
3932d52a580bSJunchao Zhang
3933d52a580bSJunchao Zhang Not Collective
3934d52a580bSJunchao Zhang
3935d52a580bSJunchao Zhang Input Parameters:
3936d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3937d52a580bSJunchao Zhang - a - pointer to the device data
3938d52a580bSJunchao Zhang
3939d52a580bSJunchao Zhang Level: developer
3940d52a580bSJunchao Zhang
3941d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3942d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArrayRead(Mat A,const PetscScalar * a[])3943d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayRead(Mat A, const PetscScalar *a[])
3944d52a580bSJunchao Zhang {
3945d52a580bSJunchao Zhang PetscFunctionBegin;
3946d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3947d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
3948d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3949d52a580bSJunchao Zhang *a = NULL;
3950d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3951d52a580bSJunchao Zhang }
3952d52a580bSJunchao Zhang
3953d52a580bSJunchao Zhang /*@C
3954d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArray - gives read-write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3955d52a580bSJunchao Zhang
3956d52a580bSJunchao Zhang Not Collective
3957d52a580bSJunchao Zhang
3958d52a580bSJunchao Zhang Input Parameter:
3959d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3960d52a580bSJunchao Zhang
3961d52a580bSJunchao Zhang Output Parameter:
3962d52a580bSJunchao Zhang . a - pointer to the device data
3963d52a580bSJunchao Zhang
3964d52a580bSJunchao Zhang Level: developer
3965d52a580bSJunchao Zhang
3966d52a580bSJunchao Zhang Note:
3967d52a580bSJunchao Zhang May trigger host-device copies if up-to-date matrix data is on host
3968d52a580bSJunchao Zhang
3969d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArray()`
3970d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArray(Mat A,PetscScalar * a[])3971d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArray(Mat A, PetscScalar *a[])
3972d52a580bSJunchao Zhang {
3973d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3974d52a580bSJunchao Zhang CsrMatrix *csr;
3975d52a580bSJunchao Zhang
3976d52a580bSJunchao Zhang PetscFunctionBegin;
3977d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3978d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
3979d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3980d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3981d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3982d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3983d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat;
3984d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3985d52a580bSJunchao Zhang *a = csr->values->data().get();
3986d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_GPU;
3987d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
3988d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
3989d52a580bSJunchao Zhang }
3990d52a580bSJunchao Zhang /*@C
3991d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArray - restore the read-write access array obtained from `MatSeqAIJHIPSPARSEGetArray()`
3992d52a580bSJunchao Zhang
3993d52a580bSJunchao Zhang Not Collective
3994d52a580bSJunchao Zhang
3995d52a580bSJunchao Zhang Input Parameters:
3996d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3997d52a580bSJunchao Zhang - a - pointer to the device data
3998d52a580bSJunchao Zhang
3999d52a580bSJunchao Zhang Level: developer
4000d52a580bSJunchao Zhang
4001d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`
4002d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArray(Mat A,PetscScalar * a[])4003d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArray(Mat A, PetscScalar *a[])
4004d52a580bSJunchao Zhang {
4005d52a580bSJunchao Zhang PetscFunctionBegin;
4006d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4007d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
4008d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4009d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)A));
4010d52a580bSJunchao Zhang *a = NULL;
4011d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
4012d52a580bSJunchao Zhang }
4013d52a580bSJunchao Zhang
4014d52a580bSJunchao Zhang /*@C
4015d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArrayWrite - gives write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
4016d52a580bSJunchao Zhang
4017d52a580bSJunchao Zhang Not Collective
4018d52a580bSJunchao Zhang
4019d52a580bSJunchao Zhang Input Parameter:
4020d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
4021d52a580bSJunchao Zhang
4022d52a580bSJunchao Zhang Output Parameter:
4023d52a580bSJunchao Zhang . a - pointer to the device data
4024d52a580bSJunchao Zhang
4025d52a580bSJunchao Zhang Level: developer
4026d52a580bSJunchao Zhang
4027d52a580bSJunchao Zhang Note:
4028d52a580bSJunchao Zhang Does not trigger host-device copies and flags data validity on the GPU
4029d52a580bSJunchao Zhang
4030d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSERestoreArrayWrite()`
4031d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArrayWrite(Mat A,PetscScalar * a[])4032d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayWrite(Mat A, PetscScalar *a[])
4033d52a580bSJunchao Zhang {
4034d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
4035d52a580bSJunchao Zhang CsrMatrix *csr;
4036d52a580bSJunchao Zhang
4037d52a580bSJunchao Zhang PetscFunctionBegin;
4038d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4039d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
4040d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4041d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4042d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4043d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat;
4044d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
4045d52a580bSJunchao Zhang *a = csr->values->data().get();
4046d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_GPU;
4047d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
4048d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
4049d52a580bSJunchao Zhang }
4050d52a580bSJunchao Zhang
4051d52a580bSJunchao Zhang /*@C
4052d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArrayWrite - restore the write-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayWrite()`
4053d52a580bSJunchao Zhang
4054d52a580bSJunchao Zhang Not Collective
4055d52a580bSJunchao Zhang
4056d52a580bSJunchao Zhang Input Parameters:
4057d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
4058d52a580bSJunchao Zhang - a - pointer to the device data
4059d52a580bSJunchao Zhang
4060d52a580bSJunchao Zhang Level: developer
4061d52a580bSJunchao Zhang
4062d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayWrite()`
4063d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A,PetscScalar * a[])4064d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A, PetscScalar *a[])
4065d52a580bSJunchao Zhang {
4066d52a580bSJunchao Zhang PetscFunctionBegin;
4067d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4068d52a580bSJunchao Zhang PetscAssertPointer(a, 2);
4069d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4070d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)A));
4071d52a580bSJunchao Zhang *a = NULL;
4072d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
4073d52a580bSJunchao Zhang }
4074d52a580bSJunchao Zhang
4075d52a580bSJunchao Zhang struct IJCompare4 {
operator ()IJCompare44076d52a580bSJunchao Zhang __host__ __device__ inline bool operator()(const thrust::tuple<int, int, PetscScalar, int> &t1, const thrust::tuple<int, int, PetscScalar, int> &t2)
4077d52a580bSJunchao Zhang {
4078d52a580bSJunchao Zhang if (t1.get<0>() < t2.get<0>()) return true;
4079d52a580bSJunchao Zhang if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
4080d52a580bSJunchao Zhang return false;
4081d52a580bSJunchao Zhang }
4082d52a580bSJunchao Zhang };
4083d52a580bSJunchao Zhang
4084d52a580bSJunchao Zhang struct Shift {
4085d52a580bSJunchao Zhang int _shift;
4086d52a580bSJunchao Zhang
ShiftShift4087d52a580bSJunchao Zhang Shift(int shift) : _shift(shift) { }
operator ()Shift4088d52a580bSJunchao Zhang __host__ __device__ inline int operator()(const int &c) { return c + _shift; }
4089d52a580bSJunchao Zhang };
4090d52a580bSJunchao Zhang
4091d52a580bSJunchao Zhang /* merges two SeqAIJHIPSPARSE matrices A, B by concatenating their rows. [A';B']' operation in MATLAB notation */
MatSeqAIJHIPSPARSEMergeMats(Mat A,Mat B,MatReuse reuse,Mat * C)4092d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
4093d52a580bSJunchao Zhang {
4094d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data, *b = (Mat_SeqAIJ *)B->data, *c;
4095d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr, *Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr, *Ccusp;
4096d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Cmat;
4097d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr;
4098d52a580bSJunchao Zhang PetscInt Annz, Bnnz;
4099d52a580bSJunchao Zhang PetscInt i, m, n, zero = 0;
4100d52a580bSJunchao Zhang
4101d52a580bSJunchao Zhang PetscFunctionBegin;
4102d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4103d52a580bSJunchao Zhang PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
4104d52a580bSJunchao Zhang PetscAssertPointer(C, 4);
4105d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4106d52a580bSJunchao Zhang PetscCheckTypeName(B, MATSEQAIJHIPSPARSE);
4107d52a580bSJunchao Zhang PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
4108d52a580bSJunchao Zhang PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
4109d52a580bSJunchao Zhang PetscCheck(Acusp->format != MAT_HIPSPARSE_ELL && Acusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4110d52a580bSJunchao Zhang PetscCheck(Bcusp->format != MAT_HIPSPARSE_ELL && Bcusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4111d52a580bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) {
4112d52a580bSJunchao Zhang m = A->rmap->n;
4113d52a580bSJunchao Zhang n = A->cmap->n + B->cmap->n;
4114d52a580bSJunchao Zhang PetscCall(MatCreate(PETSC_COMM_SELF, C));
4115d52a580bSJunchao Zhang PetscCall(MatSetSizes(*C, m, n, m, n));
4116d52a580bSJunchao Zhang PetscCall(MatSetType(*C, MATSEQAIJHIPSPARSE));
4117d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)(*C)->data;
4118d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4119d52a580bSJunchao Zhang Cmat = new Mat_SeqAIJHIPSPARSEMultStruct;
4120d52a580bSJunchao Zhang Ccsr = new CsrMatrix;
4121d52a580bSJunchao Zhang Cmat->cprowIndices = NULL;
4122d52a580bSJunchao Zhang c->compressedrow.use = PETSC_FALSE;
4123d52a580bSJunchao Zhang c->compressedrow.nrows = 0;
4124d52a580bSJunchao Zhang c->compressedrow.i = NULL;
4125d52a580bSJunchao Zhang c->compressedrow.rindex = NULL;
4126d52a580bSJunchao Zhang Ccusp->workVector = NULL;
4127d52a580bSJunchao Zhang Ccusp->nrows = m;
4128d52a580bSJunchao Zhang Ccusp->mat = Cmat;
4129d52a580bSJunchao Zhang Ccusp->mat->mat = Ccsr;
4130d52a580bSJunchao Zhang Ccsr->num_rows = m;
4131d52a580bSJunchao Zhang Ccsr->num_cols = n;
4132d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
4133d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
4134d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4135d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
4136d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
4137d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
4138d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4139d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4140d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4141d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4142d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4143d52a580bSJunchao Zhang PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4144d52a580bSJunchao Zhang PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4145d52a580bSJunchao Zhang
4146d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat;
4147d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bcusp->mat->mat;
4148d52a580bSJunchao Zhang Annz = (PetscInt)Acsr->column_indices->size();
4149d52a580bSJunchao Zhang Bnnz = (PetscInt)Bcsr->column_indices->size();
4150d52a580bSJunchao Zhang c->nz = Annz + Bnnz;
4151d52a580bSJunchao Zhang Ccsr->row_offsets = new THRUSTINTARRAY32(m + 1);
4152d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
4153d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz);
4154d52a580bSJunchao Zhang Ccsr->num_entries = c->nz;
4155d52a580bSJunchao Zhang Ccusp->coords = new THRUSTINTARRAY(c->nz);
4156d52a580bSJunchao Zhang if (c->nz) {
4157d52a580bSJunchao Zhang auto Acoo = new THRUSTINTARRAY32(Annz);
4158d52a580bSJunchao Zhang auto Bcoo = new THRUSTINTARRAY32(Bnnz);
4159d52a580bSJunchao Zhang auto Ccoo = new THRUSTINTARRAY32(c->nz);
4160d52a580bSJunchao Zhang THRUSTINTARRAY32 *Aroff, *Broff;
4161d52a580bSJunchao Zhang
4162d52a580bSJunchao Zhang if (a->compressedrow.use) { /* need full row offset */
4163d52a580bSJunchao Zhang if (!Acusp->rowoffsets_gpu) {
4164d52a580bSJunchao Zhang Acusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
4165d52a580bSJunchao Zhang Acusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
4166d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
4167d52a580bSJunchao Zhang }
4168d52a580bSJunchao Zhang Aroff = Acusp->rowoffsets_gpu;
4169d52a580bSJunchao Zhang } else Aroff = Acsr->row_offsets;
4170d52a580bSJunchao Zhang if (b->compressedrow.use) { /* need full row offset */
4171d52a580bSJunchao Zhang if (!Bcusp->rowoffsets_gpu) {
4172d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
4173d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
4174d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
4175d52a580bSJunchao Zhang }
4176d52a580bSJunchao Zhang Broff = Bcusp->rowoffsets_gpu;
4177d52a580bSJunchao Zhang } else Broff = Bcsr->row_offsets;
4178d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
4179d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsr2coo(Acusp->handle, Aroff->data().get(), Annz, m, Acoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4180d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsr2coo(Bcusp->handle, Broff->data().get(), Bnnz, m, Bcoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4181d52a580bSJunchao Zhang /* Issues when using bool with large matrices on SUMMIT 10.2.89 */
4182d52a580bSJunchao Zhang auto Aperm = thrust::make_constant_iterator(1);
4183d52a580bSJunchao Zhang auto Bperm = thrust::make_constant_iterator(0);
4184d52a580bSJunchao Zhang auto Bcib = thrust::make_transform_iterator(Bcsr->column_indices->begin(), Shift(A->cmap->n));
4185d52a580bSJunchao Zhang auto Bcie = thrust::make_transform_iterator(Bcsr->column_indices->end(), Shift(A->cmap->n));
4186d52a580bSJunchao Zhang auto wPerm = new THRUSTINTARRAY32(Annz + Bnnz);
4187d52a580bSJunchao Zhang auto Azb = thrust::make_zip_iterator(thrust::make_tuple(Acoo->begin(), Acsr->column_indices->begin(), Acsr->values->begin(), Aperm));
4188d52a580bSJunchao Zhang auto Aze = thrust::make_zip_iterator(thrust::make_tuple(Acoo->end(), Acsr->column_indices->end(), Acsr->values->end(), Aperm));
4189d52a580bSJunchao Zhang auto Bzb = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->begin(), Bcib, Bcsr->values->begin(), Bperm));
4190d52a580bSJunchao Zhang auto Bze = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->end(), Bcie, Bcsr->values->end(), Bperm));
4191d52a580bSJunchao Zhang auto Czb = thrust::make_zip_iterator(thrust::make_tuple(Ccoo->begin(), Ccsr->column_indices->begin(), Ccsr->values->begin(), wPerm->begin()));
4192d52a580bSJunchao Zhang auto p1 = Ccusp->coords->begin();
4193d52a580bSJunchao Zhang auto p2 = Ccusp->coords->begin();
4194d52a580bSJunchao Zhang thrust::advance(p2, Annz);
4195d52a580bSJunchao Zhang PetscCallThrust(thrust::merge(thrust::device, Azb, Aze, Bzb, Bze, Czb, IJCompare4()));
4196d52a580bSJunchao Zhang auto cci = thrust::make_counting_iterator(zero);
4197d52a580bSJunchao Zhang auto cce = thrust::make_counting_iterator(c->nz);
4198d52a580bSJunchao Zhang #if 0 //Errors on SUMMIT cuda 11.1.0
4199d52a580bSJunchao Zhang PetscCallThrust(thrust::partition_copy(thrust::device, cci, cce, wPerm->begin(), p1, p2, thrust::identity<int>()));
4200d52a580bSJunchao Zhang #else
4201*5a884c48SSatish Balay auto pred = [](const int &x) { return x; };
4202d52a580bSJunchao Zhang PetscCallThrust(thrust::copy_if(thrust::device, cci, cce, wPerm->begin(), p1, pred));
4203d52a580bSJunchao Zhang PetscCallThrust(thrust::remove_copy_if(thrust::device, cci, cce, wPerm->begin(), p2, pred));
4204d52a580bSJunchao Zhang #endif
4205d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcoo2csr(Ccusp->handle, Ccoo->data().get(), c->nz, m, Ccsr->row_offsets->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4206d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
4207d52a580bSJunchao Zhang delete wPerm;
4208d52a580bSJunchao Zhang delete Acoo;
4209d52a580bSJunchao Zhang delete Bcoo;
4210d52a580bSJunchao Zhang delete Ccoo;
4211d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, Ccsr->num_entries, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4212d52a580bSJunchao Zhang
4213d52a580bSJunchao Zhang if (A->form_explicit_transpose && B->form_explicit_transpose) { /* if A and B have the transpose, generate C transpose too */
4214d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
4215d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
4216d52a580bSJunchao Zhang PetscBool AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4217d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *CmatT = new Mat_SeqAIJHIPSPARSEMultStruct;
4218d52a580bSJunchao Zhang CsrMatrix *CcsrT = new CsrMatrix;
4219d52a580bSJunchao Zhang CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4220d52a580bSJunchao Zhang CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4221d52a580bSJunchao Zhang
4222d52a580bSJunchao Zhang (*C)->form_explicit_transpose = PETSC_TRUE;
4223d52a580bSJunchao Zhang (*C)->transupdated = PETSC_TRUE;
4224d52a580bSJunchao Zhang Ccusp->rowoffsets_gpu = NULL;
4225d52a580bSJunchao Zhang CmatT->cprowIndices = NULL;
4226d52a580bSJunchao Zhang CmatT->mat = CcsrT;
4227d52a580bSJunchao Zhang CcsrT->num_rows = n;
4228d52a580bSJunchao Zhang CcsrT->num_cols = m;
4229d52a580bSJunchao Zhang CcsrT->num_entries = c->nz;
4230d52a580bSJunchao Zhang CcsrT->row_offsets = new THRUSTINTARRAY32(n + 1);
4231d52a580bSJunchao Zhang CcsrT->column_indices = new THRUSTINTARRAY32(c->nz);
4232d52a580bSJunchao Zhang CcsrT->values = new THRUSTARRAY(c->nz);
4233d52a580bSJunchao Zhang
4234d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
4235d52a580bSJunchao Zhang auto rT = CcsrT->row_offsets->begin();
4236d52a580bSJunchao Zhang if (AT) {
4237d52a580bSJunchao Zhang rT = thrust::copy(AcsrT->row_offsets->begin(), AcsrT->row_offsets->end(), rT);
4238d52a580bSJunchao Zhang thrust::advance(rT, -1);
4239d52a580bSJunchao Zhang }
4240d52a580bSJunchao Zhang if (BT) {
4241d52a580bSJunchao Zhang auto titb = thrust::make_transform_iterator(BcsrT->row_offsets->begin(), Shift(a->nz));
4242d52a580bSJunchao Zhang auto tite = thrust::make_transform_iterator(BcsrT->row_offsets->end(), Shift(a->nz));
4243d52a580bSJunchao Zhang thrust::copy(titb, tite, rT);
4244d52a580bSJunchao Zhang }
4245d52a580bSJunchao Zhang auto cT = CcsrT->column_indices->begin();
4246d52a580bSJunchao Zhang if (AT) cT = thrust::copy(AcsrT->column_indices->begin(), AcsrT->column_indices->end(), cT);
4247d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->column_indices->begin(), BcsrT->column_indices->end(), cT);
4248d52a580bSJunchao Zhang auto vT = CcsrT->values->begin();
4249d52a580bSJunchao Zhang if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4250d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4251d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
4252d52a580bSJunchao Zhang
4253d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&CmatT->descr));
4254d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(CmatT->descr, HIPSPARSE_INDEX_BASE_ZERO));
4255d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(CmatT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4256d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->alpha_one, sizeof(PetscScalar)));
4257d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->beta_zero, sizeof(PetscScalar)));
4258d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->beta_one, sizeof(PetscScalar)));
4259d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4260d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4261d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4262d52a580bSJunchao Zhang
4263d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&CmatT->matDescr, CcsrT->num_rows, CcsrT->num_cols, CcsrT->num_entries, CcsrT->row_offsets->data().get(), CcsrT->column_indices->data().get(), CcsrT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4264d52a580bSJunchao Zhang Ccusp->matTranspose = CmatT;
4265d52a580bSJunchao Zhang }
4266d52a580bSJunchao Zhang }
4267d52a580bSJunchao Zhang
4268d52a580bSJunchao Zhang c->free_a = PETSC_TRUE;
4269d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
4270d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
4271d52a580bSJunchao Zhang c->free_ij = PETSC_TRUE;
4272d52a580bSJunchao Zhang if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
4273d52a580bSJunchao Zhang THRUSTINTARRAY ii(Ccsr->row_offsets->size());
4274d52a580bSJunchao Zhang THRUSTINTARRAY jj(Ccsr->column_indices->size());
4275d52a580bSJunchao Zhang ii = *Ccsr->row_offsets;
4276d52a580bSJunchao Zhang jj = *Ccsr->column_indices;
4277d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4278d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4279d52a580bSJunchao Zhang } else {
4280d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4281d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4282d52a580bSJunchao Zhang }
4283d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
4284d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->ilen));
4285d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->imax));
4286d52a580bSJunchao Zhang c->maxnz = c->nz;
4287d52a580bSJunchao Zhang c->nonzerorowcnt = 0;
4288d52a580bSJunchao Zhang c->rmax = 0;
4289d52a580bSJunchao Zhang for (i = 0; i < m; i++) {
4290d52a580bSJunchao Zhang const PetscInt nn = c->i[i + 1] - c->i[i];
4291d52a580bSJunchao Zhang c->ilen[i] = c->imax[i] = nn;
4292d52a580bSJunchao Zhang c->nonzerorowcnt += (PetscInt)!!nn;
4293d52a580bSJunchao Zhang c->rmax = PetscMax(c->rmax, nn);
4294d52a580bSJunchao Zhang }
4295d52a580bSJunchao Zhang PetscCall(PetscMalloc1(c->nz, &c->a));
4296d52a580bSJunchao Zhang (*C)->nonzerostate++;
4297d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp((*C)->rmap));
4298d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp((*C)->cmap));
4299d52a580bSJunchao Zhang Ccusp->nonzerostate = (*C)->nonzerostate;
4300d52a580bSJunchao Zhang (*C)->preallocated = PETSC_TRUE;
4301d52a580bSJunchao Zhang } else {
4302d52a580bSJunchao Zhang PetscCheck((*C)->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, (*C)->rmap->n, B->rmap->n);
4303d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)(*C)->data;
4304d52a580bSJunchao Zhang if (c->nz) {
4305d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4306d52a580bSJunchao Zhang PetscCheck(Ccusp->coords, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing coords");
4307d52a580bSJunchao Zhang PetscCheck(Ccusp->format != MAT_HIPSPARSE_ELL && Ccusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4308d52a580bSJunchao Zhang PetscCheck(Ccusp->nonzerostate == (*C)->nonzerostate, PETSC_COMM_SELF, PETSC_ERR_COR, "Wrong nonzerostate");
4309d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4310d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4311d52a580bSJunchao Zhang PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4312d52a580bSJunchao Zhang PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4313d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat;
4314d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bcusp->mat->mat;
4315d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Ccusp->mat->mat;
4316d52a580bSJunchao Zhang PetscCheck(Acsr->num_entries == (PetscInt)Acsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "A nnz %" PetscInt_FMT " != %" PetscInt_FMT, Acsr->num_entries, (PetscInt)Acsr->values->size());
4317d52a580bSJunchao Zhang PetscCheck(Bcsr->num_entries == (PetscInt)Bcsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "B nnz %" PetscInt_FMT " != %" PetscInt_FMT, Bcsr->num_entries, (PetscInt)Bcsr->values->size());
4318d52a580bSJunchao Zhang PetscCheck(Ccsr->num_entries == (PetscInt)Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT, Ccsr->num_entries, (PetscInt)Ccsr->values->size());
4319d52a580bSJunchao Zhang PetscCheck(Ccsr->num_entries == Acsr->num_entries + Bcsr->num_entries, PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT " + %" PetscInt_FMT, Ccsr->num_entries, Acsr->num_entries, Bcsr->num_entries);
4320d52a580bSJunchao Zhang PetscCheck(Ccusp->coords->size() == Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "permSize %" PetscInt_FMT " != %" PetscInt_FMT, (PetscInt)Ccusp->coords->size(), (PetscInt)Ccsr->values->size());
4321d52a580bSJunchao Zhang auto pmid = Ccusp->coords->begin();
4322d52a580bSJunchao Zhang thrust::advance(pmid, Acsr->num_entries);
4323d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin());
4324d52a580bSJunchao Zhang auto zibait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->begin())));
4325d52a580bSJunchao Zhang auto zieait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4326d52a580bSJunchao Zhang thrust::for_each(zibait, zieait, VecHIPEquals());
4327d52a580bSJunchao Zhang auto zibbit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4328d52a580bSJunchao Zhang auto ziebit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->end())));
4329d52a580bSJunchao Zhang thrust::for_each(zibbit, ziebit, VecHIPEquals());
4330d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(*C, PETSC_FALSE));
4331d52a580bSJunchao Zhang if (A->form_explicit_transpose && B->form_explicit_transpose && (*C)->form_explicit_transpose) {
4332d52a580bSJunchao Zhang PetscCheck(Ccusp->matTranspose, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing transpose Mat_SeqAIJHIPSPARSEMultStruct");
4333d52a580bSJunchao Zhang PetscBool AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4334d52a580bSJunchao Zhang CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4335d52a580bSJunchao Zhang CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4336d52a580bSJunchao Zhang CsrMatrix *CcsrT = (CsrMatrix *)Ccusp->matTranspose->mat;
4337d52a580bSJunchao Zhang auto vT = CcsrT->values->begin();
4338d52a580bSJunchao Zhang if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4339d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4340d52a580bSJunchao Zhang (*C)->transupdated = PETSC_TRUE;
4341d52a580bSJunchao Zhang }
4342d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd());
4343d52a580bSJunchao Zhang }
4344d52a580bSJunchao Zhang }
4345d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)*C));
4346d52a580bSJunchao Zhang (*C)->assembled = PETSC_TRUE;
4347d52a580bSJunchao Zhang (*C)->was_assembled = PETSC_FALSE;
4348d52a580bSJunchao Zhang (*C)->offloadmask = PETSC_OFFLOAD_GPU;
4349d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
4350d52a580bSJunchao Zhang }
4351d52a580bSJunchao Zhang
MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A,PetscInt n,const PetscInt idx[],PetscScalar v[])4352d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A, PetscInt n, const PetscInt idx[], PetscScalar v[])
4353d52a580bSJunchao Zhang {
4354d52a580bSJunchao Zhang bool dmem;
4355d52a580bSJunchao Zhang const PetscScalar *av;
4356d52a580bSJunchao Zhang
4357d52a580bSJunchao Zhang PetscFunctionBegin;
4358d52a580bSJunchao Zhang dmem = isHipMem(v);
4359d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(A, &av));
4360d52a580bSJunchao Zhang if (n && idx) {
4361d52a580bSJunchao Zhang THRUSTINTARRAY widx(n);
4362d52a580bSJunchao Zhang widx.assign(idx, idx + n);
4363d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
4364d52a580bSJunchao Zhang
4365d52a580bSJunchao Zhang THRUSTARRAY *w = NULL;
4366d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> dv;
4367d52a580bSJunchao Zhang if (dmem) dv = thrust::device_pointer_cast(v);
4368d52a580bSJunchao Zhang else {
4369d52a580bSJunchao Zhang w = new THRUSTARRAY(n);
4370d52a580bSJunchao Zhang dv = w->data();
4371d52a580bSJunchao Zhang }
4372d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> dav = thrust::device_pointer_cast(av);
4373d52a580bSJunchao Zhang
4374d52a580bSJunchao Zhang auto zibit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.begin()), dv));
4375d52a580bSJunchao Zhang auto zieit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.end()), dv + n));
4376d52a580bSJunchao Zhang thrust::for_each(zibit, zieit, VecHIPEquals());
4377d52a580bSJunchao Zhang if (w) PetscCallHIP(hipMemcpy(v, w->data().get(), n * sizeof(PetscScalar), hipMemcpyDeviceToHost));
4378d52a580bSJunchao Zhang delete w;
4379d52a580bSJunchao Zhang } else PetscCallHIP(hipMemcpy(v, av, n * sizeof(PetscScalar), dmem ? hipMemcpyDeviceToDevice : hipMemcpyDeviceToHost));
4380d52a580bSJunchao Zhang
4381d52a580bSJunchao Zhang if (!dmem) PetscCall(PetscLogCpuToGpu(n * sizeof(PetscScalar)));
4382d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(A, &av));
4383d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS);
4384d52a580bSJunchao Zhang }
4385