1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscmat_kokkos.hpp>
4 #include <petscpkg_version.h>
5 #include <petsc/private/sfimpl.h>
6 #include <petsc/private/kokkosimpl.hpp>
7 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
8 #include <../src/mat/impls/aij/mpi/mpiaij.h>
9 #include <KokkosSparse_spadd.hpp>
10 #include <KokkosSparse_spgemm.hpp>
11
MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)12 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
13 {
14 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
15
16 PetscFunctionBegin;
17 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
18 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
19 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
20 */
21 if (mode == MAT_FINAL_ASSEMBLY) {
22 PetscScalarKokkosView v;
23
24 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
25 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
26 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device
27 PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
28 PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
29 }
30 PetscFunctionReturn(PETSC_SUCCESS);
31 }
32
MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])33 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34 {
35 Mat_MPIAIJ *mpiaij;
36
37 PetscFunctionBegin;
38 // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
39 PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
40 mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
41 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
42 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
43 PetscFunctionReturn(PETSC_SUCCESS);
44 }
45
MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)46 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
47 {
48 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
49 PetscInt nt;
50
51 PetscFunctionBegin;
52 PetscCall(VecGetLocalSize(xx, &nt));
53 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
54 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
55 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
56 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
57 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
58 PetscFunctionReturn(PETSC_SUCCESS);
59 }
60
MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)61 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
62 {
63 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
64 PetscInt nt;
65
66 PetscFunctionBegin;
67 PetscCall(VecGetLocalSize(xx, &nt));
68 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
69 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
70 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
71 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
72 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
73 PetscFunctionReturn(PETSC_SUCCESS);
74 }
75
MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)76 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
77 {
78 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
79 PetscInt nt;
80
81 PetscFunctionBegin;
82 PetscCall(VecGetLocalSize(xx, &nt));
83 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
84 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
85 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
86 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
88 PetscFunctionReturn(PETSC_SUCCESS);
89 }
90
91 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
92 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
93 C still uses local column ids. Their corresponding global column ids are returned in glob.
94 */
MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS * glob,Mat * C)95 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
96 {
97 Mat Ad, Ao;
98 const PetscInt *cmap;
99
100 PetscFunctionBegin;
101 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
102 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
103 if (glob) {
104 PetscInt cst, i, dn, on, *gidx;
105 PetscCall(MatGetLocalSize(Ad, NULL, &dn));
106 PetscCall(MatGetLocalSize(Ao, NULL, &on));
107 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
108 PetscCall(PetscMalloc1(dn + on, &gidx));
109 for (i = 0; i < dn; i++) gidx[i] = cst + i;
110 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
111 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
112 }
113 PetscFunctionReturn(PETSC_SUCCESS);
114 }
115
116 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
117 struct MatMatStruct {
118 PetscInt n, *garray; // C's garray and its size.
119 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies)
120 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products
121 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
122 PetscIntKokkosView E_NzLeft;
123 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F
124 MatScalarKokkosView rootBuf, leafBuf;
125 KokkosCsrMatrix Fd, Fo; // F in split form
126
127 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
128 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
129 KernelHandle kh3; // compute C3
130 KernelHandle kh4; // compute C4
131
132 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
133 PetscInt E_VectorLength;
134 PetscInt E_RowsPerTeam;
135 PetscInt F_TeamSize;
136 PetscInt F_VectorLength;
137 PetscInt F_RowsPerTeam;
138
~MatMatStructMatMatStruct139 ~MatMatStruct()
140 {
141 PetscFunctionBegin;
142 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
143 PetscFunctionReturnVoid();
144 }
145 };
146
147 struct MatMatStruct_AB : public MatMatStruct {
148 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
149 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
150 PetscIntKokkosView rowoffset;
151 };
152
153 struct MatMatStruct_AtB : public MatMatStruct {
154 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
155 MatColIdxKokkosView Fdjperm;
156 MatColIdxKokkosView Fojmap;
157 MatColIdxKokkosView Fojperm;
158 };
159
160 struct MatProductCtx_MPIAIJKokkos {
161 MatMatStruct_AB *mmAB = nullptr;
162 MatMatStruct_AtB *mmAtB = nullptr;
163 PetscBool reusesym = PETSC_FALSE;
164 Mat Z = nullptr; // store Z=AB in computing BtAB
165
~MatProductCtx_MPIAIJKokkosMatProductCtx_MPIAIJKokkos166 ~MatProductCtx_MPIAIJKokkos()
167 {
168 delete mmAB;
169 delete mmAtB;
170 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
171 }
172 };
173
MatProductCtxDestroy_MPIAIJKokkos(PetscCtxRt data)174 static PetscErrorCode MatProductCtxDestroy_MPIAIJKokkos(PetscCtxRt data)
175 {
176 PetscFunctionBegin;
177 PetscCallCXX(delete *reinterpret_cast<MatProductCtx_MPIAIJKokkos **>(data));
178 PetscFunctionReturn(PETSC_SUCCESS);
179 }
180
181 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
182 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
183 template <class ExecutionSpace>
MatMergeGetLaunchParameters(PetscInt numRows,PetscInt nnz,PetscInt rows_per_thread,PetscInt & team_size,PetscInt & vector_length,PetscInt & rows_per_team)184 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
185 {
186 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
187 constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
188 #else
189 constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
190 #endif
191 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
192
193 PetscFunctionBegin;
194 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
195
196 if (nnz_per_row < 1) nnz_per_row = 1;
197
198 int max_vector_length = teamPolicy.vector_length_max();
199
200 if (vector_length < 1) {
201 vector_length = 1;
202 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
203 }
204
205 // Determine rows per thread
206 if (rows_per_thread < 1) {
207 if (is_gpu_exec_space) rows_per_thread = 1;
208 else {
209 if (nnz_per_row < 20 && nnz > 5000000) {
210 rows_per_thread = 256;
211 } else rows_per_thread = 64;
212 }
213 }
214
215 if (team_size < 1) {
216 if (is_gpu_exec_space) {
217 team_size = 256 / vector_length;
218 } else {
219 team_size = 1;
220 }
221 }
222
223 rows_per_team = rows_per_thread * team_size;
224
225 if (rows_per_team < 0) {
226 PetscInt nnz_per_team = 4096;
227 PetscInt conc = ExecutionSpace().concurrency();
228 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
229 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
230 }
231 PetscFunctionReturn(PETSC_SUCCESS);
232 }
233
234 /*
235 Reduce two sets of global indices into local ones
236
237 Input Parameters:
238 + n1 - size of garray1[], the first set
239 . garray1[n1] - a sorted global index array (without duplicates)
240 . m - size of indices[], the second set
241 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones
242
243 Output Parameters:
244 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[]
245 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
246 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]]
247 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
248
249 Example, say
250 n1 = 5
251 garray1[5] = {1, 4, 7, 8, 10}
252 m = 4
253 indices[4] = {2, 4, 8, 9}
254
255 Combining them together, we have 7 global indices in garray2[]
256 n2 = 7
257 garray2[7] = {1, 2, 4, 7, 8, 9, 10}
258
259 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
260 map[5] = {0, 2, 3, 4, 6}
261
262 On output, indices[] is updated with local indices
263 indices[4] = {1, 2, 4, 5}
264 */
ReduceTwoSetsOfGlobalIndices(PetscInt n1,const PetscInt * garray1,PetscInt m,PetscInt * indices,PetscInt * n2_,PetscInt ** garray2_,PetscInt * map)265 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
266 {
267 PetscHMapI g2l = nullptr;
268 PetscHashIter iter;
269 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id
270 PetscInt n2, *garray2;
271
272 PetscFunctionBegin;
273 tot = 0;
274 PetscCall(PetscHMapICreateWithSize(n1, &g2l));
275 for (PetscInt i = 0; i < m; i++) { // insert those in indices[]
276 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
277 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet
278 }
279
280 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
281 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
282 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
283 }
284
285 // Pull out (unique) globals in the hash table and put them in garray2[]
286 n2 = tot;
287 PetscCall(PetscMalloc1(n2, &garray2));
288 tot = 0;
289 PetscHashIterBegin(g2l, iter);
290 while (!PetscHashIterAtEnd(g2l, iter)) {
291 PetscHashIterGetKey(g2l, iter, key);
292 PetscHashIterNext(g2l, iter);
293 garray2[tot++] = key;
294 }
295
296 // Sort garray2[] and then map them to local indices starting from 0
297 PetscCall(PetscSortInt(n2, garray2));
298 PetscCall(PetscHMapIClear(g2l));
299 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
300
301 // Rewrite indices[] with local indices
302 for (PetscInt i = 0; i < m; i++) {
303 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
304 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
305 indices[i] = val;
306 }
307 // Record the map that maps garray1[i] to garray2[map[i]]
308 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
309 PetscCall(PetscHMapIDestroy(&g2l));
310 *n2_ = n2;
311 *garray2_ = garray2;
312 PetscFunctionReturn(PETSC_SUCCESS);
313 }
314
315 /*
316 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
317
318 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
319
320 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
321 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
322
323 Input Parameters:
324 + comm - MPI communicator of E
325 . A - diag block of E, using local column indices
326 . B - off-diag block of E, using local column indices
327 . cstart - (global) start column of Ed
328 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend)
329 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
330 . ownerSF - the SF specifies ownership (root) of rows in E
331 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
332 - mm - to stash intermediate data structures for reuse
333
334 Output Parameters:
335 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
336 - mm - contains various info, such as garray2[], F (Fd, Fo) etc.
337
338 Notes:
339 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
340
341 */
MatMPIAIJKokkosReduceBegin(MPI_Comm comm,KokkosCsrMatrix A,KokkosCsrMatrix B,PetscInt cstart,PetscInt cend,const PetscInt * garray1,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AtB * mm)342 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
343 {
344 PetscFunctionBegin;
345 if (reuse == MAT_INITIAL_MATRIX) {
346 PetscInt Em = A.numRows(), Fm;
347 PetscInt n1 = B.numCols();
348
349 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
350
351 // Do the analysis on host
352 auto Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
353 auto Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
354 auto Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
355 auto Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
356 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
357 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
358
359 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
360 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
361 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
362 for (PetscInt i = 0; i < Em; i++) {
363 const PetscInt *first, *last, *it;
364 PetscInt count, step;
365 // std::lower_bound(first,last,cstart), but need to use global column indices
366 first = Bj + Bi[i];
367 last = Bj + Bi[i + 1];
368 count = last - first;
369 while (count > 0) {
370 it = first;
371 step = count / 2;
372 it += step;
373 if (garray1[*it] < cstart) { // map local to global
374 first = ++it;
375 count -= step + 1;
376 } else count = step;
377 }
378 E_NzLeft[i] = first - (Bj + Bi[i]);
379 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
380 }
381
382 // Get length of rows (i.e., sizes of leaves) that contribute to my roots
383 const PetscMPIInt *iranks, *ranks;
384 const PetscInt *ioffset, *irootloc, *roffset, *rmine;
385 PetscMPIInt niranks, nranks;
386 MPI_Request *reqs;
387 PetscMPIInt tag;
388 PetscSF reduceSF;
389 PetscInt *sdisp, *rdisp;
390
391 PetscCall(PetscCommGetNewTag(comm, &tag));
392 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them)
393 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
394
395 // Find out length of each row I will receive. Even for the same row index, when they are from
396 // different senders, they might have different lengths (and sparsity patterns)
397 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
398 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
399
400 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
401
402 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
403 recvRowLen[0] = 0; // since we will make it in CSR format later
404 recvRowLen++; // advance the pointer now
405 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
406 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]));
407 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
408
409 // Build the real PetscSF for reducing E rows (buffer to buffer)
410 rdisp[0] = 0;
411 for (PetscInt i = 0; i < niranks; i++) {
412 rdisp[i + 1] = rdisp[i];
413 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) rdisp[i + 1] += recvRowLen[j];
414 }
415 recvRowLen--; // put it back into csr format
416 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
417
418 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
419 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
420 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
421
422 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send
423 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv
424 PetscSFNode *iremote;
425
426 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
427 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
428 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
429
430 for (PetscInt i = 0; i < nranks; i++) {
431 PetscInt count = 0;
432 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
433 for (PetscInt j = 0; j < count; j++) {
434 iremote[nleaves + j].rank = ranks[i];
435 iremote[nleaves + j].index = sdisp[i] + j;
436 }
437 nleaves += count;
438 }
439 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
440
441 PetscCall(PetscSFCreate(comm, &reduceSF));
442 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
443
444 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
445 PetscInt *sendCol, *recvCol;
446 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
447 for (PetscInt k = 0; k < roffset[nranks]; k++) {
448 PetscInt i = rmine[k]; // row to be copied
449 PetscInt *buf = &sendCol[Ai[i] + Bi[i]];
450 PetscInt nzLeft = E_NzLeft[i];
451 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
452 for (PetscInt j = 0; j < alen + blen; j++) {
453 if (j < nzLeft) {
454 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
455 } else if (j < nzLeft + alen) {
456 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
457 } else {
458 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
459 }
460 }
461 }
462 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
463 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
464
465 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
466 PetscInt *recvRowPerm, *recvColSorted;
467 PetscInt *recvNzPerm, *recvNzPermSorted;
468 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
469
470 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros
471 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
472 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
473
474 // i[] array, nz are always easiest to compute
475 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
476 MatRowMapType *Fdi, *Foi;
477 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
478 PetscInt iter;
479
480 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
481 Kokkos::deep_copy(Foi_h, 0);
482 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below
483 Foi = Foi_h.data() + 1;
484 iter = 0;
485 while (iter < recvRowCnt) { // iter over received rows
486 PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
487 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns)
488
489 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
490
491 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
492 PetscInt nz = 0; // nz (with dups) in the current row
493 PetscInt *jbuf = recvColSorted + FnzDups;
494 PetscInt *pbuf = recvNzPermSorted + FnzDups;
495 PetscInt *jbuf2 = jbuf; // temp pointers
496 PetscInt *pbuf2 = pbuf;
497 for (PetscInt d = 0; d < dupRows; d++) {
498 PetscInt i = recvRowPerm[iter + d];
499 PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
500 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
501 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
502 jbuf2 += len;
503 pbuf2 += len;
504 nz += len;
505 }
506 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
507
508 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
509 PetscInt cur = 0;
510 while (cur < nz) {
511 PetscInt curColIdx = jbuf[cur];
512 PetscInt dups = 1;
513
514 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
515 if (curColIdx >= cstart && curColIdx < cend) {
516 Fdi[curRowIdx]++;
517 FdnzDups += dups;
518 } else {
519 Foi[curRowIdx]++;
520 FonzDups += dups;
521 }
522 cur += dups;
523 }
524
525 FnzDups += nz;
526 iter += dupRows; // Move to next unique row
527 }
528
529 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
530 Foi = Foi_h.data();
531 for (PetscInt i = 0; i < Fm; i++) {
532 Fdi[i + 1] += Fdi[i];
533 Foi[i + 1] += Foi[i];
534 }
535 Fdnz = Fdi[Fm];
536 Fonz = Foi[Fm];
537 PetscCall(PetscFree2(sendCol, recvCol));
538
539 // Allocate j, jmap, jperm for Fd and Fo
540 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
541 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
542 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
543 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
544 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
545 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
546
547 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
548 Fdjmap[0] = 0;
549 Fojmap[0] = 0;
550 FnzDups = 0;
551 Fdnz = 0;
552 Fonz = 0;
553 iter = 0; // iter over received rows
554 while (iter < recvRowCnt) {
555 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
556 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths)
557 PetscInt nz = 0; // nz (with dups) in the current row
558
559 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
560 for (PetscInt d = 0; d < dupRows; d++) {
561 PetscInt i = recvRowPerm[iter + d];
562 nz += recvRowLen[i + 1] - recvRowLen[i];
563 }
564
565 PetscInt *jbuf = recvColSorted + FnzDups;
566 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
567 PetscInt cur = 0;
568 while (cur < nz) {
569 PetscInt curColIdx = jbuf[cur];
570 PetscInt dups = 1;
571
572 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
573 if (curColIdx >= cstart && curColIdx < cend) {
574 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local
575 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
576 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
577 FdnzDups += dups;
578 Fdnz++;
579 } else {
580 Foj[Fonz] = curColIdx; // in global
581 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
582 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
583 FonzDups += dups;
584 Fonz++;
585 }
586 cur += dups;
587 FnzDups += dups;
588 }
589 iter += dupRows; // Move to next unique row
590 }
591 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
592 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
593
594 // Combine global column indices in garray1[] and Foj[]
595 PetscInt n2, *garray2;
596
597 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
598 mm->sf = reduceSF;
599 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
600 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots);
601 mm->garray = garray2; // give ownership, so no free
602 mm->n = n2;
603 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
604 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
605 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
606 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
607 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
608
609 // Output Fd and Fo in KokkosCsrMatrix format
610 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
611 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
612 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
613 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
614 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
615 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
616
617 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
618 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
619
620 // Compute kernel launch parameters in merging E
621 PetscInt teamSize, vectorLength, rowsPerTeam;
622
623 teamSize = vectorLength = rowsPerTeam = -1;
624 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
625 mm->E_TeamSize = teamSize;
626 mm->E_VectorLength = vectorLength;
627 mm->E_RowsPerTeam = rowsPerTeam;
628 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
629
630 // Handy aliases
631 auto &Aa = A.values;
632 auto &Ba = B.values;
633 const auto &Ai = A.graph.row_map;
634 const auto &Bi = B.graph.row_map;
635 const auto &E_NzLeft = mm->E_NzLeft;
636 auto &leafBuf = mm->leafBuf;
637 auto &rootBuf = mm->rootBuf;
638 PetscSF reduceSF = mm->sf;
639 PetscInt Em = A.numRows();
640 PetscInt teamSize = mm->E_TeamSize;
641 PetscInt vectorLength = mm->E_VectorLength;
642 PetscInt rowsPerTeam = mm->E_RowsPerTeam;
643 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam;
644
645 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
646 PetscCallCXX(Kokkos::parallel_for(
647 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
648 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
649 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
650 if (i < Em) {
651 PetscInt disp = Ai(i) + Bi(i);
652 PetscInt alen = Ai(i + 1) - Ai(i);
653 PetscInt blen = Bi(i + 1) - Bi(i);
654 PetscInt nzleft = E_NzLeft(i);
655
656 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
657 MatScalar &val = leafBuf(disp + j);
658 if (j < nzleft) { // B left
659 val = Ba(Bi(i) + j);
660 } else if (j < nzleft + alen) { // diag A
661 val = Aa(Ai(i) + j - nzleft);
662 } else { // B right
663 val = Ba(Bi(i) + j - alen);
664 }
665 });
666 }
667 });
668 }));
669 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
670 PetscFunctionReturn(PETSC_SUCCESS);
671 }
672
673 // To finish MatMPIAIJKokkosReduce.
MatMPIAIJKokkosReduceEnd(MPI_Comm comm,KokkosCsrMatrix A,KokkosCsrMatrix B,PetscInt cstart,PetscInt cend,const PetscInt * garray1,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AtB * mm)674 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
675 {
676 auto &leafBuf = mm->leafBuf;
677 auto &rootBuf = mm->rootBuf;
678 auto &Fda = mm->Fd.values;
679 const auto &Fdjmap = mm->Fdjmap;
680 const auto &Fdjperm = mm->Fdjperm;
681 auto Fdnz = mm->Fd.nnz();
682 auto &Foa = mm->Fo.values;
683 const auto &Fojmap = mm->Fojmap;
684 const auto &Fojperm = mm->Fojperm;
685 auto Fonz = mm->Fo.nnz();
686 PetscSF reduceSF = mm->sf;
687
688 PetscFunctionBegin;
689 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
690
691 // Reduce data in rootBuf to Fd and Fo
692 PetscCallCXX(Kokkos::parallel_for(
693 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
694 PetscScalar sum = 0.0;
695 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
696 Fda(i) = sum;
697 }));
698
699 PetscCallCXX(Kokkos::parallel_for(
700 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
701 PetscScalar sum = 0.0;
702 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
703 Foa(i) = sum;
704 }));
705 PetscFunctionReturn(PETSC_SUCCESS);
706 }
707
708 /*
709 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
710
711 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
712 device and involves various index mapping.
713
714 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
715 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
716 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
717 F has the same column layout as E.
718
719 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
720 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
721 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
722 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
723 column indices in Fo and update Fo with local indices.
724
725 Input Parameters:
726 + E - the MPIAIJKOKKOS matrix
727 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
728 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
729 - mm - to stash matproduct intermediate data structures
730
731 Output Parameters:
732 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
733 - mm - contains various info, such as garray2[], Fd, Fo, etc.
734
735 Notes:
736 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
737 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
738 */
MatMPIAIJKokkosBcastBegin(Mat E,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AB * mm)739 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
740 {
741 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data);
742 Mat A = empi->A, B = empi->B; // diag and off-diag
743 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
744 PetscInt Em = E->rmap->n; // #local rows
745 MPI_Comm comm;
746
747 PetscFunctionBegin;
748 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
749 if (reuse == MAT_INITIAL_MATRIX) {
750 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
751 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
752 const PetscInt *garray1 = empi->garray; // its size is n1
753 PetscInt cstart, cend;
754 PetscSF bcastSF;
755
756 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
757
758 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
759 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
760 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
761 for (PetscInt i = 0; i < Em; i++) {
762 const PetscInt *first, *last, *it;
763 PetscInt count, step;
764 // std::lower_bound(first,last,cstart), but need to use global column indices
765 first = Bj + Bi[i];
766 last = Bj + Bi[i + 1];
767 count = last - first;
768 while (count > 0) {
769 it = first;
770 step = count / 2;
771 it += step;
772 if (empi->garray[*it] < cstart) { // map local to global
773 first = ++it;
774 count -= step + 1;
775 } else count = step;
776 }
777 E_NzLeft[i] = first - (Bj + Bi[i]);
778 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
779 }
780
781 // Compute row pointer Fi of F
782 PetscInt *Fi, Fm, Fnz;
783 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
784 PetscCall(PetscMalloc1(Fm + 1, &Fi));
785 Fi[0] = 0;
786 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
787 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
788 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
789 Fnz = Fi[Fm];
790
791 // Build the real PetscSF for bcasting E rows (buffer to buffer)
792 const PetscMPIInt *iranks, *ranks;
793 const PetscInt *ioffset, *irootloc, *roffset;
794 PetscMPIInt niranks, nranks;
795 PetscInt *sdisp, *rdisp;
796 MPI_Request *reqs;
797 PetscMPIInt tag;
798
799 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
800 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info
801 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
802
803 sdisp[0] = 0; // send displacement
804 for (PetscInt i = 0; i < niranks; i++) {
805 sdisp[i + 1] = sdisp[i];
806 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
807 PetscInt r = irootloc[j]; // row to be sent
808 sdisp[i + 1] += E_RowLen[r];
809 }
810 }
811
812 PetscCallMPI(PetscCommGetNewTag(comm, &tag));
813 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
814 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
815 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
816
817 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive
818 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send
819 PetscSFNode *iremote; // give ownership to bcastSF
820 PetscCall(PetscMalloc1(nleaves, &iremote));
821 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
822 PetscInt k = 0;
823 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
824 iremote[j].rank = ranks[i];
825 iremote[j].index = rdisp[i] + k; // their root location
826 k++;
827 }
828 }
829 PetscCall(PetscSFCreate(comm, &bcastSF));
830 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
831 PetscCall(PetscFree3(sdisp, rdisp, reqs));
832
833 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
834 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
835 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
836 rowoffset[0] = 0;
837 for (PetscInt i = 0; i < ioffset[niranks]; i++) rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]];
838
839 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
840 PetscInt *jbuf, *Fj;
841 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
842 for (PetscInt k = 0; k < ioffset[niranks]; k++) {
843 PetscInt i = irootloc[k]; // row to be copied
844 PetscInt *buf = &jbuf[rowoffset[k]];
845 PetscInt nzLeft = E_NzLeft[i];
846 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
847 for (PetscInt j = 0; j < alen + blen; j++) {
848 if (j < nzLeft) {
849 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
850 } else if (j < nzLeft + alen) {
851 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
852 } else {
853 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
854 }
855 }
856 }
857 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
858 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
859
860 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
861 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
862 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
863 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
864 MatColIdxType *F_NzLeft = F_NzLeft_h.data();
865
866 Fdi[0] = Foi[0] = 0;
867 for (PetscInt i = 0; i < Fm; i++) {
868 PetscInt *first, *last, *lb1, *lb2;
869 // cut the row into: Left, [cstart, cend), Right
870 first = Fj + Fi[i];
871 last = Fj + Fi[i + 1];
872 lb1 = std::lower_bound(first, last, cstart);
873 F_NzLeft[i] = lb1 - first;
874 lb2 = std::lower_bound(first, last, cend);
875 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi
876 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
877 }
878 for (PetscInt i = 0; i < Fm; i++) {
879 Fdi[i + 1] += Fdi[i];
880 Foi[i + 1] += Foi[i];
881 }
882
883 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
884 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm];
885 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
886 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
887
888 for (PetscInt i = 0; i < Fm; i++) {
889 PetscInt nzLeft = F_NzLeft[i];
890 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len
891 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
892 gid = Fj[Fi[i] + j];
893 if (j < nzLeft) { // left, in global
894 Foj[Foi[i] + j] = gid;
895 } else if (j < nzLeft + len) { // diag, in local
896 Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
897 } else { // right, in global
898 Foj[Foi[i] + j - len] = gid;
899 }
900 }
901 }
902 PetscCall(PetscFree2(jbuf, Fj));
903 PetscCall(PetscFree(Fi));
904
905 // Reduce global indices in Foj[] and garray1[] into local ones
906 PetscInt n2, *garray2;
907 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
908
909 // Record the plans built above, for reuse
910 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
911 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
912 Kokkos::deep_copy(irootloc_h, tmp);
913 mm->sf = bcastSF;
914 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
915 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
916 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
917 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
918 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots);
919 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
920 mm->garray = garray2;
921 mm->n = n2;
922
923 // Output Fd and Fo in KokkosCsrMatrix format
924 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
925 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
926 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
927 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
928 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
929
930 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
931 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
932
933 // Compute kernel launch parameters in merging E or splitting F
934 PetscInt teamSize, vectorLength, rowsPerTeam;
935
936 teamSize = vectorLength = rowsPerTeam = -1;
937 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
938 mm->E_TeamSize = teamSize;
939 mm->E_VectorLength = vectorLength;
940 mm->E_RowsPerTeam = rowsPerTeam;
941
942 teamSize = vectorLength = rowsPerTeam = -1;
943 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
944 mm->F_TeamSize = teamSize;
945 mm->F_VectorLength = vectorLength;
946 mm->F_RowsPerTeam = rowsPerTeam;
947 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
948
949 // Sync E's value to device
950 PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace()));
951 PetscCall(KokkosDualViewSyncDevice(bkok->a_dual, PetscGetKokkosExecutionSpace()));
952
953 // Handy aliases
954 const auto &Aa = akok->a_dual.view_device();
955 const auto &Ba = bkok->a_dual.view_device();
956 const auto &Ai = akok->i_dual.view_device();
957 const auto &Bi = bkok->i_dual.view_device();
958
959 // Fetch the plans
960 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft;
961 PetscSF &bcastSF = mm->sf;
962 MatScalarKokkosView &rootBuf = mm->rootBuf;
963 MatScalarKokkosView &leafBuf = mm->leafBuf;
964 PetscIntKokkosView &irootloc = mm->irootloc;
965 PetscIntKokkosView &rowoffset = mm->rowoffset;
966
967 PetscInt teamSize = mm->E_TeamSize;
968 PetscInt vectorLength = mm->E_VectorLength;
969 PetscInt rowsPerTeam = mm->E_RowsPerTeam;
970 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
971
972 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
973 PetscCallCXX(Kokkos::parallel_for(
974 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
975 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
976 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
977 if (r < irootloc.extent(0)) {
978 PetscInt i = irootloc(r); // row i of E
979 PetscInt disp = rowoffset(r);
980 PetscInt alen = Ai(i + 1) - Ai(i);
981 PetscInt blen = Bi(i + 1) - Bi(i);
982 PetscInt nzleft = E_NzLeft(i);
983
984 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
985 if (j < nzleft) { // B left
986 rootBuf(disp + j) = Ba(Bi(i) + j);
987 } else if (j < nzleft + alen) { // diag A
988 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
989 } else { // B right
990 rootBuf(disp + j) = Ba(Bi(i) + j - alen);
991 }
992 });
993 }
994 });
995 }));
996 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
997 PetscFunctionReturn(PETSC_SUCCESS);
998 }
999
1000 // To finish MatMPIAIJKokkosBcast.
MatMPIAIJKokkosBcastEnd(Mat E,PetscSF ownerSF,MatReuse reuse,PetscInt * map,MatMatStruct_AB * mm)1001 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1002 {
1003 PetscFunctionBegin;
1004 const auto &Fd = mm->Fd;
1005 const auto &Fo = mm->Fo;
1006 const auto &Fdi = Fd.graph.row_map;
1007 const auto &Foi = Fo.graph.row_map;
1008 auto &Fda = Fd.values;
1009 auto &Foa = Fo.values;
1010 auto Fm = Fd.numRows();
1011
1012 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft;
1013 PetscSF &bcastSF = mm->sf;
1014 MatScalarKokkosView &rootBuf = mm->rootBuf;
1015 MatScalarKokkosView &leafBuf = mm->leafBuf;
1016 PetscInt teamSize = mm->F_TeamSize;
1017 PetscInt vectorLength = mm->F_VectorLength;
1018 PetscInt rowsPerTeam = mm->F_RowsPerTeam;
1019 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1020
1021 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1022
1023 // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1024 PetscCallCXX(Kokkos::parallel_for(
1025 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1026 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1027 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1028 if (i < Fm) {
1029 PetscInt nzLeft = F_NzLeft(i);
1030 PetscInt alen = Fdi(i + 1) - Fdi(i);
1031 PetscInt blen = Foi(i + 1) - Foi(i);
1032 PetscInt Fii = Fdi(i) + Foi(i);
1033
1034 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1035 PetscScalar val = leafBuf(Fii + j);
1036 if (j < nzLeft) { // left
1037 Foa(Foi(i) + j) = val;
1038 } else if (j < nzLeft + alen) { // diag
1039 Fda(Fdi(i) + j - nzLeft) = val;
1040 } else { // right
1041 Foa(Foi(i) + j - alen) = val;
1042 }
1043 });
1044 }
1045 });
1046 }));
1047 PetscFunctionReturn(PETSC_SUCCESS);
1048 }
1049
MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AtB * mm)1050 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1051 {
1052 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1053 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1054 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1055 PetscInt cstart, cend;
1056 MPI_Comm comm;
1057
1058 PetscFunctionBegin;
1059 PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1060 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1061 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1062 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1063 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1064 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1065 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1066
1067 // TODO: add command line options to select spgemm algorithms
1068 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1069
1070 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1071 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1072 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1073 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1074 #endif
1075 #endif
1076
1077 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1078 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1079 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1080 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1081
1082 // Aot * (B's diag + B's off-diag)
1083 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1084 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1085 // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1086 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1087 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1088 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1089 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1090
1091 PetscCallCXX(sort_crs_matrix(mm->C3));
1092 PetscCallCXX(sort_crs_matrix(mm->C4));
1093 #endif
1094
1095 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1096 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1097 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1098 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1099
1100 // Adt * (B's diag + B's off-diag)
1101 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1102 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1103 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1104 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1105 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1106 PetscCallCXX(sort_crs_matrix(mm->C1));
1107 PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1108 #endif
1109
1110 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1111
1112 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1113 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1114 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1115 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1116 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1117
1118 // C = (C1+Fd, C2+Fo)
1119 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1120 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1121 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1122 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1123 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1124 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1125 PetscFunctionReturn(PETSC_SUCCESS);
1126 }
1127
MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AtB * mm)1128 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1129 {
1130 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1131 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1132 KokkosCsrMatrix Adt, Aot, Bd, Bo;
1133 MPI_Comm comm;
1134
1135 PetscFunctionBegin;
1136 PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1137 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1138 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1139 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1140 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1141
1142 // Aot * (B's diag + B's off-diag)
1143 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1144 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1145
1146 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1147 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1148
1149 // Adt * (B's diag + B's off-diag)
1150 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1151 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1152
1153 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1154
1155 // C = (C1+Fd, C2+Fo)
1156 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1157 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1158 PetscFunctionReturn(PETSC_SUCCESS);
1159 }
1160
1161 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1162
1163 Input Parameters:
1164 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1165 . A - an MPIAIJKOKKOS matrix
1166 . B - an MPIAIJKOKKOS matrix
1167 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1168 */
MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AB * mm)1169 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1170 {
1171 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1172 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1173 KokkosCsrMatrix Ad, Ao, Bd, Bo;
1174
1175 PetscFunctionBegin;
1176 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1177 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1178 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1179 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1180
1181 // TODO: add command line options to select spgemm algorithms
1182 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1183
1184 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1185 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1186 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1187 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1188 #endif
1189 #endif
1190
1191 mm->kh1.create_spgemm_handle(spgemm_alg);
1192 mm->kh2.create_spgemm_handle(spgemm_alg);
1193 mm->kh3.create_spgemm_handle(spgemm_alg);
1194 mm->kh4.create_spgemm_handle(spgemm_alg);
1195
1196 // Bcast B's rows to form F, and overlap the communication
1197 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1198 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1199
1200 // A's diag * (B's diag + B's off-diag)
1201 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1202 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1203 // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1204 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1205 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1206 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1207 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1208 PetscCallCXX(sort_crs_matrix(mm->C1));
1209 PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1210 #endif
1211
1212 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1213
1214 // A's off-diag * (F's diag + F's off-diag)
1215 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1216 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1217 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1218 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1219 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1220 PetscCallCXX(sort_crs_matrix(mm->C3));
1221 PetscCallCXX(sort_crs_matrix(mm->C4));
1222 #endif
1223
1224 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1225 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1226 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1227 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1228 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1229
1230 // C = (Cd, Co) = (C1+C3, C2+C4)
1231 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1232 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1233 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1234 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1235 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1236 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1237 PetscFunctionReturn(PETSC_SUCCESS);
1238 }
1239
MatProductNumeric_MPIAIJKokkos_AB(Mat_Product * product,Mat A,Mat B,MatMatStruct_AB * mm)1240 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241 {
1242 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244 KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245
1246 PetscFunctionBegin;
1247 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251
1252 // Bcast B's rows to form F, and overlap the communication
1253 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1254
1255 // A's diag * (B's diag + B's off-diag)
1256 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1257 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1258
1259 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1260
1261 // A's off-diag * (F's diag + F's off-diag)
1262 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1263 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1264
1265 // C = (Cd, Co) = (C1+C3, C2+C4)
1266 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1267 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1268 PetscFunctionReturn(PETSC_SUCCESS);
1269 }
1270
MatProductNumeric_MPIAIJKokkos(Mat C)1271 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1272 {
1273 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1274 Mat_Product *product;
1275 MatProductCtx_MPIAIJKokkos *pdata;
1276 MatProductType ptype;
1277 Mat A, B;
1278
1279 PetscFunctionBegin;
1280 MatCheckProduct(C, 1); // make sure C is a product
1281 product = C->product;
1282 pdata = static_cast<MatProductCtx_MPIAIJKokkos *>(product->data);
1283 ptype = product->type;
1284 A = product->A;
1285 B = product->B;
1286
1287 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1288 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1289 // we still do numeric.
1290 if (pdata->reusesym) { // numeric reuses results from symbolic
1291 pdata->reusesym = PETSC_FALSE;
1292 PetscFunctionReturn(PETSC_SUCCESS);
1293 }
1294
1295 if (ptype == MATPRODUCT_AB) {
1296 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1297 } else if (ptype == MATPRODUCT_AtB) {
1298 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1299 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1300 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1301 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1302 }
1303 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1304 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1305 PetscFunctionReturn(PETSC_SUCCESS);
1306 }
1307
MatProductSymbolic_MPIAIJKokkos(Mat C)1308 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1309 {
1310 Mat A, B;
1311 Mat_Product *product;
1312 MatProductType ptype;
1313 MatProductCtx_MPIAIJKokkos *pdata;
1314 MatMatStruct *mm = NULL;
1315 PetscInt m, n, M, N;
1316 Mat Cd, Co;
1317 MPI_Comm comm;
1318 Mat_MPIAIJ *mpiaij;
1319
1320 PetscFunctionBegin;
1321 PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1322 MatCheckProduct(C, 1);
1323 product = C->product;
1324 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1325 ptype = product->type;
1326 A = product->A;
1327 B = product->B;
1328
1329 switch (ptype) {
1330 case MATPRODUCT_AB:
1331 m = A->rmap->n;
1332 n = B->cmap->n;
1333 M = A->rmap->N;
1334 N = B->cmap->N;
1335 break;
1336 case MATPRODUCT_AtB:
1337 m = A->cmap->n;
1338 n = B->cmap->n;
1339 M = A->cmap->N;
1340 N = B->cmap->N;
1341 break;
1342 case MATPRODUCT_PtAP:
1343 m = B->cmap->n;
1344 n = B->cmap->n;
1345 M = B->cmap->N;
1346 N = B->cmap->N;
1347 break; /* BtAB */
1348 default:
1349 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1350 }
1351
1352 PetscCall(MatSetSizes(C, m, n, M, N));
1353 PetscCall(PetscLayoutSetUp(C->rmap));
1354 PetscCall(PetscLayoutSetUp(C->cmap));
1355 PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1356
1357 pdata = new MatProductCtx_MPIAIJKokkos();
1358 pdata->reusesym = product->api_user;
1359
1360 if (ptype == MATPRODUCT_AB) {
1361 auto mmAB = new MatMatStruct_AB();
1362 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1363 mm = pdata->mmAB = mmAB;
1364 } else if (ptype == MATPRODUCT_AtB) {
1365 auto mmAtB = new MatMatStruct_AtB();
1366 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1367 mm = pdata->mmAtB = mmAtB;
1368 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1369 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z
1370
1371 auto mmAB = new MatMatStruct_AB();
1372 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1373 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1374 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1375 pdata->mmAB = mmAB;
1376
1377 m = A->rmap->n; // Z's layout
1378 n = B->cmap->n;
1379 M = A->rmap->N;
1380 N = B->cmap->N;
1381 PetscCall(MatCreateMPIAIJWithSeqAIJ(comm, M, N, Zd, Zo, mmAB->garray, &Z));
1382
1383 auto mmAtB = new MatMatStruct_AtB();
1384 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1385
1386 pdata->Z = Z; // give ownership to pdata
1387 mm = pdata->mmAtB = mmAtB;
1388 }
1389
1390 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1391 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1392
1393 mpiaij = (Mat_MPIAIJ *)C->data;
1394 mpiaij->A = Cd;
1395 mpiaij->B = Co;
1396 mpiaij->garray = mm->garray;
1397
1398 C->preallocated = PETSC_TRUE;
1399 C->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
1400
1401 PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
1402 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
1403 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
1404 PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
1405 PetscCall(MatSetOption(C, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
1406
1407 /* set block sizes */
1408 switch (ptype) {
1409 case MATPRODUCT_PtAP:
1410 if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
1411 break;
1412 case MATPRODUCT_RARt:
1413 if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
1414 break;
1415 case MATPRODUCT_ABC:
1416 PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
1417 break;
1418 case MATPRODUCT_AB:
1419 PetscCall(MatSetBlockSizesFromMats(C, A, B));
1420 break;
1421 case MATPRODUCT_AtB:
1422 if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
1423 break;
1424 case MATPRODUCT_ABt:
1425 if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
1426 break;
1427 default:
1428 SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
1429 }
1430 C->product->data = pdata;
1431 C->product->destroy = MatProductCtxDestroy_MPIAIJKokkos;
1432 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1433 PetscFunctionReturn(PETSC_SUCCESS);
1434 }
1435
MatProductSetFromOptions_MPIAIJKokkos(Mat mat)1436 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1437 {
1438 Mat_Product *product = mat->product;
1439 PetscBool match = PETSC_FALSE;
1440 PetscBool usecpu = PETSC_FALSE;
1441
1442 PetscFunctionBegin;
1443 MatCheckProduct(mat, 1);
1444 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1445 if (match) { /* we can always fallback to the CPU if requested */
1446 switch (product->type) {
1447 case MATPRODUCT_AB:
1448 if (product->api_user) {
1449 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1450 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1451 PetscOptionsEnd();
1452 } else {
1453 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1454 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1455 PetscOptionsEnd();
1456 }
1457 break;
1458 case MATPRODUCT_AtB:
1459 if (product->api_user) {
1460 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1461 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1462 PetscOptionsEnd();
1463 } else {
1464 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1465 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1466 PetscOptionsEnd();
1467 }
1468 break;
1469 case MATPRODUCT_PtAP:
1470 if (product->api_user) {
1471 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1472 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1473 PetscOptionsEnd();
1474 } else {
1475 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1476 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1477 PetscOptionsEnd();
1478 }
1479 break;
1480 default:
1481 break;
1482 }
1483 match = (PetscBool)!usecpu;
1484 }
1485 if (match) {
1486 switch (product->type) {
1487 case MATPRODUCT_AB:
1488 case MATPRODUCT_AtB:
1489 case MATPRODUCT_PtAP:
1490 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1491 break;
1492 default:
1493 break;
1494 }
1495 }
1496 /* fallback to MPIAIJ ops */
1497 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1498 PetscFunctionReturn(PETSC_SUCCESS);
1499 }
1500
1501 // Mirror of MatCOOStruct_MPIAIJ on device
1502 struct MatCOOStruct_MPIAIJKokkos {
1503 PetscCount n;
1504 PetscSF sf;
1505 PetscCount Annz, Bnnz;
1506 PetscCount Annz2, Bnnz2;
1507 PetscCountKokkosView Ajmap1, Aperm1;
1508 PetscCountKokkosView Bjmap1, Bperm1;
1509 PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1510 PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1511 PetscCountKokkosView Cperm1;
1512 MatScalarKokkosView sendbuf, recvbuf;
1513
MatCOOStruct_MPIAIJKokkosMatCOOStruct_MPIAIJKokkos1514 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1515 {
1516 auto exec = PetscGetKokkosExecutionSpace();
1517
1518 n = coo_h->n;
1519 sf = coo_h->sf;
1520 Annz = coo_h->Annz;
1521 Bnnz = coo_h->Bnnz;
1522 Annz2 = coo_h->Annz2;
1523 Bnnz2 = coo_h->Bnnz2;
1524 Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1525 Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1526 Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1527 Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1528 Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1529 Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1530 Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1531 Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1532 Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1533 Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1534 Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1535 sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1536 recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1537 PetscCallVoid(PetscObjectReference((PetscObject)sf));
1538 }
1539
~MatCOOStruct_MPIAIJKokkosMatCOOStruct_MPIAIJKokkos1540 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1541 };
1542
MatCOOStructDestroy_MPIAIJKokkos(PetscCtxRt data)1543 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(PetscCtxRt data)
1544 {
1545 PetscFunctionBegin;
1546 PetscCallCXX(delete *static_cast<MatCOOStruct_MPIAIJKokkos **>(data));
1547 PetscFunctionReturn(PETSC_SUCCESS);
1548 }
1549
MatSetPreallocationCOO_MPIAIJKokkos(Mat mat,PetscCount coo_n,PetscInt coo_i[],PetscInt coo_j[])1550 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1551 {
1552 PetscContainer container_h, container_d;
1553 MatCOOStruct_MPIAIJ *coo_h;
1554 MatCOOStruct_MPIAIJKokkos *coo_d;
1555
1556 PetscFunctionBegin;
1557 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1558 mat->preallocated = PETSC_TRUE;
1559 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1560 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1561 PetscCall(MatZeroEntries(mat));
1562
1563 // Copy the COO struct to device
1564 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1565 PetscCall(PetscContainerGetPointer(container_h, &coo_h));
1566 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1567
1568 // Put the COO struct in a container and then attach that to the matrix
1569 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1570 PetscCall(PetscContainerSetPointer(container_d, coo_d));
1571 PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1572 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1573 PetscCall(PetscContainerDestroy(&container_d));
1574 PetscFunctionReturn(PETSC_SUCCESS);
1575 }
1576
MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)1577 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1578 {
1579 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1580 Mat A = mpiaij->A, B = mpiaij->B;
1581 MatScalarKokkosView Aa, Ba;
1582 MatScalarKokkosView v1;
1583 PetscMemType memtype;
1584 PetscContainer container;
1585 MatCOOStruct_MPIAIJKokkos *coo;
1586 Kokkos::DefaultExecutionSpace exec = PetscGetKokkosExecutionSpace();
1587
1588 PetscFunctionBegin;
1589 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1590 PetscCall(PetscContainerGetPointer(container, &coo));
1591
1592 const auto &n = coo->n;
1593 const auto &Annz = coo->Annz;
1594 const auto &Annz2 = coo->Annz2;
1595 const auto &Bnnz = coo->Bnnz;
1596 const auto &Bnnz2 = coo->Bnnz2;
1597 const auto &vsend = coo->sendbuf;
1598 const auto &v2 = coo->recvbuf;
1599 const auto &Ajmap1 = coo->Ajmap1;
1600 const auto &Ajmap2 = coo->Ajmap2;
1601 const auto &Aimap2 = coo->Aimap2;
1602 const auto &Bjmap1 = coo->Bjmap1;
1603 const auto &Bjmap2 = coo->Bjmap2;
1604 const auto &Bimap2 = coo->Bimap2;
1605 const auto &Aperm1 = coo->Aperm1;
1606 const auto &Aperm2 = coo->Aperm2;
1607 const auto &Bperm1 = coo->Bperm1;
1608 const auto &Bperm2 = coo->Bperm2;
1609 const auto &Cperm1 = coo->Cperm1;
1610
1611 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1612 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1613 v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1614 } else {
1615 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1616 }
1617
1618 if (imode == INSERT_VALUES) {
1619 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1620 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1621 } else {
1622 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1623 PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1624 }
1625
1626 PetscCall(PetscLogGpuTimeBegin());
1627 /* Pack entries to be sent to remote */
1628 Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1629
1630 /* Send remote entries to their owner and overlap the communication with local computation */
1631 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1632 /* Add local entries to A and B in one kernel */
1633 Kokkos::parallel_for(
1634 Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1635 PetscScalar sum = 0.0;
1636 if (i < Annz) {
1637 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1638 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1639 } else {
1640 i -= Annz;
1641 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1642 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1643 }
1644 });
1645 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1646
1647 /* Add received remote entries to A and B in one kernel */
1648 Kokkos::parallel_for(
1649 Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1650 if (i < Annz2) {
1651 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1652 } else {
1653 i -= Annz2;
1654 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1655 }
1656 });
1657 PetscCall(PetscLogGpuTimeEnd());
1658
1659 if (imode == INSERT_VALUES) {
1660 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1661 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1662 } else {
1663 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1664 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1665 }
1666 PetscFunctionReturn(PETSC_SUCCESS);
1667 }
1668
MatDestroy_MPIAIJKokkos(Mat A)1669 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1670 {
1671 PetscFunctionBegin;
1672 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1673 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1674 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1675 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1676 #if defined(PETSC_HAVE_HYPRE)
1677 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
1678 #endif
1679 PetscCall(MatDestroy_MPIAIJ(A));
1680 PetscFunctionReturn(PETSC_SUCCESS);
1681 }
1682
MatShift_MPIAIJKokkos(Mat A,PetscScalar a)1683 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1684 {
1685 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1686 PetscBool congruent;
1687
1688 PetscFunctionBegin;
1689 PetscCall(MatHasCongruentLayouts(A, &congruent));
1690 if (congruent) { // square matrix and the diagonals are solely in the diag block
1691 PetscCall(MatShift(mpiaij->A, a));
1692 } else { // too hard, use the general version
1693 PetscCall(MatShift_Basic(A, a));
1694 }
1695 PetscFunctionReturn(PETSC_SUCCESS);
1696 }
1697
MatSetOps_MPIAIJKokkos(Mat B)1698 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1699 {
1700 PetscFunctionBegin;
1701 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos;
1702 B->ops->mult = MatMult_MPIAIJKokkos;
1703 B->ops->multadd = MatMultAdd_MPIAIJKokkos;
1704 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos;
1705 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1706 B->ops->destroy = MatDestroy_MPIAIJKokkos;
1707 B->ops->shift = MatShift_MPIAIJKokkos;
1708 B->ops->getcurrentmemtype = MatGetCurrentMemType_MPIAIJ;
1709
1710 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1711 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1712 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1713 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1714 #if defined(PETSC_HAVE_HYPRE)
1715 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1716 #endif
1717 PetscFunctionReturn(PETSC_SUCCESS);
1718 }
1719
MatConvert_MPIAIJ_MPIAIJKokkos(Mat A,MatType mtype,MatReuse reuse,Mat * newmat)1720 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1721 {
1722 Mat B;
1723 Mat_MPIAIJ *a;
1724
1725 PetscFunctionBegin;
1726 if (reuse == MAT_INITIAL_MATRIX) {
1727 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1728 } else if (reuse == MAT_REUSE_MATRIX) {
1729 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1730 }
1731 B = *newmat;
1732
1733 B->boundtocpu = PETSC_FALSE;
1734 PetscCall(PetscFree(B->defaultvectype));
1735 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1736 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1737
1738 a = static_cast<Mat_MPIAIJ *>(A->data);
1739 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1740 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1741 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1742 PetscCall(MatSetOps_MPIAIJKokkos(B));
1743 PetscFunctionReturn(PETSC_SUCCESS);
1744 }
1745
1746 /*MC
1747 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos.
1748
1749 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1750
1751 Options Database Key:
1752 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1753
1754 Level: beginner
1755
1756 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1757 M*/
MatCreate_MPIAIJKokkos(Mat A)1758 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1759 {
1760 PetscFunctionBegin;
1761 PetscCall(PetscKokkosInitializeCheck());
1762 PetscCall(MatCreate_MPIAIJ(A));
1763 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1764 PetscFunctionReturn(PETSC_SUCCESS);
1765 }
1766
1767 /*@C
1768 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
1769 (the default parallel PETSc format). This matrix will ultimately pushed down
1770 to Kokkos for calculations.
1771
1772 Collective
1773
1774 Input Parameters:
1775 + comm - MPI communicator, set to `PETSC_COMM_SELF`
1776 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1777 This value should be the same as the local size used in creating the
1778 y vector for the matrix-vector product y = Ax.
1779 . n - This value should be the same as the local size used in creating the
1780 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1781 calculated if N is given) For square matrices n is almost always `m`.
1782 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1783 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1784 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix
1785 (same value is used for all local rows)
1786 . d_nnz - array containing the number of nonzeros in the various rows of the
1787 DIAGONAL portion of the local submatrix (possibly different for each row)
1788 or `NULL`, if `d_nz` is used to specify the nonzero structure.
1789 The size of this array is equal to the number of local rows, i.e `m`.
1790 For matrices you plan to factor you must leave room for the diagonal entry and
1791 put in the entry even if it is zero.
1792 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local
1793 submatrix (same value is used for all local rows).
1794 - o_nnz - array containing the number of nonzeros in the various rows of the
1795 OFF-DIAGONAL portion of the local submatrix (possibly different for
1796 each row) or `NULL`, if `o_nz` is used to specify the nonzero
1797 structure. The size of this array is equal to the number
1798 of local rows, i.e `m`.
1799
1800 Output Parameter:
1801 . A - the matrix
1802
1803 Level: intermediate
1804
1805 Notes:
1806 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1807 MatXXXXSetPreallocation() paradigm instead of this routine directly.
1808 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1809
1810 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1811 storage. That is, the stored row and column indices can begin at
1812 either one (as in Fortran) or zero.
1813
1814 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1815 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1816 @*/
MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat * A)1817 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1818 {
1819 PetscMPIInt size;
1820
1821 PetscFunctionBegin;
1822 PetscCall(MatCreate(comm, A));
1823 PetscCall(MatSetSizes(*A, m, n, M, N));
1824 PetscCallMPI(MPI_Comm_size(comm, &size));
1825 if (size > 1) {
1826 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1827 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1828 } else {
1829 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1830 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1831 }
1832 PetscFunctionReturn(PETSC_SUCCESS);
1833 }
1834