xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 4e8208cbcbc709572b8abe32f33c78b69c819375)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscmat_kokkos.hpp>
4 #include <petscpkg_version.h>
5 #include <petsc/private/sfimpl.h>
6 #include <petsc/private/kokkosimpl.hpp>
7 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
8 #include <../src/mat/impls/aij/mpi/mpiaij.h>
9 #include <KokkosSparse_spadd.hpp>
10 #include <KokkosSparse_spgemm.hpp>
11 
MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)12 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
13 {
14   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
15 
16   PetscFunctionBegin;
17   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
18   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
19      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
20    */
21   if (mode == MAT_FINAL_ASSEMBLY) {
22     PetscScalarKokkosView v;
23 
24     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
25     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
26     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
27     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
28     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
29   }
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])33 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34 {
35   Mat_MPIAIJ *mpiaij;
36 
37   PetscFunctionBegin;
38   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
39   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
40   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
42   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
43   PetscFunctionReturn(PETSC_SUCCESS);
44 }
45 
MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)46 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
47 {
48   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
49   PetscInt    nt;
50 
51   PetscFunctionBegin;
52   PetscCall(VecGetLocalSize(xx, &nt));
53   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
54   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
55   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
56   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
57   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
58   PetscFunctionReturn(PETSC_SUCCESS);
59 }
60 
MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)61 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
62 {
63   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
64   PetscInt    nt;
65 
66   PetscFunctionBegin;
67   PetscCall(VecGetLocalSize(xx, &nt));
68   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
69   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
70   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
71   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
72   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
73   PetscFunctionReturn(PETSC_SUCCESS);
74 }
75 
MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)76 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
77 {
78   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
79   PetscInt    nt;
80 
81   PetscFunctionBegin;
82   PetscCall(VecGetLocalSize(xx, &nt));
83   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
84   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
85   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
86   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
88   PetscFunctionReturn(PETSC_SUCCESS);
89 }
90 
91 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
92    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
93    C still uses local column ids. Their corresponding global column ids are returned in glob.
94 */
MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS * glob,Mat * C)95 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
96 {
97   Mat             Ad, Ao;
98   const PetscInt *cmap;
99 
100   PetscFunctionBegin;
101   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
102   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
103   if (glob) {
104     PetscInt cst, i, dn, on, *gidx;
105     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
106     PetscCall(MatGetLocalSize(Ao, NULL, &on));
107     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
108     PetscCall(PetscMalloc1(dn + on, &gidx));
109     for (i = 0; i < dn; i++) gidx[i] = cst + i;
110     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
111     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
112   }
113   PetscFunctionReturn(PETSC_SUCCESS);
114 }
115 
116 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
117 struct MatMatStruct {
118   PetscInt            n, *garray;     // C's garray and its size.
119   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
120   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
121   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
122   PetscIntKokkosView  E_NzLeft;
123   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
124   MatScalarKokkosView rootBuf, leafBuf;
125   KokkosCsrMatrix     Fd, Fo; // F in split form
126 
127   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
128   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
129   KernelHandle kh3; // compute C3
130   KernelHandle kh4; // compute C4
131 
132   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
133   PetscInt E_VectorLength;
134   PetscInt E_RowsPerTeam;
135   PetscInt F_TeamSize;
136   PetscInt F_VectorLength;
137   PetscInt F_RowsPerTeam;
138 
~MatMatStructMatMatStruct139   ~MatMatStruct()
140   {
141     PetscFunctionBegin;
142     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
143     PetscFunctionReturnVoid();
144   }
145 };
146 
147 struct MatMatStruct_AB : public MatMatStruct {
148   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
149   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
150   PetscIntKokkosView rowoffset;
151 };
152 
153 struct MatMatStruct_AtB : public MatMatStruct {
154   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
155   MatColIdxKokkosView Fdjperm;
156   MatColIdxKokkosView Fojmap;
157   MatColIdxKokkosView Fojperm;
158 };
159 
160 struct MatProductCtx_MPIAIJKokkos {
161   MatMatStruct_AB  *mmAB     = nullptr;
162   MatMatStruct_AtB *mmAtB    = nullptr;
163   PetscBool         reusesym = PETSC_FALSE;
164   Mat               Z        = nullptr; // store Z=AB in computing BtAB
165 
~MatProductCtx_MPIAIJKokkosMatProductCtx_MPIAIJKokkos166   ~MatProductCtx_MPIAIJKokkos()
167   {
168     delete mmAB;
169     delete mmAtB;
170     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
171   }
172 };
173 
MatProductCtxDestroy_MPIAIJKokkos(PetscCtxRt data)174 static PetscErrorCode MatProductCtxDestroy_MPIAIJKokkos(PetscCtxRt data)
175 {
176   PetscFunctionBegin;
177   PetscCallCXX(delete *reinterpret_cast<MatProductCtx_MPIAIJKokkos **>(data));
178   PetscFunctionReturn(PETSC_SUCCESS);
179 }
180 
181 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
182 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
183 template <class ExecutionSpace>
MatMergeGetLaunchParameters(PetscInt numRows,PetscInt nnz,PetscInt rows_per_thread,PetscInt & team_size,PetscInt & vector_length,PetscInt & rows_per_team)184 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
185 {
186 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
187   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
188 #else
189   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
190 #endif
191   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
192 
193   PetscFunctionBegin;
194   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
195 
196   if (nnz_per_row < 1) nnz_per_row = 1;
197 
198   int max_vector_length = teamPolicy.vector_length_max();
199 
200   if (vector_length < 1) {
201     vector_length = 1;
202     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
203   }
204 
205   // Determine rows per thread
206   if (rows_per_thread < 1) {
207     if (is_gpu_exec_space) rows_per_thread = 1;
208     else {
209       if (nnz_per_row < 20 && nnz > 5000000) {
210         rows_per_thread = 256;
211       } else rows_per_thread = 64;
212     }
213   }
214 
215   if (team_size < 1) {
216     if (is_gpu_exec_space) {
217       team_size = 256 / vector_length;
218     } else {
219       team_size = 1;
220     }
221   }
222 
223   rows_per_team = rows_per_thread * team_size;
224 
225   if (rows_per_team < 0) {
226     PetscInt nnz_per_team = 4096;
227     PetscInt conc         = ExecutionSpace().concurrency();
228     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
229     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
230   }
231   PetscFunctionReturn(PETSC_SUCCESS);
232 }
233 
234 /*
235   Reduce two sets of global indices into local ones
236 
237   Input Parameters:
238 +  n1          - size of garray1[], the first set
239 .  garray1[n1] - a sorted global index array (without duplicates)
240 .  m           - size of indices[], the second set
241 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
242 
243   Output Parameters:
244 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
245 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
246 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
247 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
248 
249    Example, say
250     n1         = 5
251     garray1[5] = {1, 4, 7, 8, 10}
252     m          = 4
253     indices[4] = {2, 4, 8, 9}
254 
255    Combining them together, we have 7 global indices in garray2[]
256     n2         = 7
257     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
258 
259    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
260     map[5] = {0, 2, 3, 4, 6}
261 
262    On output, indices[] is updated with local indices
263     indices[4] = {1, 2, 4, 5}
264 */
ReduceTwoSetsOfGlobalIndices(PetscInt n1,const PetscInt * garray1,PetscInt m,PetscInt * indices,PetscInt * n2_,PetscInt ** garray2_,PetscInt * map)265 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
266 {
267   PetscHMapI    g2l = nullptr;
268   PetscHashIter iter;
269   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
270   PetscInt      n2, *garray2;
271 
272   PetscFunctionBegin;
273   tot = 0;
274   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
275   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
276     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
277     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
278   }
279 
280   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
281     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
282     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
283   }
284 
285   // Pull out (unique) globals in the hash table and put them in garray2[]
286   n2 = tot;
287   PetscCall(PetscMalloc1(n2, &garray2));
288   tot = 0;
289   PetscHashIterBegin(g2l, iter);
290   while (!PetscHashIterAtEnd(g2l, iter)) {
291     PetscHashIterGetKey(g2l, iter, key);
292     PetscHashIterNext(g2l, iter);
293     garray2[tot++] = key;
294   }
295 
296   // Sort garray2[] and then map them to local indices starting from 0
297   PetscCall(PetscSortInt(n2, garray2));
298   PetscCall(PetscHMapIClear(g2l));
299   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
300 
301   // Rewrite indices[] with local indices
302   for (PetscInt i = 0; i < m; i++) {
303     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
304     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
305     indices[i] = val;
306   }
307   // Record the map that maps garray1[i] to garray2[map[i]]
308   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
309   PetscCall(PetscHMapIDestroy(&g2l));
310   *n2_      = n2;
311   *garray2_ = garray2;
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 /*
316   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
317 
318   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
319 
320   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
321   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
322 
323   Input Parameters:
324 +  comm       - MPI communicator of E
325 .  A          - diag block of E, using local column indices
326 .  B          - off-diag block of E, using local column indices
327 .  cstart      - (global) start column of Ed
328 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
329 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
330 .  ownerSF     - the SF specifies ownership (root) of rows in E
331 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
332 -  mm          - to stash intermediate data structures for reuse
333 
334   Output Parameters:
335 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
336 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
337 
338   Notes:
339   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
340 
341  */
MatMPIAIJKokkosReduceBegin(MPI_Comm comm,KokkosCsrMatrix A,KokkosCsrMatrix B,PetscInt cstart,PetscInt cend,const PetscInt * garray1,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AtB * mm)342 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
343 {
344   PetscFunctionBegin;
345   if (reuse == MAT_INITIAL_MATRIX) {
346     PetscInt Em = A.numRows(), Fm;
347     PetscInt n1 = B.numCols();
348 
349     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
350 
351     // Do the analysis on host
352     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
353     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
354     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
355     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
356     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
357     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
358 
359     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
360     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
361     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
362     for (PetscInt i = 0; i < Em; i++) {
363       const PetscInt *first, *last, *it;
364       PetscInt        count, step;
365       // std::lower_bound(first,last,cstart), but need to use global column indices
366       first = Bj + Bi[i];
367       last  = Bj + Bi[i + 1];
368       count = last - first;
369       while (count > 0) {
370         it   = first;
371         step = count / 2;
372         it += step;
373         if (garray1[*it] < cstart) { // map local to global
374           first = ++it;
375           count -= step + 1;
376         } else count = step;
377       }
378       E_NzLeft[i] = first - (Bj + Bi[i]);
379       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
380     }
381 
382     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
383     const PetscMPIInt *iranks, *ranks;
384     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
385     PetscMPIInt        niranks, nranks;
386     MPI_Request       *reqs;
387     PetscMPIInt        tag;
388     PetscSF            reduceSF;
389     PetscInt          *sdisp, *rdisp;
390 
391     PetscCall(PetscCommGetNewTag(comm, &tag));
392     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
393     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
394 
395     // Find out length of each row I will receive. Even for the same row index, when they are from
396     // different senders, they might have different lengths (and sparsity patterns)
397     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
398     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
399 
400     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
401 
402     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
403     recvRowLen[0] = 0; // since we will make it in CSR format later
404     recvRowLen++;      // advance the pointer now
405     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
406     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]));
407     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
408 
409     // Build the real PetscSF for reducing E rows (buffer to buffer)
410     rdisp[0] = 0;
411     for (PetscInt i = 0; i < niranks; i++) {
412       rdisp[i + 1] = rdisp[i];
413       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) rdisp[i + 1] += recvRowLen[j];
414     }
415     recvRowLen--; // put it back into csr format
416     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
417 
418     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
419     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
420     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
421 
422     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
423     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
424     PetscSFNode *iremote;
425 
426     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
427     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
428     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
429 
430     for (PetscInt i = 0; i < nranks; i++) {
431       PetscInt count = 0;
432       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
433       for (PetscInt j = 0; j < count; j++) {
434         iremote[nleaves + j].rank  = ranks[i];
435         iremote[nleaves + j].index = sdisp[i] + j;
436       }
437       nleaves += count;
438     }
439     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
440 
441     PetscCall(PetscSFCreate(comm, &reduceSF));
442     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
443 
444     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
445     PetscInt *sendCol, *recvCol;
446     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
447     for (PetscInt k = 0; k < roffset[nranks]; k++) {
448       PetscInt  i      = rmine[k]; // row to be copied
449       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
450       PetscInt  nzLeft = E_NzLeft[i];
451       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
452       for (PetscInt j = 0; j < alen + blen; j++) {
453         if (j < nzLeft) {
454           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
455         } else if (j < nzLeft + alen) {
456           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
457         } else {
458           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
459         }
460       }
461     }
462     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
463     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
464 
465     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
466     PetscInt *recvRowPerm, *recvColSorted;
467     PetscInt *recvNzPerm, *recvNzPermSorted;
468     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
469 
470     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
471     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
472     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
473 
474     // i[] array, nz are always easiest to compute
475     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
476     MatRowMapType          *Fdi, *Foi;
477     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
478     PetscInt                iter;
479 
480     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
481     Kokkos::deep_copy(Foi_h, 0);
482     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
483     Foi  = Foi_h.data() + 1;
484     iter = 0;
485     while (iter < recvRowCnt) { // iter over received rows
486       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
487       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
488 
489       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
490 
491       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
492       PetscInt  nz    = 0; // nz (with dups) in the current row
493       PetscInt *jbuf  = recvColSorted + FnzDups;
494       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
495       PetscInt *jbuf2 = jbuf; // temp pointers
496       PetscInt *pbuf2 = pbuf;
497       for (PetscInt d = 0; d < dupRows; d++) {
498         PetscInt i   = recvRowPerm[iter + d];
499         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
500         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
501         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
502         jbuf2 += len;
503         pbuf2 += len;
504         nz += len;
505       }
506       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
507 
508       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
509       PetscInt cur = 0;
510       while (cur < nz) {
511         PetscInt curColIdx = jbuf[cur];
512         PetscInt dups      = 1;
513 
514         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
515         if (curColIdx >= cstart && curColIdx < cend) {
516           Fdi[curRowIdx]++;
517           FdnzDups += dups;
518         } else {
519           Foi[curRowIdx]++;
520           FonzDups += dups;
521         }
522         cur += dups;
523       }
524 
525       FnzDups += nz;
526       iter += dupRows; // Move to next unique row
527     }
528 
529     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
530     Foi = Foi_h.data();
531     for (PetscInt i = 0; i < Fm; i++) {
532       Fdi[i + 1] += Fdi[i];
533       Foi[i + 1] += Foi[i];
534     }
535     Fdnz = Fdi[Fm];
536     Fonz = Foi[Fm];
537     PetscCall(PetscFree2(sendCol, recvCol));
538 
539     // Allocate j, jmap, jperm for Fd and Fo
540     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
541     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
542     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
543     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
544     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
545     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
546 
547     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
548     Fdjmap[0] = 0;
549     Fojmap[0] = 0;
550     FnzDups   = 0;
551     Fdnz      = 0;
552     Fonz      = 0;
553     iter      = 0; // iter over received rows
554     while (iter < recvRowCnt) {
555       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
556       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
557       PetscInt nz        = 0;                           // nz (with dups) in the current row
558 
559       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
560       for (PetscInt d = 0; d < dupRows; d++) {
561         PetscInt i = recvRowPerm[iter + d];
562         nz += recvRowLen[i + 1] - recvRowLen[i];
563       }
564 
565       PetscInt *jbuf = recvColSorted + FnzDups;
566       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
567       PetscInt cur = 0;
568       while (cur < nz) {
569         PetscInt curColIdx = jbuf[cur];
570         PetscInt dups      = 1;
571 
572         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
573         if (curColIdx >= cstart && curColIdx < cend) {
574           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
575           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
576           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
577           FdnzDups += dups;
578           Fdnz++;
579         } else {
580           Foj[Fonz]        = curColIdx; // in global
581           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
582           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
583           FonzDups += dups;
584           Fonz++;
585         }
586         cur += dups;
587         FnzDups += dups;
588       }
589       iter += dupRows; // Move to next unique row
590     }
591     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
592     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
593 
594     // Combine global column indices in garray1[] and Foj[]
595     PetscInt n2, *garray2;
596 
597     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
598     mm->sf       = reduceSF;
599     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
600     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
601     mm->garray   = garray2; // give ownership, so no free
602     mm->n        = n2;
603     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
604     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
605     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
606     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
607     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
608 
609     // Output Fd and Fo in KokkosCsrMatrix format
610     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
611     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
612     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
613     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
614     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
615     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
616 
617     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
618     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
619 
620     // Compute kernel launch parameters in merging E
621     PetscInt teamSize, vectorLength, rowsPerTeam;
622 
623     teamSize = vectorLength = rowsPerTeam = -1;
624     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
625     mm->E_TeamSize     = teamSize;
626     mm->E_VectorLength = vectorLength;
627     mm->E_RowsPerTeam  = rowsPerTeam;
628   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
629 
630   // Handy aliases
631   auto       &Aa           = A.values;
632   auto       &Ba           = B.values;
633   const auto &Ai           = A.graph.row_map;
634   const auto &Bi           = B.graph.row_map;
635   const auto &E_NzLeft     = mm->E_NzLeft;
636   auto       &leafBuf      = mm->leafBuf;
637   auto       &rootBuf      = mm->rootBuf;
638   PetscSF     reduceSF     = mm->sf;
639   PetscInt    Em           = A.numRows();
640   PetscInt    teamSize     = mm->E_TeamSize;
641   PetscInt    vectorLength = mm->E_VectorLength;
642   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
643   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
644 
645   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
646   PetscCallCXX(Kokkos::parallel_for(
647     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
648       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
649         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
650         if (i < Em) {
651           PetscInt disp   = Ai(i) + Bi(i);
652           PetscInt alen   = Ai(i + 1) - Ai(i);
653           PetscInt blen   = Bi(i + 1) - Bi(i);
654           PetscInt nzleft = E_NzLeft(i);
655 
656           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
657             MatScalar &val = leafBuf(disp + j);
658             if (j < nzleft) { // B left
659               val = Ba(Bi(i) + j);
660             } else if (j < nzleft + alen) { // diag A
661               val = Aa(Ai(i) + j - nzleft);
662             } else { // B right
663               val = Ba(Bi(i) + j - alen);
664             }
665           });
666         }
667       });
668     }));
669   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
670   PetscFunctionReturn(PETSC_SUCCESS);
671 }
672 
673 // To finish MatMPIAIJKokkosReduce.
MatMPIAIJKokkosReduceEnd(MPI_Comm comm,KokkosCsrMatrix A,KokkosCsrMatrix B,PetscInt cstart,PetscInt cend,const PetscInt * garray1,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AtB * mm)674 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
675 {
676   auto       &leafBuf  = mm->leafBuf;
677   auto       &rootBuf  = mm->rootBuf;
678   auto       &Fda      = mm->Fd.values;
679   const auto &Fdjmap   = mm->Fdjmap;
680   const auto &Fdjperm  = mm->Fdjperm;
681   auto        Fdnz     = mm->Fd.nnz();
682   auto       &Foa      = mm->Fo.values;
683   const auto &Fojmap   = mm->Fojmap;
684   const auto &Fojperm  = mm->Fojperm;
685   auto        Fonz     = mm->Fo.nnz();
686   PetscSF     reduceSF = mm->sf;
687 
688   PetscFunctionBegin;
689   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
690 
691   // Reduce data in rootBuf to Fd and Fo
692   PetscCallCXX(Kokkos::parallel_for(
693     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
694       PetscScalar sum = 0.0;
695       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
696       Fda(i) = sum;
697     }));
698 
699   PetscCallCXX(Kokkos::parallel_for(
700     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
701       PetscScalar sum = 0.0;
702       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
703       Foa(i) = sum;
704     }));
705   PetscFunctionReturn(PETSC_SUCCESS);
706 }
707 
708 /*
709   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
710 
711   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
712   device and involves various index mapping.
713 
714   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
715   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
716   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
717   F has the same column layout as E.
718 
719   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
720   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
721   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
722   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
723   column indices in Fo and update Fo with local indices.
724 
725    Input Parameters:
726 +   E       - the MPIAIJKOKKOS matrix
727 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
728 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
729 -   mm      - to stash matproduct intermediate data structures
730 
731     Output Parameters:
732 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
733 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
734 
735     Notes:
736     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
737     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
738 */
MatMPIAIJKokkosBcastBegin(Mat E,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AB * mm)739 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
740 {
741   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
742   Mat               A = empi->A, B = empi->B; // diag and off-diag
743   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
744   PetscInt          Em = E->rmap->n; // #local rows
745   MPI_Comm          comm;
746 
747   PetscFunctionBegin;
748   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
749   if (reuse == MAT_INITIAL_MATRIX) {
750     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
751     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
752     const PetscInt *garray1 = empi->garray; // its size is n1
753     PetscInt        cstart, cend;
754     PetscSF         bcastSF;
755 
756     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
757 
758     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
759     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
760     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
761     for (PetscInt i = 0; i < Em; i++) {
762       const PetscInt *first, *last, *it;
763       PetscInt        count, step;
764       // std::lower_bound(first,last,cstart), but need to use global column indices
765       first = Bj + Bi[i];
766       last  = Bj + Bi[i + 1];
767       count = last - first;
768       while (count > 0) {
769         it   = first;
770         step = count / 2;
771         it += step;
772         if (empi->garray[*it] < cstart) { // map local to global
773           first = ++it;
774           count -= step + 1;
775         } else count = step;
776       }
777       E_NzLeft[i] = first - (Bj + Bi[i]);
778       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
779     }
780 
781     // Compute row pointer Fi of F
782     PetscInt *Fi, Fm, Fnz;
783     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
784     PetscCall(PetscMalloc1(Fm + 1, &Fi));
785     Fi[0] = 0;
786     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
787     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
788     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
789     Fnz = Fi[Fm];
790 
791     // Build the real PetscSF for bcasting E rows (buffer to buffer)
792     const PetscMPIInt *iranks, *ranks;
793     const PetscInt    *ioffset, *irootloc, *roffset;
794     PetscMPIInt        niranks, nranks;
795     PetscInt          *sdisp, *rdisp;
796     MPI_Request       *reqs;
797     PetscMPIInt        tag;
798 
799     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
800     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
801     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
802 
803     sdisp[0] = 0; // send displacement
804     for (PetscInt i = 0; i < niranks; i++) {
805       sdisp[i + 1] = sdisp[i];
806       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
807         PetscInt r = irootloc[j]; // row to be sent
808         sdisp[i + 1] += E_RowLen[r];
809       }
810     }
811 
812     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
813     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
814     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
815     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
816 
817     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
818     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
819     PetscSFNode *iremote;                  // give ownership to bcastSF
820     PetscCall(PetscMalloc1(nleaves, &iremote));
821     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
822       PetscInt k = 0;
823       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
824         iremote[j].rank  = ranks[i];
825         iremote[j].index = rdisp[i] + k; // their root location
826         k++;
827       }
828     }
829     PetscCall(PetscSFCreate(comm, &bcastSF));
830     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
831     PetscCall(PetscFree3(sdisp, rdisp, reqs));
832 
833     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
834     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
835     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
836     rowoffset[0]                     = 0;
837     for (PetscInt i = 0; i < ioffset[niranks]; i++) rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]];
838 
839     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
840     PetscInt *jbuf, *Fj;
841     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
842     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
843       PetscInt  i      = irootloc[k]; // row to be copied
844       PetscInt *buf    = &jbuf[rowoffset[k]];
845       PetscInt  nzLeft = E_NzLeft[i];
846       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
847       for (PetscInt j = 0; j < alen + blen; j++) {
848         if (j < nzLeft) {
849           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
850         } else if (j < nzLeft + alen) {
851           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
852         } else {
853           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
854         }
855       }
856     }
857     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
858     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
859 
860     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
861     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
862     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
863     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
864     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
865 
866     Fdi[0] = Foi[0] = 0;
867     for (PetscInt i = 0; i < Fm; i++) {
868       PetscInt *first, *last, *lb1, *lb2;
869       // cut the row into: Left, [cstart, cend), Right
870       first       = Fj + Fi[i];
871       last        = Fj + Fi[i + 1];
872       lb1         = std::lower_bound(first, last, cstart);
873       F_NzLeft[i] = lb1 - first;
874       lb2         = std::lower_bound(first, last, cend);
875       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
876       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
877     }
878     for (PetscInt i = 0; i < Fm; i++) {
879       Fdi[i + 1] += Fdi[i];
880       Foi[i + 1] += Foi[i];
881     }
882 
883     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
884     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
885     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
886     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
887 
888     for (PetscInt i = 0; i < Fm; i++) {
889       PetscInt nzLeft = F_NzLeft[i];
890       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
891       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
892         gid = Fj[Fi[i] + j];
893         if (j < nzLeft) { // left, in global
894           Foj[Foi[i] + j] = gid;
895         } else if (j < nzLeft + len) { // diag, in local
896           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
897         } else { // right, in global
898           Foj[Foi[i] + j - len] = gid;
899         }
900       }
901     }
902     PetscCall(PetscFree2(jbuf, Fj));
903     PetscCall(PetscFree(Fi));
904 
905     // Reduce global indices in Foj[] and garray1[] into local ones
906     PetscInt n2, *garray2;
907     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
908 
909     // Record the plans built above, for reuse
910     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
911     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
912     Kokkos::deep_copy(irootloc_h, tmp);
913     mm->sf        = bcastSF;
914     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
915     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
916     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
917     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
918     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
919     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
920     mm->garray    = garray2;
921     mm->n         = n2;
922 
923     // Output Fd and Fo in KokkosCsrMatrix format
924     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
925     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
926     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
927     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
928     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
929 
930     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
931     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
932 
933     // Compute kernel launch parameters in merging E or splitting F
934     PetscInt teamSize, vectorLength, rowsPerTeam;
935 
936     teamSize = vectorLength = rowsPerTeam = -1;
937     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
938     mm->E_TeamSize     = teamSize;
939     mm->E_VectorLength = vectorLength;
940     mm->E_RowsPerTeam  = rowsPerTeam;
941 
942     teamSize = vectorLength = rowsPerTeam = -1;
943     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
944     mm->F_TeamSize     = teamSize;
945     mm->F_VectorLength = vectorLength;
946     mm->F_RowsPerTeam  = rowsPerTeam;
947   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
948 
949   // Sync E's value to device
950   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace()));
951   PetscCall(KokkosDualViewSyncDevice(bkok->a_dual, PetscGetKokkosExecutionSpace()));
952 
953   // Handy aliases
954   const auto &Aa = akok->a_dual.view_device();
955   const auto &Ba = bkok->a_dual.view_device();
956   const auto &Ai = akok->i_dual.view_device();
957   const auto &Bi = bkok->i_dual.view_device();
958 
959   // Fetch the plans
960   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
961   PetscSF             &bcastSF   = mm->sf;
962   MatScalarKokkosView &rootBuf   = mm->rootBuf;
963   MatScalarKokkosView &leafBuf   = mm->leafBuf;
964   PetscIntKokkosView  &irootloc  = mm->irootloc;
965   PetscIntKokkosView  &rowoffset = mm->rowoffset;
966 
967   PetscInt teamSize     = mm->E_TeamSize;
968   PetscInt vectorLength = mm->E_VectorLength;
969   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
970   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
971 
972   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
973   PetscCallCXX(Kokkos::parallel_for(
974     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
975       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
976         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
977         if (r < irootloc.extent(0)) {
978           PetscInt i      = irootloc(r); // row i of E
979           PetscInt disp   = rowoffset(r);
980           PetscInt alen   = Ai(i + 1) - Ai(i);
981           PetscInt blen   = Bi(i + 1) - Bi(i);
982           PetscInt nzleft = E_NzLeft(i);
983 
984           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
985             if (j < nzleft) { // B left
986               rootBuf(disp + j) = Ba(Bi(i) + j);
987             } else if (j < nzleft + alen) { // diag A
988               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
989             } else { // B right
990               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
991             }
992           });
993         }
994       });
995     }));
996   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
997   PetscFunctionReturn(PETSC_SUCCESS);
998 }
999 
1000 // To finish MatMPIAIJKokkosBcast.
MatMPIAIJKokkosBcastEnd(Mat E,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AB * mm)1001 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1002 {
1003   PetscFunctionBegin;
1004   const auto &Fd  = mm->Fd;
1005   const auto &Fo  = mm->Fo;
1006   const auto &Fdi = Fd.graph.row_map;
1007   const auto &Foi = Fo.graph.row_map;
1008   auto       &Fda = Fd.values;
1009   auto       &Foa = Fo.values;
1010   auto        Fm  = Fd.numRows();
1011 
1012   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1013   PetscSF             &bcastSF      = mm->sf;
1014   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1015   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1016   PetscInt             teamSize     = mm->F_TeamSize;
1017   PetscInt             vectorLength = mm->F_VectorLength;
1018   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1019   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1020 
1021   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1022 
1023   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1024   PetscCallCXX(Kokkos::parallel_for(
1025     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1026       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1027         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1028         if (i < Fm) {
1029           PetscInt nzLeft = F_NzLeft(i);
1030           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1031           PetscInt blen   = Foi(i + 1) - Foi(i);
1032           PetscInt Fii    = Fdi(i) + Foi(i);
1033 
1034           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1035             PetscScalar val = leafBuf(Fii + j);
1036             if (j < nzLeft) { // left
1037               Foa(Foi(i) + j) = val;
1038             } else if (j < nzLeft + alen) { // diag
1039               Fda(Fdi(i) + j - nzLeft) = val;
1040             } else { // right
1041               Foa(Foi(i) + j - alen) = val;
1042             }
1043           });
1044         }
1045       });
1046     }));
1047   PetscFunctionReturn(PETSC_SUCCESS);
1048 }
1049 
MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AtB * mm)1050 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1051 {
1052   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1053   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1054   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1055   PetscInt        cstart, cend;
1056   MPI_Comm        comm;
1057 
1058   PetscFunctionBegin;
1059   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1060   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1061   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1062   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1063   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1064   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1065   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1066 
1067   // TODO: add command line options to select spgemm algorithms
1068   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1069 
1070   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1071 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1072   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1073   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1074   #endif
1075 #endif
1076 
1077   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1078   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1079   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1080   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1081 
1082   // Aot * (B's diag + B's off-diag)
1083   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1084   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1085   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1086   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1087   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1088   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1089 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1090 
1091   PetscCallCXX(sort_crs_matrix(mm->C3));
1092   PetscCallCXX(sort_crs_matrix(mm->C4));
1093 #endif
1094 
1095   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1096   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1097   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1098   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1099 
1100   // Adt * (B's diag + B's off-diag)
1101   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1102   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1103   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1104   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1105 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1106   PetscCallCXX(sort_crs_matrix(mm->C1));
1107   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1108 #endif
1109 
1110   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1111 
1112   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1113   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1114   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1115   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1116   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1117 
1118   // C = (C1+Fd, C2+Fo)
1119   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1120   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1121   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1122   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1123   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1124   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1125   PetscFunctionReturn(PETSC_SUCCESS);
1126 }
1127 
MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AtB * mm)1128 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1129 {
1130   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1131   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1132   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1133   MPI_Comm        comm;
1134 
1135   PetscFunctionBegin;
1136   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1137   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1138   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1139   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1140   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1141 
1142   // Aot * (B's diag + B's off-diag)
1143   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1144   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1145 
1146   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1147   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1148 
1149   // Adt * (B's diag + B's off-diag)
1150   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1151   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1152 
1153   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1154 
1155   // C = (C1+Fd, C2+Fo)
1156   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1157   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1158   PetscFunctionReturn(PETSC_SUCCESS);
1159 }
1160 
1161 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1162 
1163   Input Parameters:
1164 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1165 .  A        - an MPIAIJKOKKOS matrix
1166 .  B        - an MPIAIJKOKKOS matrix
1167 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1168 */
MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AB * mm)1169 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1170 {
1171   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1172   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1173   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1174 
1175   PetscFunctionBegin;
1176   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1177   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1178   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1179   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1180 
1181   // TODO: add command line options to select spgemm algorithms
1182   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1183 
1184   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1185 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1186   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1187   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1188   #endif
1189 #endif
1190 
1191   mm->kh1.create_spgemm_handle(spgemm_alg);
1192   mm->kh2.create_spgemm_handle(spgemm_alg);
1193   mm->kh3.create_spgemm_handle(spgemm_alg);
1194   mm->kh4.create_spgemm_handle(spgemm_alg);
1195 
1196   // Bcast B's rows to form F, and overlap the communication
1197   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1198   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1199 
1200   // A's diag * (B's diag + B's off-diag)
1201   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1202   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1203   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1204   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1205   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1206   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1207 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1208   PetscCallCXX(sort_crs_matrix(mm->C1));
1209   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1210 #endif
1211 
1212   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1213 
1214   // A's off-diag * (F's diag + F's off-diag)
1215   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1216   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1217   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1218   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1219 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1220   PetscCallCXX(sort_crs_matrix(mm->C3));
1221   PetscCallCXX(sort_crs_matrix(mm->C4));
1222 #endif
1223 
1224   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1225   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1226   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1227   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1228   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1229 
1230   // C = (Cd, Co) = (C1+C3, C2+C4)
1231   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1232   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1233   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1234   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1235   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1236   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1237   PetscFunctionReturn(PETSC_SUCCESS);
1238 }
1239 
MatProductNumeric_MPIAIJKokkos_AB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AB * mm)1240 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241 {
1242   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245 
1246   PetscFunctionBegin;
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251 
1252   // Bcast B's rows to form F, and overlap the communication
1253   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1254 
1255   // A's diag * (B's diag + B's off-diag)
1256   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1257   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1258 
1259   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1260 
1261   // A's off-diag * (F's diag + F's off-diag)
1262   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1263   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1264 
1265   // C = (Cd, Co) = (C1+C3, C2+C4)
1266   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1267   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1268   PetscFunctionReturn(PETSC_SUCCESS);
1269 }
1270 
MatProductNumeric_MPIAIJKokkos(Mat C)1271 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1272 {
1273   Mat_MPIAIJ                 *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1274   Mat_Product                *product;
1275   MatProductCtx_MPIAIJKokkos *pdata;
1276   MatProductType              ptype;
1277   Mat                         A, B;
1278 
1279   PetscFunctionBegin;
1280   MatCheckProduct(C, 1); // make sure C is a product
1281   product = C->product;
1282   pdata   = static_cast<MatProductCtx_MPIAIJKokkos *>(product->data);
1283   ptype   = product->type;
1284   A       = product->A;
1285   B       = product->B;
1286 
1287   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1288   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1289   // we still do numeric.
1290   if (pdata->reusesym) { // numeric reuses results from symbolic
1291     pdata->reusesym = PETSC_FALSE;
1292     PetscFunctionReturn(PETSC_SUCCESS);
1293   }
1294 
1295   if (ptype == MATPRODUCT_AB) {
1296     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1297   } else if (ptype == MATPRODUCT_AtB) {
1298     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1299   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1300     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1301     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1302   }
1303   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1304   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1305   PetscFunctionReturn(PETSC_SUCCESS);
1306 }
1307 
MatProductSymbolic_MPIAIJKokkos(Mat C)1308 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1309 {
1310   Mat                         A, B;
1311   Mat_Product                *product;
1312   MatProductType              ptype;
1313   MatProductCtx_MPIAIJKokkos *pdata;
1314   MatMatStruct               *mm = NULL;
1315   PetscInt                    m, n, M, N;
1316   Mat                         Cd, Co;
1317   MPI_Comm                    comm;
1318   Mat_MPIAIJ                 *mpiaij;
1319 
1320   PetscFunctionBegin;
1321   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1322   MatCheckProduct(C, 1);
1323   product = C->product;
1324   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1325   ptype = product->type;
1326   A     = product->A;
1327   B     = product->B;
1328 
1329   switch (ptype) {
1330   case MATPRODUCT_AB:
1331     m = A->rmap->n;
1332     n = B->cmap->n;
1333     M = A->rmap->N;
1334     N = B->cmap->N;
1335     break;
1336   case MATPRODUCT_AtB:
1337     m = A->cmap->n;
1338     n = B->cmap->n;
1339     M = A->cmap->N;
1340     N = B->cmap->N;
1341     break;
1342   case MATPRODUCT_PtAP:
1343     m = B->cmap->n;
1344     n = B->cmap->n;
1345     M = B->cmap->N;
1346     N = B->cmap->N;
1347     break; /* BtAB */
1348   default:
1349     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1350   }
1351 
1352   PetscCall(MatSetSizes(C, m, n, M, N));
1353   PetscCall(PetscLayoutSetUp(C->rmap));
1354   PetscCall(PetscLayoutSetUp(C->cmap));
1355   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1356 
1357   pdata           = new MatProductCtx_MPIAIJKokkos();
1358   pdata->reusesym = product->api_user;
1359 
1360   if (ptype == MATPRODUCT_AB) {
1361     auto mmAB = new MatMatStruct_AB();
1362     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1363     mm = pdata->mmAB = mmAB;
1364   } else if (ptype == MATPRODUCT_AtB) {
1365     auto mmAtB = new MatMatStruct_AtB();
1366     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1367     mm = pdata->mmAtB = mmAtB;
1368   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1369     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1370 
1371     auto mmAB = new MatMatStruct_AB();
1372     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1373     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1374     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1375     pdata->mmAB = mmAB;
1376 
1377     m = A->rmap->n; // Z's layout
1378     n = B->cmap->n;
1379     M = A->rmap->N;
1380     N = B->cmap->N;
1381     PetscCall(MatCreateMPIAIJWithSeqAIJ(comm, M, N, Zd, Zo, mmAB->garray, &Z));
1382 
1383     auto mmAtB = new MatMatStruct_AtB();
1384     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1385 
1386     pdata->Z = Z; // give ownership to pdata
1387     mm = pdata->mmAtB = mmAtB;
1388   }
1389 
1390   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1391   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1392 
1393   mpiaij         = (Mat_MPIAIJ *)C->data;
1394   mpiaij->A      = Cd;
1395   mpiaij->B      = Co;
1396   mpiaij->garray = mm->garray;
1397 
1398   C->preallocated     = PETSC_TRUE;
1399   C->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
1400 
1401   PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
1402   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
1403   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
1404   PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
1405   PetscCall(MatSetOption(C, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
1406 
1407   /* set block sizes */
1408   switch (ptype) {
1409   case MATPRODUCT_PtAP:
1410     if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
1411     break;
1412   case MATPRODUCT_RARt:
1413     if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
1414     break;
1415   case MATPRODUCT_ABC:
1416     PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
1417     break;
1418   case MATPRODUCT_AB:
1419     PetscCall(MatSetBlockSizesFromMats(C, A, B));
1420     break;
1421   case MATPRODUCT_AtB:
1422     if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
1423     break;
1424   case MATPRODUCT_ABt:
1425     if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
1426     break;
1427   default:
1428     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
1429   }
1430   C->product->data       = pdata;
1431   C->product->destroy    = MatProductCtxDestroy_MPIAIJKokkos;
1432   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1433   PetscFunctionReturn(PETSC_SUCCESS);
1434 }
1435 
MatProductSetFromOptions_MPIAIJKokkos(Mat mat)1436 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1437 {
1438   Mat_Product *product = mat->product;
1439   PetscBool    match   = PETSC_FALSE;
1440   PetscBool    usecpu  = PETSC_FALSE;
1441 
1442   PetscFunctionBegin;
1443   MatCheckProduct(mat, 1);
1444   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1445   if (match) { /* we can always fallback to the CPU if requested */
1446     switch (product->type) {
1447     case MATPRODUCT_AB:
1448       if (product->api_user) {
1449         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1450         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1451         PetscOptionsEnd();
1452       } else {
1453         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1454         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1455         PetscOptionsEnd();
1456       }
1457       break;
1458     case MATPRODUCT_AtB:
1459       if (product->api_user) {
1460         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1461         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1462         PetscOptionsEnd();
1463       } else {
1464         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1465         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1466         PetscOptionsEnd();
1467       }
1468       break;
1469     case MATPRODUCT_PtAP:
1470       if (product->api_user) {
1471         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1472         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1473         PetscOptionsEnd();
1474       } else {
1475         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1476         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1477         PetscOptionsEnd();
1478       }
1479       break;
1480     default:
1481       break;
1482     }
1483     match = (PetscBool)!usecpu;
1484   }
1485   if (match) {
1486     switch (product->type) {
1487     case MATPRODUCT_AB:
1488     case MATPRODUCT_AtB:
1489     case MATPRODUCT_PtAP:
1490       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1491       break;
1492     default:
1493       break;
1494     }
1495   }
1496   /* fallback to MPIAIJ ops */
1497   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1498   PetscFunctionReturn(PETSC_SUCCESS);
1499 }
1500 
1501 // Mirror of MatCOOStruct_MPIAIJ on device
1502 struct MatCOOStruct_MPIAIJKokkos {
1503   PetscCount           n;
1504   PetscSF              sf;
1505   PetscCount           Annz, Bnnz;
1506   PetscCount           Annz2, Bnnz2;
1507   PetscCountKokkosView Ajmap1, Aperm1;
1508   PetscCountKokkosView Bjmap1, Bperm1;
1509   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1510   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1511   PetscCountKokkosView Cperm1;
1512   MatScalarKokkosView  sendbuf, recvbuf;
1513 
MatCOOStruct_MPIAIJKokkosMatCOOStruct_MPIAIJKokkos1514   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1515   {
1516     auto exec = PetscGetKokkosExecutionSpace();
1517 
1518     n       = coo_h->n;
1519     sf      = coo_h->sf;
1520     Annz    = coo_h->Annz;
1521     Bnnz    = coo_h->Bnnz;
1522     Annz2   = coo_h->Annz2;
1523     Bnnz2   = coo_h->Bnnz2;
1524     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1525     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1526     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1527     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1528     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1529     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1530     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1531     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1532     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1533     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1534     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1535     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1536     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1537     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1538   }
1539 
~MatCOOStruct_MPIAIJKokkosMatCOOStruct_MPIAIJKokkos1540   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1541 };
1542 
MatCOOStructDestroy_MPIAIJKokkos(PetscCtxRt data)1543 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(PetscCtxRt data)
1544 {
1545   PetscFunctionBegin;
1546   PetscCallCXX(delete *static_cast<MatCOOStruct_MPIAIJKokkos **>(data));
1547   PetscFunctionReturn(PETSC_SUCCESS);
1548 }
1549 
MatSetPreallocationCOO_MPIAIJKokkos(Mat mat,PetscCount coo_n,PetscInt coo_i[],PetscInt coo_j[])1550 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1551 {
1552   PetscContainer             container_h, container_d;
1553   MatCOOStruct_MPIAIJ       *coo_h;
1554   MatCOOStruct_MPIAIJKokkos *coo_d;
1555 
1556   PetscFunctionBegin;
1557   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1558   mat->preallocated = PETSC_TRUE;
1559   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1560   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1561   PetscCall(MatZeroEntries(mat));
1562 
1563   // Copy the COO struct to device
1564   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1565   PetscCall(PetscContainerGetPointer(container_h, &coo_h));
1566   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1567 
1568   // Put the COO struct in a container and then attach that to the matrix
1569   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1570   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1571   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1572   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1573   PetscCall(PetscContainerDestroy(&container_d));
1574   PetscFunctionReturn(PETSC_SUCCESS);
1575 }
1576 
MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)1577 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1578 {
1579   Mat_MPIAIJ                   *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1580   Mat                           A = mpiaij->A, B = mpiaij->B;
1581   MatScalarKokkosView           Aa, Ba;
1582   MatScalarKokkosView           v1;
1583   PetscMemType                  memtype;
1584   PetscContainer                container;
1585   MatCOOStruct_MPIAIJKokkos    *coo;
1586   Kokkos::DefaultExecutionSpace exec = PetscGetKokkosExecutionSpace();
1587 
1588   PetscFunctionBegin;
1589   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1590   PetscCall(PetscContainerGetPointer(container, &coo));
1591 
1592   const auto &n      = coo->n;
1593   const auto &Annz   = coo->Annz;
1594   const auto &Annz2  = coo->Annz2;
1595   const auto &Bnnz   = coo->Bnnz;
1596   const auto &Bnnz2  = coo->Bnnz2;
1597   const auto &vsend  = coo->sendbuf;
1598   const auto &v2     = coo->recvbuf;
1599   const auto &Ajmap1 = coo->Ajmap1;
1600   const auto &Ajmap2 = coo->Ajmap2;
1601   const auto &Aimap2 = coo->Aimap2;
1602   const auto &Bjmap1 = coo->Bjmap1;
1603   const auto &Bjmap2 = coo->Bjmap2;
1604   const auto &Bimap2 = coo->Bimap2;
1605   const auto &Aperm1 = coo->Aperm1;
1606   const auto &Aperm2 = coo->Aperm2;
1607   const auto &Bperm1 = coo->Bperm1;
1608   const auto &Bperm2 = coo->Bperm2;
1609   const auto &Cperm1 = coo->Cperm1;
1610 
1611   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1612   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1613     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1614   } else {
1615     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1616   }
1617 
1618   if (imode == INSERT_VALUES) {
1619     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1620     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1621   } else {
1622     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1623     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1624   }
1625 
1626   PetscCall(PetscLogGpuTimeBegin());
1627   /* Pack entries to be sent to remote */
1628   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1629 
1630   /* Send remote entries to their owner and overlap the communication with local computation */
1631   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1632   /* Add local entries to A and B in one kernel */
1633   Kokkos::parallel_for(
1634     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1635       PetscScalar sum = 0.0;
1636       if (i < Annz) {
1637         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1638         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1639       } else {
1640         i -= Annz;
1641         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1642         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1643       }
1644     });
1645   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1646 
1647   /* Add received remote entries to A and B in one kernel */
1648   Kokkos::parallel_for(
1649     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1650       if (i < Annz2) {
1651         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1652       } else {
1653         i -= Annz2;
1654         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1655       }
1656     });
1657   PetscCall(PetscLogGpuTimeEnd());
1658 
1659   if (imode == INSERT_VALUES) {
1660     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1661     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1662   } else {
1663     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1664     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1665   }
1666   PetscFunctionReturn(PETSC_SUCCESS);
1667 }
1668 
MatDestroy_MPIAIJKokkos(Mat A)1669 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1670 {
1671   PetscFunctionBegin;
1672   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1673   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1674   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1675   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1676 #if defined(PETSC_HAVE_HYPRE)
1677   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
1678 #endif
1679   PetscCall(MatDestroy_MPIAIJ(A));
1680   PetscFunctionReturn(PETSC_SUCCESS);
1681 }
1682 
MatShift_MPIAIJKokkos(Mat A,PetscScalar a)1683 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1684 {
1685   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1686   PetscBool   congruent;
1687 
1688   PetscFunctionBegin;
1689   PetscCall(MatHasCongruentLayouts(A, &congruent));
1690   if (congruent) { // square matrix and the diagonals are solely in the diag block
1691     PetscCall(MatShift(mpiaij->A, a));
1692   } else { // too hard, use the general version
1693     PetscCall(MatShift_Basic(A, a));
1694   }
1695   PetscFunctionReturn(PETSC_SUCCESS);
1696 }
1697 
MatSetOps_MPIAIJKokkos(Mat B)1698 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1699 {
1700   PetscFunctionBegin;
1701   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1702   B->ops->mult                  = MatMult_MPIAIJKokkos;
1703   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1704   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1705   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1706   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1707   B->ops->shift                 = MatShift_MPIAIJKokkos;
1708   B->ops->getcurrentmemtype     = MatGetCurrentMemType_MPIAIJ;
1709 
1710   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1711   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1712   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1713   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1714 #if defined(PETSC_HAVE_HYPRE)
1715   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1716 #endif
1717   PetscFunctionReturn(PETSC_SUCCESS);
1718 }
1719 
MatConvert_MPIAIJ_MPIAIJKokkos(Mat A,MatType mtype,MatReuse reuse,Mat * newmat)1720 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1721 {
1722   Mat         B;
1723   Mat_MPIAIJ *a;
1724 
1725   PetscFunctionBegin;
1726   if (reuse == MAT_INITIAL_MATRIX) {
1727     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1728   } else if (reuse == MAT_REUSE_MATRIX) {
1729     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1730   }
1731   B = *newmat;
1732 
1733   B->boundtocpu = PETSC_FALSE;
1734   PetscCall(PetscFree(B->defaultvectype));
1735   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1736   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1737 
1738   a = static_cast<Mat_MPIAIJ *>(A->data);
1739   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1740   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1741   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1742   PetscCall(MatSetOps_MPIAIJKokkos(B));
1743   PetscFunctionReturn(PETSC_SUCCESS);
1744 }
1745 
1746 /*MC
1747    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos.
1748 
1749    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1750 
1751    Options Database Key:
1752 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1753 
1754   Level: beginner
1755 
1756 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1757 M*/
MatCreate_MPIAIJKokkos(Mat A)1758 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1759 {
1760   PetscFunctionBegin;
1761   PetscCall(PetscKokkosInitializeCheck());
1762   PetscCall(MatCreate_MPIAIJ(A));
1763   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1764   PetscFunctionReturn(PETSC_SUCCESS);
1765 }
1766 
1767 /*@C
1768   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
1769   (the default parallel PETSc format).  This matrix will ultimately pushed down
1770   to Kokkos for calculations.
1771 
1772   Collective
1773 
1774   Input Parameters:
1775 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1776 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1777            This value should be the same as the local size used in creating the
1778            y vector for the matrix-vector product y = Ax.
1779 . n     - This value should be the same as the local size used in creating the
1780        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1781        calculated if N is given) For square matrices n is almost always `m`.
1782 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1783 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1784 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1785            (same value is used for all local rows)
1786 . d_nnz - array containing the number of nonzeros in the various rows of the
1787            DIAGONAL portion of the local submatrix (possibly different for each row)
1788            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1789            The size of this array is equal to the number of local rows, i.e `m`.
1790            For matrices you plan to factor you must leave room for the diagonal entry and
1791            put in the entry even if it is zero.
1792 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1793            submatrix (same value is used for all local rows).
1794 - o_nnz - array containing the number of nonzeros in the various rows of the
1795            OFF-DIAGONAL portion of the local submatrix (possibly different for
1796            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1797            structure. The size of this array is equal to the number
1798            of local rows, i.e `m`.
1799 
1800   Output Parameter:
1801 . A - the matrix
1802 
1803   Level: intermediate
1804 
1805   Notes:
1806   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1807   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1808   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1809 
1810   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1811   storage.  That is, the stored row and column indices can begin at
1812   either one (as in Fortran) or zero.
1813 
1814 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1815           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1816 @*/
MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat * A)1817 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1818 {
1819   PetscMPIInt size;
1820 
1821   PetscFunctionBegin;
1822   PetscCall(MatCreate(comm, A));
1823   PetscCall(MatSetSizes(*A, m, n, M, N));
1824   PetscCallMPI(MPI_Comm_size(comm, &size));
1825   if (size > 1) {
1826     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1827     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1828   } else {
1829     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1830     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1831   }
1832   PetscFunctionReturn(PETSC_SUCCESS);
1833 }
1834