xref: /petsc/src/mat/impls/aij/mpi/fdmpiaij.c (revision 2d30e087755efd99e28fdfe792ffbeb2ee1ea928)
1 #include <../src/mat/impls/sell/mpi/mpisell.h>
2 #include <../src/mat/impls/aij/mpi/mpiaij.h>
3 #include <../src/mat/impls/baij/mpi/mpibaij.h>
4 #include <petsc/private/isimpl.h>
5 
6 PetscErrorCode MatFDColoringApply_BAIJ(Mat J, MatFDColoring coloring, Vec x1, void *sctx) {
7   PetscErrorCode (*f)(void *, Vec, Vec, void *) = (PetscErrorCode(*)(void *, Vec, Vec, void *))coloring->f;
8   PetscInt           k, cstart, cend, l, row, col, nz, spidx, i, j;
9   PetscScalar        dx = 0.0, *w3_array, *dy_i, *dy = coloring->dy;
10   PetscScalar       *vscale_array;
11   const PetscScalar *xx;
12   PetscReal          epsilon = coloring->error_rel, umin = coloring->umin, unorm;
13   Vec                w1 = coloring->w1, w2 = coloring->w2, w3, vscale = coloring->vscale;
14   void              *fctx  = coloring->fctx;
15   PetscInt           ctype = coloring->ctype, nxloc, nrows_k;
16   PetscScalar       *valaddr;
17   MatEntry          *Jentry  = coloring->matentry;
18   MatEntry2         *Jentry2 = coloring->matentry2;
19   const PetscInt     ncolors = coloring->ncolors, *ncolumns = coloring->ncolumns, *nrows = coloring->nrows;
20   PetscInt           bs = J->rmap->bs;
21 
22   PetscFunctionBegin;
23   PetscCall(VecBindToCPU(x1, PETSC_TRUE));
24   /* (1) Set w1 = F(x1) */
25   if (!coloring->fset) {
26     PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, coloring, 0, 0, 0));
27     PetscCall((*f)(sctx, x1, w1, fctx));
28     PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, coloring, 0, 0, 0));
29   } else {
30     coloring->fset = PETSC_FALSE;
31   }
32 
33   /* (2) Compute vscale = 1./dx - the local scale factors, including ghost points */
34   PetscCall(VecGetLocalSize(x1, &nxloc));
35   if (coloring->htype[0] == 'w') {
36     /* vscale = dx is a constant scalar */
37     PetscCall(VecNorm(x1, NORM_2, &unorm));
38     dx = 1.0 / (PetscSqrtReal(1.0 + unorm) * epsilon);
39   } else {
40     PetscCall(VecGetArrayRead(x1, &xx));
41     PetscCall(VecGetArray(vscale, &vscale_array));
42     for (col = 0; col < nxloc; col++) {
43       dx = xx[col];
44       if (PetscAbsScalar(dx) < umin) {
45         if (PetscRealPart(dx) >= 0.0) dx = umin;
46         else if (PetscRealPart(dx) < 0.0) dx = -umin;
47       }
48       dx *= epsilon;
49       vscale_array[col] = 1.0 / dx;
50     }
51     PetscCall(VecRestoreArrayRead(x1, &xx));
52     PetscCall(VecRestoreArray(vscale, &vscale_array));
53   }
54   if (ctype == IS_COLORING_GLOBAL && coloring->htype[0] == 'd') {
55     PetscCall(VecGhostUpdateBegin(vscale, INSERT_VALUES, SCATTER_FORWARD));
56     PetscCall(VecGhostUpdateEnd(vscale, INSERT_VALUES, SCATTER_FORWARD));
57   }
58 
59   /* (3) Loop over each color */
60   if (!coloring->w3) {
61     PetscCall(VecDuplicate(x1, &coloring->w3));
62     /* Vec is used intensively in particular piece of scalar CPU code; won't benefit from bouncing back and forth to the GPU */
63     PetscCall(VecBindToCPU(coloring->w3, PETSC_TRUE));
64     PetscCall(PetscLogObjectParent((PetscObject)coloring, (PetscObject)coloring->w3));
65   }
66   w3 = coloring->w3;
67 
68   PetscCall(VecGetOwnershipRange(x1, &cstart, &cend)); /* used by ghosted vscale */
69   if (vscale) PetscCall(VecGetArray(vscale, &vscale_array));
70   nz = 0;
71   for (k = 0; k < ncolors; k++) {
72     coloring->currentcolor = k;
73 
74     /*
75       (3-1) Loop over each column associated with color
76       adding the perturbation to the vector w3 = x1 + dx.
77     */
78     PetscCall(VecCopy(x1, w3));
79     dy_i = dy;
80     for (i = 0; i < bs; i++) { /* Loop over a block of columns */
81       PetscCall(VecGetArray(w3, &w3_array));
82       if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
83       if (coloring->htype[0] == 'w') {
84         for (l = 0; l < ncolumns[k]; l++) {
85           col = i + bs * coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
86           w3_array[col] += 1.0 / dx;
87           if (i) w3_array[col - 1] -= 1.0 / dx; /* resume original w3[col-1] */
88         }
89       } else {                  /* htype == 'ds' */
90         vscale_array -= cstart; /* shift pointer so global index can be used */
91         for (l = 0; l < ncolumns[k]; l++) {
92           col = i + bs * coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
93           w3_array[col] += 1.0 / vscale_array[col];
94           if (i) w3_array[col - 1] -= 1.0 / vscale_array[col - 1]; /* resume original w3[col-1] */
95         }
96         vscale_array += cstart;
97       }
98       if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
99       PetscCall(VecRestoreArray(w3, &w3_array));
100 
101       /*
102        (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
103                            w2 = F(x1 + dx) - F(x1)
104        */
105       PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
106       PetscCall(VecPlaceArray(w2, dy_i)); /* place w2 to the array dy_i */
107       PetscCall((*f)(sctx, w3, w2, fctx));
108       PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
109       PetscCall(VecAXPY(w2, -1.0, w1));
110       PetscCall(VecResetArray(w2));
111       dy_i += nxloc; /* points to dy+i*nxloc */
112     }
113 
114     /*
115      (3-3) Loop over rows of vector, putting results into Jacobian matrix
116     */
117     nrows_k = nrows[k];
118     if (coloring->htype[0] == 'w') {
119       for (l = 0; l < nrows_k; l++) {
120         row     = bs * Jentry2[nz].row; /* local row index */
121         valaddr = Jentry2[nz++].valaddr;
122         spidx   = 0;
123         dy_i    = dy;
124         for (i = 0; i < bs; i++) {   /* column of the block */
125           for (j = 0; j < bs; j++) { /* row of the block */
126             valaddr[spidx++] = dy_i[row + j] * dx;
127           }
128           dy_i += nxloc; /* points to dy+i*nxloc */
129         }
130       }
131     } else { /* htype == 'ds' */
132       for (l = 0; l < nrows_k; l++) {
133         row     = bs * Jentry[nz].row; /* local row index */
134         col     = bs * Jentry[nz].col; /* local column index */
135         valaddr = Jentry[nz++].valaddr;
136         spidx   = 0;
137         dy_i    = dy;
138         for (i = 0; i < bs; i++) {   /* column of the block */
139           for (j = 0; j < bs; j++) { /* row of the block */
140             valaddr[spidx++] = dy_i[row + j] * vscale_array[col + i];
141           }
142           dy_i += nxloc; /* points to dy+i*nxloc */
143         }
144       }
145     }
146   }
147   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
148   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
149   if (vscale) PetscCall(VecRestoreArray(vscale, &vscale_array));
150 
151   coloring->currentcolor = -1;
152   PetscCall(VecBindToCPU(x1, PETSC_FALSE));
153   PetscFunctionReturn(0);
154 }
155 
156 /* this is declared PETSC_EXTERN because it is used by MatFDColoringUseDM() which is in the DM library */
157 PetscErrorCode MatFDColoringApply_AIJ(Mat J, MatFDColoring coloring, Vec x1, void *sctx) {
158   PetscErrorCode (*f)(void *, Vec, Vec, void *) = (PetscErrorCode(*)(void *, Vec, Vec, void *))coloring->f;
159   PetscInt           k, cstart, cend, l, row, col, nz;
160   PetscScalar        dx = 0.0, *y, *w3_array;
161   const PetscScalar *xx;
162   PetscScalar       *vscale_array;
163   PetscReal          epsilon = coloring->error_rel, umin = coloring->umin, unorm;
164   Vec                w1 = coloring->w1, w2 = coloring->w2, w3, vscale = coloring->vscale;
165   void              *fctx  = coloring->fctx;
166   ISColoringType     ctype = coloring->ctype;
167   PetscInt           nxloc, nrows_k;
168   MatEntry          *Jentry  = coloring->matentry;
169   MatEntry2         *Jentry2 = coloring->matentry2;
170   const PetscInt     ncolors = coloring->ncolors, *ncolumns = coloring->ncolumns, *nrows = coloring->nrows;
171   PetscBool          alreadyboundtocpu;
172 
173   PetscFunctionBegin;
174   PetscCall(VecBoundToCPU(x1, &alreadyboundtocpu));
175   PetscCall(VecBindToCPU(x1, PETSC_TRUE));
176   PetscCheck(!(ctype == IS_COLORING_LOCAL) || !(J->ops->fdcoloringapply == MatFDColoringApply_AIJ), PetscObjectComm((PetscObject)J), PETSC_ERR_SUP, "Must call MatColoringUseDM() with IS_COLORING_LOCAL");
177   /* (1) Set w1 = F(x1) */
178   if (!coloring->fset) {
179     PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
180     PetscCall((*f)(sctx, x1, w1, fctx));
181     PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
182   } else {
183     coloring->fset = PETSC_FALSE;
184   }
185 
186   /* (2) Compute vscale = 1./dx - the local scale factors, including ghost points */
187   if (coloring->htype[0] == 'w') {
188     /* vscale = 1./dx is a constant scalar */
189     PetscCall(VecNorm(x1, NORM_2, &unorm));
190     dx = 1.0 / (PetscSqrtReal(1.0 + unorm) * epsilon);
191   } else {
192     PetscCall(VecGetLocalSize(x1, &nxloc));
193     PetscCall(VecGetArrayRead(x1, &xx));
194     PetscCall(VecGetArray(vscale, &vscale_array));
195     for (col = 0; col < nxloc; col++) {
196       dx = xx[col];
197       if (PetscAbsScalar(dx) < umin) {
198         if (PetscRealPart(dx) >= 0.0) dx = umin;
199         else if (PetscRealPart(dx) < 0.0) dx = -umin;
200       }
201       dx *= epsilon;
202       vscale_array[col] = 1.0 / dx;
203     }
204     PetscCall(VecRestoreArrayRead(x1, &xx));
205     PetscCall(VecRestoreArray(vscale, &vscale_array));
206   }
207   if (ctype == IS_COLORING_GLOBAL && coloring->htype[0] == 'd') {
208     PetscCall(VecGhostUpdateBegin(vscale, INSERT_VALUES, SCATTER_FORWARD));
209     PetscCall(VecGhostUpdateEnd(vscale, INSERT_VALUES, SCATTER_FORWARD));
210   }
211 
212   /* (3) Loop over each color */
213   if (!coloring->w3) {
214     PetscCall(VecDuplicate(x1, &coloring->w3));
215     PetscCall(PetscLogObjectParent((PetscObject)coloring, (PetscObject)coloring->w3));
216   }
217   w3 = coloring->w3;
218 
219   PetscCall(VecGetOwnershipRange(x1, &cstart, &cend)); /* used by ghosted vscale */
220   if (vscale) PetscCall(VecGetArray(vscale, &vscale_array));
221   nz = 0;
222 
223   if (coloring->bcols > 1) { /* use blocked insertion of Jentry */
224     PetscInt     i, m = J->rmap->n, nbcols, bcols = coloring->bcols;
225     PetscScalar *dy = coloring->dy, *dy_k;
226 
227     nbcols = 0;
228     for (k = 0; k < ncolors; k += bcols) {
229       /*
230        (3-1) Loop over each column associated with color
231        adding the perturbation to the vector w3 = x1 + dx.
232        */
233 
234       dy_k = dy;
235       if (k + bcols > ncolors) bcols = ncolors - k;
236       for (i = 0; i < bcols; i++) {
237         coloring->currentcolor = k + i;
238 
239         PetscCall(VecCopy(x1, w3));
240         PetscCall(VecGetArray(w3, &w3_array));
241         if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
242         if (coloring->htype[0] == 'w') {
243           for (l = 0; l < ncolumns[k + i]; l++) {
244             col = coloring->columns[k + i][l]; /* local column (in global index!) of the matrix we are probing for */
245             w3_array[col] += 1.0 / dx;
246           }
247         } else {                  /* htype == 'ds' */
248           vscale_array -= cstart; /* shift pointer so global index can be used */
249           for (l = 0; l < ncolumns[k + i]; l++) {
250             col = coloring->columns[k + i][l]; /* local column (in global index!) of the matrix we are probing for */
251             w3_array[col] += 1.0 / vscale_array[col];
252           }
253           vscale_array += cstart;
254         }
255         if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
256         PetscCall(VecRestoreArray(w3, &w3_array));
257 
258         /*
259          (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
260                            w2 = F(x1 + dx) - F(x1)
261          */
262         PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
263         PetscCall(VecPlaceArray(w2, dy_k)); /* place w2 to the array dy_i */
264         PetscCall((*f)(sctx, w3, w2, fctx));
265         PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
266         PetscCall(VecAXPY(w2, -1.0, w1));
267         PetscCall(VecResetArray(w2));
268         dy_k += m; /* points to dy+i*nxloc */
269       }
270 
271       /*
272        (3-3) Loop over block rows of vector, putting results into Jacobian matrix
273        */
274       nrows_k = nrows[nbcols++];
275 
276       if (coloring->htype[0] == 'w') {
277         for (l = 0; l < nrows_k; l++) {
278           row = Jentry2[nz].row; /* local row index */
279                                  /* The 'useless' ifdef is due to a bug in NVIDIA nvc 21.11, which triggers a segfault on this line. We write it in
280              another way, and it seems work. See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html
281            */
282 #if defined(PETSC_USE_COMPLEX)
283           PetscScalar *tmp = Jentry2[nz].valaddr;
284           *tmp             = dy[row] * dx;
285 #else
286           *(Jentry2[nz].valaddr) = dy[row] * dx;
287 #endif
288           nz++;
289         }
290       } else { /* htype == 'ds' */
291         for (l = 0; l < nrows_k; l++) {
292           row = Jentry[nz].row; /* local row index */
293 #if defined(PETSC_USE_COMPLEX)  /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
294           PetscScalar *tmp = Jentry[nz].valaddr;
295           *tmp             = dy[row] * vscale_array[Jentry[nz].col];
296 #else
297           *(Jentry[nz].valaddr)  = dy[row] * vscale_array[Jentry[nz].col];
298 #endif
299           nz++;
300         }
301       }
302     }
303   } else { /* bcols == 1 */
304     for (k = 0; k < ncolors; k++) {
305       coloring->currentcolor = k;
306 
307       /*
308        (3-1) Loop over each column associated with color
309        adding the perturbation to the vector w3 = x1 + dx.
310        */
311       PetscCall(VecCopy(x1, w3));
312       PetscCall(VecGetArray(w3, &w3_array));
313       if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
314       if (coloring->htype[0] == 'w') {
315         for (l = 0; l < ncolumns[k]; l++) {
316           col = coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
317           w3_array[col] += 1.0 / dx;
318         }
319       } else {                  /* htype == 'ds' */
320         vscale_array -= cstart; /* shift pointer so global index can be used */
321         for (l = 0; l < ncolumns[k]; l++) {
322           col = coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
323           w3_array[col] += 1.0 / vscale_array[col];
324         }
325         vscale_array += cstart;
326       }
327       if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
328       PetscCall(VecRestoreArray(w3, &w3_array));
329 
330       /*
331        (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
332                            w2 = F(x1 + dx) - F(x1)
333        */
334       PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
335       PetscCall((*f)(sctx, w3, w2, fctx));
336       PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
337       PetscCall(VecAXPY(w2, -1.0, w1));
338 
339       /*
340        (3-3) Loop over rows of vector, putting results into Jacobian matrix
341        */
342       nrows_k = nrows[k];
343       PetscCall(VecGetArray(w2, &y));
344       if (coloring->htype[0] == 'w') {
345         for (l = 0; l < nrows_k; l++) {
346           row = Jentry2[nz].row; /* local row index */
347 #if defined(PETSC_USE_COMPLEX)   /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
348           PetscScalar *tmp = Jentry2[nz].valaddr;
349           *tmp             = y[row] * dx;
350 #else
351           *(Jentry2[nz].valaddr) = y[row] * dx;
352 #endif
353           nz++;
354         }
355       } else { /* htype == 'ds' */
356         for (l = 0; l < nrows_k; l++) {
357           row = Jentry[nz].row; /* local row index */
358 #if defined(PETSC_USE_COMPLEX)  /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
359           PetscScalar *tmp = Jentry[nz].valaddr;
360           *tmp             = y[row] * vscale_array[Jentry[nz].col];
361 #else
362           *(Jentry[nz].valaddr)  = y[row] * vscale_array[Jentry[nz].col];
363 #endif
364           nz++;
365         }
366       }
367       PetscCall(VecRestoreArray(w2, &y));
368     }
369   }
370 
371 #if defined(PETSC_HAVE_VIENNACL) || defined(PETSC_HAVE_CUDA)
372   if (J->offloadmask != PETSC_OFFLOAD_UNALLOCATED) J->offloadmask = PETSC_OFFLOAD_CPU;
373 #endif
374   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
375   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
376   if (vscale) PetscCall(VecRestoreArray(vscale, &vscale_array));
377   coloring->currentcolor = -1;
378   if (!alreadyboundtocpu) PetscCall(VecBindToCPU(x1, PETSC_FALSE));
379   PetscFunctionReturn(0);
380 }
381 
382 PetscErrorCode MatFDColoringSetUp_MPIXAIJ(Mat mat, ISColoring iscoloring, MatFDColoring c) {
383   PetscMPIInt            size, *ncolsonproc, *disp, nn;
384   PetscInt               i, n, nrows, nrows_i, j, k, m, ncols, col, *rowhit, cstart, cend, colb;
385   const PetscInt        *is, *A_ci, *A_cj, *B_ci, *B_cj, *row = NULL, *ltog = NULL;
386   PetscInt               nis = iscoloring->n, nctot, *cols, tmp = 0;
387   ISLocalToGlobalMapping map   = mat->cmap->mapping;
388   PetscInt               ctype = c->ctype, *spidxA, *spidxB, nz, bs, bs2, spidx;
389   Mat                    A, B;
390   PetscScalar           *A_val, *B_val, **valaddrhit;
391   MatEntry              *Jentry;
392   MatEntry2             *Jentry2;
393   PetscBool              isBAIJ, isSELL;
394   PetscInt               bcols = c->bcols;
395 #if defined(PETSC_USE_CTABLE)
396   PetscTable colmap = NULL;
397 #else
398   PetscInt *colmap = NULL;      /* local col number of off-diag col */
399 #endif
400 
401   PetscFunctionBegin;
402   if (ctype == IS_COLORING_LOCAL) {
403     PetscCheck(map, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_INCOMP, "When using ghosted differencing matrix must have local to global mapping provided with MatSetLocalToGlobalMapping");
404     PetscCall(ISLocalToGlobalMappingGetIndices(map, &ltog));
405   }
406 
407   PetscCall(MatGetBlockSize(mat, &bs));
408   PetscCall(PetscObjectBaseTypeCompare((PetscObject)mat, MATMPIBAIJ, &isBAIJ));
409   PetscCall(PetscObjectTypeCompare((PetscObject)mat, MATMPISELL, &isSELL));
410   if (isBAIJ) {
411     Mat_MPIBAIJ *baij = (Mat_MPIBAIJ *)mat->data;
412     Mat_SeqBAIJ *spA, *spB;
413     A     = baij->A;
414     spA   = (Mat_SeqBAIJ *)A->data;
415     A_val = spA->a;
416     B     = baij->B;
417     spB   = (Mat_SeqBAIJ *)B->data;
418     B_val = spB->a;
419     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
420     if (!baij->colmap) PetscCall(MatCreateColmap_MPIBAIJ_Private(mat));
421     colmap = baij->colmap;
422     PetscCall(MatGetColumnIJ_SeqBAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
423     PetscCall(MatGetColumnIJ_SeqBAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
424 
425     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
426       PetscInt *garray;
427       PetscCall(PetscMalloc1(B->cmap->n, &garray));
428       for (i = 0; i < baij->B->cmap->n / bs; i++) {
429         for (j = 0; j < bs; j++) garray[i * bs + j] = bs * baij->garray[i] + j;
430       }
431       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, garray, &c->vscale));
432       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
433       PetscCall(PetscFree(garray));
434     }
435   } else if (isSELL) {
436     Mat_MPISELL *sell = (Mat_MPISELL *)mat->data;
437     Mat_SeqSELL *spA, *spB;
438     A     = sell->A;
439     spA   = (Mat_SeqSELL *)A->data;
440     A_val = spA->val;
441     B     = sell->B;
442     spB   = (Mat_SeqSELL *)B->data;
443     B_val = spB->val;
444     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
445     if (!sell->colmap) {
446       /* Allow access to data structures of local part of matrix
447        - creates aij->colmap which maps global column number to local number in part B */
448       PetscCall(MatCreateColmap_MPISELL_Private(mat));
449     }
450     colmap = sell->colmap;
451     PetscCall(MatGetColumnIJ_SeqSELL_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
452     PetscCall(MatGetColumnIJ_SeqSELL_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
453 
454     bs = 1; /* only bs=1 is supported for non MPIBAIJ matrix */
455 
456     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
457       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, sell->garray, &c->vscale));
458       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
459     }
460   } else {
461     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)mat->data;
462     Mat_SeqAIJ *spA, *spB;
463     A     = aij->A;
464     spA   = (Mat_SeqAIJ *)A->data;
465     A_val = spA->a;
466     B     = aij->B;
467     spB   = (Mat_SeqAIJ *)B->data;
468     B_val = spB->a;
469     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
470     if (!aij->colmap) {
471       /* Allow access to data structures of local part of matrix
472        - creates aij->colmap which maps global column number to local number in part B */
473       PetscCall(MatCreateColmap_MPIAIJ_Private(mat));
474     }
475     colmap = aij->colmap;
476     PetscCall(MatGetColumnIJ_SeqAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
477     PetscCall(MatGetColumnIJ_SeqAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
478 
479     bs = 1; /* only bs=1 is supported for non MPIBAIJ matrix */
480 
481     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
482       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, aij->garray, &c->vscale));
483       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
484     }
485   }
486 
487   m      = mat->rmap->n / bs;
488   cstart = mat->cmap->rstart / bs;
489   cend   = mat->cmap->rend / bs;
490 
491   PetscCall(PetscMalloc2(nis, &c->ncolumns, nis, &c->columns));
492   PetscCall(PetscMalloc1(nis, &c->nrows));
493   PetscCall(PetscLogObjectMemory((PetscObject)c, 3 * nis * sizeof(PetscInt)));
494 
495   if (c->htype[0] == 'd') {
496     PetscCall(PetscMalloc1(nz, &Jentry));
497     PetscCall(PetscLogObjectMemory((PetscObject)c, nz * sizeof(MatEntry)));
498     c->matentry = Jentry;
499   } else if (c->htype[0] == 'w') {
500     PetscCall(PetscMalloc1(nz, &Jentry2));
501     PetscCall(PetscLogObjectMemory((PetscObject)c, nz * sizeof(MatEntry2)));
502     c->matentry2 = Jentry2;
503   } else SETERRQ(PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "htype is not supported");
504 
505   PetscCall(PetscMalloc2(m + 1, &rowhit, m + 1, &valaddrhit));
506   nz = 0;
507   PetscCall(ISColoringGetIS(iscoloring, PETSC_OWN_POINTER, PETSC_IGNORE, &c->isa));
508 
509   if (ctype == IS_COLORING_GLOBAL) {
510     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
511     PetscCall(PetscMalloc2(size, &ncolsonproc, size, &disp));
512   }
513 
514   for (i = 0; i < nis; i++) { /* for each local color */
515     PetscCall(ISGetLocalSize(c->isa[i], &n));
516     PetscCall(ISGetIndices(c->isa[i], &is));
517 
518     c->ncolumns[i] = n; /* local number of columns of this color on this process */
519     c->columns[i]  = (PetscInt *)is;
520 
521     if (ctype == IS_COLORING_GLOBAL) {
522       /* Determine nctot, the total (parallel) number of columns of this color */
523       /* ncolsonproc[j]: local ncolumns on proc[j] of this color */
524       PetscCall(PetscMPIIntCast(n, &nn));
525       PetscCallMPI(MPI_Allgather(&nn, 1, MPI_INT, ncolsonproc, 1, MPI_INT, PetscObjectComm((PetscObject)mat)));
526       nctot = 0;
527       for (j = 0; j < size; j++) nctot += ncolsonproc[j];
528       if (!nctot) PetscCall(PetscInfo(mat, "Coloring of matrix has some unneeded colors with no corresponding rows\n"));
529 
530       disp[0] = 0;
531       for (j = 1; j < size; j++) disp[j] = disp[j - 1] + ncolsonproc[j - 1];
532 
533       /* Get cols, the complete list of columns for this color on each process */
534       PetscCall(PetscMalloc1(nctot + 1, &cols));
535       PetscCallMPI(MPI_Allgatherv((void *)is, n, MPIU_INT, cols, ncolsonproc, disp, MPIU_INT, PetscObjectComm((PetscObject)mat)));
536     } else if (ctype == IS_COLORING_LOCAL) {
537       /* Determine local number of columns of this color on this process, including ghost points */
538       nctot = n;
539       cols  = (PetscInt *)is;
540     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not provided for this MatFDColoring type");
541 
542     /* Mark all rows affect by these columns */
543     PetscCall(PetscArrayzero(rowhit, m));
544     bs2     = bs * bs;
545     nrows_i = 0;
546     for (j = 0; j < nctot; j++) { /* loop over columns*/
547       if (ctype == IS_COLORING_LOCAL) {
548         col = ltog[cols[j]];
549       } else {
550         col = cols[j];
551       }
552       if (col >= cstart && col < cend) { /* column is in A, diagonal block of mat */
553         tmp   = A_ci[col - cstart];
554         row   = A_cj + tmp;
555         nrows = A_ci[col - cstart + 1] - tmp;
556         nrows_i += nrows;
557 
558         /* loop over columns of A marking them in rowhit */
559         for (k = 0; k < nrows; k++) {
560           /* set valaddrhit for part A */
561           spidx            = bs2 * spidxA[tmp + k];
562           valaddrhit[*row] = &A_val[spidx];
563           rowhit[*row++]   = col - cstart + 1; /* local column index */
564         }
565       } else { /* column is in B, off-diagonal block of mat */
566 #if defined(PETSC_USE_CTABLE)
567         PetscCall(PetscTableFind(colmap, col + 1, &colb));
568         colb--;
569 #else
570         colb = colmap[col] - 1; /* local column index */
571 #endif
572         if (colb == -1) {
573           nrows = 0;
574         } else {
575           colb  = colb / bs;
576           tmp   = B_ci[colb];
577           row   = B_cj + tmp;
578           nrows = B_ci[colb + 1] - tmp;
579         }
580         nrows_i += nrows;
581         /* loop over columns of B marking them in rowhit */
582         for (k = 0; k < nrows; k++) {
583           /* set valaddrhit for part B */
584           spidx            = bs2 * spidxB[tmp + k];
585           valaddrhit[*row] = &B_val[spidx];
586           rowhit[*row++]   = colb + 1 + cend - cstart; /* local column index */
587         }
588       }
589     }
590     c->nrows[i] = nrows_i;
591 
592     if (c->htype[0] == 'd') {
593       for (j = 0; j < m; j++) {
594         if (rowhit[j]) {
595           Jentry[nz].row     = j;             /* local row index */
596           Jentry[nz].col     = rowhit[j] - 1; /* local column index */
597           Jentry[nz].valaddr = valaddrhit[j]; /* address of mat value for this entry */
598           nz++;
599         }
600       }
601     } else { /* c->htype == 'wp' */
602       for (j = 0; j < m; j++) {
603         if (rowhit[j]) {
604           Jentry2[nz].row     = j;             /* local row index */
605           Jentry2[nz].valaddr = valaddrhit[j]; /* address of mat value for this entry */
606           nz++;
607         }
608       }
609     }
610     if (ctype == IS_COLORING_GLOBAL) PetscCall(PetscFree(cols));
611   }
612   if (ctype == IS_COLORING_GLOBAL) PetscCall(PetscFree2(ncolsonproc, disp));
613 
614   if (bcols > 1) { /* reorder Jentry for faster MatFDColoringApply() */
615     PetscCall(MatFDColoringSetUpBlocked_AIJ_Private(mat, c, nz));
616   }
617 
618   if (isBAIJ) {
619     PetscCall(MatRestoreColumnIJ_SeqBAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
620     PetscCall(MatRestoreColumnIJ_SeqBAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
621     PetscCall(PetscMalloc1(bs * mat->rmap->n, &c->dy));
622   } else if (isSELL) {
623     PetscCall(MatRestoreColumnIJ_SeqSELL_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
624     PetscCall(MatRestoreColumnIJ_SeqSELL_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
625   } else {
626     PetscCall(MatRestoreColumnIJ_SeqAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
627     PetscCall(MatRestoreColumnIJ_SeqAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
628   }
629 
630   PetscCall(ISColoringRestoreIS(iscoloring, PETSC_OWN_POINTER, &c->isa));
631   PetscCall(PetscFree2(rowhit, valaddrhit));
632 
633   if (ctype == IS_COLORING_LOCAL) PetscCall(ISLocalToGlobalMappingRestoreIndices(map, &ltog));
634   PetscCall(PetscInfo(c, "ncolors %" PetscInt_FMT ", brows %" PetscInt_FMT " and bcols %" PetscInt_FMT " are used.\n", c->ncolors, c->brows, c->bcols));
635   PetscFunctionReturn(0);
636 }
637 
638 PetscErrorCode MatFDColoringCreate_MPIXAIJ(Mat mat, ISColoring iscoloring, MatFDColoring c) {
639   PetscInt  bs, nis = iscoloring->n, m = mat->rmap->n;
640   PetscBool isBAIJ, isSELL;
641 
642   PetscFunctionBegin;
643   /* set default brows and bcols for speedup inserting the dense matrix into sparse Jacobian;
644    bcols is chosen s.t. dy-array takes 50% of memory space as mat */
645   PetscCall(MatGetBlockSize(mat, &bs));
646   PetscCall(PetscObjectBaseTypeCompare((PetscObject)mat, MATMPIBAIJ, &isBAIJ));
647   PetscCall(PetscObjectTypeCompare((PetscObject)mat, MATMPISELL, &isSELL));
648   if (isBAIJ || m == 0) {
649     c->brows = m;
650     c->bcols = 1;
651   } else if (isSELL) {
652     /* bcols is chosen s.t. dy-array takes 50% of local memory space as mat */
653     Mat_MPISELL *sell = (Mat_MPISELL *)mat->data;
654     Mat_SeqSELL *spA, *spB;
655     Mat          A, B;
656     PetscInt     nz, brows, bcols;
657     PetscReal    mem;
658 
659     bs = 1; /* only bs=1 is supported for MPISELL matrix */
660 
661     A     = sell->A;
662     spA   = (Mat_SeqSELL *)A->data;
663     B     = sell->B;
664     spB   = (Mat_SeqSELL *)B->data;
665     nz    = spA->nz + spB->nz; /* total local nonzero entries of mat */
666     mem   = nz * (sizeof(PetscScalar) + sizeof(PetscInt)) + 3 * m * sizeof(PetscInt);
667     bcols = (PetscInt)(0.5 * mem / (m * sizeof(PetscScalar)));
668     brows = 1000 / bcols;
669     if (bcols > nis) bcols = nis;
670     if (brows == 0 || brows > m) brows = m;
671     c->brows = brows;
672     c->bcols = bcols;
673   } else { /* mpiaij matrix */
674     /* bcols is chosen s.t. dy-array takes 50% of local memory space as mat */
675     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)mat->data;
676     Mat_SeqAIJ *spA, *spB;
677     Mat         A, B;
678     PetscInt    nz, brows, bcols;
679     PetscReal   mem;
680 
681     bs = 1; /* only bs=1 is supported for MPIAIJ matrix */
682 
683     A     = aij->A;
684     spA   = (Mat_SeqAIJ *)A->data;
685     B     = aij->B;
686     spB   = (Mat_SeqAIJ *)B->data;
687     nz    = spA->nz + spB->nz; /* total local nonzero entries of mat */
688     mem   = nz * (sizeof(PetscScalar) + sizeof(PetscInt)) + 3 * m * sizeof(PetscInt);
689     bcols = (PetscInt)(0.5 * mem / (m * sizeof(PetscScalar)));
690     brows = 1000 / bcols;
691     if (bcols > nis) bcols = nis;
692     if (brows == 0 || brows > m) brows = m;
693     c->brows = brows;
694     c->bcols = bcols;
695   }
696 
697   c->M       = mat->rmap->N / bs; /* set the global rows and columns and local rows */
698   c->N       = mat->cmap->N / bs;
699   c->m       = mat->rmap->n / bs;
700   c->rstart  = mat->rmap->rstart / bs;
701   c->ncolors = nis;
702   PetscFunctionReturn(0);
703 }
704 
705 /*@C
706 
707     MatFDColoringSetValues - takes a matrix in compressed color format and enters the matrix into a PETSc `Mat`
708 
709    Collective on J
710 
711    Input Parameters:
712 +    J - the sparse matrix
713 .    coloring - created with `MatFDColoringCreate()` and a local coloring
714 -    y - column major storage of matrix values with one color of values per column, the number of rows of y should match
715          the number of local rows of J and the number of columns is the number of colors.
716 
717    Level: intermediate
718 
719    Notes: the matrix in compressed color format may come from an automatic differentiation code
720 
721    The code will be slightly faster if `MatFDColoringSetBlockSize`(coloring,`PETSC_DEFAULT`,nc); is called immediately after creating the coloring
722 
723 .seealso: `MatFDColoringCreate()`, `ISColoring`, `ISColoringCreate()`, `ISColoringSetType()`, `IS_COLORING_LOCAL`, `MatFDColoringSetBlockSize()`
724 @*/
725 PetscErrorCode MatFDColoringSetValues(Mat J, MatFDColoring coloring, const PetscScalar *y) {
726   MatEntry2      *Jentry2;
727   PetscInt        row, i, nrows_k, l, ncolors, nz = 0, bcols, nbcols = 0;
728   const PetscInt *nrows;
729   PetscBool       eq;
730 
731   PetscFunctionBegin;
732   PetscValidHeaderSpecific(J, MAT_CLASSID, 1);
733   PetscValidHeaderSpecific(coloring, MAT_FDCOLORING_CLASSID, 2);
734   PetscCall(PetscObjectCompareId((PetscObject)J, coloring->matid, &eq));
735   PetscCheck(eq, PetscObjectComm((PetscObject)J), PETSC_ERR_ARG_WRONG, "Matrix used with MatFDColoringSetValues() must be that used with MatFDColoringCreate()");
736   Jentry2 = coloring->matentry2;
737   nrows   = coloring->nrows;
738   ncolors = coloring->ncolors;
739   bcols   = coloring->bcols;
740 
741   for (i = 0; i < ncolors; i += bcols) {
742     nrows_k = nrows[nbcols++];
743     for (l = 0; l < nrows_k; l++) {
744       row                      = Jentry2[nz].row; /* local row index */
745       *(Jentry2[nz++].valaddr) = y[row];
746     }
747     y += bcols * coloring->m;
748   }
749   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
750   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
751   PetscFunctionReturn(0);
752 }
753