xref: /petsc/src/mat/impls/aij/mpi/mpihipsparse/mpiaijhipsparse.hip.cxx (revision daba9d70159ea2f6905738fcbec7404635487b2b)
1*d52a580bSJunchao Zhang /* Portions of this code are under:
2*d52a580bSJunchao Zhang    Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
3*d52a580bSJunchao Zhang */
4*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> /*I "petscmat.h" I*/
5*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/seqhipsparse/hipsparsematimpl.h>
6*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpihipsparse/mpihipsparsematimpl.h>
7*d52a580bSJunchao Zhang #include <thrust/advance.h>
8*d52a580bSJunchao Zhang #include <thrust/partition.h>
9*d52a580bSJunchao Zhang #include <thrust/sort.h>
10*d52a580bSJunchao Zhang #include <thrust/unique.h>
11*d52a580bSJunchao Zhang #include <petscsf.h>
12*d52a580bSJunchao Zhang 
13*d52a580bSJunchao Zhang struct VecHIPEquals {
14*d52a580bSJunchao Zhang   template <typename Tuple>
operator ()VecHIPEquals15*d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
16*d52a580bSJunchao Zhang   {
17*d52a580bSJunchao Zhang     thrust::get<1>(t) = thrust::get<0>(t);
18*d52a580bSJunchao Zhang   }
19*d52a580bSJunchao Zhang };
20*d52a580bSJunchao Zhang 
MatCOOStructDestroy_MPIAIJHIPSPARSE(PetscCtxRt data)21*d52a580bSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJHIPSPARSE(PetscCtxRt data)
22*d52a580bSJunchao Zhang {
23*d52a580bSJunchao Zhang   MatCOOStruct_MPIAIJ *coo = *(MatCOOStruct_MPIAIJ **)data;
24*d52a580bSJunchao Zhang 
25*d52a580bSJunchao Zhang   PetscFunctionBegin;
26*d52a580bSJunchao Zhang   PetscCall(PetscSFDestroy(&coo->sf));
27*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Ajmap1));
28*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Aperm1));
29*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Bjmap1));
30*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Bperm1));
31*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Aimap2));
32*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Ajmap2));
33*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Aperm2));
34*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Bimap2));
35*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Bjmap2));
36*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Bperm2));
37*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->Cperm1));
38*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->sendbuf));
39*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->recvbuf));
40*d52a580bSJunchao Zhang   PetscCall(PetscFree(coo));
41*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
42*d52a580bSJunchao Zhang }
43*d52a580bSJunchao Zhang 
MatSetPreallocationCOO_MPIAIJHIPSPARSE(Mat mat,PetscCount coo_n,PetscInt coo_i[],PetscInt coo_j[])44*d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_MPIAIJHIPSPARSE(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
45*d52a580bSJunchao Zhang {
46*d52a580bSJunchao Zhang   Mat_MPIAIJ          *mpiaij = (Mat_MPIAIJ *)mat->data;
47*d52a580bSJunchao Zhang   PetscBool            dev_ij = PETSC_FALSE;
48*d52a580bSJunchao Zhang   PetscMemType         mtype  = PETSC_MEMTYPE_HOST;
49*d52a580bSJunchao Zhang   PetscInt            *i, *j;
50*d52a580bSJunchao Zhang   PetscContainer       container_h;
51*d52a580bSJunchao Zhang   MatCOOStruct_MPIAIJ *coo_h, *coo_d;
52*d52a580bSJunchao Zhang 
53*d52a580bSJunchao Zhang   PetscFunctionBegin;
54*d52a580bSJunchao Zhang   PetscCall(PetscFree(mpiaij->garray));
55*d52a580bSJunchao Zhang   PetscCall(VecDestroy(&mpiaij->lvec));
56*d52a580bSJunchao Zhang #if defined(PETSC_USE_CTABLE)
57*d52a580bSJunchao Zhang   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
58*d52a580bSJunchao Zhang #else
59*d52a580bSJunchao Zhang   PetscCall(PetscFree(mpiaij->colmap));
60*d52a580bSJunchao Zhang #endif
61*d52a580bSJunchao Zhang   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
62*d52a580bSJunchao Zhang   mat->assembled     = PETSC_FALSE;
63*d52a580bSJunchao Zhang   mat->was_assembled = PETSC_FALSE;
64*d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(coo_i, &mtype));
65*d52a580bSJunchao Zhang   if (PetscMemTypeDevice(mtype)) {
66*d52a580bSJunchao Zhang     dev_ij = PETSC_TRUE;
67*d52a580bSJunchao Zhang     PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
68*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(i, coo_i, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
69*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(j, coo_j, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
70*d52a580bSJunchao Zhang   } else {
71*d52a580bSJunchao Zhang     i = coo_i;
72*d52a580bSJunchao Zhang     j = coo_j;
73*d52a580bSJunchao Zhang   }
74*d52a580bSJunchao Zhang 
75*d52a580bSJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
76*d52a580bSJunchao Zhang   if (dev_ij) PetscCall(PetscFree2(i, j));
77*d52a580bSJunchao Zhang   mat->offloadmask = PETSC_OFFLOAD_CPU;
78*d52a580bSJunchao Zhang   // Create the GPU memory
79*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mpiaij->A));
80*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mpiaij->B));
81*d52a580bSJunchao Zhang 
82*d52a580bSJunchao Zhang   // Copy the COO struct to device
83*d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
84*d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, &coo_h));
85*d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(1, &coo_d));
86*d52a580bSJunchao Zhang   *coo_d = *coo_h; // do a shallow copy and then amend fields in coo_d
87*d52a580bSJunchao Zhang 
88*d52a580bSJunchao Zhang   PetscCall(PetscObjectReference((PetscObject)coo_d->sf)); // Since we destroy the sf in both coo_h and coo_d
89*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount)));
90*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Aperm1, coo_h->Atot1 * sizeof(PetscCount)));
91*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount)));
92*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Bperm1, coo_h->Btot1 * sizeof(PetscCount)));
93*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Aimap2, coo_h->Annz2 * sizeof(PetscCount)));
94*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount)));
95*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Aperm2, coo_h->Atot2 * sizeof(PetscCount)));
96*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount)));
97*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount)));
98*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Bperm2, coo_h->Btot2 * sizeof(PetscCount)));
99*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->Cperm1, coo_h->sendlen * sizeof(PetscCount)));
100*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->sendbuf, coo_h->sendlen * sizeof(PetscScalar)));
101*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->recvbuf, coo_h->recvlen * sizeof(PetscScalar)));
102*d52a580bSJunchao Zhang 
103*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Ajmap1, coo_h->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
104*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Aperm1, coo_h->Aperm1, coo_h->Atot1 * sizeof(PetscCount), hipMemcpyHostToDevice));
105*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Bjmap1, coo_h->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
106*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Bperm1, coo_h->Bperm1, coo_h->Btot1 * sizeof(PetscCount), hipMemcpyHostToDevice));
107*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Aimap2, coo_h->Aimap2, coo_h->Annz2 * sizeof(PetscCount), hipMemcpyHostToDevice));
108*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Ajmap2, coo_h->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
109*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Aperm2, coo_h->Aperm2, coo_h->Atot2 * sizeof(PetscCount), hipMemcpyHostToDevice));
110*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Bimap2, coo_h->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount), hipMemcpyHostToDevice));
111*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Bjmap2, coo_h->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
112*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Bperm2, coo_h->Bperm2, coo_h->Btot2 * sizeof(PetscCount), hipMemcpyHostToDevice));
113*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->Cperm1, coo_h->Cperm1, coo_h->sendlen * sizeof(PetscCount), hipMemcpyHostToDevice));
114*d52a580bSJunchao Zhang 
115*d52a580bSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
116*d52a580bSJunchao Zhang   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_MPIAIJHIPSPARSE));
117*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
118*d52a580bSJunchao Zhang }
119*d52a580bSJunchao Zhang 
MatPackCOOValues(const PetscScalar kv[],PetscCount nnz,const PetscCount perm[],PetscScalar buf[])120*d52a580bSJunchao Zhang __global__ static void MatPackCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount perm[], PetscScalar buf[])
121*d52a580bSJunchao Zhang {
122*d52a580bSJunchao Zhang   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
123*d52a580bSJunchao Zhang   const PetscCount grid_size = gridDim.x * blockDim.x;
124*d52a580bSJunchao Zhang   for (; i < nnz; i += grid_size) buf[i] = kv[perm[i]];
125*d52a580bSJunchao Zhang }
126*d52a580bSJunchao Zhang 
MatAddLocalCOOValues(const PetscScalar kv[],InsertMode imode,PetscCount Annz,const PetscCount Ajmap1[],const PetscCount Aperm1[],PetscScalar Aa[],PetscCount Bnnz,const PetscCount Bjmap1[],const PetscCount Bperm1[],PetscScalar Ba[])127*d52a580bSJunchao Zhang __global__ static void MatAddLocalCOOValues(const PetscScalar kv[], InsertMode imode, PetscCount Annz, const PetscCount Ajmap1[], const PetscCount Aperm1[], PetscScalar Aa[], PetscCount Bnnz, const PetscCount Bjmap1[], const PetscCount Bperm1[], PetscScalar Ba[])
128*d52a580bSJunchao Zhang {
129*d52a580bSJunchao Zhang   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
130*d52a580bSJunchao Zhang   const PetscCount grid_size = gridDim.x * blockDim.x;
131*d52a580bSJunchao Zhang   for (; i < Annz + Bnnz; i += grid_size) {
132*d52a580bSJunchao Zhang     PetscScalar sum = 0.0;
133*d52a580bSJunchao Zhang     if (i < Annz) {
134*d52a580bSJunchao Zhang       for (PetscCount k = Ajmap1[i]; k < Ajmap1[i + 1]; k++) sum += kv[Aperm1[k]];
135*d52a580bSJunchao Zhang       Aa[i] = (imode == INSERT_VALUES ? 0.0 : Aa[i]) + sum;
136*d52a580bSJunchao Zhang     } else {
137*d52a580bSJunchao Zhang       i -= Annz;
138*d52a580bSJunchao Zhang       for (PetscCount k = Bjmap1[i]; k < Bjmap1[i + 1]; k++) sum += kv[Bperm1[k]];
139*d52a580bSJunchao Zhang       Ba[i] = (imode == INSERT_VALUES ? 0.0 : Ba[i]) + sum;
140*d52a580bSJunchao Zhang     }
141*d52a580bSJunchao Zhang   }
142*d52a580bSJunchao Zhang }
143*d52a580bSJunchao Zhang 
MatAddRemoteCOOValues(const PetscScalar kv[],PetscCount Annz2,const PetscCount Aimap2[],const PetscCount Ajmap2[],const PetscCount Aperm2[],PetscScalar Aa[],PetscCount Bnnz2,const PetscCount Bimap2[],const PetscCount Bjmap2[],const PetscCount Bperm2[],PetscScalar Ba[])144*d52a580bSJunchao Zhang __global__ static void MatAddRemoteCOOValues(const PetscScalar kv[], PetscCount Annz2, const PetscCount Aimap2[], const PetscCount Ajmap2[], const PetscCount Aperm2[], PetscScalar Aa[], PetscCount Bnnz2, const PetscCount Bimap2[], const PetscCount Bjmap2[], const PetscCount Bperm2[], PetscScalar Ba[])
145*d52a580bSJunchao Zhang {
146*d52a580bSJunchao Zhang   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
147*d52a580bSJunchao Zhang   const PetscCount grid_size = gridDim.x * blockDim.x;
148*d52a580bSJunchao Zhang   for (; i < Annz2 + Bnnz2; i += grid_size) {
149*d52a580bSJunchao Zhang     if (i < Annz2) {
150*d52a580bSJunchao Zhang       for (PetscCount k = Ajmap2[i]; k < Ajmap2[i + 1]; k++) Aa[Aimap2[i]] += kv[Aperm2[k]];
151*d52a580bSJunchao Zhang     } else {
152*d52a580bSJunchao Zhang       i -= Annz2;
153*d52a580bSJunchao Zhang       for (PetscCount k = Bjmap2[i]; k < Bjmap2[i + 1]; k++) Ba[Bimap2[i]] += kv[Bperm2[k]];
154*d52a580bSJunchao Zhang     }
155*d52a580bSJunchao Zhang   }
156*d52a580bSJunchao Zhang }
157*d52a580bSJunchao Zhang 
MatSetValuesCOO_MPIAIJHIPSPARSE(Mat mat,const PetscScalar v[],InsertMode imode)158*d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_MPIAIJHIPSPARSE(Mat mat, const PetscScalar v[], InsertMode imode)
159*d52a580bSJunchao Zhang {
160*d52a580bSJunchao Zhang   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
161*d52a580bSJunchao Zhang   Mat                  A = mpiaij->A, B = mpiaij->B;
162*d52a580bSJunchao Zhang   PetscScalar         *Aa, *Ba;
163*d52a580bSJunchao Zhang   const PetscScalar   *v1 = v;
164*d52a580bSJunchao Zhang   PetscMemType         memtype;
165*d52a580bSJunchao Zhang   PetscContainer       container;
166*d52a580bSJunchao Zhang   MatCOOStruct_MPIAIJ *coo;
167*d52a580bSJunchao Zhang 
168*d52a580bSJunchao Zhang   PetscFunctionBegin;
169*d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
170*d52a580bSJunchao Zhang   PetscCheck(container, PetscObjectComm((PetscObject)mat), PETSC_ERR_PLIB, "Not found MatCOOStruct on this matrix");
171*d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, &coo));
172*d52a580bSJunchao Zhang 
173*d52a580bSJunchao Zhang   const auto &Annz   = coo->Annz;
174*d52a580bSJunchao Zhang   const auto &Annz2  = coo->Annz2;
175*d52a580bSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
176*d52a580bSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
177*d52a580bSJunchao Zhang   const auto &vsend  = coo->sendbuf;
178*d52a580bSJunchao Zhang   const auto &v2     = coo->recvbuf;
179*d52a580bSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
180*d52a580bSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
181*d52a580bSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
182*d52a580bSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
183*d52a580bSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
184*d52a580bSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
185*d52a580bSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
186*d52a580bSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
187*d52a580bSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
188*d52a580bSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
189*d52a580bSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
190*d52a580bSJunchao Zhang 
191*d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(v, &memtype));
192*d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device */
193*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
194*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), hipMemcpyHostToDevice));
195*d52a580bSJunchao Zhang   }
196*d52a580bSJunchao Zhang 
197*d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) {
198*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(A, &Aa)); /* write matrix values */
199*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(B, &Ba));
200*d52a580bSJunchao Zhang   } else {
201*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(A, &Aa)); /* read & write matrix values */
202*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(B, &Ba));
203*d52a580bSJunchao Zhang   }
204*d52a580bSJunchao Zhang 
205*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
206*d52a580bSJunchao Zhang   /* Pack entries to be sent to remote */
207*d52a580bSJunchao Zhang   if (coo->sendlen) {
208*d52a580bSJunchao Zhang     hipLaunchKernelGGL(HIP_KERNEL_NAME(MatPackCOOValues), dim3((coo->sendlen + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, coo->sendlen, Cperm1, vsend);
209*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError());
210*d52a580bSJunchao Zhang   }
211*d52a580bSJunchao Zhang 
212*d52a580bSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
213*d52a580bSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_HIP, vsend, PETSC_MEMTYPE_HIP, v2, MPI_REPLACE));
214*d52a580bSJunchao Zhang   /* Add local entries to A and B */
215*d52a580bSJunchao Zhang   if (Annz + Bnnz > 0) {
216*d52a580bSJunchao Zhang     hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddLocalCOOValues), dim3((Annz + Bnnz + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, imode, Annz, Ajmap1, Aperm1, Aa, Bnnz, Bjmap1, Bperm1, Ba);
217*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError());
218*d52a580bSJunchao Zhang   }
219*d52a580bSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend, v2, MPI_REPLACE));
220*d52a580bSJunchao Zhang 
221*d52a580bSJunchao Zhang   /* Add received remote entries to A and B */
222*d52a580bSJunchao Zhang   if (Annz2 + Bnnz2 > 0) {
223*d52a580bSJunchao Zhang     hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddRemoteCOOValues), dim3((Annz2 + Bnnz2 + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v2, Annz2, Aimap2, Ajmap2, Aperm2, Aa, Bnnz2, Bimap2, Bjmap2, Bperm2, Ba);
224*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError());
225*d52a580bSJunchao Zhang   }
226*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
227*d52a580bSJunchao Zhang 
228*d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) {
229*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(A, &Aa));
230*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(B, &Ba));
231*d52a580bSJunchao Zhang   } else {
232*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(A, &Aa));
233*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(B, &Ba));
234*d52a580bSJunchao Zhang   }
235*d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) PetscCallHIP(hipFree((void *)v1));
236*d52a580bSJunchao Zhang   mat->offloadmask = PETSC_OFFLOAD_GPU;
237*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
238*d52a580bSJunchao Zhang }
239*d52a580bSJunchao Zhang 
MatMPIAIJGetLocalMatMerge_MPIAIJHIPSPARSE(Mat A,MatReuse scall,IS * glob,Mat * A_loc)240*d52a580bSJunchao Zhang static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJHIPSPARSE(Mat A, MatReuse scall, IS *glob, Mat *A_loc)
241*d52a580bSJunchao Zhang {
242*d52a580bSJunchao Zhang   Mat             Ad, Ao;
243*d52a580bSJunchao Zhang   const PetscInt *cmap;
244*d52a580bSJunchao Zhang 
245*d52a580bSJunchao Zhang   PetscFunctionBegin;
246*d52a580bSJunchao Zhang   PetscCall(MatMPIAIJGetSeqAIJ(A, &Ad, &Ao, &cmap));
247*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEMergeMats(Ad, Ao, scall, A_loc));
248*d52a580bSJunchao Zhang   if (glob) {
249*d52a580bSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
250*d52a580bSJunchao Zhang 
251*d52a580bSJunchao Zhang     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
252*d52a580bSJunchao Zhang     PetscCall(MatGetLocalSize(Ao, NULL, &on));
253*d52a580bSJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(A, &cst, NULL));
254*d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(dn + on, &gidx));
255*d52a580bSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
256*d52a580bSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
257*d52a580bSJunchao Zhang     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
258*d52a580bSJunchao Zhang   }
259*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
260*d52a580bSJunchao Zhang }
261*d52a580bSJunchao Zhang 
MatMPIAIJSetPreallocation_MPIAIJHIPSPARSE(Mat B,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])262*d52a580bSJunchao Zhang static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJHIPSPARSE(Mat B, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
263*d52a580bSJunchao Zhang {
264*d52a580bSJunchao Zhang   Mat_MPIAIJ          *b               = (Mat_MPIAIJ *)B->data;
265*d52a580bSJunchao Zhang   Mat_MPIAIJHIPSPARSE *hipsparseStruct = (Mat_MPIAIJHIPSPARSE *)b->spptr;
266*d52a580bSJunchao Zhang   PetscInt             i;
267*d52a580bSJunchao Zhang 
268*d52a580bSJunchao Zhang   PetscFunctionBegin;
269*d52a580bSJunchao Zhang   if (B->hash_active) {
270*d52a580bSJunchao Zhang     B->ops[0]      = b->cops;
271*d52a580bSJunchao Zhang     B->hash_active = PETSC_FALSE;
272*d52a580bSJunchao Zhang   }
273*d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(B->rmap));
274*d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(B->cmap));
275*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG) && d_nnz) {
276*d52a580bSJunchao Zhang     for (i = 0; i < B->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
277*d52a580bSJunchao Zhang   }
278*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG) && o_nnz) {
279*d52a580bSJunchao Zhang     for (i = 0; i < B->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
280*d52a580bSJunchao Zhang   }
281*d52a580bSJunchao Zhang #if defined(PETSC_USE_CTABLE)
282*d52a580bSJunchao Zhang   PetscCall(PetscHMapIDestroy(&b->colmap));
283*d52a580bSJunchao Zhang #else
284*d52a580bSJunchao Zhang   PetscCall(PetscFree(b->colmap));
285*d52a580bSJunchao Zhang #endif
286*d52a580bSJunchao Zhang   PetscCall(PetscFree(b->garray));
287*d52a580bSJunchao Zhang   PetscCall(VecDestroy(&b->lvec));
288*d52a580bSJunchao Zhang   PetscCall(VecScatterDestroy(&b->Mvctx));
289*d52a580bSJunchao Zhang   /* Because the B will have been resized we simply destroy it and create a new one each time */
290*d52a580bSJunchao Zhang   PetscCall(MatDestroy(&b->B));
291*d52a580bSJunchao Zhang   if (!b->A) {
292*d52a580bSJunchao Zhang     PetscCall(MatCreate(PETSC_COMM_SELF, &b->A));
293*d52a580bSJunchao Zhang     PetscCall(MatSetSizes(b->A, B->rmap->n, B->cmap->n, B->rmap->n, B->cmap->n));
294*d52a580bSJunchao Zhang   }
295*d52a580bSJunchao Zhang   if (!b->B) {
296*d52a580bSJunchao Zhang     PetscMPIInt size;
297*d52a580bSJunchao Zhang     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)B), &size));
298*d52a580bSJunchao Zhang     PetscCall(MatCreate(PETSC_COMM_SELF, &b->B));
299*d52a580bSJunchao Zhang     PetscCall(MatSetSizes(b->B, B->rmap->n, size > 1 ? B->cmap->N : 0, B->rmap->n, size > 1 ? B->cmap->N : 0));
300*d52a580bSJunchao Zhang   }
301*d52a580bSJunchao Zhang   PetscCall(MatSetType(b->A, MATSEQAIJHIPSPARSE));
302*d52a580bSJunchao Zhang   PetscCall(MatSetType(b->B, MATSEQAIJHIPSPARSE));
303*d52a580bSJunchao Zhang   PetscCall(MatBindToCPU(b->A, B->boundtocpu));
304*d52a580bSJunchao Zhang   PetscCall(MatBindToCPU(b->B, B->boundtocpu));
305*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation(b->A, d_nz, d_nnz));
306*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation(b->B, o_nz, o_nnz));
307*d52a580bSJunchao Zhang   PetscCall(MatHIPSPARSESetFormat(b->A, MAT_HIPSPARSE_MULT, hipsparseStruct->diagGPUMatFormat));
308*d52a580bSJunchao Zhang   PetscCall(MatHIPSPARSESetFormat(b->B, MAT_HIPSPARSE_MULT, hipsparseStruct->offdiagGPUMatFormat));
309*d52a580bSJunchao Zhang   B->preallocated  = PETSC_TRUE;
310*d52a580bSJunchao Zhang   B->was_assembled = PETSC_FALSE;
311*d52a580bSJunchao Zhang   B->assembled     = PETSC_FALSE;
312*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
313*d52a580bSJunchao Zhang }
314*d52a580bSJunchao Zhang 
MatMult_MPIAIJHIPSPARSE(Mat A,Vec xx,Vec yy)315*d52a580bSJunchao Zhang static PetscErrorCode MatMult_MPIAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
316*d52a580bSJunchao Zhang {
317*d52a580bSJunchao Zhang   Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
318*d52a580bSJunchao Zhang 
319*d52a580bSJunchao Zhang   PetscFunctionBegin;
320*d52a580bSJunchao Zhang   PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
321*d52a580bSJunchao Zhang   PetscCall((*a->A->ops->mult)(a->A, xx, yy));
322*d52a580bSJunchao Zhang   PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
323*d52a580bSJunchao Zhang   PetscCall((*a->B->ops->multadd)(a->B, a->lvec, yy, yy));
324*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
325*d52a580bSJunchao Zhang }
326*d52a580bSJunchao Zhang 
MatZeroEntries_MPIAIJHIPSPARSE(Mat A)327*d52a580bSJunchao Zhang static PetscErrorCode MatZeroEntries_MPIAIJHIPSPARSE(Mat A)
328*d52a580bSJunchao Zhang {
329*d52a580bSJunchao Zhang   Mat_MPIAIJ *l = (Mat_MPIAIJ *)A->data;
330*d52a580bSJunchao Zhang 
331*d52a580bSJunchao Zhang   PetscFunctionBegin;
332*d52a580bSJunchao Zhang   PetscCall(MatZeroEntries(l->A));
333*d52a580bSJunchao Zhang   PetscCall(MatZeroEntries(l->B));
334*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
335*d52a580bSJunchao Zhang }
336*d52a580bSJunchao Zhang 
MatMultAdd_MPIAIJHIPSPARSE(Mat A,Vec xx,Vec yy,Vec zz)337*d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_MPIAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
338*d52a580bSJunchao Zhang {
339*d52a580bSJunchao Zhang   Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
340*d52a580bSJunchao Zhang 
341*d52a580bSJunchao Zhang   PetscFunctionBegin;
342*d52a580bSJunchao Zhang   PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
343*d52a580bSJunchao Zhang   PetscCall((*a->A->ops->multadd)(a->A, xx, yy, zz));
344*d52a580bSJunchao Zhang   PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
345*d52a580bSJunchao Zhang   PetscCall((*a->B->ops->multadd)(a->B, a->lvec, zz, zz));
346*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
347*d52a580bSJunchao Zhang }
348*d52a580bSJunchao Zhang 
MatMultTranspose_MPIAIJHIPSPARSE(Mat A,Vec xx,Vec yy)349*d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_MPIAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
350*d52a580bSJunchao Zhang {
351*d52a580bSJunchao Zhang   Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
352*d52a580bSJunchao Zhang 
353*d52a580bSJunchao Zhang   PetscFunctionBegin;
354*d52a580bSJunchao Zhang   PetscCall((*a->B->ops->multtranspose)(a->B, xx, a->lvec));
355*d52a580bSJunchao Zhang   PetscCall((*a->A->ops->multtranspose)(a->A, xx, yy));
356*d52a580bSJunchao Zhang   PetscCall(VecScatterBegin(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
357*d52a580bSJunchao Zhang   PetscCall(VecScatterEnd(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
358*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
359*d52a580bSJunchao Zhang }
360*d52a580bSJunchao Zhang 
MatHIPSPARSESetFormat_MPIAIJHIPSPARSE(Mat A,MatHIPSPARSEFormatOperation op,MatHIPSPARSEStorageFormat format)361*d52a580bSJunchao Zhang static PetscErrorCode MatHIPSPARSESetFormat_MPIAIJHIPSPARSE(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
362*d52a580bSJunchao Zhang {
363*d52a580bSJunchao Zhang   Mat_MPIAIJ          *a               = (Mat_MPIAIJ *)A->data;
364*d52a580bSJunchao Zhang   Mat_MPIAIJHIPSPARSE *hipsparseStruct = (Mat_MPIAIJHIPSPARSE *)a->spptr;
365*d52a580bSJunchao Zhang 
366*d52a580bSJunchao Zhang   PetscFunctionBegin;
367*d52a580bSJunchao Zhang   switch (op) {
368*d52a580bSJunchao Zhang   case MAT_HIPSPARSE_MULT_DIAG:
369*d52a580bSJunchao Zhang     hipsparseStruct->diagGPUMatFormat = format;
370*d52a580bSJunchao Zhang     break;
371*d52a580bSJunchao Zhang   case MAT_HIPSPARSE_MULT_OFFDIAG:
372*d52a580bSJunchao Zhang     hipsparseStruct->offdiagGPUMatFormat = format;
373*d52a580bSJunchao Zhang     break;
374*d52a580bSJunchao Zhang   case MAT_HIPSPARSE_ALL:
375*d52a580bSJunchao Zhang     hipsparseStruct->diagGPUMatFormat    = format;
376*d52a580bSJunchao Zhang     hipsparseStruct->offdiagGPUMatFormat = format;
377*d52a580bSJunchao Zhang     break;
378*d52a580bSJunchao Zhang   default:
379*d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "unsupported operation %d for MatHIPSPARSEFormatOperation. Only MAT_HIPSPARSE_MULT_DIAG, MAT_HIPSPARSE_MULT_DIAG, and MAT_HIPSPARSE_MULT_ALL are currently supported.", op);
380*d52a580bSJunchao Zhang   }
381*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
382*d52a580bSJunchao Zhang }
383*d52a580bSJunchao Zhang 
MatSetFromOptions_MPIAIJHIPSPARSE(Mat A,PetscOptionItems PetscOptionsObject)384*d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_MPIAIJHIPSPARSE(Mat A, PetscOptionItems PetscOptionsObject)
385*d52a580bSJunchao Zhang {
386*d52a580bSJunchao Zhang   MatHIPSPARSEStorageFormat format;
387*d52a580bSJunchao Zhang   PetscBool                 flg;
388*d52a580bSJunchao Zhang   Mat_MPIAIJ               *a               = (Mat_MPIAIJ *)A->data;
389*d52a580bSJunchao Zhang   Mat_MPIAIJHIPSPARSE      *hipsparseStruct = (Mat_MPIAIJHIPSPARSE *)a->spptr;
390*d52a580bSJunchao Zhang 
391*d52a580bSJunchao Zhang   PetscFunctionBegin;
392*d52a580bSJunchao Zhang   PetscOptionsHeadBegin(PetscOptionsObject, "MPIAIJHIPSPARSE options");
393*d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
394*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_diag_storage_format", "sets storage format of the diagonal blocks of (mpi)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparseStruct->diagGPUMatFormat, (PetscEnum *)&format, &flg));
395*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT_DIAG, format));
396*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_offdiag_storage_format", "sets storage format of the off-diagonal blocks (mpi)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparseStruct->offdiagGPUMatFormat, (PetscEnum *)&format, &flg));
397*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT_OFFDIAG, format));
398*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_storage_format", "sets storage format of the diagonal and off-diagonal blocks (mpi)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparseStruct->diagGPUMatFormat, (PetscEnum *)&format, &flg));
399*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_ALL, format));
400*d52a580bSJunchao Zhang   }
401*d52a580bSJunchao Zhang   PetscOptionsHeadEnd();
402*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
403*d52a580bSJunchao Zhang }
404*d52a580bSJunchao Zhang 
MatAssemblyEnd_MPIAIJHIPSPARSE(Mat A,MatAssemblyType mode)405*d52a580bSJunchao Zhang static PetscErrorCode MatAssemblyEnd_MPIAIJHIPSPARSE(Mat A, MatAssemblyType mode)
406*d52a580bSJunchao Zhang {
407*d52a580bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
408*d52a580bSJunchao Zhang 
409*d52a580bSJunchao Zhang   PetscFunctionBegin;
410*d52a580bSJunchao Zhang   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
411*d52a580bSJunchao Zhang   if (mpiaij->lvec) PetscCall(VecSetType(mpiaij->lvec, VECSEQHIP));
412*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
413*d52a580bSJunchao Zhang }
414*d52a580bSJunchao Zhang 
MatDestroy_MPIAIJHIPSPARSE(Mat A)415*d52a580bSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJHIPSPARSE(Mat A)
416*d52a580bSJunchao Zhang {
417*d52a580bSJunchao Zhang   Mat_MPIAIJ          *aij             = (Mat_MPIAIJ *)A->data;
418*d52a580bSJunchao Zhang   Mat_MPIAIJHIPSPARSE *hipsparseStruct = (Mat_MPIAIJHIPSPARSE *)aij->spptr;
419*d52a580bSJunchao Zhang 
420*d52a580bSJunchao Zhang   PetscFunctionBegin;
421*d52a580bSJunchao Zhang   PetscCheck(hipsparseStruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
422*d52a580bSJunchao Zhang   PetscCallCXX(delete hipsparseStruct);
423*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
424*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
425*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
426*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
427*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", NULL));
428*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijhipsparse_hypre_C", NULL));
429*d52a580bSJunchao Zhang   PetscCall(MatDestroy_MPIAIJ(A));
430*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
431*d52a580bSJunchao Zhang }
432*d52a580bSJunchao Zhang 
MatConvert_MPIAIJ_MPIAIJHIPSPARSE(Mat B,MatType mtype,MatReuse reuse,Mat * newmat)433*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJHIPSPARSE(Mat B, MatType mtype, MatReuse reuse, Mat *newmat)
434*d52a580bSJunchao Zhang {
435*d52a580bSJunchao Zhang   Mat_MPIAIJ *a;
436*d52a580bSJunchao Zhang   Mat         A;
437*d52a580bSJunchao Zhang 
438*d52a580bSJunchao Zhang   PetscFunctionBegin;
439*d52a580bSJunchao Zhang   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP));
440*d52a580bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) PetscCall(MatDuplicate(B, MAT_COPY_VALUES, newmat));
441*d52a580bSJunchao Zhang   else if (reuse == MAT_REUSE_MATRIX) PetscCall(MatCopy(B, *newmat, SAME_NONZERO_PATTERN));
442*d52a580bSJunchao Zhang   A             = *newmat;
443*d52a580bSJunchao Zhang   A->boundtocpu = PETSC_FALSE;
444*d52a580bSJunchao Zhang   PetscCall(PetscFree(A->defaultvectype));
445*d52a580bSJunchao Zhang   PetscCall(PetscStrallocpy(VECHIP, &A->defaultvectype));
446*d52a580bSJunchao Zhang 
447*d52a580bSJunchao Zhang   a = (Mat_MPIAIJ *)A->data;
448*d52a580bSJunchao Zhang   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJHIPSPARSE));
449*d52a580bSJunchao Zhang   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJHIPSPARSE));
450*d52a580bSJunchao Zhang   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQHIP));
451*d52a580bSJunchao Zhang 
452*d52a580bSJunchao Zhang   if (reuse != MAT_REUSE_MATRIX && !a->spptr) PetscCallCXX(a->spptr = new Mat_MPIAIJHIPSPARSE);
453*d52a580bSJunchao Zhang 
454*d52a580bSJunchao Zhang   A->ops->assemblyend           = MatAssemblyEnd_MPIAIJHIPSPARSE;
455*d52a580bSJunchao Zhang   A->ops->mult                  = MatMult_MPIAIJHIPSPARSE;
456*d52a580bSJunchao Zhang   A->ops->multadd               = MatMultAdd_MPIAIJHIPSPARSE;
457*d52a580bSJunchao Zhang   A->ops->multtranspose         = MatMultTranspose_MPIAIJHIPSPARSE;
458*d52a580bSJunchao Zhang   A->ops->setfromoptions        = MatSetFromOptions_MPIAIJHIPSPARSE;
459*d52a580bSJunchao Zhang   A->ops->destroy               = MatDestroy_MPIAIJHIPSPARSE;
460*d52a580bSJunchao Zhang   A->ops->zeroentries           = MatZeroEntries_MPIAIJHIPSPARSE;
461*d52a580bSJunchao Zhang   A->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJBACKEND;
462*d52a580bSJunchao Zhang   A->ops->getcurrentmemtype     = MatGetCurrentMemType_MPIAIJ;
463*d52a580bSJunchao Zhang 
464*d52a580bSJunchao Zhang   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATMPIAIJHIPSPARSE));
465*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJHIPSPARSE));
466*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJHIPSPARSE));
467*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", MatHIPSPARSESetFormat_MPIAIJHIPSPARSE));
468*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJHIPSPARSE));
469*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJHIPSPARSE));
470*d52a580bSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
471*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijhipsparse_hypre_C", MatConvert_AIJ_HYPRE));
472*d52a580bSJunchao Zhang #endif
473*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
474*d52a580bSJunchao Zhang }
475*d52a580bSJunchao Zhang 
MatCreate_MPIAIJHIPSPARSE(Mat A)476*d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJHIPSPARSE(Mat A)
477*d52a580bSJunchao Zhang {
478*d52a580bSJunchao Zhang   PetscFunctionBegin;
479*d52a580bSJunchao Zhang   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP));
480*d52a580bSJunchao Zhang   PetscCall(MatCreate_MPIAIJ(A));
481*d52a580bSJunchao Zhang   PetscCall(MatConvert_MPIAIJ_MPIAIJHIPSPARSE(A, MATMPIAIJHIPSPARSE, MAT_INPLACE_MATRIX, &A));
482*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
483*d52a580bSJunchao Zhang }
484*d52a580bSJunchao Zhang 
485*d52a580bSJunchao Zhang /*@
486*d52a580bSJunchao Zhang   MatCreateAIJHIPSPARSE - Creates a sparse matrix in AIJ (compressed row) format
487*d52a580bSJunchao Zhang   (the default parallel PETSc format).
488*d52a580bSJunchao Zhang 
489*d52a580bSJunchao Zhang   Collective
490*d52a580bSJunchao Zhang 
491*d52a580bSJunchao Zhang   Input Parameters:
492*d52a580bSJunchao Zhang + comm  - MPI communicator, set to `PETSC_COMM_SELF`
493*d52a580bSJunchao Zhang . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
494*d52a580bSJunchao Zhang            This value should be the same as the local size used in creating the
495*d52a580bSJunchao Zhang            y vector for the matrix-vector product y = Ax.
496*d52a580bSJunchao Zhang . n     - This value should be the same as the local size used in creating the
497*d52a580bSJunchao Zhang        x vector for the matrix-vector product y = Ax. (or PETSC_DECIDE to have
498*d52a580bSJunchao Zhang        calculated if `N` is given) For square matrices `n` is almost always `m`.
499*d52a580bSJunchao Zhang . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
500*d52a580bSJunchao Zhang . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
501*d52a580bSJunchao Zhang . d_nz  - number of nonzeros per row (same for all rows), for the "diagonal" portion of the matrix
502*d52a580bSJunchao Zhang . d_nnz - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`, for the "diagonal" portion of the matrix
503*d52a580bSJunchao Zhang . o_nz  - number of nonzeros per row (same for all rows), for the "off-diagonal" portion of the matrix
504*d52a580bSJunchao Zhang - o_nnz - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`, for the "off-diagonal" portion of the matrix
505*d52a580bSJunchao Zhang 
506*d52a580bSJunchao Zhang   Output Parameter:
507*d52a580bSJunchao Zhang . A - the matrix
508*d52a580bSJunchao Zhang 
509*d52a580bSJunchao Zhang   Level: intermediate
510*d52a580bSJunchao Zhang 
511*d52a580bSJunchao Zhang   Notes:
512*d52a580bSJunchao Zhang   This matrix will ultimately pushed down to AMD GPUs and use the HIPSPARSE library for
513*d52a580bSJunchao Zhang   calculations. For good matrix assembly performance the user should preallocate the matrix
514*d52a580bSJunchao Zhang   storage by setting the parameter `nz` (or the array `nnz`).
515*d52a580bSJunchao Zhang 
516*d52a580bSJunchao Zhang   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
517*d52a580bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
518*d52a580bSJunchao Zhang   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
519*d52a580bSJunchao Zhang 
520*d52a580bSJunchao Zhang   If `d_nnz` (`o_nnz`) is given then `d_nz` (`o_nz`) is ignored
521*d52a580bSJunchao Zhang 
522*d52a580bSJunchao Zhang   The `MATAIJ` format (compressed row storage), is fully compatible with standard Fortran
523*d52a580bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
524*d52a580bSJunchao Zhang   either one (as in Fortran) or zero.
525*d52a580bSJunchao Zhang 
526*d52a580bSJunchao Zhang   Specify the preallocated storage with either `d_nz` (`o_nz`) or `d_nnz` (`o_nnz`) (not both).
527*d52a580bSJunchao Zhang   Set `d_nz` (`o_nz`) = `PETSC_DEFAULT` and `d_nnz` (`o_nnz`) = `NULL` for PETSc to control dynamic memory
528*d52a580bSJunchao Zhang   allocation.
529*d52a580bSJunchao Zhang 
530*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJHIPSPARSE`, `MATAIJHIPSPARSE`
531*d52a580bSJunchao Zhang @*/
MatCreateAIJHIPSPARSE(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat * A)532*d52a580bSJunchao Zhang PetscErrorCode MatCreateAIJHIPSPARSE(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
533*d52a580bSJunchao Zhang {
534*d52a580bSJunchao Zhang   PetscMPIInt size;
535*d52a580bSJunchao Zhang 
536*d52a580bSJunchao Zhang   PetscFunctionBegin;
537*d52a580bSJunchao Zhang   PetscCall(MatCreate(comm, A));
538*d52a580bSJunchao Zhang   PetscCall(MatSetSizes(*A, m, n, M, N));
539*d52a580bSJunchao Zhang   PetscCallMPI(MPI_Comm_size(comm, &size));
540*d52a580bSJunchao Zhang   if (size > 1) {
541*d52a580bSJunchao Zhang     PetscCall(MatSetType(*A, MATMPIAIJHIPSPARSE));
542*d52a580bSJunchao Zhang     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
543*d52a580bSJunchao Zhang   } else {
544*d52a580bSJunchao Zhang     PetscCall(MatSetType(*A, MATSEQAIJHIPSPARSE));
545*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
546*d52a580bSJunchao Zhang   }
547*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
548*d52a580bSJunchao Zhang }
549*d52a580bSJunchao Zhang 
550*d52a580bSJunchao Zhang /*MC
551*d52a580bSJunchao Zhang    MATAIJHIPSPARSE - A matrix type to be used for sparse matrices; it is as same as `MATMPIAIJHIPSPARSE`.
552*d52a580bSJunchao Zhang 
553*d52a580bSJunchao Zhang    A matrix type whose data resides on GPUs. These matrices can be in either
554*d52a580bSJunchao Zhang    CSR, ELL, or Hybrid format. All matrix calculations are performed on AMD GPUs using the HIPSPARSE library.
555*d52a580bSJunchao Zhang 
556*d52a580bSJunchao Zhang    This matrix type is identical to `MATSEQAIJHIPSPARSE` when constructed with a single process communicator,
557*d52a580bSJunchao Zhang    and `MATMPIAIJHIPSPARSE` otherwise.  As a result, for single process communicators,
558*d52a580bSJunchao Zhang    `MatSeqAIJSetPreallocation()` is supported, and similarly `MatMPIAIJSetPreallocation()` is supported
559*d52a580bSJunchao Zhang    for communicators controlling multiple processes.  It is recommended that you call both of
560*d52a580bSJunchao Zhang    the above preallocation routines for simplicity.
561*d52a580bSJunchao Zhang 
562*d52a580bSJunchao Zhang    Options Database Keys:
563*d52a580bSJunchao Zhang +  -mat_type mpiaijhipsparse - sets the matrix type to `MATMPIAIJHIPSPARSE`
564*d52a580bSJunchao Zhang .  -mat_hipsparse_storage_format csr - sets the storage format of diagonal and off-diagonal matrices. Other options include ell (ellpack) or hyb (hybrid).
565*d52a580bSJunchao Zhang .  -mat_hipsparse_mult_diag_storage_format csr - sets the storage format of diagonal matrix. Other options include ell (ellpack) or hyb (hybrid).
566*d52a580bSJunchao Zhang -  -mat_hipsparse_mult_offdiag_storage_format csr - sets the storage format of off-diagonal matrix. Other options include ell (ellpack) or hyb (hybrid).
567*d52a580bSJunchao Zhang 
568*d52a580bSJunchao Zhang   Level: beginner
569*d52a580bSJunchao Zhang 
570*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreateAIJHIPSPARSE()`, `MATSEQAIJHIPSPARSE`, `MATMPIAIJHIPSPARSE`, `MatCreateSeqAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
571*d52a580bSJunchao Zhang M*/
572*d52a580bSJunchao Zhang 
573*d52a580bSJunchao Zhang /*MC
574*d52a580bSJunchao Zhang    MATMPIAIJHIPSPARSE - A matrix type to be used for sparse matrices; it is as same as `MATAIJHIPSPARSE`.
575*d52a580bSJunchao Zhang 
576*d52a580bSJunchao Zhang   Level: beginner
577*d52a580bSJunchao Zhang 
578*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATAIJHIPSPARSE`, `MATSEQAIJHIPSPARSE`
579*d52a580bSJunchao Zhang M*/
580