xref: /petsc/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cxx (revision 5a884c48ab0c46bab83cd9bb8710f380fa6d8bcf)
1d52a580bSJunchao Zhang /*
2d52a580bSJunchao Zhang   Defines the basic matrix operations for the AIJ (compressed row)
3d52a580bSJunchao Zhang   matrix storage format using the HIPSPARSE library,
4d52a580bSJunchao Zhang   Portions of this code are under:
5d52a580bSJunchao Zhang   Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
6d52a580bSJunchao Zhang */
7d52a580bSJunchao Zhang #include <petscconf.h>
8d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
9d52a580bSJunchao Zhang #include <../src/mat/impls/sbaij/seq/sbaij.h>
10d52a580bSJunchao Zhang #include <../src/mat/impls/dense/seq/dense.h> // MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal()
11d52a580bSJunchao Zhang #include <../src/vec/vec/impls/dvecimpl.h>
12d52a580bSJunchao Zhang #include <petsc/private/vecimpl.h>
13d52a580bSJunchao Zhang #undef VecType
14d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/seqhipsparse/hipsparsematimpl.h>
15d52a580bSJunchao Zhang #include <thrust/adjacent_difference.h>
16d52a580bSJunchao Zhang #include <thrust/iterator/transform_iterator.h>
17d52a580bSJunchao Zhang #if PETSC_CPP_VERSION >= 14
18d52a580bSJunchao Zhang   #define PETSC_HAVE_THRUST_ASYNC 1
19d52a580bSJunchao Zhang   #include <thrust/async/for_each.h>
20d52a580bSJunchao Zhang #endif
21d52a580bSJunchao Zhang #include <thrust/iterator/constant_iterator.h>
22d52a580bSJunchao Zhang #include <thrust/iterator/discard_iterator.h>
23d52a580bSJunchao Zhang #include <thrust/binary_search.h>
24d52a580bSJunchao Zhang #include <thrust/remove.h>
25d52a580bSJunchao Zhang #include <thrust/sort.h>
26d52a580bSJunchao Zhang #include <thrust/unique.h>
27d52a580bSJunchao Zhang 
28d52a580bSJunchao Zhang const char *const MatHIPSPARSEStorageFormats[] = {"CSR", "ELL", "HYB", "MatHIPSPARSEStorageFormat", "MAT_HIPSPARSE_", 0};
29d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMVAlgorithms[] = {"MV_ALG_DEFAULT", "COOMV_ALG", "CSRMV_ALG1", "CSRMV_ALG2", "SPMV_ALG_DEFAULT", "SPMV_COO_ALG1", "SPMV_COO_ALG2", "SPMV_CSR_ALG1", "SPMV_CSR_ALG2", "hipsparseSpMVAlg_t", "HIPSPARSE_", 0};
30d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMMAlgorithms[] = {"ALG_DEFAULT", "COO_ALG1", "COO_ALG2", "COO_ALG3", "CSR_ALG1", "COO_ALG4", "CSR_ALG2", "hipsparseSpMMAlg_t", "HIPSPARSE_SPMM_", 0};
31d52a580bSJunchao Zhang //const char *const MatHIPSPARSECsr2CscAlgorithms[] = {"INVALID"/*HIPSPARSE does not have enum 0! We created one*/, "ALG1", "ALG2", "hipsparseCsr2CscAlg_t", "HIPSPARSE_CSR2CSC_", 0};
32d52a580bSJunchao Zhang 
33d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
34d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
35d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
36d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
37d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
38d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
39d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat, Vec, Vec);
40d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
41d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
42d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
43d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat, PetscOptionItems PetscOptionsObject);
44d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat, PetscScalar, Mat, MatStructure);
45d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat, PetscScalar);
46d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat, Vec, Vec);
47d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
48d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
49d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
50d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
51d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
52d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec, PetscBool, PetscBool);
53d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **);
54d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **);
55d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **, MatHIPSPARSEStorageFormat);
56d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **);
57d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat);
58d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat);
59d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat);
60d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat, PetscBool);
61d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat, PetscInt, const PetscInt[], PetscScalar[]);
62d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat, PetscBool);
63d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat, PetscCount, PetscInt[], PetscInt[]);
64d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat, const PetscScalar[], InsertMode);
65d52a580bSJunchao Zhang 
66d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_SeqAIJ_SeqDense(Mat);
67d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat, MatType, MatReuse, Mat *);
68d52a580bSJunchao Zhang 
69d52a580bSJunchao Zhang /*
70d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetStream(Mat A, const hipStream_t stream)
71d52a580bSJunchao Zhang {
72d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
73d52a580bSJunchao Zhang 
74d52a580bSJunchao Zhang   PetscFunctionBegin;
75d52a580bSJunchao Zhang   PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
76d52a580bSJunchao Zhang   hipsparsestruct->stream = stream;
77d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetStream(hipsparsestruct->handle, hipsparsestruct->stream));
78d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
79d52a580bSJunchao Zhang }
80d52a580bSJunchao Zhang 
81d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetHandle(Mat A, const hipsparseHandle_t handle)
82d52a580bSJunchao Zhang {
83d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
84d52a580bSJunchao Zhang 
85d52a580bSJunchao Zhang   PetscFunctionBegin;
86d52a580bSJunchao Zhang   PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
87d52a580bSJunchao Zhang   if (hipsparsestruct->handle != handle) {
88d52a580bSJunchao Zhang     if (hipsparsestruct->handle) PetscCallHIPSPARSE(hipsparseDestroy(hipsparsestruct->handle));
89d52a580bSJunchao Zhang     hipsparsestruct->handle = handle;
90d52a580bSJunchao Zhang   }
91d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
92d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
93d52a580bSJunchao Zhang }
94d52a580bSJunchao Zhang 
95d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSEClearHandle(Mat A)
96d52a580bSJunchao Zhang {
97d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
98d52a580bSJunchao Zhang   PetscBool            flg;
99d52a580bSJunchao Zhang 
100d52a580bSJunchao Zhang   PetscFunctionBegin;
101d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
102d52a580bSJunchao Zhang   if (!flg || !hipsparsestruct) PetscFunctionReturn(PETSC_SUCCESS);
103d52a580bSJunchao Zhang   if (hipsparsestruct->handle) hipsparsestruct->handle = 0;
104d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
105d52a580bSJunchao Zhang }
106d52a580bSJunchao Zhang */
107d52a580bSJunchao Zhang 
MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A,MatHIPSPARSEFormatOperation op,MatHIPSPARSEStorageFormat format)108d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
109d52a580bSJunchao Zhang {
110d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
111d52a580bSJunchao Zhang 
112d52a580bSJunchao Zhang   PetscFunctionBegin;
113d52a580bSJunchao Zhang   switch (op) {
114d52a580bSJunchao Zhang   case MAT_HIPSPARSE_MULT:
115d52a580bSJunchao Zhang     hipsparsestruct->format = format;
116d52a580bSJunchao Zhang     break;
117d52a580bSJunchao Zhang   case MAT_HIPSPARSE_ALL:
118d52a580bSJunchao Zhang     hipsparsestruct->format = format;
119d52a580bSJunchao Zhang     break;
120d52a580bSJunchao Zhang   default:
121d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "unsupported operation %d for MatHIPSPARSEFormatOperation. MAT_HIPSPARSE_MULT and MAT_HIPSPARSE_ALL are currently supported.", op);
122d52a580bSJunchao Zhang   }
123d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
124d52a580bSJunchao Zhang }
125d52a580bSJunchao Zhang 
126d52a580bSJunchao Zhang /*@
127d52a580bSJunchao Zhang   MatHIPSPARSESetFormat - Sets the storage format of `MATSEQHIPSPARSE` matrices for a particular
128d52a580bSJunchao Zhang   operation. Only the `MatMult()` operation can use different GPU storage formats
129d52a580bSJunchao Zhang 
130d52a580bSJunchao Zhang   Not Collective
131d52a580bSJunchao Zhang 
132d52a580bSJunchao Zhang   Input Parameters:
133d52a580bSJunchao Zhang + A      - Matrix of type `MATSEQAIJHIPSPARSE`
134d52a580bSJunchao Zhang . op     - `MatHIPSPARSEFormatOperation`. `MATSEQAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT` and `MAT_HIPSPARSE_ALL`.
135d52a580bSJunchao Zhang          `MATMPIAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT_DIAG`, `MAT_HIPSPARSE_MULT_OFFDIAG`, and `MAT_HIPSPARSE_ALL`.
136d52a580bSJunchao Zhang - format - `MatHIPSPARSEStorageFormat` (one of `MAT_HIPSPARSE_CSR`, `MAT_HIPSPARSE_ELL`, `MAT_HIPSPARSE_HYB`.)
137d52a580bSJunchao Zhang 
138d52a580bSJunchao Zhang   Level: intermediate
139d52a580bSJunchao Zhang 
140d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
141d52a580bSJunchao Zhang @*/
MatHIPSPARSESetFormat(Mat A,MatHIPSPARSEFormatOperation op,MatHIPSPARSEStorageFormat format)142d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetFormat(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
143d52a580bSJunchao Zhang {
144d52a580bSJunchao Zhang   PetscFunctionBegin;
145d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
146d52a580bSJunchao Zhang   PetscTryMethod(A, "MatHIPSPARSESetFormat_C", (Mat, MatHIPSPARSEFormatOperation, MatHIPSPARSEStorageFormat), (A, op, format));
147d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
148d52a580bSJunchao Zhang }
149d52a580bSJunchao Zhang 
MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A,PetscBool use_cpu)150d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A, PetscBool use_cpu)
151d52a580bSJunchao Zhang {
152d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
153d52a580bSJunchao Zhang 
154d52a580bSJunchao Zhang   PetscFunctionBegin;
155d52a580bSJunchao Zhang   hipsparsestruct->use_cpu_solve = use_cpu;
156d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
157d52a580bSJunchao Zhang }
158d52a580bSJunchao Zhang 
159d52a580bSJunchao Zhang /*@
160d52a580bSJunchao Zhang   MatHIPSPARSESetUseCPUSolve - Sets use CPU `MatSolve()`.
161d52a580bSJunchao Zhang 
162d52a580bSJunchao Zhang   Input Parameters:
163d52a580bSJunchao Zhang + A       - Matrix of type `MATSEQAIJHIPSPARSE`
164d52a580bSJunchao Zhang - use_cpu - set flag for using the built-in CPU `MatSolve()`
165d52a580bSJunchao Zhang 
166d52a580bSJunchao Zhang   Level: intermediate
167d52a580bSJunchao Zhang 
168d52a580bSJunchao Zhang   Notes:
169d52a580bSJunchao Zhang   The hipSparse LU solver currently computes the factors with the built-in CPU method
170d52a580bSJunchao Zhang   and moves the factors to the GPU for the solve. We have observed better performance keeping the data on the CPU and computing the solve there.
171d52a580bSJunchao Zhang   This method to specifies if the solve is done on the CPU or GPU (GPU is the default).
172d52a580bSJunchao Zhang 
173d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSolve()`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
174d52a580bSJunchao Zhang @*/
MatHIPSPARSESetUseCPUSolve(Mat A,PetscBool use_cpu)175d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetUseCPUSolve(Mat A, PetscBool use_cpu)
176d52a580bSJunchao Zhang {
177d52a580bSJunchao Zhang   PetscFunctionBegin;
178d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
179d52a580bSJunchao Zhang   PetscTryMethod(A, "MatHIPSPARSESetUseCPUSolve_C", (Mat, PetscBool), (A, use_cpu));
180d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
181d52a580bSJunchao Zhang }
182d52a580bSJunchao Zhang 
MatSetOption_SeqAIJHIPSPARSE(Mat A,MatOption op,PetscBool flg)183d52a580bSJunchao Zhang static PetscErrorCode MatSetOption_SeqAIJHIPSPARSE(Mat A, MatOption op, PetscBool flg)
184d52a580bSJunchao Zhang {
185d52a580bSJunchao Zhang   PetscFunctionBegin;
186d52a580bSJunchao Zhang   switch (op) {
187d52a580bSJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
188d52a580bSJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
189d52a580bSJunchao Zhang     if (A->form_explicit_transpose && !flg) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
190d52a580bSJunchao Zhang     A->form_explicit_transpose = flg;
191d52a580bSJunchao Zhang     break;
192d52a580bSJunchao Zhang   default:
193d52a580bSJunchao Zhang     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
194d52a580bSJunchao Zhang     break;
195d52a580bSJunchao Zhang   }
196d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
197d52a580bSJunchao Zhang }
198d52a580bSJunchao Zhang 
MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B,Mat A,const MatFactorInfo * info)199d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
200d52a580bSJunchao Zhang {
201d52a580bSJunchao Zhang   PetscBool            row_identity, col_identity;
202d52a580bSJunchao Zhang   Mat_SeqAIJ          *b     = (Mat_SeqAIJ *)B->data;
203d52a580bSJunchao Zhang   IS                   isrow = b->row, iscol = b->col;
204d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)B->spptr;
205d52a580bSJunchao Zhang 
206d52a580bSJunchao Zhang   PetscFunctionBegin;
207d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
208d52a580bSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
209d52a580bSJunchao Zhang   B->offloadmask = PETSC_OFFLOAD_CPU;
210d52a580bSJunchao Zhang   /* determine which version of MatSolve needs to be used. */
211d52a580bSJunchao Zhang   PetscCall(ISIdentity(isrow, &row_identity));
212d52a580bSJunchao Zhang   PetscCall(ISIdentity(iscol, &col_identity));
213d52a580bSJunchao Zhang   if (!hipsparsestruct->use_cpu_solve) {
214d52a580bSJunchao Zhang     if (row_identity && col_identity) {
215d52a580bSJunchao Zhang       B->ops->solve          = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
216d52a580bSJunchao Zhang       B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
217d52a580bSJunchao Zhang     } else {
218d52a580bSJunchao Zhang       B->ops->solve          = MatSolve_SeqAIJHIPSPARSE;
219d52a580bSJunchao Zhang       B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE;
220d52a580bSJunchao Zhang     }
221d52a580bSJunchao Zhang   }
222d52a580bSJunchao Zhang   B->ops->matsolve          = NULL;
223d52a580bSJunchao Zhang   B->ops->matsolvetranspose = NULL;
224d52a580bSJunchao Zhang 
225d52a580bSJunchao Zhang   /* get the triangular factors */
226d52a580bSJunchao Zhang   if (!hipsparsestruct->use_cpu_solve) PetscCall(MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(B));
227d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
228d52a580bSJunchao Zhang }
229d52a580bSJunchao Zhang 
MatSetFromOptions_SeqAIJHIPSPARSE(Mat A,PetscOptionItems PetscOptionsObject)230d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat A, PetscOptionItems PetscOptionsObject)
231d52a580bSJunchao Zhang {
232d52a580bSJunchao Zhang   MatHIPSPARSEStorageFormat format;
233d52a580bSJunchao Zhang   PetscBool                 flg;
234d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE      *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
235d52a580bSJunchao Zhang 
236d52a580bSJunchao Zhang   PetscFunctionBegin;
237d52a580bSJunchao Zhang   PetscOptionsHeadBegin(PetscOptionsObject, "SeqAIJHIPSPARSE options");
238d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
239d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
240d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT, format));
241d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV and TriSolve", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
242d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_ALL, format));
243d52a580bSJunchao Zhang     PetscCall(PetscOptionsBool("-mat_hipsparse_use_cpu_solve", "Use CPU (I)LU solve", "MatHIPSPARSESetUseCPUSolve", hipsparsestruct->use_cpu_solve, &hipsparsestruct->use_cpu_solve, &flg));
244d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetUseCPUSolve(A, hipsparsestruct->use_cpu_solve));
245d52a580bSJunchao Zhang     PetscCall(
246d52a580bSJunchao Zhang       PetscOptionsEnum("-mat_hipsparse_spmv_alg", "sets hipSPARSE algorithm used in sparse-mat dense-vector multiplication (SpMV)", "hipsparseSpMVAlg_t", MatHIPSPARSESpMVAlgorithms, (PetscEnum)hipsparsestruct->spmvAlg, (PetscEnum *)&hipsparsestruct->spmvAlg, &flg));
247d52a580bSJunchao Zhang     /* If user did use this option, check its consistency with hipSPARSE, since PetscOptionsEnum() sets enum values based on their position in MatHIPSPARSESpMVAlgorithms[] */
248d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_CSRMV_ALG1 == 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMVAlg_t has been changed but PETSc has not been updated accordingly");
249d52a580bSJunchao Zhang     PetscCall(
250d52a580bSJunchao Zhang       PetscOptionsEnum("-mat_hipsparse_spmm_alg", "sets hipSPARSE algorithm used in sparse-mat dense-mat multiplication (SpMM)", "hipsparseSpMMAlg_t", MatHIPSPARSESpMMAlgorithms, (PetscEnum)hipsparsestruct->spmmAlg, (PetscEnum *)&hipsparsestruct->spmmAlg, &flg));
251d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_SPMM_CSR_ALG1 == 4, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMMAlg_t has been changed but PETSc has not been updated accordingly");
252d52a580bSJunchao Zhang     /*
253d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_csr2csc_alg", "sets hipSPARSE algorithm used in converting CSR matrices to CSC matrices", "hipsparseCsr2CscAlg_t", MatHIPSPARSECsr2CscAlgorithms, (PetscEnum)hipsparsestruct->csr2cscAlg, (PetscEnum*)&hipsparsestruct->csr2cscAlg, &flg));
254d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_CSR2CSC_ALG1 == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseCsr2CscAlg_t has been changed but PETSc has not been updated accordingly");
255d52a580bSJunchao Zhang     */
256d52a580bSJunchao Zhang   }
257d52a580bSJunchao Zhang   PetscOptionsHeadEnd();
258d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
259d52a580bSJunchao Zhang }
260d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A)261d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A)
262d52a580bSJunchao Zhang {
263d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
264d52a580bSJunchao Zhang   PetscInt                            n                   = A->rmap->n;
265d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
266d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
267d52a580bSJunchao Zhang   const PetscInt                     *ai = a->i, *aj = a->j, *vi;
268d52a580bSJunchao Zhang   const MatScalar                    *aa = a->a, *v;
269d52a580bSJunchao Zhang   PetscInt                           *AiLo, *AjLo;
270d52a580bSJunchao Zhang   PetscInt                            i, nz, nzLower, offset, rowOffset;
271d52a580bSJunchao Zhang 
272d52a580bSJunchao Zhang   PetscFunctionBegin;
273d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
274d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
275d52a580bSJunchao Zhang     try {
276d52a580bSJunchao Zhang       /* first figure out the number of nonzeros in the lower triangular matrix including 1's on the diagonal. */
277d52a580bSJunchao Zhang       nzLower = n + ai[n] - ai[1];
278d52a580bSJunchao Zhang       if (!loTriFactor) {
279d52a580bSJunchao Zhang         PetscScalar *AALo;
280d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AALo, nzLower * sizeof(PetscScalar)));
281d52a580bSJunchao Zhang 
282d52a580bSJunchao Zhang         /* Allocate Space for the lower triangular matrix */
283d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiLo, (n + 1) * sizeof(PetscInt)));
284d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjLo, nzLower * sizeof(PetscInt)));
285d52a580bSJunchao Zhang 
286d52a580bSJunchao Zhang         /* Fill the lower triangular matrix */
287d52a580bSJunchao Zhang         AiLo[0]   = (PetscInt)0;
288d52a580bSJunchao Zhang         AiLo[n]   = nzLower;
289d52a580bSJunchao Zhang         AjLo[0]   = (PetscInt)0;
290d52a580bSJunchao Zhang         AALo[0]   = (MatScalar)1.0;
291d52a580bSJunchao Zhang         v         = aa;
292d52a580bSJunchao Zhang         vi        = aj;
293d52a580bSJunchao Zhang         offset    = 1;
294d52a580bSJunchao Zhang         rowOffset = 1;
295d52a580bSJunchao Zhang         for (i = 1; i < n; i++) {
296d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i];
297d52a580bSJunchao Zhang           /* additional 1 for the term on the diagonal */
298d52a580bSJunchao Zhang           AiLo[i] = rowOffset;
299d52a580bSJunchao Zhang           rowOffset += nz + 1;
300d52a580bSJunchao Zhang 
301d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AjLo[offset], vi, nz));
302d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AALo[offset], v, nz));
303d52a580bSJunchao Zhang           offset += nz;
304d52a580bSJunchao Zhang           AjLo[offset] = (PetscInt)i;
305d52a580bSJunchao Zhang           AALo[offset] = (MatScalar)1.0;
306d52a580bSJunchao Zhang           offset += 1;
307d52a580bSJunchao Zhang           v += nz;
308d52a580bSJunchao Zhang           vi += nz;
309d52a580bSJunchao Zhang         }
310d52a580bSJunchao Zhang 
311d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
312d52a580bSJunchao Zhang         PetscCall(PetscNew(&loTriFactor));
313d52a580bSJunchao Zhang         loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
314d52a580bSJunchao Zhang         /* Create the matrix description */
315d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
316d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
317d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
318d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_LOWER));
319d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
320d52a580bSJunchao Zhang 
321d52a580bSJunchao Zhang         /* set the operation */
322d52a580bSJunchao Zhang         loTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
323d52a580bSJunchao Zhang 
324d52a580bSJunchao Zhang         /* set the matrix */
325d52a580bSJunchao Zhang         loTriFactor->csrMat                 = new CsrMatrix;
326d52a580bSJunchao Zhang         loTriFactor->csrMat->num_rows       = n;
327d52a580bSJunchao Zhang         loTriFactor->csrMat->num_cols       = n;
328d52a580bSJunchao Zhang         loTriFactor->csrMat->num_entries    = nzLower;
329d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(n + 1);
330d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzLower);
331d52a580bSJunchao Zhang         loTriFactor->csrMat->values         = new THRUSTARRAY(nzLower);
332d52a580bSJunchao Zhang 
333d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets->assign(AiLo, AiLo + n + 1);
334d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices->assign(AjLo, AjLo + nzLower);
335d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + nzLower);
336d52a580bSJunchao Zhang 
337d52a580bSJunchao Zhang         /* Create the solve analysis information */
338d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
339d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
340d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
341d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
342d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
343d52a580bSJunchao Zhang 
344d52a580bSJunchao Zhang         /* perform the solve analysis */
345d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
346d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
347d52a580bSJunchao Zhang 
348d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
349d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
350d52a580bSJunchao Zhang 
351d52a580bSJunchao Zhang         /* assign the pointer */
352d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
353d52a580bSJunchao Zhang         loTriFactor->AA_h                                           = AALo;
354d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiLo));
355d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjLo));
356d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((n + 1 + nzLower) * sizeof(int) + nzLower * sizeof(PetscScalar)));
357d52a580bSJunchao Zhang       } else { /* update values only */
358d52a580bSJunchao Zhang         if (!loTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&loTriFactor->AA_h, nzLower * sizeof(PetscScalar)));
359d52a580bSJunchao Zhang         /* Fill the lower triangular matrix */
360d52a580bSJunchao Zhang         loTriFactor->AA_h[0] = 1.0;
361d52a580bSJunchao Zhang         v                    = aa;
362d52a580bSJunchao Zhang         vi                   = aj;
363d52a580bSJunchao Zhang         offset               = 1;
364d52a580bSJunchao Zhang         for (i = 1; i < n; i++) {
365d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i];
366d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&loTriFactor->AA_h[offset], v, nz));
367d52a580bSJunchao Zhang           offset += nz;
368d52a580bSJunchao Zhang           loTriFactor->AA_h[offset] = 1.0;
369d52a580bSJunchao Zhang           offset += 1;
370d52a580bSJunchao Zhang           v += nz;
371d52a580bSJunchao Zhang         }
372d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(loTriFactor->AA_h, loTriFactor->AA_h + nzLower);
373d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(nzLower * sizeof(PetscScalar)));
374d52a580bSJunchao Zhang       }
375d52a580bSJunchao Zhang     } catch (char *ex) {
376d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
377d52a580bSJunchao Zhang     }
378d52a580bSJunchao Zhang   }
379d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
380d52a580bSJunchao Zhang }
381d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A)382d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A)
383d52a580bSJunchao Zhang {
384d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
385d52a580bSJunchao Zhang   PetscInt                            n                   = A->rmap->n;
386d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
387d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
388d52a580bSJunchao Zhang   const PetscInt                     *aj                  = a->j, *adiag, *vi;
389d52a580bSJunchao Zhang   const MatScalar                    *aa                  = a->a, *v;
390d52a580bSJunchao Zhang   PetscInt                           *AiUp, *AjUp;
391d52a580bSJunchao Zhang   PetscInt                            i, nz, nzUpper, offset;
392d52a580bSJunchao Zhang 
393d52a580bSJunchao Zhang   PetscFunctionBegin;
394d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
395d52a580bSJunchao Zhang   PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &adiag, NULL));
396d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
397d52a580bSJunchao Zhang     try {
398d52a580bSJunchao Zhang       /* next, figure out the number of nonzeros in the upper triangular matrix. */
399d52a580bSJunchao Zhang       nzUpper = adiag[0] - adiag[n];
400d52a580bSJunchao Zhang       if (!upTriFactor) {
401d52a580bSJunchao Zhang         PetscScalar *AAUp;
402d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
403d52a580bSJunchao Zhang 
404d52a580bSJunchao Zhang         /* Allocate Space for the upper triangular matrix */
405d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
406d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
407d52a580bSJunchao Zhang 
408d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
409d52a580bSJunchao Zhang         AiUp[0] = (PetscInt)0;
410d52a580bSJunchao Zhang         AiUp[n] = nzUpper;
411d52a580bSJunchao Zhang         offset  = nzUpper;
412d52a580bSJunchao Zhang         for (i = n - 1; i >= 0; i--) {
413d52a580bSJunchao Zhang           v  = aa + adiag[i + 1] + 1;
414d52a580bSJunchao Zhang           vi = aj + adiag[i + 1] + 1;
415d52a580bSJunchao Zhang           nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
416d52a580bSJunchao Zhang           offset -= (nz + 1);               /* decrement the offset */
417d52a580bSJunchao Zhang 
418d52a580bSJunchao Zhang           /* first, set the diagonal elements */
419d52a580bSJunchao Zhang           AjUp[offset] = (PetscInt)i;
420d52a580bSJunchao Zhang           AAUp[offset] = (MatScalar)1. / v[nz];
421d52a580bSJunchao Zhang           AiUp[i]      = AiUp[i + 1] - (nz + 1);
422d52a580bSJunchao Zhang 
423d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AjUp[offset + 1], vi, nz));
424d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AAUp[offset + 1], v, nz));
425d52a580bSJunchao Zhang         }
426d52a580bSJunchao Zhang 
427d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
428d52a580bSJunchao Zhang         PetscCall(PetscNew(&upTriFactor));
429d52a580bSJunchao Zhang         upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
430d52a580bSJunchao Zhang 
431d52a580bSJunchao Zhang         /* Create the matrix description */
432d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
433d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
434d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
435d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
436d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
437d52a580bSJunchao Zhang 
438d52a580bSJunchao Zhang         /* set the operation */
439d52a580bSJunchao Zhang         upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
440d52a580bSJunchao Zhang 
441d52a580bSJunchao Zhang         /* set the matrix */
442d52a580bSJunchao Zhang         upTriFactor->csrMat                 = new CsrMatrix;
443d52a580bSJunchao Zhang         upTriFactor->csrMat->num_rows       = n;
444d52a580bSJunchao Zhang         upTriFactor->csrMat->num_cols       = n;
445d52a580bSJunchao Zhang         upTriFactor->csrMat->num_entries    = nzUpper;
446d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(n + 1);
447d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzUpper);
448d52a580bSJunchao Zhang         upTriFactor->csrMat->values         = new THRUSTARRAY(nzUpper);
449d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + n + 1);
450d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + nzUpper);
451d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + nzUpper);
452d52a580bSJunchao Zhang 
453d52a580bSJunchao Zhang         /* Create the solve analysis information */
454d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
455d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
456d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
457d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
458d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
459d52a580bSJunchao Zhang 
460d52a580bSJunchao Zhang         /* perform the solve analysis */
461d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
462d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
463d52a580bSJunchao Zhang 
464d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
465d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
466d52a580bSJunchao Zhang 
467d52a580bSJunchao Zhang         /* assign the pointer */
468d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
469d52a580bSJunchao Zhang         upTriFactor->AA_h                                           = AAUp;
470d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiUp));
471d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjUp));
472d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((n + 1 + nzUpper) * sizeof(int) + nzUpper * sizeof(PetscScalar)));
473d52a580bSJunchao Zhang       } else {
474d52a580bSJunchao Zhang         if (!upTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&upTriFactor->AA_h, nzUpper * sizeof(PetscScalar)));
475d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
476d52a580bSJunchao Zhang         offset = nzUpper;
477d52a580bSJunchao Zhang         for (i = n - 1; i >= 0; i--) {
478d52a580bSJunchao Zhang           v  = aa + adiag[i + 1] + 1;
479d52a580bSJunchao Zhang           nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
480d52a580bSJunchao Zhang           offset -= (nz + 1);               /* decrement the offset */
481d52a580bSJunchao Zhang 
482d52a580bSJunchao Zhang           /* first, set the diagonal elements */
483d52a580bSJunchao Zhang           upTriFactor->AA_h[offset] = 1. / v[nz];
484d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&upTriFactor->AA_h[offset + 1], v, nz));
485d52a580bSJunchao Zhang         }
486d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(upTriFactor->AA_h, upTriFactor->AA_h + nzUpper);
487d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(nzUpper * sizeof(PetscScalar)));
488d52a580bSJunchao Zhang       }
489d52a580bSJunchao Zhang     } catch (char *ex) {
490d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
491d52a580bSJunchao Zhang     }
492d52a580bSJunchao Zhang   }
493d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
494d52a580bSJunchao Zhang }
495d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A)496d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A)
497d52a580bSJunchao Zhang {
498d52a580bSJunchao Zhang   PetscBool                      row_identity, col_identity;
499d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a                   = (Mat_SeqAIJ *)A->data;
500d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
501d52a580bSJunchao Zhang   IS                             isrow = a->row, iscol = a->icol;
502d52a580bSJunchao Zhang   PetscInt                       n = A->rmap->n;
503d52a580bSJunchao Zhang 
504d52a580bSJunchao Zhang   PetscFunctionBegin;
505d52a580bSJunchao Zhang   PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
506d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(A));
507d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(A));
508d52a580bSJunchao Zhang 
509d52a580bSJunchao Zhang   if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
510d52a580bSJunchao Zhang   hipsparseTriFactors->nnz = a->nz;
511d52a580bSJunchao Zhang 
512d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_BOTH;
513d52a580bSJunchao Zhang   /* lower triangular indices */
514d52a580bSJunchao Zhang   PetscCall(ISIdentity(isrow, &row_identity));
515d52a580bSJunchao Zhang   if (!row_identity && !hipsparseTriFactors->rpermIndices) {
516d52a580bSJunchao Zhang     const PetscInt *r;
517d52a580bSJunchao Zhang 
518d52a580bSJunchao Zhang     PetscCall(ISGetIndices(isrow, &r));
519d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
520d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices->assign(r, r + n);
521d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(isrow, &r));
522d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
523d52a580bSJunchao Zhang   }
524d52a580bSJunchao Zhang   /* upper triangular indices */
525d52a580bSJunchao Zhang   PetscCall(ISIdentity(iscol, &col_identity));
526d52a580bSJunchao Zhang   if (!col_identity && !hipsparseTriFactors->cpermIndices) {
527d52a580bSJunchao Zhang     const PetscInt *c;
528d52a580bSJunchao Zhang 
529d52a580bSJunchao Zhang     PetscCall(ISGetIndices(iscol, &c));
530d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
531d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices->assign(c, c + n);
532d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(iscol, &c));
533d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
534d52a580bSJunchao Zhang   }
535d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
536d52a580bSJunchao Zhang }
537d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A)538d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A)
539d52a580bSJunchao Zhang {
540d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
541d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
542d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
543d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
544d52a580bSJunchao Zhang   PetscInt                           *AiUp, *AjUp;
545d52a580bSJunchao Zhang   PetscScalar                        *AAUp;
546d52a580bSJunchao Zhang   PetscScalar                        *AALo;
547d52a580bSJunchao Zhang   PetscInt                            nzUpper = a->nz, n = A->rmap->n, i, offset, nz, j;
548d52a580bSJunchao Zhang   Mat_SeqSBAIJ                       *b  = (Mat_SeqSBAIJ *)A->data;
549d52a580bSJunchao Zhang   const PetscInt                     *ai = b->i, *aj = b->j, *vj;
550d52a580bSJunchao Zhang   const MatScalar                    *aa = b->a, *v;
551d52a580bSJunchao Zhang 
552d52a580bSJunchao Zhang   PetscFunctionBegin;
553d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
554d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
555d52a580bSJunchao Zhang     try {
556d52a580bSJunchao Zhang       PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
557d52a580bSJunchao Zhang       PetscCallHIP(hipHostMalloc((void **)&AALo, nzUpper * sizeof(PetscScalar)));
558d52a580bSJunchao Zhang       if (!upTriFactor && !loTriFactor) {
559d52a580bSJunchao Zhang         /* Allocate Space for the upper triangular matrix */
560d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
561d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
562d52a580bSJunchao Zhang 
563d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
564d52a580bSJunchao Zhang         AiUp[0] = (PetscInt)0;
565d52a580bSJunchao Zhang         AiUp[n] = nzUpper;
566d52a580bSJunchao Zhang         offset  = 0;
567d52a580bSJunchao Zhang         for (i = 0; i < n; i++) {
568d52a580bSJunchao Zhang           /* set the pointers */
569d52a580bSJunchao Zhang           v  = aa + ai[i];
570d52a580bSJunchao Zhang           vj = aj + ai[i];
571d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
572d52a580bSJunchao Zhang 
573d52a580bSJunchao Zhang           /* first, set the diagonal elements */
574d52a580bSJunchao Zhang           AjUp[offset] = (PetscInt)i;
575d52a580bSJunchao Zhang           AAUp[offset] = (MatScalar)1.0 / v[nz];
576d52a580bSJunchao Zhang           AiUp[i]      = offset;
577d52a580bSJunchao Zhang           AALo[offset] = (MatScalar)1.0 / v[nz];
578d52a580bSJunchao Zhang 
579d52a580bSJunchao Zhang           offset += 1;
580d52a580bSJunchao Zhang           if (nz > 0) {
581d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AjUp[offset], vj, nz));
582d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
583d52a580bSJunchao Zhang             for (j = offset; j < offset + nz; j++) {
584d52a580bSJunchao Zhang               AAUp[j] = -AAUp[j];
585d52a580bSJunchao Zhang               AALo[j] = AAUp[j] / v[nz];
586d52a580bSJunchao Zhang             }
587d52a580bSJunchao Zhang             offset += nz;
588d52a580bSJunchao Zhang           }
589d52a580bSJunchao Zhang         }
590d52a580bSJunchao Zhang 
591d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
592d52a580bSJunchao Zhang         PetscCall(PetscNew(&upTriFactor));
593d52a580bSJunchao Zhang         upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
594d52a580bSJunchao Zhang 
595d52a580bSJunchao Zhang         /* Create the matrix description */
596d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
597d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
598d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
599d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
600d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
601d52a580bSJunchao Zhang 
602d52a580bSJunchao Zhang         /* set the matrix */
603d52a580bSJunchao Zhang         upTriFactor->csrMat                 = new CsrMatrix;
604d52a580bSJunchao Zhang         upTriFactor->csrMat->num_rows       = A->rmap->n;
605d52a580bSJunchao Zhang         upTriFactor->csrMat->num_cols       = A->cmap->n;
606d52a580bSJunchao Zhang         upTriFactor->csrMat->num_entries    = a->nz;
607d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
608d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
609d52a580bSJunchao Zhang         upTriFactor->csrMat->values         = new THRUSTARRAY(a->nz);
610d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
611d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
612d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
613d52a580bSJunchao Zhang 
614d52a580bSJunchao Zhang         /* set the operation */
615d52a580bSJunchao Zhang         upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
616d52a580bSJunchao Zhang 
617d52a580bSJunchao Zhang         /* Create the solve analysis information */
618d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
619d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
620d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
621d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
622d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
623d52a580bSJunchao Zhang 
624d52a580bSJunchao Zhang         /* perform the solve analysis */
625d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
626d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
627d52a580bSJunchao Zhang 
628d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
629d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
630d52a580bSJunchao Zhang 
631d52a580bSJunchao Zhang         /* assign the pointer */
632d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
633d52a580bSJunchao Zhang 
634d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
635d52a580bSJunchao Zhang         PetscCall(PetscNew(&loTriFactor));
636d52a580bSJunchao Zhang         loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
637d52a580bSJunchao Zhang 
638d52a580bSJunchao Zhang         /* Create the matrix description */
639d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
640d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
641d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
642d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
643d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
644d52a580bSJunchao Zhang 
645d52a580bSJunchao Zhang         /* set the operation */
646d52a580bSJunchao Zhang         loTriFactor->solveOp = HIPSPARSE_OPERATION_TRANSPOSE;
647d52a580bSJunchao Zhang 
648d52a580bSJunchao Zhang         /* set the matrix */
649d52a580bSJunchao Zhang         loTriFactor->csrMat                 = new CsrMatrix;
650d52a580bSJunchao Zhang         loTriFactor->csrMat->num_rows       = A->rmap->n;
651d52a580bSJunchao Zhang         loTriFactor->csrMat->num_cols       = A->cmap->n;
652d52a580bSJunchao Zhang         loTriFactor->csrMat->num_entries    = a->nz;
653d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
654d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
655d52a580bSJunchao Zhang         loTriFactor->csrMat->values         = new THRUSTARRAY(a->nz);
656d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
657d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
658d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
659d52a580bSJunchao Zhang 
660d52a580bSJunchao Zhang         /* Create the solve analysis information */
661d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
662d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
663d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
664d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
665d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
666d52a580bSJunchao Zhang 
667d52a580bSJunchao Zhang         /* perform the solve analysis */
668d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
669d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
670d52a580bSJunchao Zhang 
671d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
672d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
673d52a580bSJunchao Zhang 
674d52a580bSJunchao Zhang         /* assign the pointer */
675d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
676d52a580bSJunchao Zhang 
677d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(2 * (((A->rmap->n + 1) + (a->nz)) * sizeof(int) + (a->nz) * sizeof(PetscScalar))));
678d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiUp));
679d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjUp));
680d52a580bSJunchao Zhang       } else {
681d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
682d52a580bSJunchao Zhang         offset = 0;
683d52a580bSJunchao Zhang         for (i = 0; i < n; i++) {
684d52a580bSJunchao Zhang           /* set the pointers */
685d52a580bSJunchao Zhang           v  = aa + ai[i];
686d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
687d52a580bSJunchao Zhang 
688d52a580bSJunchao Zhang           /* first, set the diagonal elements */
689d52a580bSJunchao Zhang           AAUp[offset] = 1.0 / v[nz];
690d52a580bSJunchao Zhang           AALo[offset] = 1.0 / v[nz];
691d52a580bSJunchao Zhang 
692d52a580bSJunchao Zhang           offset += 1;
693d52a580bSJunchao Zhang           if (nz > 0) {
694d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
695d52a580bSJunchao Zhang             for (j = offset; j < offset + nz; j++) {
696d52a580bSJunchao Zhang               AAUp[j] = -AAUp[j];
697d52a580bSJunchao Zhang               AALo[j] = AAUp[j] / v[nz];
698d52a580bSJunchao Zhang             }
699d52a580bSJunchao Zhang             offset += nz;
700d52a580bSJunchao Zhang           }
701d52a580bSJunchao Zhang         }
702d52a580bSJunchao Zhang         PetscCheck(upTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
703d52a580bSJunchao Zhang         PetscCheck(loTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
704d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
705d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
706d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(2 * (a->nz) * sizeof(PetscScalar)));
707d52a580bSJunchao Zhang       }
708d52a580bSJunchao Zhang       PetscCallHIP(hipHostFree(AAUp));
709d52a580bSJunchao Zhang       PetscCallHIP(hipHostFree(AALo));
710d52a580bSJunchao Zhang     } catch (char *ex) {
711d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
712d52a580bSJunchao Zhang     }
713d52a580bSJunchao Zhang   }
714d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
715d52a580bSJunchao Zhang }
716d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A)717d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A)
718d52a580bSJunchao Zhang {
719d52a580bSJunchao Zhang   PetscBool                      perm_identity;
720d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a                   = (Mat_SeqAIJ *)A->data;
721d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
722d52a580bSJunchao Zhang   IS                             ip                  = a->row;
723d52a580bSJunchao Zhang   PetscInt                       n                   = A->rmap->n;
724d52a580bSJunchao Zhang 
725d52a580bSJunchao Zhang   PetscFunctionBegin;
726d52a580bSJunchao Zhang   PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
727d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildICCTriMatrices(A));
728d52a580bSJunchao Zhang   if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
729d52a580bSJunchao Zhang   hipsparseTriFactors->nnz = (a->nz - n) * 2 + n;
730d52a580bSJunchao Zhang 
731d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_BOTH;
732d52a580bSJunchao Zhang   /* lower triangular indices */
733d52a580bSJunchao Zhang   PetscCall(ISIdentity(ip, &perm_identity));
734d52a580bSJunchao Zhang   if (!perm_identity) {
735d52a580bSJunchao Zhang     IS              iip;
736d52a580bSJunchao Zhang     const PetscInt *irip, *rip;
737d52a580bSJunchao Zhang 
738d52a580bSJunchao Zhang     PetscCall(ISInvertPermutation(ip, PETSC_DECIDE, &iip));
739d52a580bSJunchao Zhang     PetscCall(ISGetIndices(iip, &irip));
740d52a580bSJunchao Zhang     PetscCall(ISGetIndices(ip, &rip));
741d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
742d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
743d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices->assign(rip, rip + n);
744d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices->assign(irip, irip + n);
745d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(iip, &irip));
746d52a580bSJunchao Zhang     PetscCall(ISDestroy(&iip));
747d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(ip, &rip));
748d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(2. * n * sizeof(PetscInt)));
749d52a580bSJunchao Zhang   }
750d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
751d52a580bSJunchao Zhang }
752d52a580bSJunchao Zhang 
MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B,Mat A,const MatFactorInfo * info)753d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
754d52a580bSJunchao Zhang {
755d52a580bSJunchao Zhang   PetscBool   perm_identity;
756d52a580bSJunchao Zhang   Mat_SeqAIJ *b  = (Mat_SeqAIJ *)B->data;
757d52a580bSJunchao Zhang   IS          ip = b->row;
758d52a580bSJunchao Zhang 
759d52a580bSJunchao Zhang   PetscFunctionBegin;
760d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
761d52a580bSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
762d52a580bSJunchao Zhang   B->offloadmask = PETSC_OFFLOAD_CPU;
763d52a580bSJunchao Zhang   /* determine which version of MatSolve needs to be used. */
764d52a580bSJunchao Zhang   PetscCall(ISIdentity(ip, &perm_identity));
765d52a580bSJunchao Zhang   if (perm_identity) {
766d52a580bSJunchao Zhang     B->ops->solve             = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
767d52a580bSJunchao Zhang     B->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
768d52a580bSJunchao Zhang     B->ops->matsolve          = NULL;
769d52a580bSJunchao Zhang     B->ops->matsolvetranspose = NULL;
770d52a580bSJunchao Zhang   } else {
771d52a580bSJunchao Zhang     B->ops->solve             = MatSolve_SeqAIJHIPSPARSE;
772d52a580bSJunchao Zhang     B->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE;
773d52a580bSJunchao Zhang     B->ops->matsolve          = NULL;
774d52a580bSJunchao Zhang     B->ops->matsolvetranspose = NULL;
775d52a580bSJunchao Zhang   }
776d52a580bSJunchao Zhang 
777d52a580bSJunchao Zhang   /* get the triangular factors */
778d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(B));
779d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
780d52a580bSJunchao Zhang }
781d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A)782d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A)
783d52a580bSJunchao Zhang {
784d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
785d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
786d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
787d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT;
788d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT;
789d52a580bSJunchao Zhang   hipsparseIndexBase_t                indexBase;
790d52a580bSJunchao Zhang   hipsparseMatrixType_t               matrixType;
791d52a580bSJunchao Zhang   hipsparseFillMode_t                 fillMode;
792d52a580bSJunchao Zhang   hipsparseDiagType_t                 diagType;
793d52a580bSJunchao Zhang 
794d52a580bSJunchao Zhang   PetscFunctionBegin;
795d52a580bSJunchao Zhang   /* allocate space for the transpose of the lower triangular factor */
796d52a580bSJunchao Zhang   PetscCall(PetscNew(&loTriFactorT));
797d52a580bSJunchao Zhang   loTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
798d52a580bSJunchao Zhang 
799d52a580bSJunchao Zhang   /* set the matrix descriptors of the lower triangular factor */
800d52a580bSJunchao Zhang   matrixType = hipsparseGetMatType(loTriFactor->descr);
801d52a580bSJunchao Zhang   indexBase  = hipsparseGetMatIndexBase(loTriFactor->descr);
802d52a580bSJunchao Zhang   fillMode   = hipsparseGetMatFillMode(loTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
803d52a580bSJunchao Zhang   diagType   = hipsparseGetMatDiagType(loTriFactor->descr);
804d52a580bSJunchao Zhang 
805d52a580bSJunchao Zhang   /* Create the matrix description */
806d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactorT->descr));
807d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactorT->descr, indexBase));
808d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactorT->descr, matrixType));
809d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactorT->descr, fillMode));
810d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactorT->descr, diagType));
811d52a580bSJunchao Zhang 
812d52a580bSJunchao Zhang   /* set the operation */
813d52a580bSJunchao Zhang   loTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
814d52a580bSJunchao Zhang 
815d52a580bSJunchao Zhang   /* allocate GPU space for the CSC of the lower triangular factor*/
816d52a580bSJunchao Zhang   loTriFactorT->csrMat                 = new CsrMatrix;
817d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_rows       = loTriFactor->csrMat->num_cols;
818d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_cols       = loTriFactor->csrMat->num_rows;
819d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_entries    = loTriFactor->csrMat->num_entries;
820d52a580bSJunchao Zhang   loTriFactorT->csrMat->row_offsets    = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_rows + 1);
821d52a580bSJunchao Zhang   loTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_entries);
822d52a580bSJunchao Zhang   loTriFactorT->csrMat->values         = new THRUSTARRAY(loTriFactorT->csrMat->num_entries);
823d52a580bSJunchao Zhang 
824d52a580bSJunchao Zhang   /* compute the transpose of the lower triangular factor, i.e. the CSC */
825d52a580bSJunchao Zhang   /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
826d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
827d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(),
828d52a580bSJunchao Zhang                                                   loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(), loTriFactorT->csrMat->row_offsets->data().get(),
829d52a580bSJunchao Zhang                                                   loTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &loTriFactor->csr2cscBufferSize));
830d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&loTriFactor->csr2cscBuffer, loTriFactor->csr2cscBufferSize));
831d52a580bSJunchao Zhang #endif
832d52a580bSJunchao Zhang */
833d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
834d52a580bSJunchao Zhang 
835d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(), loTriFactor->csrMat->row_offsets->data().get(),
836d52a580bSJunchao Zhang                                        loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(),
837d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
838d52a580bSJunchao Zhang                           loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(),
839d52a580bSJunchao Zhang                           hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, loTriFactor->csr2cscBuffer));
840d52a580bSJunchao Zhang #else
841d52a580bSJunchao Zhang                                        loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
842d52a580bSJunchao Zhang #endif
843d52a580bSJunchao Zhang 
844d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
845d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
846d52a580bSJunchao Zhang 
847d52a580bSJunchao Zhang   /* Create the solve analysis information */
848d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
849d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactorT->solveInfo));
850d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
851d52a580bSJunchao Zhang                                               loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, &loTriFactorT->solveBufferSize));
852d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&loTriFactorT->solveBuffer, loTriFactorT->solveBufferSize));
853d52a580bSJunchao Zhang 
854d52a580bSJunchao Zhang   /* perform the solve analysis */
855d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
856d52a580bSJunchao Zhang                                               loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
857d52a580bSJunchao Zhang 
858d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
859d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
860d52a580bSJunchao Zhang 
861d52a580bSJunchao Zhang   /* assign the pointer */
862d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtrTranspose = loTriFactorT;
863d52a580bSJunchao Zhang 
864d52a580bSJunchao Zhang   /*********************************************/
865d52a580bSJunchao Zhang   /* Now the Transpose of the Upper Tri Factor */
866d52a580bSJunchao Zhang   /*********************************************/
867d52a580bSJunchao Zhang 
868d52a580bSJunchao Zhang   /* allocate space for the transpose of the upper triangular factor */
869d52a580bSJunchao Zhang   PetscCall(PetscNew(&upTriFactorT));
870d52a580bSJunchao Zhang   upTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
871d52a580bSJunchao Zhang 
872d52a580bSJunchao Zhang   /* set the matrix descriptors of the upper triangular factor */
873d52a580bSJunchao Zhang   matrixType = hipsparseGetMatType(upTriFactor->descr);
874d52a580bSJunchao Zhang   indexBase  = hipsparseGetMatIndexBase(upTriFactor->descr);
875d52a580bSJunchao Zhang   fillMode   = hipsparseGetMatFillMode(upTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
876d52a580bSJunchao Zhang   diagType   = hipsparseGetMatDiagType(upTriFactor->descr);
877d52a580bSJunchao Zhang 
878d52a580bSJunchao Zhang   /* Create the matrix description */
879d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactorT->descr));
880d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactorT->descr, indexBase));
881d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactorT->descr, matrixType));
882d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactorT->descr, fillMode));
883d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactorT->descr, diagType));
884d52a580bSJunchao Zhang 
885d52a580bSJunchao Zhang   /* set the operation */
886d52a580bSJunchao Zhang   upTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
887d52a580bSJunchao Zhang 
888d52a580bSJunchao Zhang   /* allocate GPU space for the CSC of the upper triangular factor*/
889d52a580bSJunchao Zhang   upTriFactorT->csrMat                 = new CsrMatrix;
890d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_rows       = upTriFactor->csrMat->num_cols;
891d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_cols       = upTriFactor->csrMat->num_rows;
892d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_entries    = upTriFactor->csrMat->num_entries;
893d52a580bSJunchao Zhang   upTriFactorT->csrMat->row_offsets    = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_rows + 1);
894d52a580bSJunchao Zhang   upTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_entries);
895d52a580bSJunchao Zhang   upTriFactorT->csrMat->values         = new THRUSTARRAY(upTriFactorT->csrMat->num_entries);
896d52a580bSJunchao Zhang 
897d52a580bSJunchao Zhang   /* compute the transpose of the upper triangular factor, i.e. the CSC */
898d52a580bSJunchao Zhang   /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
899d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
900d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(),
901d52a580bSJunchao Zhang                                                   upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(), upTriFactorT->csrMat->row_offsets->data().get(),
902d52a580bSJunchao Zhang                                                   upTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &upTriFactor->csr2cscBufferSize));
903d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&upTriFactor->csr2cscBuffer, upTriFactor->csr2cscBufferSize));
904d52a580bSJunchao Zhang #endif
905d52a580bSJunchao Zhang */
906d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
907d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(), upTriFactor->csrMat->row_offsets->data().get(),
908d52a580bSJunchao Zhang                                        upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(),
909d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
910d52a580bSJunchao Zhang                           upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(),
911d52a580bSJunchao Zhang                           hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, upTriFactor->csr2cscBuffer));
912d52a580bSJunchao Zhang #else
913d52a580bSJunchao Zhang                                        upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
914d52a580bSJunchao Zhang #endif
915d52a580bSJunchao Zhang 
916d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
917d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
918d52a580bSJunchao Zhang 
919d52a580bSJunchao Zhang   /* Create the solve analysis information */
920d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
921d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactorT->solveInfo));
922d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
923d52a580bSJunchao Zhang                                               upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, &upTriFactorT->solveBufferSize));
924d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&upTriFactorT->solveBuffer, upTriFactorT->solveBufferSize));
925d52a580bSJunchao Zhang 
926d52a580bSJunchao Zhang   /* perform the solve analysis */
927d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
928d52a580bSJunchao Zhang                                               upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
929d52a580bSJunchao Zhang 
930d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
931d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
932d52a580bSJunchao Zhang 
933d52a580bSJunchao Zhang   /* assign the pointer */
934d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtrTranspose = upTriFactorT;
935d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
936d52a580bSJunchao Zhang }
937d52a580bSJunchao Zhang 
938d52a580bSJunchao Zhang struct PetscScalarToPetscInt {
operator ()PetscScalarToPetscInt939d52a580bSJunchao Zhang   __host__ __device__ PetscInt operator()(PetscScalar s) { return (PetscInt)PetscRealPart(s); }
940d52a580bSJunchao Zhang };
941d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A)942d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A)
943d52a580bSJunchao Zhang {
944d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
945d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct, *matstructT;
946d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a = (Mat_SeqAIJ *)A->data;
947d52a580bSJunchao Zhang   hipsparseIndexBase_t           indexBase;
948d52a580bSJunchao Zhang 
949d52a580bSJunchao Zhang   PetscFunctionBegin;
950d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
951d52a580bSJunchao Zhang   matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
952d52a580bSJunchao Zhang   PetscCheck(matstruct, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing mat struct");
953d52a580bSJunchao Zhang   matstructT = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
954d52a580bSJunchao Zhang   PetscCheck(!A->transupdated || matstructT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing matTranspose struct");
955d52a580bSJunchao Zhang   if (A->transupdated) PetscFunctionReturn(PETSC_SUCCESS);
956d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
957d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
958d52a580bSJunchao Zhang   if (hipsparsestruct->format != MAT_HIPSPARSE_CSR) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
959d52a580bSJunchao Zhang   if (!hipsparsestruct->matTranspose) { /* create hipsparse matrix */
960d52a580bSJunchao Zhang     matstructT = new Mat_SeqAIJHIPSPARSEMultStruct;
961d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstructT->descr));
962d52a580bSJunchao Zhang     indexBase = hipsparseGetMatIndexBase(matstruct->descr);
963d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstructT->descr, indexBase));
964d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatType(matstructT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
965d52a580bSJunchao Zhang 
966d52a580bSJunchao Zhang     /* set alpha and beta */
967d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->alpha_one, sizeof(PetscScalar)));
968d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->beta_zero, sizeof(PetscScalar)));
969d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->beta_one, sizeof(PetscScalar)));
970d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
971d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
972d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
973d52a580bSJunchao Zhang 
974d52a580bSJunchao Zhang     if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
975d52a580bSJunchao Zhang       CsrMatrix *matrixT      = new CsrMatrix;
976d52a580bSJunchao Zhang       matstructT->mat         = matrixT;
977d52a580bSJunchao Zhang       matrixT->num_rows       = A->cmap->n;
978d52a580bSJunchao Zhang       matrixT->num_cols       = A->rmap->n;
979d52a580bSJunchao Zhang       matrixT->num_entries    = a->nz;
980d52a580bSJunchao Zhang       matrixT->row_offsets    = new THRUSTINTARRAY32(matrixT->num_rows + 1);
981d52a580bSJunchao Zhang       matrixT->column_indices = new THRUSTINTARRAY32(a->nz);
982d52a580bSJunchao Zhang       matrixT->values         = new THRUSTARRAY(a->nz);
983d52a580bSJunchao Zhang 
984d52a580bSJunchao Zhang       if (!hipsparsestruct->rowoffsets_gpu) hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
985d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
986d52a580bSJunchao Zhang 
987d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&matstructT->matDescr, matrixT->num_rows, matrixT->num_cols, matrixT->num_entries, matrixT->row_offsets->data().get(), matrixT->column_indices->data().get(), matrixT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx type due to THRUSTINTARRAY32 */
988d52a580bSJunchao Zhang                                             indexBase, hipsparse_scalartype));
989d52a580bSJunchao Zhang     } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
990d52a580bSJunchao Zhang       CsrMatrix *temp  = new CsrMatrix;
991d52a580bSJunchao Zhang       CsrMatrix *tempT = new CsrMatrix;
992d52a580bSJunchao Zhang       /* First convert HYB to CSR */
993d52a580bSJunchao Zhang       temp->num_rows       = A->rmap->n;
994d52a580bSJunchao Zhang       temp->num_cols       = A->cmap->n;
995d52a580bSJunchao Zhang       temp->num_entries    = a->nz;
996d52a580bSJunchao Zhang       temp->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
997d52a580bSJunchao Zhang       temp->column_indices = new THRUSTINTARRAY32(a->nz);
998d52a580bSJunchao Zhang       temp->values         = new THRUSTARRAY(a->nz);
999d52a580bSJunchao Zhang 
1000d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_hyb2csr(hipsparsestruct->handle, matstruct->descr, (hipsparseHybMat_t)matstruct->mat, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get()));
1001d52a580bSJunchao Zhang 
1002d52a580bSJunchao Zhang       /* Next, convert CSR to CSC (i.e. the matrix transpose) */
1003d52a580bSJunchao Zhang       tempT->num_rows       = A->rmap->n;
1004d52a580bSJunchao Zhang       tempT->num_cols       = A->cmap->n;
1005d52a580bSJunchao Zhang       tempT->num_entries    = a->nz;
1006d52a580bSJunchao Zhang       tempT->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
1007d52a580bSJunchao Zhang       tempT->column_indices = new THRUSTINTARRAY32(a->nz);
1008d52a580bSJunchao Zhang       tempT->values         = new THRUSTARRAY(a->nz);
1009d52a580bSJunchao Zhang 
1010d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, temp->num_rows, temp->num_cols, temp->num_entries, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get(), tempT->values->data().get(),
1011d52a580bSJunchao Zhang                                            tempT->column_indices->data().get(), tempT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1012d52a580bSJunchao Zhang 
1013d52a580bSJunchao Zhang       /* Last, convert CSC to HYB */
1014d52a580bSJunchao Zhang       hipsparseHybMat_t hybMat;
1015d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
1016d52a580bSJunchao Zhang       hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
1017d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matstructT->descr, tempT->values->data().get(), tempT->row_offsets->data().get(), tempT->column_indices->data().get(), hybMat, 0, partition));
1018d52a580bSJunchao Zhang 
1019d52a580bSJunchao Zhang       /* assign the pointer */
1020d52a580bSJunchao Zhang       matstructT->mat = hybMat;
1021d52a580bSJunchao Zhang       A->transupdated = PETSC_TRUE;
1022d52a580bSJunchao Zhang       /* delete temporaries */
1023d52a580bSJunchao Zhang       if (tempT) {
1024d52a580bSJunchao Zhang         if (tempT->values) delete (THRUSTARRAY *)tempT->values;
1025d52a580bSJunchao Zhang         if (tempT->column_indices) delete (THRUSTINTARRAY32 *)tempT->column_indices;
1026d52a580bSJunchao Zhang         if (tempT->row_offsets) delete (THRUSTINTARRAY32 *)tempT->row_offsets;
1027d52a580bSJunchao Zhang         delete (CsrMatrix *)tempT;
1028d52a580bSJunchao Zhang       }
1029d52a580bSJunchao Zhang       if (temp) {
1030d52a580bSJunchao Zhang         if (temp->values) delete (THRUSTARRAY *)temp->values;
1031d52a580bSJunchao Zhang         if (temp->column_indices) delete (THRUSTINTARRAY32 *)temp->column_indices;
1032d52a580bSJunchao Zhang         if (temp->row_offsets) delete (THRUSTINTARRAY32 *)temp->row_offsets;
1033d52a580bSJunchao Zhang         delete (CsrMatrix *)temp;
1034d52a580bSJunchao Zhang       }
1035d52a580bSJunchao Zhang     }
1036d52a580bSJunchao Zhang   }
1037d52a580bSJunchao Zhang   if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* transpose mat struct may be already present, update data */
1038d52a580bSJunchao Zhang     CsrMatrix *matrix  = (CsrMatrix *)matstruct->mat;
1039d52a580bSJunchao Zhang     CsrMatrix *matrixT = (CsrMatrix *)matstructT->mat;
1040d52a580bSJunchao Zhang     PetscCheck(matrix, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix");
1041d52a580bSJunchao Zhang     PetscCheck(matrix->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix rows");
1042d52a580bSJunchao Zhang     PetscCheck(matrix->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix cols");
1043d52a580bSJunchao Zhang     PetscCheck(matrix->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix values");
1044d52a580bSJunchao Zhang     PetscCheck(matrixT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT");
1045d52a580bSJunchao Zhang     PetscCheck(matrixT->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT rows");
1046d52a580bSJunchao Zhang     PetscCheck(matrixT->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT cols");
1047d52a580bSJunchao Zhang     PetscCheck(matrixT->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT values");
1048d52a580bSJunchao Zhang     if (!hipsparsestruct->rowoffsets_gpu) { /* this may be absent when we did not construct the transpose with csr2csc */
1049d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
1050d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
1051d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
1052d52a580bSJunchao Zhang     }
1053d52a580bSJunchao Zhang     if (!hipsparsestruct->csr2csc_i) {
1054d52a580bSJunchao Zhang       THRUSTARRAY csr2csc_a(matrix->num_entries);
1055d52a580bSJunchao Zhang       PetscCallThrust(thrust::sequence(thrust::device, csr2csc_a.begin(), csr2csc_a.end(), 0.0));
1056d52a580bSJunchao Zhang 
1057d52a580bSJunchao Zhang       indexBase = hipsparseGetMatIndexBase(matstruct->descr);
1058d52a580bSJunchao Zhang       if (matrix->num_entries) {
1059d52a580bSJunchao Zhang         /* This routine is known to give errors with CUDA-11, but works fine with CUDA-10
1060d52a580bSJunchao Zhang            Need to verify this for ROCm.
1061d52a580bSJunchao Zhang         */
1062d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matrix->num_entries, csr2csc_a.data().get(), hipsparsestruct->rowoffsets_gpu->data().get(), matrix->column_indices->data().get(), matrixT->values->data().get(),
1063d52a580bSJunchao Zhang                                              matrixT->column_indices->data().get(), matrixT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1064d52a580bSJunchao Zhang       } else {
1065d52a580bSJunchao Zhang         matrixT->row_offsets->assign(matrixT->row_offsets->size(), indexBase);
1066d52a580bSJunchao Zhang       }
1067d52a580bSJunchao Zhang 
1068d52a580bSJunchao Zhang       hipsparsestruct->csr2csc_i = new THRUSTINTARRAY(matrix->num_entries);
1069d52a580bSJunchao Zhang       PetscCallThrust(thrust::transform(thrust::device, matrixT->values->begin(), matrixT->values->end(), hipsparsestruct->csr2csc_i->begin(), PetscScalarToPetscInt()));
1070d52a580bSJunchao Zhang     }
1071d52a580bSJunchao Zhang     PetscCallThrust(
1072d52a580bSJunchao Zhang       thrust::copy(thrust::device, thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->begin()), thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->end()), matrixT->values->begin()));
1073d52a580bSJunchao Zhang   }
1074d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1075d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
1076d52a580bSJunchao Zhang   /* the compressed row indices is not used for matTranspose */
1077d52a580bSJunchao Zhang   matstructT->cprowIndices = NULL;
1078d52a580bSJunchao Zhang   /* assign the pointer */
1079d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSE *)A->spptr)->matTranspose = matstructT;
1080d52a580bSJunchao Zhang   A->transupdated                                 = PETSC_TRUE;
1081d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1082d52a580bSJunchao Zhang }
1083d52a580bSJunchao Zhang 
1084d52a580bSJunchao Zhang /* Why do we need to analyze the transposed matrix again? Can't we just use op(A) = HIPSPARSE_OPERATION_TRANSPOSE in MatSolve_SeqAIJHIPSPARSE? */
MatSolveTranspose_SeqAIJHIPSPARSE(Mat A,Vec bb,Vec xx)1085d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1086d52a580bSJunchao Zhang {
1087d52a580bSJunchao Zhang   PetscInt                              n = xx->map->n;
1088d52a580bSJunchao Zhang   const PetscScalar                    *barray;
1089d52a580bSJunchao Zhang   PetscScalar                          *xarray;
1090d52a580bSJunchao Zhang   thrust::device_ptr<const PetscScalar> bGPU;
1091d52a580bSJunchao Zhang   thrust::device_ptr<PetscScalar>       xGPU;
1092d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors        *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1093d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *loTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1094d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *upTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1095d52a580bSJunchao Zhang   THRUSTARRAY                          *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1096d52a580bSJunchao Zhang 
1097d52a580bSJunchao Zhang   PetscFunctionBegin;
1098d52a580bSJunchao Zhang   /* Analyze the matrix and create the transpose ... on the fly */
1099d52a580bSJunchao Zhang   if (!loTriFactorT && !upTriFactorT) {
1100d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1101d52a580bSJunchao Zhang     loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1102d52a580bSJunchao Zhang     upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1103d52a580bSJunchao Zhang   }
1104d52a580bSJunchao Zhang 
1105d52a580bSJunchao Zhang   /* Get the GPU pointers */
1106d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1107d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1108d52a580bSJunchao Zhang   xGPU = thrust::device_pointer_cast(xarray);
1109d52a580bSJunchao Zhang   bGPU = thrust::device_pointer_cast(barray);
1110d52a580bSJunchao Zhang 
1111d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1112d52a580bSJunchao Zhang   /* First, reorder with the row permutation */
1113d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU + n, hipsparseTriFactors->rpermIndices->end()), xGPU);
1114d52a580bSJunchao Zhang 
1115d52a580bSJunchao Zhang   /* First, solve U */
1116d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1117d52a580bSJunchao Zhang                                            upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, xarray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1118d52a580bSJunchao Zhang 
1119d52a580bSJunchao Zhang   /* Then, solve L */
1120d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1121d52a580bSJunchao Zhang                                            loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1122d52a580bSJunchao Zhang 
1123d52a580bSJunchao Zhang   /* Last, copy the solution, xGPU, into a temporary with the column permutation ... can't be done in place. */
1124d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(xGPU, hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(xGPU + n, hipsparseTriFactors->cpermIndices->end()), tempGPU->begin());
1125d52a580bSJunchao Zhang 
1126d52a580bSJunchao Zhang   /* Copy the temporary to the full solution. */
1127d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), tempGPU->begin(), tempGPU->end(), xGPU);
1128d52a580bSJunchao Zhang 
1129d52a580bSJunchao Zhang   /* restore */
1130d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1131d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1132d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1133d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1134d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1135d52a580bSJunchao Zhang }
1136d52a580bSJunchao Zhang 
MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A,Vec bb,Vec xx)1137d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1138d52a580bSJunchao Zhang {
1139d52a580bSJunchao Zhang   const PetscScalar                  *barray;
1140d52a580bSJunchao Zhang   PetscScalar                        *xarray;
1141d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1142d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1143d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1144d52a580bSJunchao Zhang   THRUSTARRAY                        *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1145d52a580bSJunchao Zhang 
1146d52a580bSJunchao Zhang   PetscFunctionBegin;
1147d52a580bSJunchao Zhang   /* Analyze the matrix and create the transpose ... on the fly */
1148d52a580bSJunchao Zhang   if (!loTriFactorT && !upTriFactorT) {
1149d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1150d52a580bSJunchao Zhang     loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1151d52a580bSJunchao Zhang     upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1152d52a580bSJunchao Zhang   }
1153d52a580bSJunchao Zhang 
1154d52a580bSJunchao Zhang   /* Get the GPU pointers */
1155d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1156d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1157d52a580bSJunchao Zhang 
1158d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1159d52a580bSJunchao Zhang   /* First, solve U */
1160d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1161d52a580bSJunchao Zhang                                            upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, barray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1162d52a580bSJunchao Zhang 
1163d52a580bSJunchao Zhang   /* Then, solve L */
1164d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1165d52a580bSJunchao Zhang                                            loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1166d52a580bSJunchao Zhang 
1167d52a580bSJunchao Zhang   /* restore */
1168d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1169d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1170d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1171d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1172d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1173d52a580bSJunchao Zhang }
1174d52a580bSJunchao Zhang 
MatSolve_SeqAIJHIPSPARSE(Mat A,Vec bb,Vec xx)1175d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1176d52a580bSJunchao Zhang {
1177d52a580bSJunchao Zhang   const PetscScalar                    *barray;
1178d52a580bSJunchao Zhang   PetscScalar                          *xarray;
1179d52a580bSJunchao Zhang   thrust::device_ptr<const PetscScalar> bGPU;
1180d52a580bSJunchao Zhang   thrust::device_ptr<PetscScalar>       xGPU;
1181d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors        *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1182d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1183d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1184d52a580bSJunchao Zhang   THRUSTARRAY                          *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1185d52a580bSJunchao Zhang 
1186d52a580bSJunchao Zhang   PetscFunctionBegin;
1187d52a580bSJunchao Zhang   /* Get the GPU pointers */
1188d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1189d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1190d52a580bSJunchao Zhang   xGPU = thrust::device_pointer_cast(xarray);
1191d52a580bSJunchao Zhang   bGPU = thrust::device_pointer_cast(barray);
1192d52a580bSJunchao Zhang 
1193d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1194d52a580bSJunchao Zhang   /* First, reorder with the row permutation */
1195d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->end()), tempGPU->begin());
1196d52a580bSJunchao Zhang 
1197d52a580bSJunchao Zhang   /* Next, solve L */
1198d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1199d52a580bSJunchao Zhang                                            loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, tempGPU->data().get(), xarray, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1200d52a580bSJunchao Zhang 
1201d52a580bSJunchao Zhang   /* Then, solve U */
1202d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1203d52a580bSJunchao Zhang                                            upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, xarray, tempGPU->data().get(), upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1204d52a580bSJunchao Zhang 
1205d52a580bSJunchao Zhang   /* Last, reorder with the column permutation */
1206d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->end()), xGPU);
1207d52a580bSJunchao Zhang 
1208d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1209d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1210d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1211d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1212d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1213d52a580bSJunchao Zhang }
1214d52a580bSJunchao Zhang 
MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A,Vec bb,Vec xx)1215d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1216d52a580bSJunchao Zhang {
1217d52a580bSJunchao Zhang   const PetscScalar                  *barray;
1218d52a580bSJunchao Zhang   PetscScalar                        *xarray;
1219d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1220d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1221d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1222d52a580bSJunchao Zhang   THRUSTARRAY                        *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1223d52a580bSJunchao Zhang 
1224d52a580bSJunchao Zhang   PetscFunctionBegin;
1225d52a580bSJunchao Zhang   /* Get the GPU pointers */
1226d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1227d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1228d52a580bSJunchao Zhang 
1229d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1230d52a580bSJunchao Zhang   /* First, solve L */
1231d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1232d52a580bSJunchao Zhang                                            loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, barray, tempGPU->data().get(), loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1233d52a580bSJunchao Zhang 
1234d52a580bSJunchao Zhang   /* Next, solve U */
1235d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1236d52a580bSJunchao Zhang                                            upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, tempGPU->data().get(), xarray, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1237d52a580bSJunchao Zhang 
1238d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1239d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1240d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1241d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1242d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1243d52a580bSJunchao Zhang }
1244d52a580bSJunchao Zhang 
1245d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1246d52a580bSJunchao Zhang /* hipsparseSpSV_solve() and related functions first appeared in ROCm-4.5.0*/
MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact,Vec b,Vec x)1247d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1248d52a580bSJunchao Zhang {
1249d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1250d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1251d52a580bSJunchao Zhang   const PetscScalar             *barray;
1252d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1253d52a580bSJunchao Zhang 
1254d52a580bSJunchao Zhang   PetscFunctionBegin;
1255d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1256d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1257d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1258d52a580bSJunchao Zhang 
1259d52a580bSJunchao Zhang   /* Solve L*y = b */
1260d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1261d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1262d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1263d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                   /* L Y = X */
1264d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1265d52a580bSJunchao Zhang   #else
1266d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                                     /* L Y = X */
1267d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1268d52a580bSJunchao Zhang   #endif
1269d52a580bSJunchao Zhang   /* Solve U*x = y */
1270d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1271d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1272d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1273d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U));
1274d52a580bSJunchao Zhang   #else
1275d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1276d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1277d52a580bSJunchao Zhang   #endif
1278d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1279d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1280d52a580bSJunchao Zhang 
1281d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1282d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1283d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1284d52a580bSJunchao Zhang }
1285d52a580bSJunchao Zhang 
MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact,Vec b,Vec x)1286d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1287d52a580bSJunchao Zhang {
1288d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1289d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1290d52a580bSJunchao Zhang   const PetscScalar             *barray;
1291d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1292d52a580bSJunchao Zhang 
1293d52a580bSJunchao Zhang   PetscFunctionBegin;
1294d52a580bSJunchao Zhang   if (!fs->createdTransposeSpSVDescr) { /* Call MatSolveTranspose() for the first time */
1295d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1296d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* The matrix is still L. We only do transpose solve with it */
1297d52a580bSJunchao Zhang                                                 fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1298d52a580bSJunchao Zhang 
1299d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Ut));
1300d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, &fs->spsvBufferSize_Ut));
1301d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1302d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Ut, fs->spsvBufferSize_Ut));
1303d52a580bSJunchao Zhang     fs->createdTransposeSpSVDescr = PETSC_TRUE;
1304d52a580bSJunchao Zhang   }
1305d52a580bSJunchao Zhang 
1306d52a580bSJunchao Zhang   if (!fs->updatedTransposeSpSVAnalysis) {
1307d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1308d52a580bSJunchao Zhang 
1309d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1310d52a580bSJunchao Zhang     fs->updatedTransposeSpSVAnalysis = PETSC_TRUE;
1311d52a580bSJunchao Zhang   }
1312d52a580bSJunchao Zhang 
1313d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1314d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1315d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1316d52a580bSJunchao Zhang 
1317d52a580bSJunchao Zhang   /* Solve Ut*y = b */
1318d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1319d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1320d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1321d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1322d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut));
1323d52a580bSJunchao Zhang   #else
1324d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1325d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1326d52a580bSJunchao Zhang   #endif
1327d52a580bSJunchao Zhang   /* Solve Lt*x = y */
1328d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1329d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1330d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1331d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1332d52a580bSJunchao Zhang   #else
1333d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1334d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1335d52a580bSJunchao Zhang   #endif
1336d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1337d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1338d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1339d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1340d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1341d52a580bSJunchao Zhang }
1342d52a580bSJunchao Zhang 
MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact,Mat A,const MatFactorInfo * info)1343d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, const MatFactorInfo *info)
1344d52a580bSJunchao Zhang {
1345d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs    = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1346d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij   = (Mat_SeqAIJ *)fact->data;
1347d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1348d52a580bSJunchao Zhang   CsrMatrix                     *Acsr;
1349d52a580bSJunchao Zhang   PetscInt                       m, nz;
1350d52a580bSJunchao Zhang   PetscBool                      flg;
1351d52a580bSJunchao Zhang 
1352d52a580bSJunchao Zhang   PetscFunctionBegin;
1353d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1354d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1355d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1356d52a580bSJunchao Zhang   }
1357d52a580bSJunchao Zhang 
1358d52a580bSJunchao Zhang   /* Copy A's value to fact */
1359d52a580bSJunchao Zhang   m  = fact->rmap->n;
1360d52a580bSJunchao Zhang   nz = aij->nz;
1361d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1362d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Acusp->mat->mat;
1363d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1364d52a580bSJunchao Zhang 
1365d52a580bSJunchao Zhang   /* Factorize fact inplace */
1366d52a580bSJunchao Zhang   if (m)
1367d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1368d52a580bSJunchao Zhang                                           fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1369d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1370d52a580bSJunchao Zhang     int               numerical_zero;
1371d52a580bSJunchao Zhang     hipsparseStatus_t status;
1372d52a580bSJunchao Zhang     status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &numerical_zero);
1373d52a580bSJunchao Zhang     PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csrilu02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1374d52a580bSJunchao Zhang   }
1375d52a580bSJunchao Zhang 
1376d52a580bSJunchao Zhang   /* hipsparseSpSV_analysis() is numeric, i.e., it requires valid matrix values, therefore, we do it after hipsparseXcsrilu02() */
1377d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1378d52a580bSJunchao Zhang 
1379d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1380d52a580bSJunchao Zhang 
1381d52a580bSJunchao Zhang   /* L, U values have changed, reset the flag to indicate we need to redo hipsparseSpSV_analysis() for transpose solve */
1382d52a580bSJunchao Zhang   fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
1383d52a580bSJunchao Zhang 
1384d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_GPU;
1385d52a580bSJunchao Zhang   fact->ops->solve             = MatSolve_SeqAIJHIPSPARSE_ILU0;
1386d52a580bSJunchao Zhang   fact->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE_ILU0;
1387d52a580bSJunchao Zhang   fact->ops->matsolve          = NULL;
1388d52a580bSJunchao Zhang   fact->ops->matsolvetranspose = NULL;
1389d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1390d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1391d52a580bSJunchao Zhang }
1392d52a580bSJunchao Zhang 
MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1393d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1394d52a580bSJunchao Zhang {
1395d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1396d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1397d52a580bSJunchao Zhang   PetscInt                       m, nz;
1398d52a580bSJunchao Zhang 
1399d52a580bSJunchao Zhang   PetscFunctionBegin;
1400d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1401d52a580bSJunchao Zhang     PetscBool flg, diagDense;
1402d52a580bSJunchao Zhang 
1403d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1404d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1405d52a580bSJunchao Zhang     PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1406d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1407d52a580bSJunchao Zhang     PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1408d52a580bSJunchao Zhang   }
1409d52a580bSJunchao Zhang 
1410d52a580bSJunchao Zhang   /* Free the old stale stuff */
1411d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1412d52a580bSJunchao Zhang 
1413d52a580bSJunchao Zhang   /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1414d52a580bSJunchao Zhang      but they will not be used. Allocate them just for easy debugging.
1415d52a580bSJunchao Zhang    */
1416d52a580bSJunchao Zhang   PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1417d52a580bSJunchao Zhang 
1418d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_BOTH;
1419d52a580bSJunchao Zhang   fact->factortype             = MAT_FACTOR_ILU;
1420d52a580bSJunchao Zhang   fact->info.factor_mallocs    = 0;
1421d52a580bSJunchao Zhang   fact->info.fill_ratio_given  = info->fill;
1422d52a580bSJunchao Zhang   fact->info.fill_ratio_needed = 1.0;
1423d52a580bSJunchao Zhang 
1424d52a580bSJunchao Zhang   aij->row = NULL;
1425d52a580bSJunchao Zhang   aij->col = NULL;
1426d52a580bSJunchao Zhang 
1427d52a580bSJunchao Zhang   /* ====================================================================== */
1428d52a580bSJunchao Zhang   /* Copy A's i, j to fact and also allocate the value array of fact.       */
1429d52a580bSJunchao Zhang   /* We'll do in-place factorization on fact                                */
1430d52a580bSJunchao Zhang   /* ====================================================================== */
1431d52a580bSJunchao Zhang   const int *Ai, *Aj;
1432d52a580bSJunchao Zhang 
1433d52a580bSJunchao Zhang   m  = fact->rmap->n;
1434d52a580bSJunchao Zhang   nz = aij->nz;
1435d52a580bSJunchao Zhang 
1436d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1437d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1438d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1439d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1440d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1441d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1442d52a580bSJunchao Zhang 
1443d52a580bSJunchao Zhang   /* ====================================================================== */
1444d52a580bSJunchao Zhang   /* Create descriptors for M, L, U                                         */
1445d52a580bSJunchao Zhang   /* ====================================================================== */
1446d52a580bSJunchao Zhang   hipsparseFillMode_t fillMode;
1447d52a580bSJunchao Zhang   hipsparseDiagType_t diagType;
1448d52a580bSJunchao Zhang 
1449d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1450d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1451d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1452d52a580bSJunchao Zhang 
1453d52a580bSJunchao Zhang   /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1454d52a580bSJunchao Zhang     hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1455d52a580bSJunchao Zhang     assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1456d52a580bSJunchao Zhang     all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1457d52a580bSJunchao Zhang     assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1458d52a580bSJunchao Zhang   */
1459d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_LOWER;
1460d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_UNIT;
1461d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1462d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1463d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1464d52a580bSJunchao Zhang 
1465d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_UPPER;
1466d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1467d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_U, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1468d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1469d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1470d52a580bSJunchao Zhang 
1471d52a580bSJunchao Zhang   /* ========================================================================= */
1472d52a580bSJunchao Zhang   /* Query buffer sizes for csrilu0, SpSV and allocate buffers                 */
1473d52a580bSJunchao Zhang   /* ========================================================================= */
1474d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrilu02Info(&fs->ilu0Info_M));
1475d52a580bSJunchao Zhang   if (m)
1476d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02_bufferSize(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1477d52a580bSJunchao Zhang                                                      fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, &fs->factBufferSize_M));
1478d52a580bSJunchao Zhang 
1479d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1480d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1481d52a580bSJunchao Zhang 
1482d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1483d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1484d52a580bSJunchao Zhang 
1485d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1486d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1487d52a580bSJunchao Zhang 
1488d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_U));
1489d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, &fs->spsvBufferSize_U));
1490d52a580bSJunchao Zhang 
1491d52a580bSJunchao Zhang   /* It appears spsvBuffer_L/U can not be shared (i.e., the same) for our case, but factBuffer_M can share with either of spsvBuffer_L/U.
1492d52a580bSJunchao Zhang      To save memory, we make factBuffer_M share with the bigger of spsvBuffer_L/U.
1493d52a580bSJunchao Zhang    */
1494d52a580bSJunchao Zhang   if (fs->spsvBufferSize_L > fs->spsvBufferSize_U) {
1495d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1496d52a580bSJunchao Zhang     fs->spsvBuffer_L = fs->factBuffer_M;
1497d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_U, fs->spsvBufferSize_U));
1498d52a580bSJunchao Zhang   } else {
1499d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_U, (size_t)fs->factBufferSize_M)));
1500d52a580bSJunchao Zhang     fs->spsvBuffer_U = fs->factBuffer_M;
1501d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1502d52a580bSJunchao Zhang   }
1503d52a580bSJunchao Zhang 
1504d52a580bSJunchao Zhang   /* ========================================================================== */
1505d52a580bSJunchao Zhang   /* Perform analysis of ilu0 on M, SpSv on L and U                             */
1506d52a580bSJunchao Zhang   /* The lower(upper) triangular part of M has the same sparsity pattern as L(U)*/
1507d52a580bSJunchao Zhang   /* ========================================================================== */
1508d52a580bSJunchao Zhang   int structural_zero;
1509d52a580bSJunchao Zhang 
1510d52a580bSJunchao Zhang   fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1511d52a580bSJunchao Zhang   if (m)
1512d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02_analysis(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1513d52a580bSJunchao Zhang                                                    fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1514d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1515d52a580bSJunchao Zhang     /* Function hipsparseXcsrilu02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1516d52a580bSJunchao Zhang     hipsparseStatus_t status;
1517d52a580bSJunchao Zhang     status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &structural_zero);
1518d52a580bSJunchao Zhang     PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csrilu02: A(%d,%d) is missing", structural_zero, structural_zero);
1519d52a580bSJunchao Zhang   }
1520d52a580bSJunchao Zhang 
1521d52a580bSJunchao Zhang   /* Estimate FLOPs of the numeric factorization */
1522d52a580bSJunchao Zhang   {
1523d52a580bSJunchao Zhang     Mat_SeqAIJ     *Aseq = (Mat_SeqAIJ *)A->data;
1524d52a580bSJunchao Zhang     PetscInt       *Ai, nzRow, nzLeft;
1525d52a580bSJunchao Zhang     PetscLogDouble  flops = 0.0;
1526d52a580bSJunchao Zhang     const PetscInt *Adiag;
1527d52a580bSJunchao Zhang 
1528d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &Adiag, NULL));
1529d52a580bSJunchao Zhang     Ai = Aseq->i;
1530d52a580bSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
1531d52a580bSJunchao Zhang       if (Ai[i] < Adiag[i] && Adiag[i] < Ai[i + 1]) { /* There are nonzeros left to the diagonal of row i */
1532d52a580bSJunchao Zhang         nzRow  = Ai[i + 1] - Ai[i];
1533d52a580bSJunchao Zhang         nzLeft = Adiag[i] - Ai[i];
1534d52a580bSJunchao Zhang         /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1535d52a580bSJunchao Zhang           and include the eliminated one will be updated, which incurs a multiplication and an addition.
1536d52a580bSJunchao Zhang         */
1537d52a580bSJunchao Zhang         nzLeft = (nzRow - 1) / 2;
1538d52a580bSJunchao Zhang         flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1539d52a580bSJunchao Zhang       }
1540d52a580bSJunchao Zhang     }
1541d52a580bSJunchao Zhang     fs->numericFactFlops = flops;
1542d52a580bSJunchao Zhang   }
1543d52a580bSJunchao Zhang   fact->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0;
1544d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1545d52a580bSJunchao Zhang }
1546d52a580bSJunchao Zhang 
MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact,Vec b,Vec x)1547d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x)
1548d52a580bSJunchao Zhang {
1549d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1550d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1551d52a580bSJunchao Zhang   const PetscScalar             *barray;
1552d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1553d52a580bSJunchao Zhang 
1554d52a580bSJunchao Zhang   PetscFunctionBegin;
1555d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1556d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1557d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1558d52a580bSJunchao Zhang 
1559d52a580bSJunchao Zhang   /* Solve L*y = b */
1560d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1561d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1562d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1563d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1564d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L));
1565d52a580bSJunchao Zhang   #else
1566d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1567d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1568d52a580bSJunchao Zhang   #endif
1569d52a580bSJunchao Zhang   /* Solve Lt*x = y */
1570d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1571d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1572d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1573d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1574d52a580bSJunchao Zhang   #else
1575d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1576d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1577d52a580bSJunchao Zhang   #endif
1578d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1579d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1580d52a580bSJunchao Zhang 
1581d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1582d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1583d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1584d52a580bSJunchao Zhang }
1585d52a580bSJunchao Zhang 
MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact,Mat A,const MatFactorInfo * info)1586d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, const MatFactorInfo *info)
1587d52a580bSJunchao Zhang {
1588d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs    = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1589d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij   = (Mat_SeqAIJ *)fact->data;
1590d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1591d52a580bSJunchao Zhang   CsrMatrix                     *Acsr;
1592d52a580bSJunchao Zhang   PetscInt                       m, nz;
1593d52a580bSJunchao Zhang   PetscBool                      flg;
1594d52a580bSJunchao Zhang 
1595d52a580bSJunchao Zhang   PetscFunctionBegin;
1596d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1597d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1598d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1599d52a580bSJunchao Zhang   }
1600d52a580bSJunchao Zhang 
1601d52a580bSJunchao Zhang   /* Copy A's value to fact */
1602d52a580bSJunchao Zhang   m  = fact->rmap->n;
1603d52a580bSJunchao Zhang   nz = aij->nz;
1604d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1605d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Acusp->mat->mat;
1606d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1607d52a580bSJunchao Zhang 
1608d52a580bSJunchao Zhang   /* Factorize fact inplace */
1609d52a580bSJunchao Zhang   /* Function csric02() only takes the lower triangular part of matrix A to perform factorization.
1610d52a580bSJunchao Zhang      The matrix type must be HIPSPARSE_MATRIX_TYPE_GENERAL, the fill mode and diagonal type are ignored,
1611d52a580bSJunchao Zhang      and the strictly upper triangular part is ignored and never touched. It does not matter if A is Hermitian or not.
1612d52a580bSJunchao Zhang      In other words, from the point of view of csric02() A is Hermitian and only the lower triangular part is provided.
1613d52a580bSJunchao Zhang    */
1614d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1615d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1616d52a580bSJunchao Zhang     int               numerical_zero;
1617d52a580bSJunchao Zhang     hipsparseStatus_t status;
1618d52a580bSJunchao Zhang     status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &numerical_zero);
1619d52a580bSJunchao Zhang     PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csric02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1620d52a580bSJunchao Zhang   }
1621d52a580bSJunchao Zhang 
1622d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1623d52a580bSJunchao Zhang 
1624d52a580bSJunchao Zhang   /* Note that hipsparse reports this error if we use double and HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
1625d52a580bSJunchao Zhang     ** On entry to hipsparseSpSV_analysis(): conjugate transpose (opA) is not supported for matA data type, current -> CUDA_R_64F
1626d52a580bSJunchao Zhang   */
1627d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1628d52a580bSJunchao Zhang 
1629d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_GPU;
1630d52a580bSJunchao Zhang   fact->ops->solve             = MatSolve_SeqAIJHIPSPARSE_ICC0;
1631d52a580bSJunchao Zhang   fact->ops->solvetranspose    = MatSolve_SeqAIJHIPSPARSE_ICC0;
1632d52a580bSJunchao Zhang   fact->ops->matsolve          = NULL;
1633d52a580bSJunchao Zhang   fact->ops->matsolvetranspose = NULL;
1634d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1635d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1636d52a580bSJunchao Zhang }
1637d52a580bSJunchao Zhang 
MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact,Mat A,IS perm,const MatFactorInfo * info)1638d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, IS perm, const MatFactorInfo *info)
1639d52a580bSJunchao Zhang {
1640d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1641d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1642d52a580bSJunchao Zhang   PetscInt                       m, nz;
1643d52a580bSJunchao Zhang 
1644d52a580bSJunchao Zhang   PetscFunctionBegin;
1645d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1646d52a580bSJunchao Zhang     PetscBool flg, diagDense;
1647d52a580bSJunchao Zhang 
1648d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1649d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1650d52a580bSJunchao Zhang     PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1651d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1652d52a580bSJunchao Zhang     PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1653d52a580bSJunchao Zhang   }
1654d52a580bSJunchao Zhang 
1655d52a580bSJunchao Zhang   /* Free the old stale stuff */
1656d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1657d52a580bSJunchao Zhang 
1658d52a580bSJunchao Zhang   /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1659d52a580bSJunchao Zhang      but they will not be used. Allocate them just for easy debugging.
1660d52a580bSJunchao Zhang    */
1661d52a580bSJunchao Zhang   PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1662d52a580bSJunchao Zhang 
1663d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_BOTH;
1664d52a580bSJunchao Zhang   fact->factortype             = MAT_FACTOR_ICC;
1665d52a580bSJunchao Zhang   fact->info.factor_mallocs    = 0;
1666d52a580bSJunchao Zhang   fact->info.fill_ratio_given  = info->fill;
1667d52a580bSJunchao Zhang   fact->info.fill_ratio_needed = 1.0;
1668d52a580bSJunchao Zhang 
1669d52a580bSJunchao Zhang   aij->row = NULL;
1670d52a580bSJunchao Zhang   aij->col = NULL;
1671d52a580bSJunchao Zhang 
1672d52a580bSJunchao Zhang   /* ====================================================================== */
1673d52a580bSJunchao Zhang   /* Copy A's i, j to fact and also allocate the value array of fact.       */
1674d52a580bSJunchao Zhang   /* We'll do in-place factorization on fact                                */
1675d52a580bSJunchao Zhang   /* ====================================================================== */
1676d52a580bSJunchao Zhang   const int *Ai, *Aj;
1677d52a580bSJunchao Zhang 
1678d52a580bSJunchao Zhang   m  = fact->rmap->n;
1679d52a580bSJunchao Zhang   nz = aij->nz;
1680d52a580bSJunchao Zhang 
1681d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1682d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1683d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1684d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1685d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1686d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1687d52a580bSJunchao Zhang 
1688d52a580bSJunchao Zhang   /* ====================================================================== */
1689d52a580bSJunchao Zhang   /* Create mat descriptors for M, L                                        */
1690d52a580bSJunchao Zhang   /* ====================================================================== */
1691d52a580bSJunchao Zhang   hipsparseFillMode_t fillMode;
1692d52a580bSJunchao Zhang   hipsparseDiagType_t diagType;
1693d52a580bSJunchao Zhang 
1694d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1695d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1696d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1697d52a580bSJunchao Zhang 
1698d52a580bSJunchao Zhang   /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1699d52a580bSJunchao Zhang     hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1700d52a580bSJunchao Zhang     assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1701d52a580bSJunchao Zhang     all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1702d52a580bSJunchao Zhang     assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1703d52a580bSJunchao Zhang   */
1704d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_LOWER;
1705d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1706d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1707d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1708d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1709d52a580bSJunchao Zhang 
1710d52a580bSJunchao Zhang   /* ========================================================================= */
1711d52a580bSJunchao Zhang   /* Query buffer sizes for csric0, SpSV of L and Lt, and allocate buffers     */
1712d52a580bSJunchao Zhang   /* ========================================================================= */
1713d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsric02Info(&fs->ic0Info_M));
1714d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02_bufferSize(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, &fs->factBufferSize_M));
1715d52a580bSJunchao Zhang 
1716d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1717d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1718d52a580bSJunchao Zhang 
1719d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1720d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1721d52a580bSJunchao Zhang 
1722d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1723d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1724d52a580bSJunchao Zhang 
1725d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1726d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1727d52a580bSJunchao Zhang 
1728d52a580bSJunchao Zhang   /* To save device memory, we make the factorization buffer share with one of the solver buffer.
1729d52a580bSJunchao Zhang      See also comments in `MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0()`.
1730d52a580bSJunchao Zhang    */
1731d52a580bSJunchao Zhang   if (fs->spsvBufferSize_L > fs->spsvBufferSize_Lt) {
1732d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1733d52a580bSJunchao Zhang     fs->spsvBuffer_L = fs->factBuffer_M;
1734d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1735d52a580bSJunchao Zhang   } else {
1736d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_Lt, (size_t)fs->factBufferSize_M)));
1737d52a580bSJunchao Zhang     fs->spsvBuffer_Lt = fs->factBuffer_M;
1738d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1739d52a580bSJunchao Zhang   }
1740d52a580bSJunchao Zhang 
1741d52a580bSJunchao Zhang   /* ========================================================================== */
1742d52a580bSJunchao Zhang   /* Perform analysis of ic0 on M                                               */
1743d52a580bSJunchao Zhang   /* The lower triangular part of M has the same sparsity pattern as L          */
1744d52a580bSJunchao Zhang   /* ========================================================================== */
1745d52a580bSJunchao Zhang   int structural_zero;
1746d52a580bSJunchao Zhang 
1747d52a580bSJunchao Zhang   fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1748d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02_analysis(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1749d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1750d52a580bSJunchao Zhang     hipsparseStatus_t status;
1751d52a580bSJunchao Zhang     /* Function hipsparseXcsric02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1752d52a580bSJunchao Zhang     status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &structural_zero);
1753d52a580bSJunchao Zhang     PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csric02: A(%d,%d) is missing", structural_zero, structural_zero);
1754d52a580bSJunchao Zhang   }
1755d52a580bSJunchao Zhang 
1756d52a580bSJunchao Zhang   /* Estimate FLOPs of the numeric factorization */
1757d52a580bSJunchao Zhang   {
1758d52a580bSJunchao Zhang     Mat_SeqAIJ    *Aseq = (Mat_SeqAIJ *)A->data;
1759d52a580bSJunchao Zhang     PetscInt      *Ai, nzRow, nzLeft;
1760d52a580bSJunchao Zhang     PetscLogDouble flops = 0.0;
1761d52a580bSJunchao Zhang 
1762d52a580bSJunchao Zhang     Ai = Aseq->i;
1763d52a580bSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
1764d52a580bSJunchao Zhang       nzRow = Ai[i + 1] - Ai[i];
1765d52a580bSJunchao Zhang       if (nzRow > 1) {
1766d52a580bSJunchao Zhang         /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1767d52a580bSJunchao Zhang           and include the eliminated one will be updated, which incurs a multiplication and an addition.
1768d52a580bSJunchao Zhang         */
1769d52a580bSJunchao Zhang         nzLeft = (nzRow - 1) / 2;
1770d52a580bSJunchao Zhang         flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1771d52a580bSJunchao Zhang       }
1772d52a580bSJunchao Zhang     }
1773d52a580bSJunchao Zhang     fs->numericFactFlops = flops;
1774d52a580bSJunchao Zhang   }
1775d52a580bSJunchao Zhang   fact->ops->choleskyfactornumeric = MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0;
1776d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1777d52a580bSJunchao Zhang }
1778d52a580bSJunchao Zhang #endif
1779d52a580bSJunchao Zhang 
MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1780d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1781d52a580bSJunchao Zhang {
1782d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1783d52a580bSJunchao Zhang 
1784d52a580bSJunchao Zhang   PetscFunctionBegin;
1785d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1786d52a580bSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
1787d52a580bSJunchao Zhang   if (!info->factoronhost) {
1788d52a580bSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
1789d52a580bSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
1790d52a580bSJunchao Zhang   }
1791d52a580bSJunchao Zhang   if (!info->levels && row_identity && col_identity) PetscCall(MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(B, A, isrow, iscol, info));
1792d52a580bSJunchao Zhang   else
1793d52a580bSJunchao Zhang #endif
1794d52a580bSJunchao Zhang   {
1795d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1796d52a580bSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1797d52a580bSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1798d52a580bSJunchao Zhang   }
1799d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1800d52a580bSJunchao Zhang }
1801d52a580bSJunchao Zhang 
MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS isrow,IS iscol,const MatFactorInfo * info)1802d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1803d52a580bSJunchao Zhang {
1804d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1805d52a580bSJunchao Zhang 
1806d52a580bSJunchao Zhang   PetscFunctionBegin;
1807d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1808d52a580bSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1809d52a580bSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1810d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1811d52a580bSJunchao Zhang }
1812d52a580bSJunchao Zhang 
MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS perm,const MatFactorInfo * info)1813d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1814d52a580bSJunchao Zhang {
1815d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1816d52a580bSJunchao Zhang 
1817d52a580bSJunchao Zhang   PetscFunctionBegin;
1818d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1819d52a580bSJunchao Zhang   PetscBool perm_identity = PETSC_FALSE;
1820d52a580bSJunchao Zhang   if (!info->factoronhost) PetscCall(ISIdentity(perm, &perm_identity));
1821d52a580bSJunchao Zhang   if (!info->levels && perm_identity) PetscCall(MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(B, A, perm, info));
1822d52a580bSJunchao Zhang   else
1823d52a580bSJunchao Zhang #endif
1824d52a580bSJunchao Zhang   {
1825d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1826d52a580bSJunchao Zhang     PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
1827d52a580bSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1828d52a580bSJunchao Zhang   }
1829d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1830d52a580bSJunchao Zhang }
1831d52a580bSJunchao Zhang 
MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B,Mat A,IS perm,const MatFactorInfo * info)1832d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1833d52a580bSJunchao Zhang {
1834d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1835d52a580bSJunchao Zhang 
1836d52a580bSJunchao Zhang   PetscFunctionBegin;
1837d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1838d52a580bSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info));
1839d52a580bSJunchao Zhang   B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1840d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1841d52a580bSJunchao Zhang }
1842d52a580bSJunchao Zhang 
MatFactorGetSolverType_seqaij_hipsparse(Mat A,MatSolverType * type)1843d52a580bSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_seqaij_hipsparse(Mat A, MatSolverType *type)
1844d52a580bSJunchao Zhang {
1845d52a580bSJunchao Zhang   PetscFunctionBegin;
1846d52a580bSJunchao Zhang   *type = MATSOLVERHIPSPARSE;
1847d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1848d52a580bSJunchao Zhang }
1849d52a580bSJunchao Zhang 
1850d52a580bSJunchao Zhang /*MC
1851d52a580bSJunchao Zhang   MATSOLVERHIPSPARSE = "hipsparse" - A matrix type providing triangular solvers for sequential matrices
1852d52a580bSJunchao Zhang   on a single GPU of type, `MATSEQAIJHIPSPARSE`. Currently supported
1853d52a580bSJunchao Zhang   algorithms are ILU(k) and ICC(k). Typically, deeper factorizations (larger k) results in poorer
1854d52a580bSJunchao Zhang   performance in the triangular solves. Full LU, and Cholesky decompositions can be solved through the
1855d52a580bSJunchao Zhang   HipSPARSE triangular solve algorithm. However, the performance can be quite poor and thus these
1856d52a580bSJunchao Zhang   algorithms are not recommended. This class does NOT support direct solver operations.
1857d52a580bSJunchao Zhang 
1858d52a580bSJunchao Zhang   Level: beginner
1859d52a580bSJunchao Zhang 
1860d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
1861d52a580bSJunchao Zhang M*/
1862d52a580bSJunchao Zhang 
MatGetFactor_seqaijhipsparse_hipsparse(Mat A,MatFactorType ftype,Mat * B)1863d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijhipsparse_hipsparse(Mat A, MatFactorType ftype, Mat *B)
1864d52a580bSJunchao Zhang {
1865d52a580bSJunchao Zhang   PetscInt n = A->rmap->n;
1866d52a580bSJunchao Zhang 
1867d52a580bSJunchao Zhang   PetscFunctionBegin;
1868d52a580bSJunchao Zhang   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1869d52a580bSJunchao Zhang   PetscCall(MatSetSizes(*B, n, n, n, n));
1870d52a580bSJunchao Zhang   (*B)->factortype = ftype;
1871d52a580bSJunchao Zhang   PetscCall(MatSetType(*B, MATSEQAIJHIPSPARSE));
1872d52a580bSJunchao Zhang 
1873d52a580bSJunchao Zhang   if (A->boundtocpu && A->bindingpropagates) PetscCall(MatBindToCPU(*B, PETSC_TRUE));
1874d52a580bSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
1875d52a580bSJunchao Zhang     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1876d52a580bSJunchao Zhang     if (!A->boundtocpu) {
1877d52a580bSJunchao Zhang       (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJHIPSPARSE;
1878d52a580bSJunchao Zhang       (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJHIPSPARSE;
1879d52a580bSJunchao Zhang     } else {
1880d52a580bSJunchao Zhang       (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJ;
1881d52a580bSJunchao Zhang       (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJ;
1882d52a580bSJunchao Zhang     }
1883d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1884d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
1885d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
1886d52a580bSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
1887d52a580bSJunchao Zhang     if (!A->boundtocpu) {
1888d52a580bSJunchao Zhang       (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJHIPSPARSE;
1889d52a580bSJunchao Zhang       (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE;
1890d52a580bSJunchao Zhang     } else {
1891d52a580bSJunchao Zhang       (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJ;
1892d52a580bSJunchao Zhang       (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJ;
1893d52a580bSJunchao Zhang     }
1894d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
1895d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
1896d52a580bSJunchao Zhang   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for HIPSPARSE Matrix Types");
1897d52a580bSJunchao Zhang 
1898d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1899d52a580bSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
1900d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_hipsparse));
1901d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1902d52a580bSJunchao Zhang }
1903d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSECopyFromGPU(Mat A)1904d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat A)
1905d52a580bSJunchao Zhang {
1906d52a580bSJunchao Zhang   Mat_SeqAIJ          *a    = (Mat_SeqAIJ *)A->data;
1907d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1908d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1909d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1910d52a580bSJunchao Zhang #endif
1911d52a580bSJunchao Zhang 
1912d52a580bSJunchao Zhang   PetscFunctionBegin;
1913d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_GPU) {
1914d52a580bSJunchao Zhang     PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1915d52a580bSJunchao Zhang     if (A->factortype == MAT_FACTOR_NONE) {
1916d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)cusp->mat->mat;
1917d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(a->a, matrix->values->data().get(), a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1918d52a580bSJunchao Zhang     }
1919d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1920d52a580bSJunchao Zhang     else if (fs->csrVal) {
1921d52a580bSJunchao Zhang       /* We have a factorized matrix on device and are able to copy it to host */
1922d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(a->a, fs->csrVal, a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1923d52a580bSJunchao Zhang     }
1924d52a580bSJunchao Zhang #endif
1925d52a580bSJunchao Zhang     else
1926d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for copying this type of factorized matrix from device to host");
1927d52a580bSJunchao Zhang     PetscCall(PetscLogGpuToCpu(a->nz * sizeof(PetscScalar)));
1928d52a580bSJunchao Zhang     PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1929d52a580bSJunchao Zhang     A->offloadmask = PETSC_OFFLOAD_BOTH;
1930d52a580bSJunchao Zhang   }
1931d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1932d52a580bSJunchao Zhang }
1933d52a580bSJunchao Zhang 
MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1934d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1935d52a580bSJunchao Zhang {
1936d52a580bSJunchao Zhang   PetscFunctionBegin;
1937d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1938d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1939d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1940d52a580bSJunchao Zhang }
1941d52a580bSJunchao Zhang 
MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1942d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1943d52a580bSJunchao Zhang {
1944d52a580bSJunchao Zhang   PetscFunctionBegin;
1945d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_CPU;
1946d52a580bSJunchao Zhang   *array         = NULL;
1947d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1948d52a580bSJunchao Zhang }
1949d52a580bSJunchao Zhang 
MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A,const PetscScalar * array[])1950d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1951d52a580bSJunchao Zhang {
1952d52a580bSJunchao Zhang   PetscFunctionBegin;
1953d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1954d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1955d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1956d52a580bSJunchao Zhang }
1957d52a580bSJunchao Zhang 
MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A,const PetscScalar * array[])1958d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1959d52a580bSJunchao Zhang {
1960d52a580bSJunchao Zhang   PetscFunctionBegin;
1961d52a580bSJunchao Zhang   *array = NULL;
1962d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1963d52a580bSJunchao Zhang }
1964d52a580bSJunchao Zhang 
MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1965d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1966d52a580bSJunchao Zhang {
1967d52a580bSJunchao Zhang   PetscFunctionBegin;
1968d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1969d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1970d52a580bSJunchao Zhang }
1971d52a580bSJunchao Zhang 
MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A,PetscScalar * array[])1972d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1973d52a580bSJunchao Zhang {
1974d52a580bSJunchao Zhang   PetscFunctionBegin;
1975d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_CPU;
1976d52a580bSJunchao Zhang   *array         = NULL;
1977d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1978d52a580bSJunchao Zhang }
1979d52a580bSJunchao Zhang 
MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A,const PetscInt ** i,const PetscInt ** j,PetscScalar ** a,PetscMemType * mtype)1980d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
1981d52a580bSJunchao Zhang {
1982d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp;
1983d52a580bSJunchao Zhang   CsrMatrix           *matrix;
1984d52a580bSJunchao Zhang 
1985d52a580bSJunchao Zhang   PetscFunctionBegin;
1986d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1987d52a580bSJunchao Zhang   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Not for factored matrix");
1988d52a580bSJunchao Zhang   cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(A->spptr);
1989d52a580bSJunchao Zhang   PetscCheck(cusp != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "cusp is NULL");
1990d52a580bSJunchao Zhang   matrix = (CsrMatrix *)cusp->mat->mat;
1991d52a580bSJunchao Zhang 
1992d52a580bSJunchao Zhang   if (i) {
1993d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
1994d52a580bSJunchao Zhang     *i = matrix->row_offsets->data().get();
1995d52a580bSJunchao Zhang #else
1996d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
1997d52a580bSJunchao Zhang #endif
1998d52a580bSJunchao Zhang   }
1999d52a580bSJunchao Zhang   if (j) {
2000d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
2001d52a580bSJunchao Zhang     *j = matrix->column_indices->data().get();
2002d52a580bSJunchao Zhang #else
2003d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
2004d52a580bSJunchao Zhang #endif
2005d52a580bSJunchao Zhang   }
2006d52a580bSJunchao Zhang   if (a) *a = matrix->values->data().get();
2007d52a580bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_HIP;
2008d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2009d52a580bSJunchao Zhang }
2010d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSECopyToGPU(Mat A)2011d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat A)
2012d52a580bSJunchao Zhang {
2013d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2014d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct       = hipsparsestruct->mat;
2015d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a               = (Mat_SeqAIJ *)A->data;
2016d52a580bSJunchao Zhang   PetscBool                      both            = PETSC_TRUE;
2017d52a580bSJunchao Zhang   PetscInt                       m               = A->rmap->n, *ii, *ridx, tmp;
2018d52a580bSJunchao Zhang 
2019d52a580bSJunchao Zhang   PetscFunctionBegin;
2020d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PETSC_COMM_SELF, PETSC_ERR_GPU, "Cannot copy to GPU");
2021d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
2022d52a580bSJunchao Zhang     if (A->nonzerostate == hipsparsestruct->nonzerostate && hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* Copy values only */
2023d52a580bSJunchao Zhang       CsrMatrix *matrix;
2024d52a580bSJunchao Zhang       matrix = (CsrMatrix *)hipsparsestruct->mat->mat;
2025d52a580bSJunchao Zhang 
2026d52a580bSJunchao Zhang       PetscCheck(!a->nz || a->a, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR values");
2027d52a580bSJunchao Zhang       PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2028d52a580bSJunchao Zhang       matrix->values->assign(a->a, a->a + a->nz);
2029d52a580bSJunchao Zhang       PetscCallHIP(WaitForHIP());
2030d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu(a->nz * sizeof(PetscScalar)));
2031d52a580bSJunchao Zhang       PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2032d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
2033d52a580bSJunchao Zhang     } else {
2034d52a580bSJunchao Zhang       PetscInt nnz;
2035d52a580bSJunchao Zhang       PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2036d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&hipsparsestruct->mat, hipsparsestruct->format));
2037d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
2038d52a580bSJunchao Zhang       delete hipsparsestruct->workVector;
2039d52a580bSJunchao Zhang       delete hipsparsestruct->rowoffsets_gpu;
2040d52a580bSJunchao Zhang       hipsparsestruct->workVector     = NULL;
2041d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu = NULL;
2042d52a580bSJunchao Zhang       try {
2043d52a580bSJunchao Zhang         if (a->compressedrow.use) {
2044d52a580bSJunchao Zhang           m    = a->compressedrow.nrows;
2045d52a580bSJunchao Zhang           ii   = a->compressedrow.i;
2046d52a580bSJunchao Zhang           ridx = a->compressedrow.rindex;
2047d52a580bSJunchao Zhang         } else {
2048d52a580bSJunchao Zhang           m    = A->rmap->n;
2049d52a580bSJunchao Zhang           ii   = a->i;
2050d52a580bSJunchao Zhang           ridx = NULL;
2051d52a580bSJunchao Zhang         }
2052d52a580bSJunchao Zhang         PetscCheck(ii, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR row data");
2053d52a580bSJunchao Zhang         if (!a->a) {
2054d52a580bSJunchao Zhang           nnz  = ii[m];
2055d52a580bSJunchao Zhang           both = PETSC_FALSE;
2056d52a580bSJunchao Zhang         } else nnz = a->nz;
2057d52a580bSJunchao Zhang         PetscCheck(!nnz || a->j, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR column data");
2058d52a580bSJunchao Zhang 
2059d52a580bSJunchao Zhang         /* create hipsparse matrix */
2060d52a580bSJunchao Zhang         hipsparsestruct->nrows = m;
2061d52a580bSJunchao Zhang         matstruct              = new Mat_SeqAIJHIPSPARSEMultStruct;
2062d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstruct->descr));
2063d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstruct->descr, HIPSPARSE_INDEX_BASE_ZERO));
2064d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(matstruct->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2065d52a580bSJunchao Zhang 
2066d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->alpha_one, sizeof(PetscScalar)));
2067d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->beta_zero, sizeof(PetscScalar)));
2068d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->beta_one, sizeof(PetscScalar)));
2069d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2070d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2071d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2072d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2073d52a580bSJunchao Zhang 
2074d52a580bSJunchao Zhang         /* Build a hybrid/ellpack matrix if this option is chosen for the storage */
2075d52a580bSJunchao Zhang         if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
2076d52a580bSJunchao Zhang           /* set the matrix */
2077d52a580bSJunchao Zhang           CsrMatrix *mat      = new CsrMatrix;
2078d52a580bSJunchao Zhang           mat->num_rows       = m;
2079d52a580bSJunchao Zhang           mat->num_cols       = A->cmap->n;
2080d52a580bSJunchao Zhang           mat->num_entries    = nnz;
2081d52a580bSJunchao Zhang           mat->row_offsets    = new THRUSTINTARRAY32(m + 1);
2082d52a580bSJunchao Zhang           mat->column_indices = new THRUSTINTARRAY32(nnz);
2083d52a580bSJunchao Zhang           mat->values         = new THRUSTARRAY(nnz);
2084d52a580bSJunchao Zhang           mat->row_offsets->assign(ii, ii + m + 1);
2085d52a580bSJunchao Zhang           mat->column_indices->assign(a->j, a->j + nnz);
2086d52a580bSJunchao Zhang           if (a->a) mat->values->assign(a->a, a->a + nnz);
2087d52a580bSJunchao Zhang 
2088d52a580bSJunchao Zhang           /* assign the pointer */
2089d52a580bSJunchao Zhang           matstruct->mat = mat;
2090d52a580bSJunchao Zhang           if (mat->num_rows) { /* hipsparse errors on empty matrices! */
2091d52a580bSJunchao Zhang             PetscCallHIPSPARSE(hipsparseCreateCsr(&matstruct->matDescr, mat->num_rows, mat->num_cols, mat->num_entries, mat->row_offsets->data().get(), mat->column_indices->data().get(), mat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2092d52a580bSJunchao Zhang                                                   HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2093d52a580bSJunchao Zhang           }
2094d52a580bSJunchao Zhang         } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
2095d52a580bSJunchao Zhang           CsrMatrix *mat      = new CsrMatrix;
2096d52a580bSJunchao Zhang           mat->num_rows       = m;
2097d52a580bSJunchao Zhang           mat->num_cols       = A->cmap->n;
2098d52a580bSJunchao Zhang           mat->num_entries    = nnz;
2099d52a580bSJunchao Zhang           mat->row_offsets    = new THRUSTINTARRAY32(m + 1);
2100d52a580bSJunchao Zhang           mat->column_indices = new THRUSTINTARRAY32(nnz);
2101d52a580bSJunchao Zhang           mat->values         = new THRUSTARRAY(nnz);
2102d52a580bSJunchao Zhang           mat->row_offsets->assign(ii, ii + m + 1);
2103d52a580bSJunchao Zhang           mat->column_indices->assign(a->j, a->j + nnz);
2104d52a580bSJunchao Zhang           if (a->a) mat->values->assign(a->a, a->a + nnz);
2105d52a580bSJunchao Zhang 
2106d52a580bSJunchao Zhang           hipsparseHybMat_t hybMat;
2107d52a580bSJunchao Zhang           PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
2108d52a580bSJunchao Zhang           hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
2109d52a580bSJunchao Zhang           PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, mat->num_rows, mat->num_cols, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), hybMat, 0, partition));
2110d52a580bSJunchao Zhang           /* assign the pointer */
2111d52a580bSJunchao Zhang           matstruct->mat = hybMat;
2112d52a580bSJunchao Zhang 
2113d52a580bSJunchao Zhang           if (mat) {
2114d52a580bSJunchao Zhang             if (mat->values) delete (THRUSTARRAY *)mat->values;
2115d52a580bSJunchao Zhang             if (mat->column_indices) delete (THRUSTINTARRAY32 *)mat->column_indices;
2116d52a580bSJunchao Zhang             if (mat->row_offsets) delete (THRUSTINTARRAY32 *)mat->row_offsets;
2117d52a580bSJunchao Zhang             delete (CsrMatrix *)mat;
2118d52a580bSJunchao Zhang           }
2119d52a580bSJunchao Zhang         }
2120d52a580bSJunchao Zhang 
2121d52a580bSJunchao Zhang         /* assign the compressed row indices */
2122d52a580bSJunchao Zhang         if (a->compressedrow.use) {
2123d52a580bSJunchao Zhang           hipsparsestruct->workVector = new THRUSTARRAY(m);
2124d52a580bSJunchao Zhang           matstruct->cprowIndices     = new THRUSTINTARRAY(m);
2125d52a580bSJunchao Zhang           matstruct->cprowIndices->assign(ridx, ridx + m);
2126d52a580bSJunchao Zhang           tmp = m;
2127d52a580bSJunchao Zhang         } else {
2128d52a580bSJunchao Zhang           hipsparsestruct->workVector = NULL;
2129d52a580bSJunchao Zhang           matstruct->cprowIndices     = NULL;
2130d52a580bSJunchao Zhang           tmp                         = 0;
2131d52a580bSJunchao Zhang         }
2132d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(((m + 1) + (a->nz)) * sizeof(int) + tmp * sizeof(PetscInt) + (3 + (a->nz)) * sizeof(PetscScalar)));
2133d52a580bSJunchao Zhang 
2134d52a580bSJunchao Zhang         /* assign the pointer */
2135d52a580bSJunchao Zhang         hipsparsestruct->mat = matstruct;
2136d52a580bSJunchao Zhang       } catch (char *ex) {
2137d52a580bSJunchao Zhang         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
2138d52a580bSJunchao Zhang       }
2139d52a580bSJunchao Zhang       PetscCallHIP(WaitForHIP());
2140d52a580bSJunchao Zhang       PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2141d52a580bSJunchao Zhang       hipsparsestruct->nonzerostate = A->nonzerostate;
2142d52a580bSJunchao Zhang     }
2143d52a580bSJunchao Zhang     if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
2144d52a580bSJunchao Zhang   }
2145d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2146d52a580bSJunchao Zhang }
2147d52a580bSJunchao Zhang 
2148d52a580bSJunchao Zhang struct VecHIPPlusEquals {
2149d52a580bSJunchao Zhang   template <typename Tuple>
operator ()VecHIPPlusEquals2150d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2151d52a580bSJunchao Zhang   {
2152d52a580bSJunchao Zhang     thrust::get<1>(t) = thrust::get<1>(t) + thrust::get<0>(t);
2153d52a580bSJunchao Zhang   }
2154d52a580bSJunchao Zhang };
2155d52a580bSJunchao Zhang 
2156d52a580bSJunchao Zhang struct VecHIPEquals {
2157d52a580bSJunchao Zhang   template <typename Tuple>
operator ()VecHIPEquals2158d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2159d52a580bSJunchao Zhang   {
2160d52a580bSJunchao Zhang     thrust::get<1>(t) = thrust::get<0>(t);
2161d52a580bSJunchao Zhang   }
2162d52a580bSJunchao Zhang };
2163d52a580bSJunchao Zhang 
2164d52a580bSJunchao Zhang struct VecHIPEqualsReverse {
2165d52a580bSJunchao Zhang   template <typename Tuple>
operator ()VecHIPEqualsReverse2166d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2167d52a580bSJunchao Zhang   {
2168d52a580bSJunchao Zhang     thrust::get<0>(t) = thrust::get<1>(t);
2169d52a580bSJunchao Zhang   }
2170d52a580bSJunchao Zhang };
2171d52a580bSJunchao Zhang 
2172d52a580bSJunchao Zhang struct MatProductCtx_MatMatHipsparse {
2173d52a580bSJunchao Zhang   PetscBool             cisdense;
2174d52a580bSJunchao Zhang   PetscScalar          *Bt;
2175d52a580bSJunchao Zhang   Mat                   X;
2176d52a580bSJunchao Zhang   PetscBool             reusesym; /* Hipsparse does not have split symbolic and numeric phases for sparse matmat operations */
2177d52a580bSJunchao Zhang   PetscLogDouble        flops;
2178d52a580bSJunchao Zhang   CsrMatrix            *Bcsr;
2179d52a580bSJunchao Zhang   hipsparseSpMatDescr_t matSpBDescr;
2180d52a580bSJunchao Zhang   PetscBool             initialized; /* C = alpha op(A) op(B) + beta C */
2181d52a580bSJunchao Zhang   hipsparseDnMatDescr_t matBDescr;
2182d52a580bSJunchao Zhang   hipsparseDnMatDescr_t matCDescr;
2183d52a580bSJunchao Zhang   PetscInt              Blda, Clda; /* Record leading dimensions of B and C here to detect changes*/
2184d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2185d52a580bSJunchao Zhang   void *dBuffer4, *dBuffer5;
2186d52a580bSJunchao Zhang #endif
2187d52a580bSJunchao Zhang   size_t                 mmBufferSize;
2188d52a580bSJunchao Zhang   void                  *mmBuffer, *mmBuffer2; /* SpGEMM WorkEstimation buffer */
2189d52a580bSJunchao Zhang   hipsparseSpGEMMDescr_t spgemmDesc;
2190d52a580bSJunchao Zhang };
2191d52a580bSJunchao Zhang 
MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data)2192d52a580bSJunchao Zhang static PetscErrorCode MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data)
2193d52a580bSJunchao Zhang {
2194d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata = *(MatProductCtx_MatMatHipsparse **)data;
2195d52a580bSJunchao Zhang 
2196d52a580bSJunchao Zhang   PetscFunctionBegin;
2197d52a580bSJunchao Zhang   PetscCallHIP(hipFree(mmdata->Bt));
2198d52a580bSJunchao Zhang   delete mmdata->Bcsr;
2199d52a580bSJunchao Zhang   if (mmdata->matSpBDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mmdata->matSpBDescr));
2200d52a580bSJunchao Zhang   if (mmdata->matBDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2201d52a580bSJunchao Zhang   if (mmdata->matCDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2202d52a580bSJunchao Zhang   if (mmdata->spgemmDesc) PetscCallHIPSPARSE(hipsparseSpGEMM_destroyDescr(mmdata->spgemmDesc));
2203d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2204d52a580bSJunchao Zhang   if (mmdata->dBuffer4) PetscCallHIP(hipFree(mmdata->dBuffer4));
2205d52a580bSJunchao Zhang   if (mmdata->dBuffer5) PetscCallHIP(hipFree(mmdata->dBuffer5));
2206d52a580bSJunchao Zhang #endif
2207d52a580bSJunchao Zhang   if (mmdata->mmBuffer) PetscCallHIP(hipFree(mmdata->mmBuffer));
2208d52a580bSJunchao Zhang   if (mmdata->mmBuffer2) PetscCallHIP(hipFree(mmdata->mmBuffer2));
2209d52a580bSJunchao Zhang   PetscCall(MatDestroy(&mmdata->X));
2210d52a580bSJunchao Zhang   PetscCall(PetscFree(*(void **)data));
2211d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2212d52a580bSJunchao Zhang }
2213d52a580bSJunchao Zhang 
MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)2214d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2215d52a580bSJunchao Zhang {
2216d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2217d52a580bSJunchao Zhang   Mat                            A, B;
2218d52a580bSJunchao Zhang   PetscInt                       m, n, blda, clda;
2219d52a580bSJunchao Zhang   PetscBool                      flg, biship;
2220d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *cusp;
2221d52a580bSJunchao Zhang   hipsparseOperation_t           opA;
2222d52a580bSJunchao Zhang   const PetscScalar             *barray;
2223d52a580bSJunchao Zhang   PetscScalar                   *carray;
2224d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2225d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *mat;
2226d52a580bSJunchao Zhang   CsrMatrix                     *csrmat;
2227d52a580bSJunchao Zhang 
2228d52a580bSJunchao Zhang   PetscFunctionBegin;
2229d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2230d52a580bSJunchao Zhang   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2231d52a580bSJunchao Zhang   mmdata = (MatProductCtx_MatMatHipsparse *)product->data;
2232d52a580bSJunchao Zhang   A      = product->A;
2233d52a580bSJunchao Zhang   B      = product->B;
2234d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2235d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2236d52a580bSJunchao Zhang   /* currently CopyToGpu does not copy if the matrix is bound to CPU
2237d52a580bSJunchao Zhang      Instead of silently accepting the wrong answer, I prefer to raise the error */
2238d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2239d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2240d52a580bSJunchao Zhang   cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2241d52a580bSJunchao Zhang   switch (product->type) {
2242d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2243d52a580bSJunchao Zhang   case MATPRODUCT_PtAP:
2244d52a580bSJunchao Zhang     mat = cusp->mat;
2245d52a580bSJunchao Zhang     opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2246d52a580bSJunchao Zhang     m   = A->rmap->n;
2247d52a580bSJunchao Zhang     n   = B->cmap->n;
2248d52a580bSJunchao Zhang     break;
2249d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2250d52a580bSJunchao Zhang     if (!A->form_explicit_transpose) {
2251d52a580bSJunchao Zhang       mat = cusp->mat;
2252d52a580bSJunchao Zhang       opA = HIPSPARSE_OPERATION_TRANSPOSE;
2253d52a580bSJunchao Zhang     } else {
2254d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2255d52a580bSJunchao Zhang       mat = cusp->matTranspose;
2256d52a580bSJunchao Zhang       opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2257d52a580bSJunchao Zhang     }
2258d52a580bSJunchao Zhang     m = A->cmap->n;
2259d52a580bSJunchao Zhang     n = B->cmap->n;
2260d52a580bSJunchao Zhang     break;
2261d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2262d52a580bSJunchao Zhang   case MATPRODUCT_RARt:
2263d52a580bSJunchao Zhang     mat = cusp->mat;
2264d52a580bSJunchao Zhang     opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2265d52a580bSJunchao Zhang     m   = A->rmap->n;
2266d52a580bSJunchao Zhang     n   = B->rmap->n;
2267d52a580bSJunchao Zhang     break;
2268d52a580bSJunchao Zhang   default:
2269d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2270d52a580bSJunchao Zhang   }
2271d52a580bSJunchao Zhang   PetscCheck(mat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
2272d52a580bSJunchao Zhang   csrmat = (CsrMatrix *)mat->mat;
2273d52a580bSJunchao Zhang   /* if the user passed a CPU matrix, copy the data to the GPU */
2274d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQDENSEHIP, &biship));
2275d52a580bSJunchao Zhang   if (!biship) PetscCall(MatConvert(B, MATSEQDENSEHIP, MAT_INPLACE_MATRIX, &B));
2276d52a580bSJunchao Zhang   PetscCall(MatDenseGetArrayReadAndMemType(B, &barray, nullptr));
2277d52a580bSJunchao Zhang   PetscCall(MatDenseGetLDA(B, &blda));
2278d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2279d52a580bSJunchao Zhang     PetscCall(MatDenseGetArrayWriteAndMemType(mmdata->X, &carray, nullptr));
2280d52a580bSJunchao Zhang     PetscCall(MatDenseGetLDA(mmdata->X, &clda));
2281d52a580bSJunchao Zhang   } else {
2282d52a580bSJunchao Zhang     PetscCall(MatDenseGetArrayWriteAndMemType(C, &carray, nullptr));
2283d52a580bSJunchao Zhang     PetscCall(MatDenseGetLDA(C, &clda));
2284d52a580bSJunchao Zhang   }
2285d52a580bSJunchao Zhang 
2286d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2287d52a580bSJunchao Zhang   hipsparseOperation_t opB = (product->type == MATPRODUCT_ABt || product->type == MATPRODUCT_RARt) ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE;
2288d52a580bSJunchao Zhang   /* (re)allocate mmBuffer if not initialized or LDAs are different */
2289d52a580bSJunchao Zhang   if (!mmdata->initialized || mmdata->Blda != blda || mmdata->Clda != clda) {
2290d52a580bSJunchao Zhang     size_t mmBufferSize;
2291d52a580bSJunchao Zhang     if (mmdata->initialized && mmdata->Blda != blda) {
2292d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2293d52a580bSJunchao Zhang       mmdata->matBDescr = NULL;
2294d52a580bSJunchao Zhang     }
2295d52a580bSJunchao Zhang     if (!mmdata->matBDescr) {
2296d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matBDescr, B->rmap->n, B->cmap->n, blda, (void *)barray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2297d52a580bSJunchao Zhang       mmdata->Blda = blda;
2298d52a580bSJunchao Zhang     }
2299d52a580bSJunchao Zhang     if (mmdata->initialized && mmdata->Clda != clda) {
2300d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2301d52a580bSJunchao Zhang       mmdata->matCDescr = NULL;
2302d52a580bSJunchao Zhang     }
2303d52a580bSJunchao Zhang     if (!mmdata->matCDescr) { /* matCDescr is for C or mmdata->X */
2304d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matCDescr, m, n, clda, (void *)carray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2305d52a580bSJunchao Zhang       mmdata->Clda = clda;
2306d52a580bSJunchao Zhang     }
2307d52a580bSJunchao Zhang     if (!mat->matDescr) {
2308d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&mat->matDescr, csrmat->num_rows, csrmat->num_cols, csrmat->num_entries, csrmat->row_offsets->data().get(), csrmat->column_indices->data().get(), csrmat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2309d52a580bSJunchao Zhang                                             HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2310d52a580bSJunchao Zhang     }
2311d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMM_bufferSize(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, &mmBufferSize));
2312d52a580bSJunchao Zhang     if ((mmdata->mmBuffer && mmdata->mmBufferSize < mmBufferSize) || !mmdata->mmBuffer) {
2313d52a580bSJunchao Zhang       PetscCallHIP(hipFree(mmdata->mmBuffer));
2314d52a580bSJunchao Zhang       PetscCallHIP(hipMalloc(&mmdata->mmBuffer, mmBufferSize));
2315d52a580bSJunchao Zhang       mmdata->mmBufferSize = mmBufferSize;
2316d52a580bSJunchao Zhang     }
2317d52a580bSJunchao Zhang     mmdata->initialized = PETSC_TRUE;
2318d52a580bSJunchao Zhang   } else {
2319d52a580bSJunchao Zhang     /* to be safe, always update pointers of the mats */
2320d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMatSetValues(mat->matDescr, csrmat->values->data().get()));
2321d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matBDescr, (void *)barray));
2322d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matCDescr, (void *)carray));
2323d52a580bSJunchao Zhang   }
2324d52a580bSJunchao Zhang 
2325d52a580bSJunchao Zhang   /* do hipsparseSpMM, which supports transpose on B */
2326d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMM(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, mmdata->mmBuffer));
2327d52a580bSJunchao Zhang 
2328d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2329d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(n * 2.0 * csrmat->num_entries));
2330d52a580bSJunchao Zhang   PetscCall(MatDenseRestoreArrayReadAndMemType(B, &barray));
2331d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt) {
2332d52a580bSJunchao Zhang     PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2333d52a580bSJunchao Zhang     PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_FALSE, PETSC_FALSE));
2334d52a580bSJunchao Zhang   } else if (product->type == MATPRODUCT_PtAP) {
2335d52a580bSJunchao Zhang     PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2336d52a580bSJunchao Zhang     PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_TRUE, PETSC_FALSE));
2337d52a580bSJunchao Zhang   } else PetscCall(MatDenseRestoreArrayWriteAndMemType(C, &carray));
2338d52a580bSJunchao Zhang   if (mmdata->cisdense) PetscCall(MatConvert(C, MATSEQDENSE, MAT_INPLACE_MATRIX, &C));
2339d52a580bSJunchao Zhang   if (!biship) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B));
2340d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2341d52a580bSJunchao Zhang }
2342d52a580bSJunchao Zhang 
MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)2343d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2344d52a580bSJunchao Zhang {
2345d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2346d52a580bSJunchao Zhang   Mat                            A, B;
2347d52a580bSJunchao Zhang   PetscInt                       m, n;
2348d52a580bSJunchao Zhang   PetscBool                      cisdense, flg;
2349d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2350d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *cusp;
2351d52a580bSJunchao Zhang 
2352d52a580bSJunchao Zhang   PetscFunctionBegin;
2353d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2354d52a580bSJunchao Zhang   PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2355d52a580bSJunchao Zhang   A = product->A;
2356d52a580bSJunchao Zhang   B = product->B;
2357d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2358d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2359d52a580bSJunchao Zhang   cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2360d52a580bSJunchao Zhang   PetscCheck(cusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2361d52a580bSJunchao Zhang   switch (product->type) {
2362d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2363d52a580bSJunchao Zhang     m = A->rmap->n;
2364d52a580bSJunchao Zhang     n = B->cmap->n;
2365d52a580bSJunchao Zhang     break;
2366d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2367d52a580bSJunchao Zhang     m = A->cmap->n;
2368d52a580bSJunchao Zhang     n = B->cmap->n;
2369d52a580bSJunchao Zhang     break;
2370d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2371d52a580bSJunchao Zhang     m = A->rmap->n;
2372d52a580bSJunchao Zhang     n = B->rmap->n;
2373d52a580bSJunchao Zhang     break;
2374d52a580bSJunchao Zhang   case MATPRODUCT_PtAP:
2375d52a580bSJunchao Zhang     m = B->cmap->n;
2376d52a580bSJunchao Zhang     n = B->cmap->n;
2377d52a580bSJunchao Zhang     break;
2378d52a580bSJunchao Zhang   case MATPRODUCT_RARt:
2379d52a580bSJunchao Zhang     m = B->rmap->n;
2380d52a580bSJunchao Zhang     n = B->rmap->n;
2381d52a580bSJunchao Zhang     break;
2382d52a580bSJunchao Zhang   default:
2383d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2384d52a580bSJunchao Zhang   }
2385d52a580bSJunchao Zhang   PetscCall(MatSetSizes(C, m, n, m, n));
2386d52a580bSJunchao Zhang   /* if C is of type MATSEQDENSE (CPU), perform the operation on the GPU and then copy on the CPU */
2387d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQDENSE, &cisdense));
2388d52a580bSJunchao Zhang   PetscCall(MatSetType(C, MATSEQDENSEHIP));
2389d52a580bSJunchao Zhang 
2390d52a580bSJunchao Zhang   /* product data */
2391d52a580bSJunchao Zhang   PetscCall(PetscNew(&mmdata));
2392d52a580bSJunchao Zhang   mmdata->cisdense = cisdense;
2393d52a580bSJunchao Zhang   /* for these products we need intermediate storage */
2394d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2395d52a580bSJunchao Zhang     PetscCall(MatCreate(PetscObjectComm((PetscObject)C), &mmdata->X));
2396d52a580bSJunchao Zhang     PetscCall(MatSetType(mmdata->X, MATSEQDENSEHIP));
2397d52a580bSJunchao Zhang     /* do not preallocate, since the first call to MatDenseHIPGetArray will preallocate on the GPU for us */
2398d52a580bSJunchao Zhang     if (product->type == MATPRODUCT_RARt) PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->rmap->n, A->rmap->n, B->rmap->n));
2399d52a580bSJunchao Zhang     else PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->cmap->n, A->rmap->n, B->cmap->n));
2400d52a580bSJunchao Zhang   }
2401d52a580bSJunchao Zhang   C->product->data       = mmdata;
2402d52a580bSJunchao Zhang   C->product->destroy    = MatProductCtxDestroy_MatMatHipsparse;
2403d52a580bSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP;
2404d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2405d52a580bSJunchao Zhang }
2406d52a580bSJunchao Zhang 
MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)2407d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2408d52a580bSJunchao Zhang {
2409d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2410d52a580bSJunchao Zhang   Mat                            A, B;
2411d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp, *Bcusp, *Ccusp;
2412d52a580bSJunchao Zhang   Mat_SeqAIJ                    *c = (Mat_SeqAIJ *)C->data;
2413d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2414d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
2415d52a580bSJunchao Zhang   PetscBool                      flg;
2416d52a580bSJunchao Zhang   MatProductType                 ptype;
2417d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2418d52a580bSJunchao Zhang   hipsparseSpMatDescr_t          BmatSpDescr;
2419d52a580bSJunchao Zhang   hipsparseOperation_t           opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2420d52a580bSJunchao Zhang 
2421d52a580bSJunchao Zhang   PetscFunctionBegin;
2422d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2423d52a580bSJunchao Zhang   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2424d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQAIJHIPSPARSE, &flg));
2425d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for C of type %s", ((PetscObject)C)->type_name);
2426d52a580bSJunchao Zhang   mmdata = (MatProductCtx_MatMatHipsparse *)C->product->data;
2427d52a580bSJunchao Zhang   A      = product->A;
2428d52a580bSJunchao Zhang   B      = product->B;
2429d52a580bSJunchao Zhang   if (mmdata->reusesym) { /* this happens when api_user is true, meaning that the matrix values have been already computed in the MatProductSymbolic phase */
2430d52a580bSJunchao Zhang     mmdata->reusesym = PETSC_FALSE;
2431d52a580bSJunchao Zhang     Ccusp            = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2432d52a580bSJunchao Zhang     PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2433d52a580bSJunchao Zhang     Cmat = Ccusp->mat;
2434d52a580bSJunchao Zhang     PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[C->product->type]);
2435d52a580bSJunchao Zhang     Ccsr = (CsrMatrix *)Cmat->mat;
2436d52a580bSJunchao Zhang     PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2437d52a580bSJunchao Zhang     goto finalize;
2438d52a580bSJunchao Zhang   }
2439d52a580bSJunchao Zhang   if (!c->nz) goto finalize;
2440d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2441d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2442d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2443d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2444d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2445d52a580bSJunchao Zhang   PetscCheck(!B->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2446d52a580bSJunchao Zhang   Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2447d52a580bSJunchao Zhang   Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2448d52a580bSJunchao Zhang   Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2449d52a580bSJunchao Zhang   PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2450d52a580bSJunchao Zhang   PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2451d52a580bSJunchao Zhang   PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2452d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2453d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2454d52a580bSJunchao Zhang 
2455d52a580bSJunchao Zhang   ptype = product->type;
2456d52a580bSJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2457d52a580bSJunchao Zhang     ptype = MATPRODUCT_AB;
2458d52a580bSJunchao Zhang     PetscCheck(product->symbolic_used_the_fact_A_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that A is symmetric");
2459d52a580bSJunchao Zhang   }
2460d52a580bSJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2461d52a580bSJunchao Zhang     ptype = MATPRODUCT_AB;
2462d52a580bSJunchao Zhang     PetscCheck(product->symbolic_used_the_fact_B_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that B is symmetric");
2463d52a580bSJunchao Zhang   }
2464d52a580bSJunchao Zhang   switch (ptype) {
2465d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2466d52a580bSJunchao Zhang     Amat = Acusp->mat;
2467d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2468d52a580bSJunchao Zhang     break;
2469d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2470d52a580bSJunchao Zhang     Amat = Acusp->matTranspose;
2471d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2472d52a580bSJunchao Zhang     break;
2473d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2474d52a580bSJunchao Zhang     Amat = Acusp->mat;
2475d52a580bSJunchao Zhang     Bmat = Bcusp->matTranspose;
2476d52a580bSJunchao Zhang     break;
2477d52a580bSJunchao Zhang   default:
2478d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2479d52a580bSJunchao Zhang   }
2480d52a580bSJunchao Zhang   Cmat = Ccusp->mat;
2481d52a580bSJunchao Zhang   PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2482d52a580bSJunchao Zhang   PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2483d52a580bSJunchao Zhang   PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[ptype]);
2484d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Amat->mat;
2485d52a580bSJunchao Zhang   Bcsr = mmdata->Bcsr ? mmdata->Bcsr : (CsrMatrix *)Bmat->mat; /* B may be in compressed row storage */
2486d52a580bSJunchao Zhang   Ccsr = (CsrMatrix *)Cmat->mat;
2487d52a580bSJunchao Zhang   PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2488d52a580bSJunchao Zhang   PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2489d52a580bSJunchao Zhang   PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2490d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2491d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2492d52a580bSJunchao Zhang   BmatSpDescr = mmdata->Bcsr ? mmdata->matSpBDescr : Bmat->matDescr; /* B may be in compressed row storage */
2493d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2494d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2495d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2496d52a580bSJunchao Zhang   #else
2497d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2498d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2499d52a580bSJunchao Zhang   #endif
2500d52a580bSJunchao Zhang #else
2501d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2502d52a580bSJunchao Zhang                                           Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2503d52a580bSJunchao Zhang                                           Ccsr->column_indices->data().get()));
2504d52a580bSJunchao Zhang #endif
2505d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(mmdata->flops));
2506d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
2507d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2508d52a580bSJunchao Zhang   C->offloadmask = PETSC_OFFLOAD_GPU;
2509d52a580bSJunchao Zhang finalize:
2510d52a580bSJunchao Zhang   /* shorter version of MatAssemblyEnd_SeqAIJ */
2511d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded, %" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
2512d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
2513d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
2514d52a580bSJunchao Zhang   c->reallocs = 0;
2515d52a580bSJunchao Zhang   C->info.mallocs += 0;
2516d52a580bSJunchao Zhang   C->info.nz_unneeded = 0;
2517d52a580bSJunchao Zhang   C->assembled = C->was_assembled = PETSC_TRUE;
2518d52a580bSJunchao Zhang   C->num_ass++;
2519d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2520d52a580bSJunchao Zhang }
2521d52a580bSJunchao Zhang 
MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)2522d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2523d52a580bSJunchao Zhang {
2524d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2525d52a580bSJunchao Zhang   Mat                            A, B;
2526d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp, *Bcusp, *Ccusp;
2527d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a, *b, *c;
2528d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2529d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
2530d52a580bSJunchao Zhang   PetscInt                       i, j, m, n, k;
2531d52a580bSJunchao Zhang   PetscBool                      flg;
2532d52a580bSJunchao Zhang   MatProductType                 ptype;
2533d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2534d52a580bSJunchao Zhang   PetscLogDouble                 flops;
2535d52a580bSJunchao Zhang   PetscBool                      biscompressed, ciscompressed;
2536d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2537d52a580bSJunchao Zhang   int64_t               C_num_rows1, C_num_cols1, C_nnz1;
2538d52a580bSJunchao Zhang   hipsparseSpMatDescr_t BmatSpDescr;
2539d52a580bSJunchao Zhang #else
2540d52a580bSJunchao Zhang   int cnz;
2541d52a580bSJunchao Zhang #endif
2542d52a580bSJunchao Zhang   hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2543d52a580bSJunchao Zhang 
2544d52a580bSJunchao Zhang   PetscFunctionBegin;
2545d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2546d52a580bSJunchao Zhang   PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2547d52a580bSJunchao Zhang   A = product->A;
2548d52a580bSJunchao Zhang   B = product->B;
2549d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2550d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2551d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2552d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2553d52a580bSJunchao Zhang   a = (Mat_SeqAIJ *)A->data;
2554d52a580bSJunchao Zhang   b = (Mat_SeqAIJ *)B->data;
2555d52a580bSJunchao Zhang   /* product data */
2556d52a580bSJunchao Zhang   PetscCall(PetscNew(&mmdata));
2557d52a580bSJunchao Zhang   C->product->data    = mmdata;
2558d52a580bSJunchao Zhang   C->product->destroy = MatProductCtxDestroy_MatMatHipsparse;
2559d52a580bSJunchao Zhang 
2560d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2561d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2562d52a580bSJunchao Zhang   Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; /* Access spptr after MatSeqAIJHIPSPARSECopyToGPU, not before */
2563d52a580bSJunchao Zhang   Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2564d52a580bSJunchao Zhang   PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2565d52a580bSJunchao Zhang   PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2566d52a580bSJunchao Zhang 
2567d52a580bSJunchao Zhang   ptype = product->type;
2568d52a580bSJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2569d52a580bSJunchao Zhang     ptype                                          = MATPRODUCT_AB;
2570d52a580bSJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
2571d52a580bSJunchao Zhang   }
2572d52a580bSJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2573d52a580bSJunchao Zhang     ptype                                          = MATPRODUCT_AB;
2574d52a580bSJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
2575d52a580bSJunchao Zhang   }
2576d52a580bSJunchao Zhang   biscompressed = PETSC_FALSE;
2577d52a580bSJunchao Zhang   ciscompressed = PETSC_FALSE;
2578d52a580bSJunchao Zhang   switch (ptype) {
2579d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2580d52a580bSJunchao Zhang     m    = A->rmap->n;
2581d52a580bSJunchao Zhang     n    = B->cmap->n;
2582d52a580bSJunchao Zhang     k    = A->cmap->n;
2583d52a580bSJunchao Zhang     Amat = Acusp->mat;
2584d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2585d52a580bSJunchao Zhang     if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2586d52a580bSJunchao Zhang     if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2587d52a580bSJunchao Zhang     break;
2588d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2589d52a580bSJunchao Zhang     m = A->cmap->n;
2590d52a580bSJunchao Zhang     n = B->cmap->n;
2591d52a580bSJunchao Zhang     k = A->rmap->n;
2592d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2593d52a580bSJunchao Zhang     Amat = Acusp->matTranspose;
2594d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2595d52a580bSJunchao Zhang     if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2596d52a580bSJunchao Zhang     break;
2597d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2598d52a580bSJunchao Zhang     m = A->rmap->n;
2599d52a580bSJunchao Zhang     n = B->rmap->n;
2600d52a580bSJunchao Zhang     k = A->cmap->n;
2601d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
2602d52a580bSJunchao Zhang     Amat = Acusp->mat;
2603d52a580bSJunchao Zhang     Bmat = Bcusp->matTranspose;
2604d52a580bSJunchao Zhang     if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2605d52a580bSJunchao Zhang     break;
2606d52a580bSJunchao Zhang   default:
2607d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2608d52a580bSJunchao Zhang   }
2609d52a580bSJunchao Zhang 
2610d52a580bSJunchao Zhang   /* create hipsparse matrix */
2611d52a580bSJunchao Zhang   PetscCall(MatSetSizes(C, m, n, m, n));
2612d52a580bSJunchao Zhang   PetscCall(MatSetType(C, MATSEQAIJHIPSPARSE));
2613d52a580bSJunchao Zhang   c     = (Mat_SeqAIJ *)C->data;
2614d52a580bSJunchao Zhang   Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2615d52a580bSJunchao Zhang   Cmat  = new Mat_SeqAIJHIPSPARSEMultStruct;
2616d52a580bSJunchao Zhang   Ccsr  = new CsrMatrix;
2617d52a580bSJunchao Zhang 
2618d52a580bSJunchao Zhang   c->compressedrow.use = ciscompressed;
2619d52a580bSJunchao Zhang   if (c->compressedrow.use) { /* if a is in compressed row, than c will be in compressed row format */
2620d52a580bSJunchao Zhang     c->compressedrow.nrows = a->compressedrow.nrows;
2621d52a580bSJunchao Zhang     PetscCall(PetscMalloc2(c->compressedrow.nrows + 1, &c->compressedrow.i, c->compressedrow.nrows, &c->compressedrow.rindex));
2622d52a580bSJunchao Zhang     PetscCall(PetscArraycpy(c->compressedrow.rindex, a->compressedrow.rindex, c->compressedrow.nrows));
2623d52a580bSJunchao Zhang     Ccusp->workVector  = new THRUSTARRAY(c->compressedrow.nrows);
2624d52a580bSJunchao Zhang     Cmat->cprowIndices = new THRUSTINTARRAY(c->compressedrow.nrows);
2625d52a580bSJunchao Zhang     Cmat->cprowIndices->assign(c->compressedrow.rindex, c->compressedrow.rindex + c->compressedrow.nrows);
2626d52a580bSJunchao Zhang   } else {
2627d52a580bSJunchao Zhang     c->compressedrow.nrows  = 0;
2628d52a580bSJunchao Zhang     c->compressedrow.i      = NULL;
2629d52a580bSJunchao Zhang     c->compressedrow.rindex = NULL;
2630d52a580bSJunchao Zhang     Ccusp->workVector       = NULL;
2631d52a580bSJunchao Zhang     Cmat->cprowIndices      = NULL;
2632d52a580bSJunchao Zhang   }
2633d52a580bSJunchao Zhang   Ccusp->nrows      = ciscompressed ? c->compressedrow.nrows : m;
2634d52a580bSJunchao Zhang   Ccusp->mat        = Cmat;
2635d52a580bSJunchao Zhang   Ccusp->mat->mat   = Ccsr;
2636d52a580bSJunchao Zhang   Ccsr->num_rows    = Ccusp->nrows;
2637d52a580bSJunchao Zhang   Ccsr->num_cols    = n;
2638d52a580bSJunchao Zhang   Ccsr->row_offsets = new THRUSTINTARRAY32(Ccusp->nrows + 1);
2639d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
2640d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
2641d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2642d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
2643d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
2644d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
2645d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2646d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2647d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2648d52a580bSJunchao Zhang   if (!Ccsr->num_rows || !Ccsr->num_cols || !a->nz || !b->nz) { /* hipsparse raise errors in different calls when matrices have zero rows/columns! */
2649d52a580bSJunchao Zhang     thrust::fill(thrust::device, Ccsr->row_offsets->begin(), Ccsr->row_offsets->end(), 0);
2650d52a580bSJunchao Zhang     c->nz                = 0;
2651d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2652d52a580bSJunchao Zhang     Ccsr->values         = new THRUSTARRAY(c->nz);
2653d52a580bSJunchao Zhang     goto finalizesym;
2654d52a580bSJunchao Zhang   }
2655d52a580bSJunchao Zhang 
2656d52a580bSJunchao Zhang   PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2657d52a580bSJunchao Zhang   PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2658d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Amat->mat;
2659d52a580bSJunchao Zhang   if (!biscompressed) {
2660d52a580bSJunchao Zhang     Bcsr        = (CsrMatrix *)Bmat->mat;
2661d52a580bSJunchao Zhang     BmatSpDescr = Bmat->matDescr;
2662d52a580bSJunchao Zhang   } else { /* we need to use row offsets for the full matrix */
2663d52a580bSJunchao Zhang     CsrMatrix *cBcsr     = (CsrMatrix *)Bmat->mat;
2664d52a580bSJunchao Zhang     Bcsr                 = new CsrMatrix;
2665d52a580bSJunchao Zhang     Bcsr->num_rows       = B->rmap->n;
2666d52a580bSJunchao Zhang     Bcsr->num_cols       = cBcsr->num_cols;
2667d52a580bSJunchao Zhang     Bcsr->num_entries    = cBcsr->num_entries;
2668d52a580bSJunchao Zhang     Bcsr->column_indices = cBcsr->column_indices;
2669d52a580bSJunchao Zhang     Bcsr->values         = cBcsr->values;
2670d52a580bSJunchao Zhang     if (!Bcusp->rowoffsets_gpu) {
2671d52a580bSJunchao Zhang       Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
2672d52a580bSJunchao Zhang       Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
2673d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
2674d52a580bSJunchao Zhang     }
2675d52a580bSJunchao Zhang     Bcsr->row_offsets = Bcusp->rowoffsets_gpu;
2676d52a580bSJunchao Zhang     mmdata->Bcsr      = Bcsr;
2677d52a580bSJunchao Zhang     if (Bcsr->num_rows && Bcsr->num_cols) {
2678d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&mmdata->matSpBDescr, Bcsr->num_rows, Bcsr->num_cols, Bcsr->num_entries, Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Bcsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2679d52a580bSJunchao Zhang     }
2680d52a580bSJunchao Zhang     BmatSpDescr = mmdata->matSpBDescr;
2681d52a580bSJunchao Zhang   }
2682d52a580bSJunchao Zhang   PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2683d52a580bSJunchao Zhang   PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2684d52a580bSJunchao Zhang   /* precompute flops count */
2685d52a580bSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
2686d52a580bSJunchao Zhang     for (i = 0, flops = 0; i < A->rmap->n; i++) {
2687d52a580bSJunchao Zhang       const PetscInt st = a->i[i];
2688d52a580bSJunchao Zhang       const PetscInt en = a->i[i + 1];
2689d52a580bSJunchao Zhang       for (j = st; j < en; j++) {
2690d52a580bSJunchao Zhang         const PetscInt brow = a->j[j];
2691d52a580bSJunchao Zhang         flops += 2. * (b->i[brow + 1] - b->i[brow]);
2692d52a580bSJunchao Zhang       }
2693d52a580bSJunchao Zhang     }
2694d52a580bSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
2695d52a580bSJunchao Zhang     for (i = 0, flops = 0; i < A->rmap->n; i++) {
2696d52a580bSJunchao Zhang       const PetscInt anzi = a->i[i + 1] - a->i[i];
2697d52a580bSJunchao Zhang       const PetscInt bnzi = b->i[i + 1] - b->i[i];
2698d52a580bSJunchao Zhang       flops += (2. * anzi) * bnzi;
2699d52a580bSJunchao Zhang     }
2700d52a580bSJunchao Zhang   } else flops = 0.; /* TODO */
2701d52a580bSJunchao Zhang 
2702d52a580bSJunchao Zhang   mmdata->flops = flops;
2703d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2704d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2705d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2706d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, 0, Ccsr->row_offsets->data().get(), NULL, NULL, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2707d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_createDescr(&mmdata->spgemmDesc));
2708d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2709d52a580bSJunchao Zhang   {
2710d52a580bSJunchao Zhang     /* hipsparseSpGEMMreuse has more reasonable APIs than hipsparseSpGEMM, so we prefer to use it.
2711d52a580bSJunchao Zhang      We follow the sample code at https://github.com/ROCmSoftwarePlatform/hipSPARSE/blob/develop/clients/include/testing_spgemmreuse_csr.hpp
2712d52a580bSJunchao Zhang   */
2713d52a580bSJunchao Zhang     void *dBuffer1 = NULL;
2714d52a580bSJunchao Zhang     void *dBuffer2 = NULL;
2715d52a580bSJunchao Zhang     void *dBuffer3 = NULL;
2716d52a580bSJunchao Zhang     /* dBuffer4, dBuffer5 are needed by hipsparseSpGEMMreuse_compute, and therefore are stored in mmdata */
2717d52a580bSJunchao Zhang     size_t bufferSize1 = 0;
2718d52a580bSJunchao Zhang     size_t bufferSize2 = 0;
2719d52a580bSJunchao Zhang     size_t bufferSize3 = 0;
2720d52a580bSJunchao Zhang     size_t bufferSize4 = 0;
2721d52a580bSJunchao Zhang     size_t bufferSize5 = 0;
2722d52a580bSJunchao Zhang 
2723d52a580bSJunchao Zhang     /* ask bufferSize1 bytes for external memory */
2724d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, NULL));
2725d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer1, bufferSize1));
2726d52a580bSJunchao Zhang     /* inspect the matrices A and B to understand the memory requirement for the next step */
2727d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, dBuffer1));
2728d52a580bSJunchao Zhang 
2729d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, NULL, &bufferSize3, NULL, &bufferSize4, NULL));
2730d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer2, bufferSize2));
2731d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer3, bufferSize3));
2732d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer4, bufferSize4));
2733d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, dBuffer2, &bufferSize3, dBuffer3, &bufferSize4, mmdata->dBuffer4));
2734d52a580bSJunchao Zhang     PetscCallHIP(hipFree(dBuffer1));
2735d52a580bSJunchao Zhang     PetscCallHIP(hipFree(dBuffer2));
2736d52a580bSJunchao Zhang 
2737d52a580bSJunchao Zhang     /* get matrix C non-zero entries C_nnz1 */
2738d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2739d52a580bSJunchao Zhang     c->nz = (PetscInt)C_nnz1;
2740d52a580bSJunchao Zhang     /* allocate matrix C */
2741d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2742d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2743d52a580bSJunchao Zhang     Ccsr->values = new THRUSTARRAY(c->nz);
2744d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2745d52a580bSJunchao Zhang     /* update matC with the new pointers */
2746d52a580bSJunchao Zhang     if (c->nz) { /* 5.5.1 has a bug with nz = 0, exposed by mat_tests_ex123_2_hypre */
2747d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2748d52a580bSJunchao Zhang 
2749d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, NULL));
2750d52a580bSJunchao Zhang       PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer5, bufferSize5));
2751d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, mmdata->dBuffer5));
2752d52a580bSJunchao Zhang       PetscCallHIP(hipFree(dBuffer3));
2753d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2754d52a580bSJunchao Zhang     }
2755d52a580bSJunchao Zhang     PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufferSize4 / 1024, bufferSize5 / 1024));
2756d52a580bSJunchao Zhang   }
2757d52a580bSJunchao Zhang   #else
2758d52a580bSJunchao Zhang   size_t bufSize2;
2759d52a580bSJunchao Zhang   /* ask bufferSize bytes for external memory */
2760d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, NULL));
2761d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer2, bufSize2));
2762d52a580bSJunchao Zhang   /* inspect the matrices A and B to understand the memory requirement for the next step */
2763d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, mmdata->mmBuffer2));
2764d52a580bSJunchao Zhang   /* ask bufferSize again bytes for external memory */
2765d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, NULL));
2766d52a580bSJunchao Zhang   /* Similar to CUSPARSE, we need both buffers to perform the operations properly!
2767d52a580bSJunchao Zhang      mmdata->mmBuffer2 does not appear anywhere in the compute/copy API
2768d52a580bSJunchao Zhang      it only appears for the workEstimation stuff, but it seems it is needed in compute, so probably the address
2769d52a580bSJunchao Zhang      is stored in the descriptor! What a messy API... */
2770d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer, mmdata->mmBufferSize));
2771d52a580bSJunchao Zhang   /* compute the intermediate product of A * B */
2772d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2773d52a580bSJunchao Zhang   /* get matrix C non-zero entries C_nnz1 */
2774d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2775d52a580bSJunchao Zhang   c->nz = (PetscInt)C_nnz1;
2776d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufSize2 / 1024,
2777d52a580bSJunchao Zhang                       mmdata->mmBufferSize / 1024));
2778d52a580bSJunchao Zhang   Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2779d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2780d52a580bSJunchao Zhang   Ccsr->values = new THRUSTARRAY(c->nz);
2781d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2782d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2783d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2784d52a580bSJunchao Zhang   #endif
2785d52a580bSJunchao Zhang #else
2786d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_HOST));
2787d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrgemmNnz(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, Bcsr->num_entries,
2788d52a580bSJunchao Zhang                                           Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->row_offsets->data().get(), &cnz));
2789d52a580bSJunchao Zhang   c->nz                = cnz;
2790d52a580bSJunchao Zhang   Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2791d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2792d52a580bSJunchao Zhang   Ccsr->values = new THRUSTARRAY(c->nz);
2793d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2794d52a580bSJunchao Zhang 
2795d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2796d52a580bSJunchao Zhang   /* with the old gemm interface (removed from 11.0 on) we cannot compute the symbolic factorization only.
2797d52a580bSJunchao Zhang       I have tried using the gemm2 interface (alpha * A * B + beta * D), which allows to do symbolic by passing NULL for values, but it seems quite buggy when
2798d52a580bSJunchao Zhang       D is NULL, despite the fact that CUSPARSE documentation claims it is supported! */
2799d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2800d52a580bSJunchao Zhang                                           Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2801d52a580bSJunchao Zhang                                           Ccsr->column_indices->data().get()));
2802d52a580bSJunchao Zhang #endif
2803d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(mmdata->flops));
2804d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2805d52a580bSJunchao Zhang finalizesym:
2806d52a580bSJunchao Zhang   c->free_a = PETSC_TRUE;
2807d52a580bSJunchao Zhang   PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
2808d52a580bSJunchao Zhang   PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
2809d52a580bSJunchao Zhang   c->free_ij = PETSC_TRUE;
2810d52a580bSJunchao Zhang   if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
2811d52a580bSJunchao Zhang     PetscInt      *d_i = c->i;
2812d52a580bSJunchao Zhang     THRUSTINTARRAY ii(Ccsr->row_offsets->size());
2813d52a580bSJunchao Zhang     THRUSTINTARRAY jj(Ccsr->column_indices->size());
2814d52a580bSJunchao Zhang     ii = *Ccsr->row_offsets;
2815d52a580bSJunchao Zhang     jj = *Ccsr->column_indices;
2816d52a580bSJunchao Zhang     if (ciscompressed) d_i = c->compressedrow.i;
2817d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(d_i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2818d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2819d52a580bSJunchao Zhang   } else {
2820d52a580bSJunchao Zhang     PetscInt *d_i = c->i;
2821d52a580bSJunchao Zhang     if (ciscompressed) d_i = c->compressedrow.i;
2822d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(d_i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2823d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2824d52a580bSJunchao Zhang   }
2825d52a580bSJunchao Zhang   if (ciscompressed) { /* need to expand host row offsets */
2826d52a580bSJunchao Zhang     PetscInt r = 0;
2827d52a580bSJunchao Zhang     c->i[0]    = 0;
2828d52a580bSJunchao Zhang     for (k = 0; k < c->compressedrow.nrows; k++) {
2829d52a580bSJunchao Zhang       const PetscInt next = c->compressedrow.rindex[k];
2830d52a580bSJunchao Zhang       const PetscInt old  = c->compressedrow.i[k];
2831d52a580bSJunchao Zhang       for (; r < next; r++) c->i[r + 1] = old;
2832d52a580bSJunchao Zhang     }
2833d52a580bSJunchao Zhang     for (; r < m; r++) c->i[r + 1] = c->compressedrow.i[c->compressedrow.nrows];
2834d52a580bSJunchao Zhang   }
2835d52a580bSJunchao Zhang   PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
2836d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(m, &c->ilen));
2837d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(m, &c->imax));
2838d52a580bSJunchao Zhang   c->maxnz         = c->nz;
2839d52a580bSJunchao Zhang   c->nonzerorowcnt = 0;
2840d52a580bSJunchao Zhang   c->rmax          = 0;
2841d52a580bSJunchao Zhang   for (k = 0; k < m; k++) {
2842d52a580bSJunchao Zhang     const PetscInt nn = c->i[k + 1] - c->i[k];
2843d52a580bSJunchao Zhang     c->ilen[k] = c->imax[k] = nn;
2844d52a580bSJunchao Zhang     c->nonzerorowcnt += (PetscInt)!!nn;
2845d52a580bSJunchao Zhang     c->rmax = PetscMax(c->rmax, nn);
2846d52a580bSJunchao Zhang   }
2847d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(c->nz, &c->a));
2848d52a580bSJunchao Zhang   Ccsr->num_entries = c->nz;
2849d52a580bSJunchao Zhang 
2850d52a580bSJunchao Zhang   C->nonzerostate++;
2851d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(C->rmap));
2852d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(C->cmap));
2853d52a580bSJunchao Zhang   Ccusp->nonzerostate = C->nonzerostate;
2854d52a580bSJunchao Zhang   C->offloadmask      = PETSC_OFFLOAD_UNALLOCATED;
2855d52a580bSJunchao Zhang   C->preallocated     = PETSC_TRUE;
2856d52a580bSJunchao Zhang   C->assembled        = PETSC_FALSE;
2857d52a580bSJunchao Zhang   C->was_assembled    = PETSC_FALSE;
2858d52a580bSJunchao Zhang   if (product->api_user && A->offloadmask == PETSC_OFFLOAD_BOTH && B->offloadmask == PETSC_OFFLOAD_BOTH) { /* flag the matrix C values as computed, so that the numeric phase will only call MatAssembly */
2859d52a580bSJunchao Zhang     mmdata->reusesym = PETSC_TRUE;
2860d52a580bSJunchao Zhang     C->offloadmask   = PETSC_OFFLOAD_GPU;
2861d52a580bSJunchao Zhang   }
2862d52a580bSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2863d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2864d52a580bSJunchao Zhang }
2865d52a580bSJunchao Zhang 
2866d52a580bSJunchao Zhang /* handles sparse or dense B */
MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat)2867d52a580bSJunchao Zhang static PetscErrorCode MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat)
2868d52a580bSJunchao Zhang {
2869d52a580bSJunchao Zhang   Mat_Product *product = mat->product;
2870d52a580bSJunchao Zhang   PetscBool    isdense = PETSC_FALSE, Biscusp = PETSC_FALSE, Ciscusp = PETSC_TRUE;
2871d52a580bSJunchao Zhang 
2872d52a580bSJunchao Zhang   PetscFunctionBegin;
2873d52a580bSJunchao Zhang   MatCheckProduct(mat, 1);
2874d52a580bSJunchao Zhang   PetscCall(PetscObjectBaseTypeCompare((PetscObject)product->B, MATSEQDENSE, &isdense));
2875d52a580bSJunchao Zhang   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJHIPSPARSE, &Biscusp));
2876d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_ABC) {
2877d52a580bSJunchao Zhang     Ciscusp = PETSC_FALSE;
2878d52a580bSJunchao Zhang     if (!product->C->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJHIPSPARSE, &Ciscusp));
2879d52a580bSJunchao Zhang   }
2880d52a580bSJunchao Zhang   if (Biscusp && Ciscusp) { /* we can always select the CPU backend */
2881d52a580bSJunchao Zhang     PetscBool usecpu = PETSC_FALSE;
2882d52a580bSJunchao Zhang     switch (product->type) {
2883d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2884d52a580bSJunchao Zhang       if (product->api_user) {
2885d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
2886d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2887d52a580bSJunchao Zhang         PetscOptionsEnd();
2888d52a580bSJunchao Zhang       } else {
2889d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
2890d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2891d52a580bSJunchao Zhang         PetscOptionsEnd();
2892d52a580bSJunchao Zhang       }
2893d52a580bSJunchao Zhang       break;
2894d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2895d52a580bSJunchao Zhang       if (product->api_user) {
2896d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
2897d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2898d52a580bSJunchao Zhang         PetscOptionsEnd();
2899d52a580bSJunchao Zhang       } else {
2900d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
2901d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2902d52a580bSJunchao Zhang         PetscOptionsEnd();
2903d52a580bSJunchao Zhang       }
2904d52a580bSJunchao Zhang       break;
2905d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2906d52a580bSJunchao Zhang       if (product->api_user) {
2907d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
2908d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2909d52a580bSJunchao Zhang         PetscOptionsEnd();
2910d52a580bSJunchao Zhang       } else {
2911d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
2912d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2913d52a580bSJunchao Zhang         PetscOptionsEnd();
2914d52a580bSJunchao Zhang       }
2915d52a580bSJunchao Zhang       break;
2916d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2917d52a580bSJunchao Zhang       if (product->api_user) {
2918d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatRARt", "Mat");
2919d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matrart_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2920d52a580bSJunchao Zhang         PetscOptionsEnd();
2921d52a580bSJunchao Zhang       } else {
2922d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_RARt", "Mat");
2923d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2924d52a580bSJunchao Zhang         PetscOptionsEnd();
2925d52a580bSJunchao Zhang       }
2926d52a580bSJunchao Zhang       break;
2927d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2928d52a580bSJunchao Zhang       if (product->api_user) {
2929d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMatMult", "Mat");
2930d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matmatmatmult_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2931d52a580bSJunchao Zhang         PetscOptionsEnd();
2932d52a580bSJunchao Zhang       } else {
2933d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_ABC", "Mat");
2934d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2935d52a580bSJunchao Zhang         PetscOptionsEnd();
2936d52a580bSJunchao Zhang       }
2937d52a580bSJunchao Zhang       break;
2938d52a580bSJunchao Zhang     default:
2939d52a580bSJunchao Zhang       break;
2940d52a580bSJunchao Zhang     }
2941d52a580bSJunchao Zhang     if (usecpu) Biscusp = Ciscusp = PETSC_FALSE;
2942d52a580bSJunchao Zhang   }
2943d52a580bSJunchao Zhang   /* dispatch */
2944d52a580bSJunchao Zhang   if (isdense) {
2945d52a580bSJunchao Zhang     switch (product->type) {
2946d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2947d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2948d52a580bSJunchao Zhang     case MATPRODUCT_ABt:
2949d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2950d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2951d52a580bSJunchao Zhang       if (product->A->boundtocpu) PetscCall(MatProductSetFromOptions_SeqAIJ_SeqDense(mat));
2952d52a580bSJunchao Zhang       else mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP;
2953d52a580bSJunchao Zhang       break;
2954d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2955d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2956d52a580bSJunchao Zhang       break;
2957d52a580bSJunchao Zhang     default:
2958d52a580bSJunchao Zhang       break;
2959d52a580bSJunchao Zhang     }
2960d52a580bSJunchao Zhang   } else if (Biscusp && Ciscusp) {
2961d52a580bSJunchao Zhang     switch (product->type) {
2962d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2963d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2964d52a580bSJunchao Zhang     case MATPRODUCT_ABt:
2965d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2966d52a580bSJunchao Zhang       break;
2967d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2968d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2969d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2970d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2971d52a580bSJunchao Zhang       break;
2972d52a580bSJunchao Zhang     default:
2973d52a580bSJunchao Zhang       break;
2974d52a580bSJunchao Zhang     }
2975d52a580bSJunchao Zhang   } else PetscCall(MatProductSetFromOptions_SeqAIJ(mat)); /* fallback for AIJ */
2976d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2977d52a580bSJunchao Zhang }
2978d52a580bSJunchao Zhang 
MatMult_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)2979d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2980d52a580bSJunchao Zhang {
2981d52a580bSJunchao Zhang   PetscFunctionBegin;
2982d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_FALSE, PETSC_FALSE));
2983d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2984d52a580bSJunchao Zhang }
2985d52a580bSJunchao Zhang 
MatMultAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)2986d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
2987d52a580bSJunchao Zhang {
2988d52a580bSJunchao Zhang   PetscFunctionBegin;
2989d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_FALSE, PETSC_FALSE));
2990d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2991d52a580bSJunchao Zhang }
2992d52a580bSJunchao Zhang 
MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)2993d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2994d52a580bSJunchao Zhang {
2995d52a580bSJunchao Zhang   PetscFunctionBegin;
2996d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_TRUE));
2997d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2998d52a580bSJunchao Zhang }
2999d52a580bSJunchao Zhang 
MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)3000d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3001d52a580bSJunchao Zhang {
3002d52a580bSJunchao Zhang   PetscFunctionBegin;
3003d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_TRUE));
3004d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3005d52a580bSJunchao Zhang }
3006d52a580bSJunchao Zhang 
MatMultTranspose_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy)3007d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
3008d52a580bSJunchao Zhang {
3009d52a580bSJunchao Zhang   PetscFunctionBegin;
3010d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_FALSE));
3011d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3012d52a580bSJunchao Zhang }
3013d52a580bSJunchao Zhang 
ScatterAdd(PetscInt n,PetscInt * idx,const PetscScalar * x,PetscScalar * y)3014d52a580bSJunchao Zhang __global__ static void ScatterAdd(PetscInt n, PetscInt *idx, const PetscScalar *x, PetscScalar *y)
3015d52a580bSJunchao Zhang {
3016d52a580bSJunchao Zhang   int i = blockIdx.x * blockDim.x + threadIdx.x;
3017d52a580bSJunchao Zhang   if (i < n) y[idx[i]] += x[i];
3018d52a580bSJunchao Zhang }
3019d52a580bSJunchao Zhang 
3020d52a580bSJunchao Zhang /* z = op(A) x + y. If trans & !herm, op = ^T; if trans & herm, op = ^H; if !trans, op = no-op */
MatMultAddKernel_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz,PetscBool trans,PetscBool herm)3021d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz, PetscBool trans, PetscBool herm)
3022d52a580bSJunchao Zhang {
3023d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a               = (Mat_SeqAIJ *)A->data;
3024d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3025d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct;
3026d52a580bSJunchao Zhang   PetscScalar                   *xarray, *zarray, *dptr, *beta, *xptr;
3027d52a580bSJunchao Zhang   hipsparseOperation_t           opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
3028d52a580bSJunchao Zhang   PetscBool                      compressed;
3029d52a580bSJunchao Zhang   PetscInt                       nx, ny;
3030d52a580bSJunchao Zhang 
3031d52a580bSJunchao Zhang   PetscFunctionBegin;
3032d52a580bSJunchao Zhang   PetscCheck(!herm || trans, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Hermitian and not transpose not supported");
3033d52a580bSJunchao Zhang   if (!a->nz) {
3034d52a580bSJunchao Zhang     if (yy) PetscCall(VecSeq_HIP::Copy(yy, zz));
3035d52a580bSJunchao Zhang     else PetscCall(VecSeq_HIP::Set(zz, 0));
3036d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3037d52a580bSJunchao Zhang   }
3038d52a580bSJunchao Zhang   /* The line below is necessary due to the operations that modify the matrix on the CPU (axpy, scale, etc) */
3039d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3040d52a580bSJunchao Zhang   if (!trans) {
3041d52a580bSJunchao Zhang     matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3042d52a580bSJunchao Zhang     PetscCheck(matstruct, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "SeqAIJHIPSPARSE does not have a 'mat' (need to fix)");
3043d52a580bSJunchao Zhang   } else {
3044d52a580bSJunchao Zhang     if (herm || !A->form_explicit_transpose) {
3045d52a580bSJunchao Zhang       opA       = herm ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE : HIPSPARSE_OPERATION_TRANSPOSE;
3046d52a580bSJunchao Zhang       matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3047d52a580bSJunchao Zhang     } else {
3048d52a580bSJunchao Zhang       if (!hipsparsestruct->matTranspose) PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
3049d52a580bSJunchao Zhang       matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
3050d52a580bSJunchao Zhang     }
3051d52a580bSJunchao Zhang   }
3052d52a580bSJunchao Zhang   /* Does the matrix use compressed rows (i.e., drop zero rows)? */
3053d52a580bSJunchao Zhang   compressed = matstruct->cprowIndices ? PETSC_TRUE : PETSC_FALSE;
3054d52a580bSJunchao Zhang   try {
3055d52a580bSJunchao Zhang     PetscCall(VecHIPGetArrayRead(xx, (const PetscScalar **)&xarray));
3056d52a580bSJunchao Zhang     if (yy == zz) PetscCall(VecHIPGetArray(zz, &zarray)); /* read & write zz, so need to get up-to-date zarray on GPU */
3057d52a580bSJunchao Zhang     else PetscCall(VecHIPGetArrayWrite(zz, &zarray));     /* write zz, so no need to init zarray on GPU */
3058d52a580bSJunchao Zhang 
3059d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3060d52a580bSJunchao Zhang     if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3061d52a580bSJunchao Zhang       /* z = A x + beta y.
3062d52a580bSJunchao Zhang          If A is compressed (with less rows), then Ax is shorter than the full z, so we need a work vector to store Ax.
3063d52a580bSJunchao Zhang          When A is non-compressed, and z = y, we can set beta=1 to compute y = Ax + y in one call.
3064d52a580bSJunchao Zhang       */
3065d52a580bSJunchao Zhang       xptr = xarray;
3066d52a580bSJunchao Zhang       dptr = compressed ? hipsparsestruct->workVector->data().get() : zarray;
3067d52a580bSJunchao Zhang       beta = (yy == zz && !compressed) ? matstruct->beta_one : matstruct->beta_zero;
3068d52a580bSJunchao Zhang       /* Get length of x, y for y=Ax. ny might be shorter than the work vector's allocated length, since the work vector is
3069d52a580bSJunchao Zhang           allocated to accommodate different uses. So we get the length info directly from mat.
3070d52a580bSJunchao Zhang        */
3071d52a580bSJunchao Zhang       if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3072d52a580bSJunchao Zhang         CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3073d52a580bSJunchao Zhang         nx             = mat->num_cols;
3074d52a580bSJunchao Zhang         ny             = mat->num_rows;
3075d52a580bSJunchao Zhang       }
3076d52a580bSJunchao Zhang     } else {
3077d52a580bSJunchao Zhang       /* z = A^T x + beta y
3078d52a580bSJunchao Zhang          If A is compressed, then we need a work vector as the shorter version of x to compute A^T x.
3079d52a580bSJunchao Zhang          Note A^Tx is of full length, so we set beta to 1.0 if y exists.
3080d52a580bSJunchao Zhang        */
3081d52a580bSJunchao Zhang       xptr = compressed ? hipsparsestruct->workVector->data().get() : xarray;
3082d52a580bSJunchao Zhang       dptr = zarray;
3083d52a580bSJunchao Zhang       beta = yy ? matstruct->beta_one : matstruct->beta_zero;
3084d52a580bSJunchao Zhang       if (compressed) { /* Scatter x to work vector */
3085d52a580bSJunchao Zhang         thrust::device_ptr<PetscScalar> xarr = thrust::device_pointer_cast(xarray);
3086d52a580bSJunchao Zhang         thrust::for_each(
3087d52a580bSJunchao Zhang #if PetscDefined(HAVE_THRUST_ASYNC)
3088d52a580bSJunchao Zhang           thrust::hip::par.on(PetscDefaultHipStream),
3089d52a580bSJunchao Zhang #endif
3090d52a580bSJunchao Zhang           thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))),
3091d52a580bSJunchao Zhang           thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(), VecHIPEqualsReverse());
3092d52a580bSJunchao Zhang       }
3093d52a580bSJunchao Zhang       if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3094d52a580bSJunchao Zhang         CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3095d52a580bSJunchao Zhang         nx             = mat->num_rows;
3096d52a580bSJunchao Zhang         ny             = mat->num_cols;
3097d52a580bSJunchao Zhang       }
3098d52a580bSJunchao Zhang     }
3099d52a580bSJunchao Zhang     /* csr_spmv does y = alpha op(A) x + beta y */
3100d52a580bSJunchao Zhang     if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3101*5a884c48SSatish Balay #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) && !PETSC_PKG_HIP_VERSION_EQ(7, 2, 0)
3102d52a580bSJunchao Zhang       PetscCheck(opA >= 0 && opA <= 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE API on hipsparseOperation_t has changed and PETSc has not been updated accordingly");
3103d52a580bSJunchao Zhang       if (!matstruct->hipSpMV[opA].initialized) { /* built on demand */
3104d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecXDescr, nx, xptr, hipsparse_scalartype));
3105d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecYDescr, ny, dptr, hipsparse_scalartype));
3106d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSpMV_bufferSize(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg,
3107d52a580bSJunchao Zhang                                                     &matstruct->hipSpMV[opA].spmvBufferSize));
3108d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&matstruct->hipSpMV[opA].spmvBuffer, matstruct->hipSpMV[opA].spmvBufferSize));
3109d52a580bSJunchao Zhang         matstruct->hipSpMV[opA].initialized = PETSC_TRUE;
3110d52a580bSJunchao Zhang       } else {
3111d52a580bSJunchao Zhang         /* x, y's value pointers might change between calls, but their shape is kept, so we just update pointers */
3112d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecXDescr, xptr));
3113d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecYDescr, dptr));
3114d52a580bSJunchao Zhang       }
3115d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpMV(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, /* built in MatSeqAIJHIPSPARSECopyToGPU() or MatSeqAIJHIPSPARSEFormExplicitTranspose() */
3116d52a580bSJunchao Zhang                                        matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg, matstruct->hipSpMV[opA].spmvBuffer));
3117d52a580bSJunchao Zhang #else
3118d52a580bSJunchao Zhang       CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3119*5a884c48SSatish Balay       nx             = mat->num_rows; /* nx,ny are set before the #if block, set them again to avoid set-but-not-used warning */
3120*5a884c48SSatish Balay       ny             = mat->num_cols;
3121*5a884c48SSatish Balay       PetscCallHIPSPARSE(hipsparse_csr_spmv(hipsparsestruct->handle, opA, nx, ny, mat->num_entries, matstruct->alpha_one, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), xptr, beta, dptr));
3122d52a580bSJunchao Zhang #endif
3123d52a580bSJunchao Zhang     } else {
3124d52a580bSJunchao Zhang       if (hipsparsestruct->nrows) {
3125d52a580bSJunchao Zhang         hipsparseHybMat_t hybMat = (hipsparseHybMat_t)matstruct->mat;
3126d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparse_hyb_spmv(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->descr, hybMat, xptr, beta, dptr));
3127d52a580bSJunchao Zhang       }
3128d52a580bSJunchao Zhang     }
3129d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3130d52a580bSJunchao Zhang 
3131d52a580bSJunchao Zhang     if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3132d52a580bSJunchao Zhang       if (yy) {                                     /* MatMultAdd: zz = A*xx + yy */
3133d52a580bSJunchao Zhang         if (compressed) {                           /* A is compressed. We first copy yy to zz, then ScatterAdd the work vector to zz */
3134d52a580bSJunchao Zhang           PetscCall(VecSeq_HIP::Copy(yy, zz));      /* zz = yy */
3135d52a580bSJunchao Zhang         } else if (zz != yy) {                      /* A is not compressed. zz already contains A*xx, and we just need to add yy */
3136d52a580bSJunchao Zhang           PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3137d52a580bSJunchao Zhang         }
3138d52a580bSJunchao Zhang       } else if (compressed) { /* MatMult: zz = A*xx. A is compressed, so we zero zz first, then ScatterAdd the work vector to zz */
3139d52a580bSJunchao Zhang         PetscCall(VecSeq_HIP::Set(zz, 0));
3140d52a580bSJunchao Zhang       }
3141d52a580bSJunchao Zhang 
3142d52a580bSJunchao Zhang       /* ScatterAdd the result from work vector into the full vector when A is compressed */
3143d52a580bSJunchao Zhang       if (compressed) {
3144d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeBegin());
3145d52a580bSJunchao Zhang         /* I wanted to make this for_each asynchronous but failed. thrust::async::for_each() returns an event (internally registered)
3146d52a580bSJunchao Zhang            and in the destructor of the scope, it will call hipStreamSynchronize() on this stream. One has to store all events to
3147d52a580bSJunchao Zhang            prevent that. So I just add a ScatterAdd kernel.
3148d52a580bSJunchao Zhang          */
3149d52a580bSJunchao Zhang #if 0
3150d52a580bSJunchao Zhang         thrust::device_ptr<PetscScalar> zptr = thrust::device_pointer_cast(zarray);
3151d52a580bSJunchao Zhang         thrust::async::for_each(thrust::hip::par.on(hipsparsestruct->stream),
3152d52a580bSJunchao Zhang                          thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))),
3153d52a580bSJunchao Zhang                          thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(),
3154d52a580bSJunchao Zhang                          VecHIPPlusEquals());
3155d52a580bSJunchao Zhang #else
3156d52a580bSJunchao Zhang         PetscInt n = matstruct->cprowIndices->size();
3157d52a580bSJunchao Zhang         hipLaunchKernelGGL(ScatterAdd, dim3((n + 255) / 256), dim3(256), 0, PetscDefaultHipStream, n, matstruct->cprowIndices->data().get(), hipsparsestruct->workVector->data().get(), zarray);
3158d52a580bSJunchao Zhang #endif
3159d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeEnd());
3160d52a580bSJunchao Zhang       }
3161d52a580bSJunchao Zhang     } else {
3162d52a580bSJunchao Zhang       if (yy && yy != zz) PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3163d52a580bSJunchao Zhang     }
3164d52a580bSJunchao Zhang     PetscCall(VecHIPRestoreArrayRead(xx, (const PetscScalar **)&xarray));
3165d52a580bSJunchao Zhang     if (yy == zz) PetscCall(VecHIPRestoreArray(zz, &zarray));
3166d52a580bSJunchao Zhang     else PetscCall(VecHIPRestoreArrayWrite(zz, &zarray));
3167d52a580bSJunchao Zhang   } catch (char *ex) {
3168d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
3169d52a580bSJunchao Zhang   }
3170d52a580bSJunchao Zhang   if (yy) PetscCall(PetscLogGpuFlops(2.0 * a->nz));
3171d52a580bSJunchao Zhang   else PetscCall(PetscLogGpuFlops(2.0 * a->nz - a->nonzerorowcnt));
3172d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3173d52a580bSJunchao Zhang }
3174d52a580bSJunchao Zhang 
MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)3175d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3176d52a580bSJunchao Zhang {
3177d52a580bSJunchao Zhang   PetscFunctionBegin;
3178d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_FALSE));
3179d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3180d52a580bSJunchao Zhang }
3181d52a580bSJunchao Zhang 
MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A,MatAssemblyType mode)3182d52a580bSJunchao Zhang static PetscErrorCode MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A, MatAssemblyType mode)
3183d52a580bSJunchao Zhang {
3184d52a580bSJunchao Zhang   PetscFunctionBegin;
3185d52a580bSJunchao Zhang   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
3186d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3187d52a580bSJunchao Zhang }
3188d52a580bSJunchao Zhang 
3189d52a580bSJunchao Zhang /*@
3190d52a580bSJunchao Zhang   MatCreateSeqAIJHIPSPARSE - Creates a sparse matrix in `MATAIJHIPSPARSE` (compressed row) format.
3191d52a580bSJunchao Zhang   This matrix will ultimately pushed down to AMD GPUs and use the HIPSPARSE library for calculations.
3192d52a580bSJunchao Zhang 
3193d52a580bSJunchao Zhang   Collective
3194d52a580bSJunchao Zhang 
3195d52a580bSJunchao Zhang   Input Parameters:
3196d52a580bSJunchao Zhang + comm - MPI communicator, set to `PETSC_COMM_SELF`
3197d52a580bSJunchao Zhang . m    - number of rows
3198d52a580bSJunchao Zhang . n    - number of columns
3199d52a580bSJunchao Zhang . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is set
3200d52a580bSJunchao Zhang - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
3201d52a580bSJunchao Zhang 
3202d52a580bSJunchao Zhang   Output Parameter:
3203d52a580bSJunchao Zhang . A - the matrix
3204d52a580bSJunchao Zhang 
3205d52a580bSJunchao Zhang   Level: intermediate
3206d52a580bSJunchao Zhang 
3207d52a580bSJunchao Zhang   Notes:
3208d52a580bSJunchao Zhang   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
3209d52a580bSJunchao Zhang   `MatXXXXSetPreallocation()` paradgm instead of this routine directly.
3210d52a580bSJunchao Zhang   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation`]
3211d52a580bSJunchao Zhang 
3212d52a580bSJunchao Zhang   The AIJ format (compressed row storage), is fully compatible with standard Fortran
3213d52a580bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
3214d52a580bSJunchao Zhang   either one (as in Fortran) or zero.
3215d52a580bSJunchao Zhang 
3216d52a580bSJunchao Zhang   Specify the preallocated storage with either `nz` or `nnz` (not both).
3217d52a580bSJunchao Zhang   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
3218d52a580bSJunchao Zhang   allocation.
3219d52a580bSJunchao Zhang 
3220d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATSEQAIJHIPSPARSE`, `MATAIJHIPSPARSE`
3221d52a580bSJunchao Zhang @*/
MatCreateSeqAIJHIPSPARSE(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt nz,const PetscInt nnz[],Mat * A)3222d52a580bSJunchao Zhang PetscErrorCode MatCreateSeqAIJHIPSPARSE(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
3223d52a580bSJunchao Zhang {
3224d52a580bSJunchao Zhang   PetscFunctionBegin;
3225d52a580bSJunchao Zhang   PetscCall(MatCreate(comm, A));
3226d52a580bSJunchao Zhang   PetscCall(MatSetSizes(*A, m, n, m, n));
3227d52a580bSJunchao Zhang   PetscCall(MatSetType(*A, MATSEQAIJHIPSPARSE));
3228d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
3229d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3230d52a580bSJunchao Zhang }
3231d52a580bSJunchao Zhang 
MatDestroy_SeqAIJHIPSPARSE(Mat A)3232d52a580bSJunchao Zhang static PetscErrorCode MatDestroy_SeqAIJHIPSPARSE(Mat A)
3233d52a580bSJunchao Zhang {
3234d52a580bSJunchao Zhang   PetscFunctionBegin;
3235d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) PetscCall(MatSeqAIJHIPSPARSE_Destroy(A));
3236d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSETriFactors_Destroy((Mat_SeqAIJHIPSPARSETriFactors **)&A->spptr));
3237d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3238d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", NULL));
3239d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetUseCPUSolve_C", NULL));
3240d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3241d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3242d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3243d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
3244d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3245d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3246d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijhipsparse_hypre_C", NULL));
3247d52a580bSJunchao Zhang   PetscCall(MatDestroy_SeqAIJ(A));
3248d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3249d52a580bSJunchao Zhang }
3250d52a580bSJunchao Zhang 
MatDuplicate_SeqAIJHIPSPARSE(Mat A,MatDuplicateOption cpvalues,Mat * B)3251d52a580bSJunchao Zhang static PetscErrorCode MatDuplicate_SeqAIJHIPSPARSE(Mat A, MatDuplicateOption cpvalues, Mat *B)
3252d52a580bSJunchao Zhang {
3253d52a580bSJunchao Zhang   PetscFunctionBegin;
3254d52a580bSJunchao Zhang   PetscCall(MatDuplicate_SeqAIJ(A, cpvalues, B));
3255d52a580bSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(*B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, B));
3256d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3257d52a580bSJunchao Zhang }
3258d52a580bSJunchao Zhang 
MatAXPY_SeqAIJHIPSPARSE(Mat Y,PetscScalar a,Mat X,MatStructure str)3259d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat Y, PetscScalar a, Mat X, MatStructure str)
3260d52a580bSJunchao Zhang {
3261d52a580bSJunchao Zhang   Mat_SeqAIJ          *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
3262d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cy;
3263d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cx;
3264d52a580bSJunchao Zhang   PetscScalar         *ay;
3265d52a580bSJunchao Zhang   const PetscScalar   *ax;
3266d52a580bSJunchao Zhang   CsrMatrix           *csry, *csrx;
3267d52a580bSJunchao Zhang 
3268d52a580bSJunchao Zhang   PetscFunctionBegin;
3269d52a580bSJunchao Zhang   cy = (Mat_SeqAIJHIPSPARSE *)Y->spptr;
3270d52a580bSJunchao Zhang   cx = (Mat_SeqAIJHIPSPARSE *)X->spptr;
3271d52a580bSJunchao Zhang   if (X->ops->axpy != Y->ops->axpy) {
3272d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3273d52a580bSJunchao Zhang     PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3274d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3275d52a580bSJunchao Zhang   }
3276d52a580bSJunchao Zhang   /* if we are here, it means both matrices are bound to GPU */
3277d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(Y));
3278d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(X));
3279d52a580bSJunchao Zhang   PetscCheck(cy->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)Y), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3280d52a580bSJunchao Zhang   PetscCheck(cx->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)X), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3281d52a580bSJunchao Zhang   csry = (CsrMatrix *)cy->mat->mat;
3282d52a580bSJunchao Zhang   csrx = (CsrMatrix *)cx->mat->mat;
3283d52a580bSJunchao Zhang   /* see if we can turn this into a hipblas axpy */
3284d52a580bSJunchao Zhang   if (str != SAME_NONZERO_PATTERN && x->nz == y->nz && !x->compressedrow.use && !y->compressedrow.use) {
3285d52a580bSJunchao Zhang     bool eq = thrust::equal(thrust::device, csry->row_offsets->begin(), csry->row_offsets->end(), csrx->row_offsets->begin());
3286d52a580bSJunchao Zhang     if (eq) eq = thrust::equal(thrust::device, csry->column_indices->begin(), csry->column_indices->end(), csrx->column_indices->begin());
3287d52a580bSJunchao Zhang     if (eq) str = SAME_NONZERO_PATTERN;
3288d52a580bSJunchao Zhang   }
3289d52a580bSJunchao Zhang   /* spgeam is buggy with one column */
3290d52a580bSJunchao Zhang   if (Y->cmap->n == 1 && str != SAME_NONZERO_PATTERN) str = DIFFERENT_NONZERO_PATTERN;
3291d52a580bSJunchao Zhang   if (str == SUBSET_NONZERO_PATTERN) {
3292d52a580bSJunchao Zhang     PetscScalar b = 1.0;
3293d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3294d52a580bSJunchao Zhang     size_t bufferSize;
3295d52a580bSJunchao Zhang     void  *buffer;
3296d52a580bSJunchao Zhang #endif
3297d52a580bSJunchao Zhang 
3298d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3299d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3300d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_HOST));
3301d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3302d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam_bufferSize(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3303d52a580bSJunchao Zhang                                                        csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), &bufferSize));
3304d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc(&buffer, bufferSize));
3305d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3306d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3307d52a580bSJunchao Zhang                                             csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), buffer));
3308d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3309d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3310d52a580bSJunchao Zhang     PetscCallHIP(hipFree(buffer));
3311d52a580bSJunchao Zhang #else
3312d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3313d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3314d52a580bSJunchao Zhang                                             csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get()));
3315d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3316d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3317d52a580bSJunchao Zhang #endif
3318d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_DEVICE));
3319d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3320d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3321d52a580bSJunchao Zhang   } else if (str == SAME_NONZERO_PATTERN) {
3322d52a580bSJunchao Zhang     hipblasHandle_t hipblasv2handle;
3323d52a580bSJunchao Zhang     PetscBLASInt    one = 1, bnz = 1;
3324d52a580bSJunchao Zhang 
3325d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3326d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3327d52a580bSJunchao Zhang     PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3328d52a580bSJunchao Zhang     PetscCall(PetscBLASIntCast(x->nz, &bnz));
3329d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3330d52a580bSJunchao Zhang     PetscCallHIPBLAS(hipblasXaxpy(hipblasv2handle, bnz, &a, ax, one, ay, one));
3331d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(2.0 * bnz));
3332d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3333d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3334d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3335d52a580bSJunchao Zhang   } else {
3336d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3337d52a580bSJunchao Zhang     PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3338d52a580bSJunchao Zhang   }
3339d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3340d52a580bSJunchao Zhang }
3341d52a580bSJunchao Zhang 
MatScale_SeqAIJHIPSPARSE(Mat Y,PetscScalar a)3342d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat Y, PetscScalar a)
3343d52a580bSJunchao Zhang {
3344d52a580bSJunchao Zhang   Mat_SeqAIJ     *y = (Mat_SeqAIJ *)Y->data;
3345d52a580bSJunchao Zhang   PetscScalar    *ay;
3346d52a580bSJunchao Zhang   hipblasHandle_t hipblasv2handle;
3347d52a580bSJunchao Zhang   PetscBLASInt    one = 1, bnz = 1;
3348d52a580bSJunchao Zhang 
3349d52a580bSJunchao Zhang   PetscFunctionBegin;
3350d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3351d52a580bSJunchao Zhang   PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3352d52a580bSJunchao Zhang   PetscCall(PetscBLASIntCast(y->nz, &bnz));
3353d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
3354d52a580bSJunchao Zhang   PetscCallHIPBLAS(hipblasXscal(hipblasv2handle, bnz, &a, ay, one));
3355d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(bnz));
3356d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
3357d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3358d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3359d52a580bSJunchao Zhang }
3360d52a580bSJunchao Zhang 
MatZeroEntries_SeqAIJHIPSPARSE(Mat A)3361d52a580bSJunchao Zhang static PetscErrorCode MatZeroEntries_SeqAIJHIPSPARSE(Mat A)
3362d52a580bSJunchao Zhang {
3363d52a580bSJunchao Zhang   PetscBool   both = PETSC_FALSE;
3364d52a580bSJunchao Zhang   Mat_SeqAIJ *a    = (Mat_SeqAIJ *)A->data;
3365d52a580bSJunchao Zhang 
3366d52a580bSJunchao Zhang   PetscFunctionBegin;
3367d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
3368d52a580bSJunchao Zhang     Mat_SeqAIJHIPSPARSE *spptr = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3369d52a580bSJunchao Zhang     if (spptr->mat) {
3370d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)spptr->mat->mat;
3371d52a580bSJunchao Zhang       if (matrix->values) {
3372d52a580bSJunchao Zhang         both = PETSC_TRUE;
3373d52a580bSJunchao Zhang         thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3374d52a580bSJunchao Zhang       }
3375d52a580bSJunchao Zhang     }
3376d52a580bSJunchao Zhang     if (spptr->matTranspose) {
3377d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)spptr->matTranspose->mat;
3378d52a580bSJunchao Zhang       if (matrix->values) thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3379d52a580bSJunchao Zhang     }
3380d52a580bSJunchao Zhang   }
3381d52a580bSJunchao Zhang   //PetscCall(MatZeroEntries_SeqAIJ(A));
3382d52a580bSJunchao Zhang   PetscCall(PetscArrayzero(a->a, a->i[A->rmap->n]));
3383d52a580bSJunchao Zhang   if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
3384d52a580bSJunchao Zhang   else A->offloadmask = PETSC_OFFLOAD_CPU;
3385d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3386d52a580bSJunchao Zhang }
3387d52a580bSJunchao Zhang 
MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A,PetscMemType * m)3388d52a580bSJunchao Zhang static PetscErrorCode MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A, PetscMemType *m)
3389d52a580bSJunchao Zhang {
3390d52a580bSJunchao Zhang   PetscFunctionBegin;
3391d52a580bSJunchao Zhang   *m = PETSC_MEMTYPE_HIP;
3392d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3393d52a580bSJunchao Zhang }
3394d52a580bSJunchao Zhang 
MatBindToCPU_SeqAIJHIPSPARSE(Mat A,PetscBool flg)3395d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat A, PetscBool flg)
3396d52a580bSJunchao Zhang {
3397d52a580bSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3398d52a580bSJunchao Zhang 
3399d52a580bSJunchao Zhang   PetscFunctionBegin;
3400d52a580bSJunchao Zhang   if (A->factortype != MAT_FACTOR_NONE) {
3401d52a580bSJunchao Zhang     A->boundtocpu = flg;
3402d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3403d52a580bSJunchao Zhang   }
3404d52a580bSJunchao Zhang   if (flg) {
3405d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
3406d52a580bSJunchao Zhang 
3407d52a580bSJunchao Zhang     A->ops->scale                     = MatScale_SeqAIJ;
3408d52a580bSJunchao Zhang     A->ops->axpy                      = MatAXPY_SeqAIJ;
3409d52a580bSJunchao Zhang     A->ops->zeroentries               = MatZeroEntries_SeqAIJ;
3410d52a580bSJunchao Zhang     A->ops->mult                      = MatMult_SeqAIJ;
3411d52a580bSJunchao Zhang     A->ops->multadd                   = MatMultAdd_SeqAIJ;
3412d52a580bSJunchao Zhang     A->ops->multtranspose             = MatMultTranspose_SeqAIJ;
3413d52a580bSJunchao Zhang     A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJ;
3414d52a580bSJunchao Zhang     A->ops->multhermitiantranspose    = NULL;
3415d52a580bSJunchao Zhang     A->ops->multhermitiantransposeadd = NULL;
3416d52a580bSJunchao Zhang     A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJ;
3417d52a580bSJunchao Zhang     A->ops->getcurrentmemtype         = NULL;
3418d52a580bSJunchao Zhang     PetscCall(PetscMemzero(a->ops, sizeof(Mat_SeqAIJOps)));
3419d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3420d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3421d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3422d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3423d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3424d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3425d52a580bSJunchao Zhang   } else {
3426d52a580bSJunchao Zhang     A->ops->scale                     = MatScale_SeqAIJHIPSPARSE;
3427d52a580bSJunchao Zhang     A->ops->axpy                      = MatAXPY_SeqAIJHIPSPARSE;
3428d52a580bSJunchao Zhang     A->ops->zeroentries               = MatZeroEntries_SeqAIJHIPSPARSE;
3429d52a580bSJunchao Zhang     A->ops->mult                      = MatMult_SeqAIJHIPSPARSE;
3430d52a580bSJunchao Zhang     A->ops->multadd                   = MatMultAdd_SeqAIJHIPSPARSE;
3431d52a580bSJunchao Zhang     A->ops->multtranspose             = MatMultTranspose_SeqAIJHIPSPARSE;
3432d52a580bSJunchao Zhang     A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJHIPSPARSE;
3433d52a580bSJunchao Zhang     A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJHIPSPARSE;
3434d52a580bSJunchao Zhang     A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE;
3435d52a580bSJunchao Zhang     A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJHIPSPARSE;
3436d52a580bSJunchao Zhang     A->ops->getcurrentmemtype         = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3437d52a580bSJunchao Zhang     a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJHIPSPARSE;
3438d52a580bSJunchao Zhang     a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJHIPSPARSE;
3439d52a580bSJunchao Zhang     a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE;
3440d52a580bSJunchao Zhang     a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE;
3441d52a580bSJunchao Zhang     a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE;
3442d52a580bSJunchao Zhang     a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE;
3443d52a580bSJunchao Zhang     a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE;
3444d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", MatSeqAIJCopySubArray_SeqAIJHIPSPARSE));
3445d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3446d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3447d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJHIPSPARSE));
3448d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJHIPSPARSE));
3449d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3450d52a580bSJunchao Zhang   }
3451d52a580bSJunchao Zhang   A->boundtocpu = flg;
3452d52a580bSJunchao Zhang   if (flg && a->inode.size_csr) a->inode.use = PETSC_TRUE;
3453d52a580bSJunchao Zhang   else a->inode.use = PETSC_FALSE;
3454d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3455d52a580bSJunchao Zhang }
3456d52a580bSJunchao Zhang 
MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A,MatType mtype,MatReuse reuse,Mat * newmat)3457d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
3458d52a580bSJunchao Zhang {
3459d52a580bSJunchao Zhang   Mat B;
3460d52a580bSJunchao Zhang 
3461d52a580bSJunchao Zhang   PetscFunctionBegin;
3462d52a580bSJunchao Zhang   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP)); /* first use of HIPSPARSE may be via MatConvert */
3463d52a580bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
3464d52a580bSJunchao Zhang     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
3465d52a580bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
3466d52a580bSJunchao Zhang     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
3467d52a580bSJunchao Zhang   }
3468d52a580bSJunchao Zhang   B = *newmat;
3469d52a580bSJunchao Zhang   PetscCall(PetscFree(B->defaultvectype));
3470d52a580bSJunchao Zhang   PetscCall(PetscStrallocpy(VECHIP, &B->defaultvectype));
3471d52a580bSJunchao Zhang   if (reuse != MAT_REUSE_MATRIX && !B->spptr) {
3472d52a580bSJunchao Zhang     if (B->factortype == MAT_FACTOR_NONE) {
3473d52a580bSJunchao Zhang       Mat_SeqAIJHIPSPARSE *spptr;
3474d52a580bSJunchao Zhang       PetscCall(PetscNew(&spptr));
3475d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3476d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3477d52a580bSJunchao Zhang       spptr->format = MAT_HIPSPARSE_CSR;
3478d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3479d52a580bSJunchao Zhang       spptr->spmvAlg = HIPSPARSE_SPMV_CSR_ALG1;
3480d52a580bSJunchao Zhang #else
3481d52a580bSJunchao Zhang       spptr->spmvAlg = HIPSPARSE_CSRMV_ALG1; /* default, since we only support csr */
3482d52a580bSJunchao Zhang #endif
3483d52a580bSJunchao Zhang       spptr->spmmAlg = HIPSPARSE_SPMM_CSR_ALG1; /* default, only support column-major dense matrix B */
3484d52a580bSJunchao Zhang       //spptr->csr2cscAlg = HIPSPARSE_CSR2CSC_ALG1;
3485d52a580bSJunchao Zhang 
3486d52a580bSJunchao Zhang       B->spptr = spptr;
3487d52a580bSJunchao Zhang     } else {
3488d52a580bSJunchao Zhang       Mat_SeqAIJHIPSPARSETriFactors *spptr;
3489d52a580bSJunchao Zhang 
3490d52a580bSJunchao Zhang       PetscCall(PetscNew(&spptr));
3491d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3492d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3493d52a580bSJunchao Zhang       B->spptr = spptr;
3494d52a580bSJunchao Zhang     }
3495d52a580bSJunchao Zhang     B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
3496d52a580bSJunchao Zhang   }
3497d52a580bSJunchao Zhang   B->ops->assemblyend       = MatAssemblyEnd_SeqAIJHIPSPARSE;
3498d52a580bSJunchao Zhang   B->ops->destroy           = MatDestroy_SeqAIJHIPSPARSE;
3499d52a580bSJunchao Zhang   B->ops->setoption         = MatSetOption_SeqAIJHIPSPARSE;
3500d52a580bSJunchao Zhang   B->ops->setfromoptions    = MatSetFromOptions_SeqAIJHIPSPARSE;
3501d52a580bSJunchao Zhang   B->ops->bindtocpu         = MatBindToCPU_SeqAIJHIPSPARSE;
3502d52a580bSJunchao Zhang   B->ops->duplicate         = MatDuplicate_SeqAIJHIPSPARSE;
3503d52a580bSJunchao Zhang   B->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3504d52a580bSJunchao Zhang 
3505d52a580bSJunchao Zhang   PetscCall(MatBindToCPU_SeqAIJHIPSPARSE(B, PETSC_FALSE));
3506d52a580bSJunchao Zhang   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATSEQAIJHIPSPARSE));
3507d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetFormat_C", MatHIPSPARSESetFormat_SeqAIJHIPSPARSE));
3508d52a580bSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
3509d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_seqaijhipsparse_hypre_C", MatConvert_AIJ_HYPRE));
3510d52a580bSJunchao Zhang #endif
3511d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetUseCPUSolve_C", MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE));
3512d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3513d52a580bSJunchao Zhang }
3514d52a580bSJunchao Zhang 
MatCreate_SeqAIJHIPSPARSE(Mat B)3515d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJHIPSPARSE(Mat B)
3516d52a580bSJunchao Zhang {
3517d52a580bSJunchao Zhang   PetscFunctionBegin;
3518d52a580bSJunchao Zhang   PetscCall(MatCreate_SeqAIJ(B));
3519d52a580bSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, &B));
3520d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3521d52a580bSJunchao Zhang }
3522d52a580bSJunchao Zhang 
3523d52a580bSJunchao Zhang /*MC
3524d52a580bSJunchao Zhang    MATSEQAIJHIPSPARSE - MATAIJHIPSPARSE = "(seq)aijhipsparse" - A matrix type to be used for sparse matrices on AMD GPUs
3525d52a580bSJunchao Zhang 
3526d52a580bSJunchao Zhang    A matrix type whose data resides on AMD GPUs. These matrices can be in either
3527d52a580bSJunchao Zhang    CSR, ELL, or Hybrid format.
3528d52a580bSJunchao Zhang    All matrix calculations are performed on AMD/NVIDIA GPUs using the HIPSPARSE library.
3529d52a580bSJunchao Zhang 
3530d52a580bSJunchao Zhang    Options Database Keys:
3531d52a580bSJunchao Zhang +  -mat_type aijhipsparse - sets the matrix type to `MATSEQAIJHIPSPARSE`
3532d52a580bSJunchao Zhang .  -mat_hipsparse_storage_format csr - sets the storage format of matrices (for `MatMult()` and factors in `MatSolve()`).
3533d52a580bSJunchao Zhang                                        Other options include ell (ellpack) or hyb (hybrid).
3534d52a580bSJunchao Zhang . -mat_hipsparse_mult_storage_format csr - sets the storage format of matrices (for `MatMult()`). Other options include ell (ellpack) or hyb (hybrid).
3535d52a580bSJunchao Zhang -  -mat_hipsparse_use_cpu_solve - Do `MatSolve()` on the CPU
3536d52a580bSJunchao Zhang 
3537d52a580bSJunchao Zhang   Level: beginner
3538d52a580bSJunchao Zhang 
3539d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
3540d52a580bSJunchao Zhang M*/
3541d52a580bSJunchao Zhang 
MatSolverTypeRegister_HIPSPARSE(void)3542d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_HIPSPARSE(void)
3543d52a580bSJunchao Zhang {
3544d52a580bSJunchao Zhang   PetscFunctionBegin;
3545d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_LU, MatGetFactor_seqaijhipsparse_hipsparse));
3546d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_CHOLESKY, MatGetFactor_seqaijhipsparse_hipsparse));
3547d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ILU, MatGetFactor_seqaijhipsparse_hipsparse));
3548d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ICC, MatGetFactor_seqaijhipsparse_hipsparse));
3549d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3550d52a580bSJunchao Zhang }
3551d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSE_Destroy(Mat mat)3552d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat mat)
3553d52a580bSJunchao Zhang {
3554d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(mat->spptr);
3555d52a580bSJunchao Zhang 
3556d52a580bSJunchao Zhang   PetscFunctionBegin;
3557d52a580bSJunchao Zhang   if (cusp) {
3558d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->mat, cusp->format));
3559d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3560d52a580bSJunchao Zhang     delete cusp->workVector;
3561d52a580bSJunchao Zhang     delete cusp->rowoffsets_gpu;
3562d52a580bSJunchao Zhang     delete cusp->csr2csc_i;
3563d52a580bSJunchao Zhang     delete cusp->coords;
3564d52a580bSJunchao Zhang     if (cusp->handle) PetscCallHIPSPARSE(hipsparseDestroy(cusp->handle));
3565d52a580bSJunchao Zhang     PetscCall(PetscFree(mat->spptr));
3566d52a580bSJunchao Zhang   }
3567d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3568d52a580bSJunchao Zhang }
3569d52a580bSJunchao Zhang 
CsrMatrix_Destroy(CsrMatrix ** mat)3570d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **mat)
3571d52a580bSJunchao Zhang {
3572d52a580bSJunchao Zhang   PetscFunctionBegin;
3573d52a580bSJunchao Zhang   if (*mat) {
3574d52a580bSJunchao Zhang     delete (*mat)->values;
3575d52a580bSJunchao Zhang     delete (*mat)->column_indices;
3576d52a580bSJunchao Zhang     delete (*mat)->row_offsets;
3577d52a580bSJunchao Zhang     delete *mat;
3578d52a580bSJunchao Zhang     *mat = 0;
3579d52a580bSJunchao Zhang   }
3580d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3581d52a580bSJunchao Zhang }
3582d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct ** trifactor)3583d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **trifactor)
3584d52a580bSJunchao Zhang {
3585d52a580bSJunchao Zhang   PetscFunctionBegin;
3586d52a580bSJunchao Zhang   if (*trifactor) {
3587d52a580bSJunchao Zhang     if ((*trifactor)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*trifactor)->descr));
3588d52a580bSJunchao Zhang     if ((*trifactor)->solveInfo) PetscCallHIPSPARSE(hipsparseDestroyCsrsvInfo((*trifactor)->solveInfo));
3589d52a580bSJunchao Zhang     PetscCall(CsrMatrix_Destroy(&(*trifactor)->csrMat));
3590d52a580bSJunchao Zhang     if ((*trifactor)->solveBuffer) PetscCallHIP(hipFree((*trifactor)->solveBuffer));
3591d52a580bSJunchao Zhang     if ((*trifactor)->AA_h) PetscCallHIP(hipHostFree((*trifactor)->AA_h));
3592d52a580bSJunchao Zhang     if ((*trifactor)->csr2cscBuffer) PetscCallHIP(hipFree((*trifactor)->csr2cscBuffer));
3593d52a580bSJunchao Zhang     PetscCall(PetscFree(*trifactor));
3594d52a580bSJunchao Zhang   }
3595d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3596d52a580bSJunchao Zhang }
3597d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct ** matstruct,MatHIPSPARSEStorageFormat format)3598d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **matstruct, MatHIPSPARSEStorageFormat format)
3599d52a580bSJunchao Zhang {
3600d52a580bSJunchao Zhang   CsrMatrix *mat;
3601d52a580bSJunchao Zhang 
3602d52a580bSJunchao Zhang   PetscFunctionBegin;
3603d52a580bSJunchao Zhang   if (*matstruct) {
3604d52a580bSJunchao Zhang     if ((*matstruct)->mat) {
3605d52a580bSJunchao Zhang       if (format == MAT_HIPSPARSE_ELL || format == MAT_HIPSPARSE_HYB) {
3606d52a580bSJunchao Zhang         hipsparseHybMat_t hybMat = (hipsparseHybMat_t)(*matstruct)->mat;
3607d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyHybMat(hybMat));
3608d52a580bSJunchao Zhang       } else {
3609d52a580bSJunchao Zhang         mat = (CsrMatrix *)(*matstruct)->mat;
3610d52a580bSJunchao Zhang         PetscCall(CsrMatrix_Destroy(&mat));
3611d52a580bSJunchao Zhang       }
3612d52a580bSJunchao Zhang     }
3613d52a580bSJunchao Zhang     if ((*matstruct)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*matstruct)->descr));
3614d52a580bSJunchao Zhang     delete (*matstruct)->cprowIndices;
3615d52a580bSJunchao Zhang     if ((*matstruct)->alpha_one) PetscCallHIP(hipFree((*matstruct)->alpha_one));
3616d52a580bSJunchao Zhang     if ((*matstruct)->beta_zero) PetscCallHIP(hipFree((*matstruct)->beta_zero));
3617d52a580bSJunchao Zhang     if ((*matstruct)->beta_one) PetscCallHIP(hipFree((*matstruct)->beta_one));
3618d52a580bSJunchao Zhang 
3619d52a580bSJunchao Zhang     Mat_SeqAIJHIPSPARSEMultStruct *mdata = *matstruct;
3620d52a580bSJunchao Zhang     if (mdata->matDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mdata->matDescr));
3621d52a580bSJunchao Zhang     for (int i = 0; i < 3; i++) {
3622d52a580bSJunchao Zhang       if (mdata->hipSpMV[i].initialized) {
3623d52a580bSJunchao Zhang         PetscCallHIP(hipFree(mdata->hipSpMV[i].spmvBuffer));
3624d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecXDescr));
3625d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecYDescr));
3626d52a580bSJunchao Zhang       }
3627d52a580bSJunchao Zhang     }
3628d52a580bSJunchao Zhang     delete *matstruct;
3629d52a580bSJunchao Zhang     *matstruct = NULL;
3630d52a580bSJunchao Zhang   }
3631d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3632d52a580bSJunchao Zhang }
3633d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p * trifactors)3634d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *trifactors)
3635d52a580bSJunchao Zhang {
3636d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs = *trifactors;
3637d52a580bSJunchao Zhang 
3638d52a580bSJunchao Zhang   PetscFunctionBegin;
3639d52a580bSJunchao Zhang   if (fs) {
3640d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtr));
3641d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtr));
3642d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtrTranspose));
3643d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtrTranspose));
3644d52a580bSJunchao Zhang     delete fs->rpermIndices;
3645d52a580bSJunchao Zhang     delete fs->cpermIndices;
3646d52a580bSJunchao Zhang     delete fs->workVector;
3647d52a580bSJunchao Zhang     fs->rpermIndices  = NULL;
3648d52a580bSJunchao Zhang     fs->cpermIndices  = NULL;
3649d52a580bSJunchao Zhang     fs->workVector    = NULL;
3650d52a580bSJunchao Zhang     fs->init_dev_prop = PETSC_FALSE;
3651d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3652d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrRowPtr));
3653d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrColIdx));
3654d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrVal));
3655d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->X));
3656d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->Y));
3657d52a580bSJunchao Zhang     // PetscCallHIP(hipFree(fs->factBuffer_M)); /* No needed since factBuffer_M shares with one of spsvBuffer_L/U */
3658d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_L));
3659d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_U));
3660d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_Lt));
3661d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_Ut));
3662d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyMatDescr(fs->matDescr_M));
3663d52a580bSJunchao Zhang     if (fs->spMatDescr_L) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_L));
3664d52a580bSJunchao Zhang     if (fs->spMatDescr_U) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_U));
3665d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_L));
3666d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Lt));
3667d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_U));
3668d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Ut));
3669d52a580bSJunchao Zhang     if (fs->dnVecDescr_X) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_X));
3670d52a580bSJunchao Zhang     if (fs->dnVecDescr_Y) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_Y));
3671d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyCsrilu02Info(fs->ilu0Info_M));
3672d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyCsric02Info(fs->ic0Info_M));
3673d52a580bSJunchao Zhang 
3674d52a580bSJunchao Zhang     fs->createdTransposeSpSVDescr    = PETSC_FALSE;
3675d52a580bSJunchao Zhang     fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
3676d52a580bSJunchao Zhang #endif
3677d52a580bSJunchao Zhang   }
3678d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3679d52a580bSJunchao Zhang }
3680d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors ** trifactors)3681d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **trifactors)
3682d52a580bSJunchao Zhang {
3683d52a580bSJunchao Zhang   hipsparseHandle_t handle;
3684d52a580bSJunchao Zhang 
3685d52a580bSJunchao Zhang   PetscFunctionBegin;
3686d52a580bSJunchao Zhang   if (*trifactors) {
3687d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(trifactors));
3688d52a580bSJunchao Zhang     if ((handle = (*trifactors)->handle)) PetscCallHIPSPARSE(hipsparseDestroy(handle));
3689d52a580bSJunchao Zhang     PetscCall(PetscFree(*trifactors));
3690d52a580bSJunchao Zhang   }
3691d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3692d52a580bSJunchao Zhang }
3693d52a580bSJunchao Zhang 
3694d52a580bSJunchao Zhang struct IJCompare {
operator ()IJCompare3695d52a580bSJunchao Zhang   __host__ __device__ inline bool operator()(const thrust::tuple<PetscInt, PetscInt> &t1, const thrust::tuple<PetscInt, PetscInt> &t2)
3696d52a580bSJunchao Zhang   {
3697d52a580bSJunchao Zhang     if (t1.get<0>() < t2.get<0>()) return true;
3698d52a580bSJunchao Zhang     if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
3699d52a580bSJunchao Zhang     return false;
3700d52a580bSJunchao Zhang   }
3701d52a580bSJunchao Zhang };
3702d52a580bSJunchao Zhang 
MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A,PetscBool destroy)3703d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A, PetscBool destroy)
3704d52a580bSJunchao Zhang {
3705d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3706d52a580bSJunchao Zhang 
3707d52a580bSJunchao Zhang   PetscFunctionBegin;
3708d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3709d52a580bSJunchao Zhang   if (!cusp) PetscFunctionReturn(PETSC_SUCCESS);
3710d52a580bSJunchao Zhang   if (destroy) {
3711d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3712d52a580bSJunchao Zhang     delete cusp->csr2csc_i;
3713d52a580bSJunchao Zhang     cusp->csr2csc_i = NULL;
3714d52a580bSJunchao Zhang   }
3715d52a580bSJunchao Zhang   A->transupdated = PETSC_FALSE;
3716d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3717d52a580bSJunchao Zhang }
3718d52a580bSJunchao Zhang 
MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data)3719d52a580bSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data)
3720d52a580bSJunchao Zhang {
3721d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo = *(MatCOOStruct_SeqAIJ **)data;
3722d52a580bSJunchao Zhang 
3723d52a580bSJunchao Zhang   PetscFunctionBegin;
3724d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->perm));
3725d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->jmap));
3726d52a580bSJunchao Zhang   PetscCall(PetscFree(coo));
3727d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3728d52a580bSJunchao Zhang }
3729d52a580bSJunchao Zhang 
MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat,PetscCount coo_n,PetscInt coo_i[],PetscInt coo_j[])3730d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
3731d52a580bSJunchao Zhang {
3732d52a580bSJunchao Zhang   PetscBool            dev_ij = PETSC_FALSE;
3733d52a580bSJunchao Zhang   PetscMemType         mtype  = PETSC_MEMTYPE_HOST;
3734d52a580bSJunchao Zhang   PetscInt            *i, *j;
3735d52a580bSJunchao Zhang   PetscContainer       container_h;
3736d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo_h, *coo_d;
3737d52a580bSJunchao Zhang 
3738d52a580bSJunchao Zhang   PetscFunctionBegin;
3739d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(coo_i, &mtype));
3740d52a580bSJunchao Zhang   if (PetscMemTypeDevice(mtype)) {
3741d52a580bSJunchao Zhang     dev_ij = PETSC_TRUE;
3742d52a580bSJunchao Zhang     PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
3743d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(i, coo_i, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3744d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(j, coo_j, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3745d52a580bSJunchao Zhang   } else {
3746d52a580bSJunchao Zhang     i = coo_i;
3747d52a580bSJunchao Zhang     j = coo_j;
3748d52a580bSJunchao Zhang   }
3749d52a580bSJunchao Zhang   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, i, j));
3750d52a580bSJunchao Zhang   if (dev_ij) PetscCall(PetscFree2(i, j));
3751d52a580bSJunchao Zhang   mat->offloadmask = PETSC_OFFLOAD_CPU;
3752d52a580bSJunchao Zhang   // Create the GPU memory
3753d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mat));
3754d52a580bSJunchao Zhang 
3755d52a580bSJunchao Zhang   // Copy the COO struct to device
3756d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
3757d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, &coo_h));
3758d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(1, &coo_d));
3759d52a580bSJunchao Zhang   *coo_d = *coo_h; // do a shallow copy and then amend some fields that need to be different
3760d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->jmap, (coo_h->nz + 1) * sizeof(PetscCount)));
3761d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->jmap, coo_h->jmap, (coo_h->nz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
3762d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->perm, coo_h->Atot * sizeof(PetscCount)));
3763d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->perm, coo_h->perm, coo_h->Atot * sizeof(PetscCount), hipMemcpyHostToDevice));
3764d52a580bSJunchao Zhang 
3765d52a580bSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
3766d52a580bSJunchao Zhang   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJHIPSPARSE));
3767d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3768d52a580bSJunchao Zhang }
3769d52a580bSJunchao Zhang 
MatAddCOOValues(const PetscScalar kv[],PetscCount nnz,const PetscCount jmap[],const PetscCount perm[],InsertMode imode,PetscScalar a[])3770d52a580bSJunchao Zhang __global__ static void MatAddCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount jmap[], const PetscCount perm[], InsertMode imode, PetscScalar a[])
3771d52a580bSJunchao Zhang {
3772d52a580bSJunchao Zhang   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
3773d52a580bSJunchao Zhang   const PetscCount grid_size = gridDim.x * blockDim.x;
3774d52a580bSJunchao Zhang   for (; i < nnz; i += grid_size) {
3775d52a580bSJunchao Zhang     PetscScalar sum = 0.0;
3776d52a580bSJunchao Zhang     for (PetscCount k = jmap[i]; k < jmap[i + 1]; k++) sum += kv[perm[k]];
3777d52a580bSJunchao Zhang     a[i] = (imode == INSERT_VALUES ? 0.0 : a[i]) + sum;
3778d52a580bSJunchao Zhang   }
3779d52a580bSJunchao Zhang }
3780d52a580bSJunchao Zhang 
MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A,const PetscScalar v[],InsertMode imode)3781d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A, const PetscScalar v[], InsertMode imode)
3782d52a580bSJunchao Zhang {
3783d52a580bSJunchao Zhang   Mat_SeqAIJ          *seq  = (Mat_SeqAIJ *)A->data;
3784d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *dev  = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3785d52a580bSJunchao Zhang   PetscCount           Annz = seq->nz;
3786d52a580bSJunchao Zhang   PetscMemType         memtype;
3787d52a580bSJunchao Zhang   const PetscScalar   *v1 = v;
3788d52a580bSJunchao Zhang   PetscScalar         *Aa;
3789d52a580bSJunchao Zhang   PetscContainer       container;
3790d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo;
3791d52a580bSJunchao Zhang 
3792d52a580bSJunchao Zhang   PetscFunctionBegin;
3793d52a580bSJunchao Zhang   if (!dev->mat) PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3794d52a580bSJunchao Zhang 
3795d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
3796d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, &coo));
3797d52a580bSJunchao Zhang 
3798d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(v, &memtype));
3799d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
3800d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
3801d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), hipMemcpyHostToDevice));
3802d52a580bSJunchao Zhang   }
3803d52a580bSJunchao Zhang 
3804d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(A, &Aa));
3805d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSEGetArray(A, &Aa));
3806d52a580bSJunchao Zhang 
3807d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
3808d52a580bSJunchao Zhang   if (Annz) {
3809d52a580bSJunchao Zhang     hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddCOOValues), dim3((Annz + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, Annz, coo->jmap, coo->perm, imode, Aa);
3810d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError());
3811d52a580bSJunchao Zhang   }
3812d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
3813d52a580bSJunchao Zhang 
3814d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(A, &Aa));
3815d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSERestoreArray(A, &Aa));
3816d52a580bSJunchao Zhang 
3817d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) PetscCallHIP(hipFree((void *)v1));
3818d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3819d52a580bSJunchao Zhang }
3820d52a580bSJunchao Zhang 
3821d52a580bSJunchao Zhang /*@C
3822d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetIJ - returns the device row storage `i` and `j` indices for `MATSEQAIJHIPSPARSE` matrices.
3823d52a580bSJunchao Zhang 
3824d52a580bSJunchao Zhang   Not Collective
3825d52a580bSJunchao Zhang 
3826d52a580bSJunchao Zhang   Input Parameters:
3827d52a580bSJunchao Zhang + A          - the matrix
3828d52a580bSJunchao Zhang - compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3829d52a580bSJunchao Zhang 
3830d52a580bSJunchao Zhang   Output Parameters:
3831d52a580bSJunchao Zhang + i - the CSR row pointers
3832d52a580bSJunchao Zhang - j - the CSR column indices
3833d52a580bSJunchao Zhang 
3834d52a580bSJunchao Zhang   Level: developer
3835d52a580bSJunchao Zhang 
3836d52a580bSJunchao Zhang   Note:
3837d52a580bSJunchao Zhang   When compressed is true, the CSR structure does not contain empty rows
3838d52a580bSJunchao Zhang 
3839d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSERestoreIJ()`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3840d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetIJ(Mat A,PetscBool compressed,const int * i[],const int * j[])3841d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3842d52a580bSJunchao Zhang {
3843d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3844d52a580bSJunchao Zhang   Mat_SeqAIJ          *a    = (Mat_SeqAIJ *)A->data;
3845d52a580bSJunchao Zhang   CsrMatrix           *csr;
3846d52a580bSJunchao Zhang 
3847d52a580bSJunchao Zhang   PetscFunctionBegin;
3848d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3849d52a580bSJunchao Zhang   if (!i || !j) PetscFunctionReturn(PETSC_SUCCESS);
3850d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3851d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3852d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3853d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3854d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3855d52a580bSJunchao Zhang   if (i) {
3856d52a580bSJunchao Zhang     if (!compressed && a->compressedrow.use) { /* need full row offset */
3857d52a580bSJunchao Zhang       if (!cusp->rowoffsets_gpu) {
3858d52a580bSJunchao Zhang         cusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
3859d52a580bSJunchao Zhang         cusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
3860d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
3861d52a580bSJunchao Zhang       }
3862d52a580bSJunchao Zhang       *i = cusp->rowoffsets_gpu->data().get();
3863d52a580bSJunchao Zhang     } else *i = csr->row_offsets->data().get();
3864d52a580bSJunchao Zhang   }
3865d52a580bSJunchao Zhang   if (j) *j = csr->column_indices->data().get();
3866d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3867d52a580bSJunchao Zhang }
3868d52a580bSJunchao Zhang 
3869d52a580bSJunchao Zhang /*@C
3870d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreIJ - restore the device row storage `i` and `j` indices obtained with `MatSeqAIJHIPSPARSEGetIJ()`
3871d52a580bSJunchao Zhang 
3872d52a580bSJunchao Zhang   Not Collective
3873d52a580bSJunchao Zhang 
3874d52a580bSJunchao Zhang   Input Parameters:
3875d52a580bSJunchao Zhang + A          - the matrix
3876d52a580bSJunchao Zhang . compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3877d52a580bSJunchao Zhang . i          - the CSR row pointers
3878d52a580bSJunchao Zhang - j          - the CSR column indices
3879d52a580bSJunchao Zhang 
3880d52a580bSJunchao Zhang   Level: developer
3881d52a580bSJunchao Zhang 
3882d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetIJ()`
3883d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreIJ(Mat A,PetscBool compressed,const int * i[],const int * j[])3884d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3885d52a580bSJunchao Zhang {
3886d52a580bSJunchao Zhang   PetscFunctionBegin;
3887d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3888d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3889d52a580bSJunchao Zhang   if (i) *i = NULL;
3890d52a580bSJunchao Zhang   if (j) *j = NULL;
3891d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3892d52a580bSJunchao Zhang }
3893d52a580bSJunchao Zhang 
3894d52a580bSJunchao Zhang /*@C
3895d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArrayRead - gives read-only access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3896d52a580bSJunchao Zhang 
3897d52a580bSJunchao Zhang   Not Collective
3898d52a580bSJunchao Zhang 
3899d52a580bSJunchao Zhang   Input Parameter:
3900d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3901d52a580bSJunchao Zhang 
3902d52a580bSJunchao Zhang   Output Parameter:
3903d52a580bSJunchao Zhang . a - pointer to the device data
3904d52a580bSJunchao Zhang 
3905d52a580bSJunchao Zhang   Level: developer
3906d52a580bSJunchao Zhang 
3907d52a580bSJunchao Zhang   Note:
3908d52a580bSJunchao Zhang   May trigger host-device copies if the up-to-date matrix data is on host
3909d52a580bSJunchao Zhang 
3910d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArrayRead()`
3911d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArrayRead(Mat A,const PetscScalar * a[])3912d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayRead(Mat A, const PetscScalar *a[])
3913d52a580bSJunchao Zhang {
3914d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3915d52a580bSJunchao Zhang   CsrMatrix           *csr;
3916d52a580bSJunchao Zhang 
3917d52a580bSJunchao Zhang   PetscFunctionBegin;
3918d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3919d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3920d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3921d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3922d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3923d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3924d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3925d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3926d52a580bSJunchao Zhang   *a = csr->values->data().get();
3927d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3928d52a580bSJunchao Zhang }
3929d52a580bSJunchao Zhang 
3930d52a580bSJunchao Zhang /*@C
3931d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArrayRead - restore the read-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayRead()`
3932d52a580bSJunchao Zhang 
3933d52a580bSJunchao Zhang   Not Collective
3934d52a580bSJunchao Zhang 
3935d52a580bSJunchao Zhang   Input Parameters:
3936d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3937d52a580bSJunchao Zhang - a - pointer to the device data
3938d52a580bSJunchao Zhang 
3939d52a580bSJunchao Zhang   Level: developer
3940d52a580bSJunchao Zhang 
3941d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3942d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArrayRead(Mat A,const PetscScalar * a[])3943d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayRead(Mat A, const PetscScalar *a[])
3944d52a580bSJunchao Zhang {
3945d52a580bSJunchao Zhang   PetscFunctionBegin;
3946d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3947d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3948d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3949d52a580bSJunchao Zhang   *a = NULL;
3950d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3951d52a580bSJunchao Zhang }
3952d52a580bSJunchao Zhang 
3953d52a580bSJunchao Zhang /*@C
3954d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArray - gives read-write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3955d52a580bSJunchao Zhang 
3956d52a580bSJunchao Zhang   Not Collective
3957d52a580bSJunchao Zhang 
3958d52a580bSJunchao Zhang   Input Parameter:
3959d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3960d52a580bSJunchao Zhang 
3961d52a580bSJunchao Zhang   Output Parameter:
3962d52a580bSJunchao Zhang . a - pointer to the device data
3963d52a580bSJunchao Zhang 
3964d52a580bSJunchao Zhang   Level: developer
3965d52a580bSJunchao Zhang 
3966d52a580bSJunchao Zhang   Note:
3967d52a580bSJunchao Zhang   May trigger host-device copies if up-to-date matrix data is on host
3968d52a580bSJunchao Zhang 
3969d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArray()`
3970d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArray(Mat A,PetscScalar * a[])3971d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArray(Mat A, PetscScalar *a[])
3972d52a580bSJunchao Zhang {
3973d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3974d52a580bSJunchao Zhang   CsrMatrix           *csr;
3975d52a580bSJunchao Zhang 
3976d52a580bSJunchao Zhang   PetscFunctionBegin;
3977d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3978d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3979d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3980d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3981d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3982d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3983d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3984d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3985d52a580bSJunchao Zhang   *a             = csr->values->data().get();
3986d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_GPU;
3987d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
3988d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3989d52a580bSJunchao Zhang }
3990d52a580bSJunchao Zhang /*@C
3991d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArray - restore the read-write access array obtained from `MatSeqAIJHIPSPARSEGetArray()`
3992d52a580bSJunchao Zhang 
3993d52a580bSJunchao Zhang   Not Collective
3994d52a580bSJunchao Zhang 
3995d52a580bSJunchao Zhang   Input Parameters:
3996d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3997d52a580bSJunchao Zhang - a - pointer to the device data
3998d52a580bSJunchao Zhang 
3999d52a580bSJunchao Zhang   Level: developer
4000d52a580bSJunchao Zhang 
4001d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`
4002d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArray(Mat A,PetscScalar * a[])4003d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArray(Mat A, PetscScalar *a[])
4004d52a580bSJunchao Zhang {
4005d52a580bSJunchao Zhang   PetscFunctionBegin;
4006d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4007d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4008d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4009d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)A));
4010d52a580bSJunchao Zhang   *a = NULL;
4011d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4012d52a580bSJunchao Zhang }
4013d52a580bSJunchao Zhang 
4014d52a580bSJunchao Zhang /*@C
4015d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArrayWrite - gives write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
4016d52a580bSJunchao Zhang 
4017d52a580bSJunchao Zhang   Not Collective
4018d52a580bSJunchao Zhang 
4019d52a580bSJunchao Zhang   Input Parameter:
4020d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
4021d52a580bSJunchao Zhang 
4022d52a580bSJunchao Zhang   Output Parameter:
4023d52a580bSJunchao Zhang . a - pointer to the device data
4024d52a580bSJunchao Zhang 
4025d52a580bSJunchao Zhang   Level: developer
4026d52a580bSJunchao Zhang 
4027d52a580bSJunchao Zhang   Note:
4028d52a580bSJunchao Zhang   Does not trigger host-device copies and flags data validity on the GPU
4029d52a580bSJunchao Zhang 
4030d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSERestoreArrayWrite()`
4031d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSEGetArrayWrite(Mat A,PetscScalar * a[])4032d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayWrite(Mat A, PetscScalar *a[])
4033d52a580bSJunchao Zhang {
4034d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
4035d52a580bSJunchao Zhang   CsrMatrix           *csr;
4036d52a580bSJunchao Zhang 
4037d52a580bSJunchao Zhang   PetscFunctionBegin;
4038d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4039d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4040d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4041d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4042d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4043d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
4044d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
4045d52a580bSJunchao Zhang   *a             = csr->values->data().get();
4046d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_GPU;
4047d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
4048d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4049d52a580bSJunchao Zhang }
4050d52a580bSJunchao Zhang 
4051d52a580bSJunchao Zhang /*@C
4052d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArrayWrite - restore the write-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayWrite()`
4053d52a580bSJunchao Zhang 
4054d52a580bSJunchao Zhang   Not Collective
4055d52a580bSJunchao Zhang 
4056d52a580bSJunchao Zhang   Input Parameters:
4057d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
4058d52a580bSJunchao Zhang - a - pointer to the device data
4059d52a580bSJunchao Zhang 
4060d52a580bSJunchao Zhang   Level: developer
4061d52a580bSJunchao Zhang 
4062d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayWrite()`
4063d52a580bSJunchao Zhang @*/
MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A,PetscScalar * a[])4064d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A, PetscScalar *a[])
4065d52a580bSJunchao Zhang {
4066d52a580bSJunchao Zhang   PetscFunctionBegin;
4067d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4068d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4069d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4070d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)A));
4071d52a580bSJunchao Zhang   *a = NULL;
4072d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4073d52a580bSJunchao Zhang }
4074d52a580bSJunchao Zhang 
4075d52a580bSJunchao Zhang struct IJCompare4 {
operator ()IJCompare44076d52a580bSJunchao Zhang   __host__ __device__ inline bool operator()(const thrust::tuple<int, int, PetscScalar, int> &t1, const thrust::tuple<int, int, PetscScalar, int> &t2)
4077d52a580bSJunchao Zhang   {
4078d52a580bSJunchao Zhang     if (t1.get<0>() < t2.get<0>()) return true;
4079d52a580bSJunchao Zhang     if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
4080d52a580bSJunchao Zhang     return false;
4081d52a580bSJunchao Zhang   }
4082d52a580bSJunchao Zhang };
4083d52a580bSJunchao Zhang 
4084d52a580bSJunchao Zhang struct Shift {
4085d52a580bSJunchao Zhang   int _shift;
4086d52a580bSJunchao Zhang 
ShiftShift4087d52a580bSJunchao Zhang   Shift(int shift) : _shift(shift) { }
operator ()Shift4088d52a580bSJunchao Zhang   __host__ __device__ inline int operator()(const int &c) { return c + _shift; }
4089d52a580bSJunchao Zhang };
4090d52a580bSJunchao Zhang 
4091d52a580bSJunchao Zhang /* merges two SeqAIJHIPSPARSE matrices A, B by concatenating their rows. [A';B']' operation in MATLAB notation */
MatSeqAIJHIPSPARSEMergeMats(Mat A,Mat B,MatReuse reuse,Mat * C)4092d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
4093d52a580bSJunchao Zhang {
4094d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a = (Mat_SeqAIJ *)A->data, *b = (Mat_SeqAIJ *)B->data, *c;
4095d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr, *Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr, *Ccusp;
4096d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Cmat;
4097d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
4098d52a580bSJunchao Zhang   PetscInt                       Annz, Bnnz;
4099d52a580bSJunchao Zhang   PetscInt                       i, m, n, zero = 0;
4100d52a580bSJunchao Zhang 
4101d52a580bSJunchao Zhang   PetscFunctionBegin;
4102d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4103d52a580bSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
4104d52a580bSJunchao Zhang   PetscAssertPointer(C, 4);
4105d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4106d52a580bSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJHIPSPARSE);
4107d52a580bSJunchao Zhang   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
4108d52a580bSJunchao Zhang   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
4109d52a580bSJunchao Zhang   PetscCheck(Acusp->format != MAT_HIPSPARSE_ELL && Acusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4110d52a580bSJunchao Zhang   PetscCheck(Bcusp->format != MAT_HIPSPARSE_ELL && Bcusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4111d52a580bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4112d52a580bSJunchao Zhang     m = A->rmap->n;
4113d52a580bSJunchao Zhang     n = A->cmap->n + B->cmap->n;
4114d52a580bSJunchao Zhang     PetscCall(MatCreate(PETSC_COMM_SELF, C));
4115d52a580bSJunchao Zhang     PetscCall(MatSetSizes(*C, m, n, m, n));
4116d52a580bSJunchao Zhang     PetscCall(MatSetType(*C, MATSEQAIJHIPSPARSE));
4117d52a580bSJunchao Zhang     c                       = (Mat_SeqAIJ *)(*C)->data;
4118d52a580bSJunchao Zhang     Ccusp                   = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4119d52a580bSJunchao Zhang     Cmat                    = new Mat_SeqAIJHIPSPARSEMultStruct;
4120d52a580bSJunchao Zhang     Ccsr                    = new CsrMatrix;
4121d52a580bSJunchao Zhang     Cmat->cprowIndices      = NULL;
4122d52a580bSJunchao Zhang     c->compressedrow.use    = PETSC_FALSE;
4123d52a580bSJunchao Zhang     c->compressedrow.nrows  = 0;
4124d52a580bSJunchao Zhang     c->compressedrow.i      = NULL;
4125d52a580bSJunchao Zhang     c->compressedrow.rindex = NULL;
4126d52a580bSJunchao Zhang     Ccusp->workVector       = NULL;
4127d52a580bSJunchao Zhang     Ccusp->nrows            = m;
4128d52a580bSJunchao Zhang     Ccusp->mat              = Cmat;
4129d52a580bSJunchao Zhang     Ccusp->mat->mat         = Ccsr;
4130d52a580bSJunchao Zhang     Ccsr->num_rows          = m;
4131d52a580bSJunchao Zhang     Ccsr->num_cols          = n;
4132d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
4133d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
4134d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4135d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
4136d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
4137d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
4138d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4139d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4140d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4141d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4142d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4143d52a580bSJunchao Zhang     PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4144d52a580bSJunchao Zhang     PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4145d52a580bSJunchao Zhang 
4146d52a580bSJunchao Zhang     Acsr                 = (CsrMatrix *)Acusp->mat->mat;
4147d52a580bSJunchao Zhang     Bcsr                 = (CsrMatrix *)Bcusp->mat->mat;
4148d52a580bSJunchao Zhang     Annz                 = (PetscInt)Acsr->column_indices->size();
4149d52a580bSJunchao Zhang     Bnnz                 = (PetscInt)Bcsr->column_indices->size();
4150d52a580bSJunchao Zhang     c->nz                = Annz + Bnnz;
4151d52a580bSJunchao Zhang     Ccsr->row_offsets    = new THRUSTINTARRAY32(m + 1);
4152d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
4153d52a580bSJunchao Zhang     Ccsr->values         = new THRUSTARRAY(c->nz);
4154d52a580bSJunchao Zhang     Ccsr->num_entries    = c->nz;
4155d52a580bSJunchao Zhang     Ccusp->coords        = new THRUSTINTARRAY(c->nz);
4156d52a580bSJunchao Zhang     if (c->nz) {
4157d52a580bSJunchao Zhang       auto              Acoo = new THRUSTINTARRAY32(Annz);
4158d52a580bSJunchao Zhang       auto              Bcoo = new THRUSTINTARRAY32(Bnnz);
4159d52a580bSJunchao Zhang       auto              Ccoo = new THRUSTINTARRAY32(c->nz);
4160d52a580bSJunchao Zhang       THRUSTINTARRAY32 *Aroff, *Broff;
4161d52a580bSJunchao Zhang 
4162d52a580bSJunchao Zhang       if (a->compressedrow.use) { /* need full row offset */
4163d52a580bSJunchao Zhang         if (!Acusp->rowoffsets_gpu) {
4164d52a580bSJunchao Zhang           Acusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
4165d52a580bSJunchao Zhang           Acusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
4166d52a580bSJunchao Zhang           PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
4167d52a580bSJunchao Zhang         }
4168d52a580bSJunchao Zhang         Aroff = Acusp->rowoffsets_gpu;
4169d52a580bSJunchao Zhang       } else Aroff = Acsr->row_offsets;
4170d52a580bSJunchao Zhang       if (b->compressedrow.use) { /* need full row offset */
4171d52a580bSJunchao Zhang         if (!Bcusp->rowoffsets_gpu) {
4172d52a580bSJunchao Zhang           Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
4173d52a580bSJunchao Zhang           Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
4174d52a580bSJunchao Zhang           PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
4175d52a580bSJunchao Zhang         }
4176d52a580bSJunchao Zhang         Broff = Bcusp->rowoffsets_gpu;
4177d52a580bSJunchao Zhang       } else Broff = Bcsr->row_offsets;
4178d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeBegin());
4179d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcsr2coo(Acusp->handle, Aroff->data().get(), Annz, m, Acoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4180d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcsr2coo(Bcusp->handle, Broff->data().get(), Bnnz, m, Bcoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4181d52a580bSJunchao Zhang       /* Issues when using bool with large matrices on SUMMIT 10.2.89 */
4182d52a580bSJunchao Zhang       auto Aperm = thrust::make_constant_iterator(1);
4183d52a580bSJunchao Zhang       auto Bperm = thrust::make_constant_iterator(0);
4184d52a580bSJunchao Zhang       auto Bcib  = thrust::make_transform_iterator(Bcsr->column_indices->begin(), Shift(A->cmap->n));
4185d52a580bSJunchao Zhang       auto Bcie  = thrust::make_transform_iterator(Bcsr->column_indices->end(), Shift(A->cmap->n));
4186d52a580bSJunchao Zhang       auto wPerm = new THRUSTINTARRAY32(Annz + Bnnz);
4187d52a580bSJunchao Zhang       auto Azb   = thrust::make_zip_iterator(thrust::make_tuple(Acoo->begin(), Acsr->column_indices->begin(), Acsr->values->begin(), Aperm));
4188d52a580bSJunchao Zhang       auto Aze   = thrust::make_zip_iterator(thrust::make_tuple(Acoo->end(), Acsr->column_indices->end(), Acsr->values->end(), Aperm));
4189d52a580bSJunchao Zhang       auto Bzb   = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->begin(), Bcib, Bcsr->values->begin(), Bperm));
4190d52a580bSJunchao Zhang       auto Bze   = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->end(), Bcie, Bcsr->values->end(), Bperm));
4191d52a580bSJunchao Zhang       auto Czb   = thrust::make_zip_iterator(thrust::make_tuple(Ccoo->begin(), Ccsr->column_indices->begin(), Ccsr->values->begin(), wPerm->begin()));
4192d52a580bSJunchao Zhang       auto p1    = Ccusp->coords->begin();
4193d52a580bSJunchao Zhang       auto p2    = Ccusp->coords->begin();
4194d52a580bSJunchao Zhang       thrust::advance(p2, Annz);
4195d52a580bSJunchao Zhang       PetscCallThrust(thrust::merge(thrust::device, Azb, Aze, Bzb, Bze, Czb, IJCompare4()));
4196d52a580bSJunchao Zhang       auto cci = thrust::make_counting_iterator(zero);
4197d52a580bSJunchao Zhang       auto cce = thrust::make_counting_iterator(c->nz);
4198d52a580bSJunchao Zhang #if 0 //Errors on SUMMIT cuda 11.1.0
4199d52a580bSJunchao Zhang       PetscCallThrust(thrust::partition_copy(thrust::device, cci, cce, wPerm->begin(), p1, p2, thrust::identity<int>()));
4200d52a580bSJunchao Zhang #else
4201*5a884c48SSatish Balay       auto pred = [](const int &x) { return x; };
4202d52a580bSJunchao Zhang       PetscCallThrust(thrust::copy_if(thrust::device, cci, cce, wPerm->begin(), p1, pred));
4203d52a580bSJunchao Zhang       PetscCallThrust(thrust::remove_copy_if(thrust::device, cci, cce, wPerm->begin(), p2, pred));
4204d52a580bSJunchao Zhang #endif
4205d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcoo2csr(Ccusp->handle, Ccoo->data().get(), c->nz, m, Ccsr->row_offsets->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4206d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeEnd());
4207d52a580bSJunchao Zhang       delete wPerm;
4208d52a580bSJunchao Zhang       delete Acoo;
4209d52a580bSJunchao Zhang       delete Bcoo;
4210d52a580bSJunchao Zhang       delete Ccoo;
4211d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, Ccsr->num_entries, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4212d52a580bSJunchao Zhang 
4213d52a580bSJunchao Zhang       if (A->form_explicit_transpose && B->form_explicit_transpose) { /* if A and B have the transpose, generate C transpose too */
4214d52a580bSJunchao Zhang         PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
4215d52a580bSJunchao Zhang         PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
4216d52a580bSJunchao Zhang         PetscBool                      AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4217d52a580bSJunchao Zhang         Mat_SeqAIJHIPSPARSEMultStruct *CmatT = new Mat_SeqAIJHIPSPARSEMultStruct;
4218d52a580bSJunchao Zhang         CsrMatrix                     *CcsrT = new CsrMatrix;
4219d52a580bSJunchao Zhang         CsrMatrix                     *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4220d52a580bSJunchao Zhang         CsrMatrix                     *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4221d52a580bSJunchao Zhang 
4222d52a580bSJunchao Zhang         (*C)->form_explicit_transpose = PETSC_TRUE;
4223d52a580bSJunchao Zhang         (*C)->transupdated            = PETSC_TRUE;
4224d52a580bSJunchao Zhang         Ccusp->rowoffsets_gpu         = NULL;
4225d52a580bSJunchao Zhang         CmatT->cprowIndices           = NULL;
4226d52a580bSJunchao Zhang         CmatT->mat                    = CcsrT;
4227d52a580bSJunchao Zhang         CcsrT->num_rows               = n;
4228d52a580bSJunchao Zhang         CcsrT->num_cols               = m;
4229d52a580bSJunchao Zhang         CcsrT->num_entries            = c->nz;
4230d52a580bSJunchao Zhang         CcsrT->row_offsets            = new THRUSTINTARRAY32(n + 1);
4231d52a580bSJunchao Zhang         CcsrT->column_indices         = new THRUSTINTARRAY32(c->nz);
4232d52a580bSJunchao Zhang         CcsrT->values                 = new THRUSTARRAY(c->nz);
4233d52a580bSJunchao Zhang 
4234d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeBegin());
4235d52a580bSJunchao Zhang         auto rT = CcsrT->row_offsets->begin();
4236d52a580bSJunchao Zhang         if (AT) {
4237d52a580bSJunchao Zhang           rT = thrust::copy(AcsrT->row_offsets->begin(), AcsrT->row_offsets->end(), rT);
4238d52a580bSJunchao Zhang           thrust::advance(rT, -1);
4239d52a580bSJunchao Zhang         }
4240d52a580bSJunchao Zhang         if (BT) {
4241d52a580bSJunchao Zhang           auto titb = thrust::make_transform_iterator(BcsrT->row_offsets->begin(), Shift(a->nz));
4242d52a580bSJunchao Zhang           auto tite = thrust::make_transform_iterator(BcsrT->row_offsets->end(), Shift(a->nz));
4243d52a580bSJunchao Zhang           thrust::copy(titb, tite, rT);
4244d52a580bSJunchao Zhang         }
4245d52a580bSJunchao Zhang         auto cT = CcsrT->column_indices->begin();
4246d52a580bSJunchao Zhang         if (AT) cT = thrust::copy(AcsrT->column_indices->begin(), AcsrT->column_indices->end(), cT);
4247d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->column_indices->begin(), BcsrT->column_indices->end(), cT);
4248d52a580bSJunchao Zhang         auto vT = CcsrT->values->begin();
4249d52a580bSJunchao Zhang         if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4250d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4251d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeEnd());
4252d52a580bSJunchao Zhang 
4253d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&CmatT->descr));
4254d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(CmatT->descr, HIPSPARSE_INDEX_BASE_ZERO));
4255d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(CmatT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4256d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->alpha_one, sizeof(PetscScalar)));
4257d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->beta_zero, sizeof(PetscScalar)));
4258d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->beta_one, sizeof(PetscScalar)));
4259d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4260d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4261d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4262d52a580bSJunchao Zhang 
4263d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsr(&CmatT->matDescr, CcsrT->num_rows, CcsrT->num_cols, CcsrT->num_entries, CcsrT->row_offsets->data().get(), CcsrT->column_indices->data().get(), CcsrT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4264d52a580bSJunchao Zhang         Ccusp->matTranspose = CmatT;
4265d52a580bSJunchao Zhang       }
4266d52a580bSJunchao Zhang     }
4267d52a580bSJunchao Zhang 
4268d52a580bSJunchao Zhang     c->free_a = PETSC_TRUE;
4269d52a580bSJunchao Zhang     PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
4270d52a580bSJunchao Zhang     PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
4271d52a580bSJunchao Zhang     c->free_ij = PETSC_TRUE;
4272d52a580bSJunchao Zhang     if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
4273d52a580bSJunchao Zhang       THRUSTINTARRAY ii(Ccsr->row_offsets->size());
4274d52a580bSJunchao Zhang       THRUSTINTARRAY jj(Ccsr->column_indices->size());
4275d52a580bSJunchao Zhang       ii = *Ccsr->row_offsets;
4276d52a580bSJunchao Zhang       jj = *Ccsr->column_indices;
4277d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4278d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4279d52a580bSJunchao Zhang     } else {
4280d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4281d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4282d52a580bSJunchao Zhang     }
4283d52a580bSJunchao Zhang     PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
4284d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(m, &c->ilen));
4285d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(m, &c->imax));
4286d52a580bSJunchao Zhang     c->maxnz         = c->nz;
4287d52a580bSJunchao Zhang     c->nonzerorowcnt = 0;
4288d52a580bSJunchao Zhang     c->rmax          = 0;
4289d52a580bSJunchao Zhang     for (i = 0; i < m; i++) {
4290d52a580bSJunchao Zhang       const PetscInt nn = c->i[i + 1] - c->i[i];
4291d52a580bSJunchao Zhang       c->ilen[i] = c->imax[i] = nn;
4292d52a580bSJunchao Zhang       c->nonzerorowcnt += (PetscInt)!!nn;
4293d52a580bSJunchao Zhang       c->rmax = PetscMax(c->rmax, nn);
4294d52a580bSJunchao Zhang     }
4295d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(c->nz, &c->a));
4296d52a580bSJunchao Zhang     (*C)->nonzerostate++;
4297d52a580bSJunchao Zhang     PetscCall(PetscLayoutSetUp((*C)->rmap));
4298d52a580bSJunchao Zhang     PetscCall(PetscLayoutSetUp((*C)->cmap));
4299d52a580bSJunchao Zhang     Ccusp->nonzerostate = (*C)->nonzerostate;
4300d52a580bSJunchao Zhang     (*C)->preallocated  = PETSC_TRUE;
4301d52a580bSJunchao Zhang   } else {
4302d52a580bSJunchao Zhang     PetscCheck((*C)->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, (*C)->rmap->n, B->rmap->n);
4303d52a580bSJunchao Zhang     c = (Mat_SeqAIJ *)(*C)->data;
4304d52a580bSJunchao Zhang     if (c->nz) {
4305d52a580bSJunchao Zhang       Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4306d52a580bSJunchao Zhang       PetscCheck(Ccusp->coords, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing coords");
4307d52a580bSJunchao Zhang       PetscCheck(Ccusp->format != MAT_HIPSPARSE_ELL && Ccusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4308d52a580bSJunchao Zhang       PetscCheck(Ccusp->nonzerostate == (*C)->nonzerostate, PETSC_COMM_SELF, PETSC_ERR_COR, "Wrong nonzerostate");
4309d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4310d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4311d52a580bSJunchao Zhang       PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4312d52a580bSJunchao Zhang       PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4313d52a580bSJunchao Zhang       Acsr = (CsrMatrix *)Acusp->mat->mat;
4314d52a580bSJunchao Zhang       Bcsr = (CsrMatrix *)Bcusp->mat->mat;
4315d52a580bSJunchao Zhang       Ccsr = (CsrMatrix *)Ccusp->mat->mat;
4316d52a580bSJunchao Zhang       PetscCheck(Acsr->num_entries == (PetscInt)Acsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "A nnz %" PetscInt_FMT " != %" PetscInt_FMT, Acsr->num_entries, (PetscInt)Acsr->values->size());
4317d52a580bSJunchao Zhang       PetscCheck(Bcsr->num_entries == (PetscInt)Bcsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "B nnz %" PetscInt_FMT " != %" PetscInt_FMT, Bcsr->num_entries, (PetscInt)Bcsr->values->size());
4318d52a580bSJunchao Zhang       PetscCheck(Ccsr->num_entries == (PetscInt)Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT, Ccsr->num_entries, (PetscInt)Ccsr->values->size());
4319d52a580bSJunchao Zhang       PetscCheck(Ccsr->num_entries == Acsr->num_entries + Bcsr->num_entries, PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT " + %" PetscInt_FMT, Ccsr->num_entries, Acsr->num_entries, Bcsr->num_entries);
4320d52a580bSJunchao Zhang       PetscCheck(Ccusp->coords->size() == Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "permSize %" PetscInt_FMT " != %" PetscInt_FMT, (PetscInt)Ccusp->coords->size(), (PetscInt)Ccsr->values->size());
4321d52a580bSJunchao Zhang       auto pmid = Ccusp->coords->begin();
4322d52a580bSJunchao Zhang       thrust::advance(pmid, Acsr->num_entries);
4323d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeBegin());
4324d52a580bSJunchao Zhang       auto zibait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->begin())));
4325d52a580bSJunchao Zhang       auto zieait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4326d52a580bSJunchao Zhang       thrust::for_each(zibait, zieait, VecHIPEquals());
4327d52a580bSJunchao Zhang       auto zibbit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4328d52a580bSJunchao Zhang       auto ziebit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->end())));
4329d52a580bSJunchao Zhang       thrust::for_each(zibbit, ziebit, VecHIPEquals());
4330d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(*C, PETSC_FALSE));
4331d52a580bSJunchao Zhang       if (A->form_explicit_transpose && B->form_explicit_transpose && (*C)->form_explicit_transpose) {
4332d52a580bSJunchao Zhang         PetscCheck(Ccusp->matTranspose, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing transpose Mat_SeqAIJHIPSPARSEMultStruct");
4333d52a580bSJunchao Zhang         PetscBool  AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4334d52a580bSJunchao Zhang         CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4335d52a580bSJunchao Zhang         CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4336d52a580bSJunchao Zhang         CsrMatrix *CcsrT = (CsrMatrix *)Ccusp->matTranspose->mat;
4337d52a580bSJunchao Zhang         auto       vT    = CcsrT->values->begin();
4338d52a580bSJunchao Zhang         if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4339d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4340d52a580bSJunchao Zhang         (*C)->transupdated = PETSC_TRUE;
4341d52a580bSJunchao Zhang       }
4342d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeEnd());
4343d52a580bSJunchao Zhang     }
4344d52a580bSJunchao Zhang   }
4345d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)*C));
4346d52a580bSJunchao Zhang   (*C)->assembled     = PETSC_TRUE;
4347d52a580bSJunchao Zhang   (*C)->was_assembled = PETSC_FALSE;
4348d52a580bSJunchao Zhang   (*C)->offloadmask   = PETSC_OFFLOAD_GPU;
4349d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4350d52a580bSJunchao Zhang }
4351d52a580bSJunchao Zhang 
MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A,PetscInt n,const PetscInt idx[],PetscScalar v[])4352d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A, PetscInt n, const PetscInt idx[], PetscScalar v[])
4353d52a580bSJunchao Zhang {
4354d52a580bSJunchao Zhang   bool               dmem;
4355d52a580bSJunchao Zhang   const PetscScalar *av;
4356d52a580bSJunchao Zhang 
4357d52a580bSJunchao Zhang   PetscFunctionBegin;
4358d52a580bSJunchao Zhang   dmem = isHipMem(v);
4359d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(A, &av));
4360d52a580bSJunchao Zhang   if (n && idx) {
4361d52a580bSJunchao Zhang     THRUSTINTARRAY widx(n);
4362d52a580bSJunchao Zhang     widx.assign(idx, idx + n);
4363d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
4364d52a580bSJunchao Zhang 
4365d52a580bSJunchao Zhang     THRUSTARRAY                    *w = NULL;
4366d52a580bSJunchao Zhang     thrust::device_ptr<PetscScalar> dv;
4367d52a580bSJunchao Zhang     if (dmem) dv = thrust::device_pointer_cast(v);
4368d52a580bSJunchao Zhang     else {
4369d52a580bSJunchao Zhang       w  = new THRUSTARRAY(n);
4370d52a580bSJunchao Zhang       dv = w->data();
4371d52a580bSJunchao Zhang     }
4372d52a580bSJunchao Zhang     thrust::device_ptr<const PetscScalar> dav = thrust::device_pointer_cast(av);
4373d52a580bSJunchao Zhang 
4374d52a580bSJunchao Zhang     auto zibit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.begin()), dv));
4375d52a580bSJunchao Zhang     auto zieit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.end()), dv + n));
4376d52a580bSJunchao Zhang     thrust::for_each(zibit, zieit, VecHIPEquals());
4377d52a580bSJunchao Zhang     if (w) PetscCallHIP(hipMemcpy(v, w->data().get(), n * sizeof(PetscScalar), hipMemcpyDeviceToHost));
4378d52a580bSJunchao Zhang     delete w;
4379d52a580bSJunchao Zhang   } else PetscCallHIP(hipMemcpy(v, av, n * sizeof(PetscScalar), dmem ? hipMemcpyDeviceToDevice : hipMemcpyDeviceToHost));
4380d52a580bSJunchao Zhang 
4381d52a580bSJunchao Zhang   if (!dmem) PetscCall(PetscLogCpuToGpu(n * sizeof(PetscScalar)));
4382d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(A, &av));
4383d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4384d52a580bSJunchao Zhang }
4385