xref: /petsc/src/ksp/pc/impls/pbjacobi/pbjacobi.c (revision bd89dbf26d8a5efecb980364933175da61864cd7)
1 #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
2 #include <petsc/private/matimpl.h>
3 
PCApply_PBJacobi(PC pc,Vec x,Vec y)4 static PetscErrorCode PCApply_PBJacobi(PC pc, Vec x, Vec y)
5 {
6   PC_PBJacobi       *jac = (PC_PBJacobi *)pc->data;
7   PetscInt           i, ib, jb;
8   const PetscInt     m    = jac->mbs;
9   const PetscInt     bs   = jac->bs;
10   const MatScalar   *diag = jac->diag;
11   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
12   const PetscScalar *xx;
13 
14   PetscFunctionBegin;
15   PetscCall(VecGetArrayRead(x, &xx));
16   PetscCall(VecGetArray(y, &yy));
17   switch (bs) {
18   case 1:
19     for (i = 0; i < m; i++) yy[i] = diag[i] * xx[i];
20     break;
21   case 2:
22     for (i = 0; i < m; i++) {
23       x0            = xx[2 * i];
24       x1            = xx[2 * i + 1];
25       yy[2 * i]     = diag[0] * x0 + diag[2] * x1;
26       yy[2 * i + 1] = diag[1] * x0 + diag[3] * x1;
27       diag += 4;
28     }
29     break;
30   case 3:
31     for (i = 0; i < m; i++) {
32       x0 = xx[3 * i];
33       x1 = xx[3 * i + 1];
34       x2 = xx[3 * i + 2];
35 
36       yy[3 * i]     = diag[0] * x0 + diag[3] * x1 + diag[6] * x2;
37       yy[3 * i + 1] = diag[1] * x0 + diag[4] * x1 + diag[7] * x2;
38       yy[3 * i + 2] = diag[2] * x0 + diag[5] * x1 + diag[8] * x2;
39       diag += 9;
40     }
41     break;
42   case 4:
43     for (i = 0; i < m; i++) {
44       x0 = xx[4 * i];
45       x1 = xx[4 * i + 1];
46       x2 = xx[4 * i + 2];
47       x3 = xx[4 * i + 3];
48 
49       yy[4 * i]     = diag[0] * x0 + diag[4] * x1 + diag[8] * x2 + diag[12] * x3;
50       yy[4 * i + 1] = diag[1] * x0 + diag[5] * x1 + diag[9] * x2 + diag[13] * x3;
51       yy[4 * i + 2] = diag[2] * x0 + diag[6] * x1 + diag[10] * x2 + diag[14] * x3;
52       yy[4 * i + 3] = diag[3] * x0 + diag[7] * x1 + diag[11] * x2 + diag[15] * x3;
53       diag += 16;
54     }
55     break;
56   case 5:
57     for (i = 0; i < m; i++) {
58       x0 = xx[5 * i];
59       x1 = xx[5 * i + 1];
60       x2 = xx[5 * i + 2];
61       x3 = xx[5 * i + 3];
62       x4 = xx[5 * i + 4];
63 
64       yy[5 * i]     = diag[0] * x0 + diag[5] * x1 + diag[10] * x2 + diag[15] * x3 + diag[20] * x4;
65       yy[5 * i + 1] = diag[1] * x0 + diag[6] * x1 + diag[11] * x2 + diag[16] * x3 + diag[21] * x4;
66       yy[5 * i + 2] = diag[2] * x0 + diag[7] * x1 + diag[12] * x2 + diag[17] * x3 + diag[22] * x4;
67       yy[5 * i + 3] = diag[3] * x0 + diag[8] * x1 + diag[13] * x2 + diag[18] * x3 + diag[23] * x4;
68       yy[5 * i + 4] = diag[4] * x0 + diag[9] * x1 + diag[14] * x2 + diag[19] * x3 + diag[24] * x4;
69       diag += 25;
70     }
71     break;
72   case 6:
73     for (i = 0; i < m; i++) {
74       x0 = xx[6 * i];
75       x1 = xx[6 * i + 1];
76       x2 = xx[6 * i + 2];
77       x3 = xx[6 * i + 3];
78       x4 = xx[6 * i + 4];
79       x5 = xx[6 * i + 5];
80 
81       yy[6 * i]     = diag[0] * x0 + diag[6] * x1 + diag[12] * x2 + diag[18] * x3 + diag[24] * x4 + diag[30] * x5;
82       yy[6 * i + 1] = diag[1] * x0 + diag[7] * x1 + diag[13] * x2 + diag[19] * x3 + diag[25] * x4 + diag[31] * x5;
83       yy[6 * i + 2] = diag[2] * x0 + diag[8] * x1 + diag[14] * x2 + diag[20] * x3 + diag[26] * x4 + diag[32] * x5;
84       yy[6 * i + 3] = diag[3] * x0 + diag[9] * x1 + diag[15] * x2 + diag[21] * x3 + diag[27] * x4 + diag[33] * x5;
85       yy[6 * i + 4] = diag[4] * x0 + diag[10] * x1 + diag[16] * x2 + diag[22] * x3 + diag[28] * x4 + diag[34] * x5;
86       yy[6 * i + 5] = diag[5] * x0 + diag[11] * x1 + diag[17] * x2 + diag[23] * x3 + diag[29] * x4 + diag[35] * x5;
87       diag += 36;
88     }
89     break;
90   case 7:
91     for (i = 0; i < m; i++) {
92       x0 = xx[7 * i];
93       x1 = xx[7 * i + 1];
94       x2 = xx[7 * i + 2];
95       x3 = xx[7 * i + 3];
96       x4 = xx[7 * i + 4];
97       x5 = xx[7 * i + 5];
98       x6 = xx[7 * i + 6];
99 
100       yy[7 * i]     = diag[0] * x0 + diag[7] * x1 + diag[14] * x2 + diag[21] * x3 + diag[28] * x4 + diag[35] * x5 + diag[42] * x6;
101       yy[7 * i + 1] = diag[1] * x0 + diag[8] * x1 + diag[15] * x2 + diag[22] * x3 + diag[29] * x4 + diag[36] * x5 + diag[43] * x6;
102       yy[7 * i + 2] = diag[2] * x0 + diag[9] * x1 + diag[16] * x2 + diag[23] * x3 + diag[30] * x4 + diag[37] * x5 + diag[44] * x6;
103       yy[7 * i + 3] = diag[3] * x0 + diag[10] * x1 + diag[17] * x2 + diag[24] * x3 + diag[31] * x4 + diag[38] * x5 + diag[45] * x6;
104       yy[7 * i + 4] = diag[4] * x0 + diag[11] * x1 + diag[18] * x2 + diag[25] * x3 + diag[32] * x4 + diag[39] * x5 + diag[46] * x6;
105       yy[7 * i + 5] = diag[5] * x0 + diag[12] * x1 + diag[19] * x2 + diag[26] * x3 + diag[33] * x4 + diag[40] * x5 + diag[47] * x6;
106       yy[7 * i + 6] = diag[6] * x0 + diag[13] * x1 + diag[20] * x2 + diag[27] * x3 + diag[34] * x4 + diag[41] * x5 + diag[48] * x6;
107       diag += 49;
108     }
109     break;
110   default:
111     for (i = 0; i < m; i++) {
112       for (ib = 0; ib < bs; ib++) {
113         PetscScalar rowsum = 0;
114         for (jb = 0; jb < bs; jb++) rowsum += diag[ib + jb * bs] * xx[bs * i + jb];
115         yy[bs * i + ib] = rowsum;
116       }
117       diag += bs * bs;
118     }
119   }
120   PetscCall(VecRestoreArrayRead(x, &xx));
121   PetscCall(VecRestoreArray(y, &yy));
122   PetscCall(PetscLogFlops((2.0 * bs * bs - bs) * m)); /* 2*bs2 - bs */
123   PetscFunctionReturn(PETSC_SUCCESS);
124 }
125 
PCApplyTranspose_PBJacobi(PC pc,Vec x,Vec y)126 static PetscErrorCode PCApplyTranspose_PBJacobi(PC pc, Vec x, Vec y)
127 {
128   PC_PBJacobi       *jac = (PC_PBJacobi *)pc->data;
129   PetscInt           i, ib, jb;
130   const PetscInt     m    = jac->mbs;
131   const PetscInt     bs   = jac->bs;
132   const MatScalar   *diag = jac->diag;
133   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
134   const PetscScalar *xx;
135 
136   PetscFunctionBegin;
137   PetscCall(VecGetArrayRead(x, &xx));
138   PetscCall(VecGetArray(y, &yy));
139   switch (bs) {
140   case 1:
141     for (i = 0; i < m; i++) yy[i] = diag[i] * xx[i];
142     break;
143   case 2:
144     for (i = 0; i < m; i++) {
145       x0            = xx[2 * i];
146       x1            = xx[2 * i + 1];
147       yy[2 * i]     = diag[0] * x0 + diag[1] * x1;
148       yy[2 * i + 1] = diag[2] * x0 + diag[3] * x1;
149       diag += 4;
150     }
151     break;
152   case 3:
153     for (i = 0; i < m; i++) {
154       x0 = xx[3 * i];
155       x1 = xx[3 * i + 1];
156       x2 = xx[3 * i + 2];
157 
158       yy[3 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2;
159       yy[3 * i + 1] = diag[3] * x0 + diag[4] * x1 + diag[5] * x2;
160       yy[3 * i + 2] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2;
161       diag += 9;
162     }
163     break;
164   case 4:
165     for (i = 0; i < m; i++) {
166       x0 = xx[4 * i];
167       x1 = xx[4 * i + 1];
168       x2 = xx[4 * i + 2];
169       x3 = xx[4 * i + 3];
170 
171       yy[4 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3;
172       yy[4 * i + 1] = diag[4] * x0 + diag[5] * x1 + diag[6] * x2 + diag[7] * x3;
173       yy[4 * i + 2] = diag[8] * x0 + diag[9] * x1 + diag[10] * x2 + diag[11] * x3;
174       yy[4 * i + 3] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3;
175       diag += 16;
176     }
177     break;
178   case 5:
179     for (i = 0; i < m; i++) {
180       x0 = xx[5 * i];
181       x1 = xx[5 * i + 1];
182       x2 = xx[5 * i + 2];
183       x3 = xx[5 * i + 3];
184       x4 = xx[5 * i + 4];
185 
186       yy[5 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4;
187       yy[5 * i + 1] = diag[5] * x0 + diag[6] * x1 + diag[7] * x2 + diag[8] * x3 + diag[9] * x4;
188       yy[5 * i + 2] = diag[10] * x0 + diag[11] * x1 + diag[12] * x2 + diag[13] * x3 + diag[14] * x4;
189       yy[5 * i + 3] = diag[15] * x0 + diag[16] * x1 + diag[17] * x2 + diag[18] * x3 + diag[19] * x4;
190       yy[5 * i + 4] = diag[20] * x0 + diag[21] * x1 + diag[22] * x2 + diag[23] * x3 + diag[24] * x4;
191       diag += 25;
192     }
193     break;
194   case 6:
195     for (i = 0; i < m; i++) {
196       x0 = xx[6 * i];
197       x1 = xx[6 * i + 1];
198       x2 = xx[6 * i + 2];
199       x3 = xx[6 * i + 3];
200       x4 = xx[6 * i + 4];
201       x5 = xx[6 * i + 5];
202 
203       yy[6 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5;
204       yy[6 * i + 1] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2 + diag[9] * x3 + diag[10] * x4 + diag[11] * x5;
205       yy[6 * i + 2] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3 + diag[16] * x4 + diag[17] * x5;
206       yy[6 * i + 3] = diag[18] * x0 + diag[19] * x1 + diag[20] * x2 + diag[21] * x3 + diag[22] * x4 + diag[23] * x5;
207       yy[6 * i + 4] = diag[24] * x0 + diag[25] * x1 + diag[26] * x2 + diag[27] * x3 + diag[28] * x4 + diag[29] * x5;
208       yy[6 * i + 5] = diag[30] * x0 + diag[31] * x1 + diag[32] * x2 + diag[33] * x3 + diag[34] * x4 + diag[35] * x5;
209       diag += 36;
210     }
211     break;
212   case 7:
213     for (i = 0; i < m; i++) {
214       x0 = xx[7 * i];
215       x1 = xx[7 * i + 1];
216       x2 = xx[7 * i + 2];
217       x3 = xx[7 * i + 3];
218       x4 = xx[7 * i + 4];
219       x5 = xx[7 * i + 5];
220       x6 = xx[7 * i + 6];
221 
222       yy[7 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5 + diag[6] * x6;
223       yy[7 * i + 1] = diag[7] * x0 + diag[8] * x1 + diag[9] * x2 + diag[10] * x3 + diag[11] * x4 + diag[12] * x5 + diag[13] * x6;
224       yy[7 * i + 2] = diag[14] * x0 + diag[15] * x1 + diag[16] * x2 + diag[17] * x3 + diag[18] * x4 + diag[19] * x5 + diag[20] * x6;
225       yy[7 * i + 3] = diag[21] * x0 + diag[22] * x1 + diag[23] * x2 + diag[24] * x3 + diag[25] * x4 + diag[26] * x5 + diag[27] * x6;
226       yy[7 * i + 4] = diag[28] * x0 + diag[29] * x1 + diag[30] * x2 + diag[31] * x3 + diag[32] * x4 + diag[33] * x5 + diag[34] * x6;
227       yy[7 * i + 5] = diag[35] * x0 + diag[36] * x1 + diag[37] * x2 + diag[38] * x3 + diag[39] * x4 + diag[40] * x5 + diag[41] * x6;
228       yy[7 * i + 6] = diag[42] * x0 + diag[43] * x1 + diag[44] * x2 + diag[45] * x3 + diag[46] * x4 + diag[47] * x5 + diag[48] * x6;
229       diag += 49;
230     }
231     break;
232   default:
233     for (i = 0; i < m; i++) {
234       for (ib = 0; ib < bs; ib++) {
235         PetscScalar rowsum = 0;
236         for (jb = 0; jb < bs; jb++) rowsum += diag[ib * bs + jb] * xx[bs * i + jb];
237         yy[bs * i + ib] = rowsum;
238       }
239       diag += bs * bs;
240     }
241   }
242   PetscCall(VecRestoreArrayRead(x, &xx));
243   PetscCall(VecRestoreArray(y, &yy));
244   PetscCall(PetscLogFlops((2.0 * bs * bs - bs) * m));
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
PCSetUp_PBJacobi_Host(PC pc,Mat diagPB)248 PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Host(PC pc, Mat diagPB)
249 {
250   PC_PBJacobi   *jac = (PC_PBJacobi *)pc->data;
251   Mat            A   = diagPB ? diagPB : pc->pmat;
252   MatFactorError err;
253   PetscInt       nlocal;
254 
255   PetscFunctionBegin;
256   PetscCall(MatInvertBlockDiagonal(A, &jac->diag));
257   PetscCall(MatFactorGetError(A, &err));
258   if (err) pc->failedreason = (PCFailedReason)err;
259 
260   PetscCall(MatGetBlockSize(A, &jac->bs));
261   PetscCall(MatGetLocalSize(A, &nlocal, NULL));
262   jac->mbs = nlocal / jac->bs;
263   PetscFunctionReturn(PETSC_SUCCESS);
264 }
265 
PCSetUp_PBJacobi(PC pc)266 static PetscErrorCode PCSetUp_PBJacobi(PC pc)
267 {
268   PetscBool flg;
269   Mat       diagPB = NULL;
270 
271   PetscFunctionBegin;
272   /* In PCCreate_PBJacobi() pmat might have not been set, so we wait to the last minute to do the dispatch */
273 
274   // pmat (e.g., MatCEED from libCEED) might have its own method to provide a matrix (diagPB)
275   // made of the diagonal blocks. So we check both pmat and diagVPB.
276   PetscCall(MatHasOperation(pc->pmat, MATOP_GET_BLOCK_DIAGONAL, &flg));
277   if (flg) PetscUseTypeMethod(pc->pmat, getblockdiagonal, &diagPB); // diagPB's reference count is increased upon return
278 
279 #if defined(PETSC_HAVE_CUDA)
280   PetscBool isCuda;
281   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
282   if (!isCuda && diagPB) PetscCall(PetscObjectTypeCompareAny((PetscObject)diagPB, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
283 #endif
284 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
285   PetscBool isKok;
286   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
287   if (!isKok && diagPB) PetscCall(PetscObjectTypeCompareAny((PetscObject)diagPB, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
288 #endif
289 
290 #if defined(PETSC_HAVE_CUDA)
291   if (isCuda) PetscCall(PCSetUp_PBJacobi_CUDA(pc, diagPB));
292   else
293 #endif
294 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
295     if (isKok)
296     PetscCall(PCSetUp_PBJacobi_Kokkos(pc, diagPB));
297   else
298 #endif
299   {
300     PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB));
301   }
302   PetscCall(MatDestroy(&diagPB)); // since we don't need it anymore, we don't stash it in PC_PBJacobi
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
PCDestroy_PBJacobi(PC pc)306 PetscErrorCode PCDestroy_PBJacobi(PC pc)
307 {
308   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
309 
310   PetscFunctionBegin;
311   /*
312       Free the private data structure that was hanging off the PC
313   */
314   // PetscCall(PetscFree(jac->diag)); // the memory is owned by e.g., a->ibdiag in Mat_SeqAIJ, so don't free it here.
315   PetscCall(MatDestroy(&jac->diagPB));
316   PetscCall(PetscFree(pc->data));
317   PetscFunctionReturn(PETSC_SUCCESS);
318 }
319 
PCView_PBJacobi(PC pc,PetscViewer viewer)320 static PetscErrorCode PCView_PBJacobi(PC pc, PetscViewer viewer)
321 {
322   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
323   PetscBool    isascii;
324 
325   PetscFunctionBegin;
326   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
327   if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "  point-block size %" PetscInt_FMT "\n", jac->bs));
328   PetscFunctionReturn(PETSC_SUCCESS);
329 }
330 
331 /*MC
332      PCPBJACOBI - Point block Jacobi preconditioner
333 
334    Notes:
335     See `PCJACOBI` for diagonal Jacobi, `PCVPBJACOBI` for variable-size point block, and `PCBJACOBI` for large size blocks
336 
337    This works for `MATAIJ` and `MATBAIJ` matrices and uses the blocksize provided to the matrix
338 
339    Uses dense LU factorization with partial pivoting to invert the blocks; if a zero pivot
340    is detected a PETSc error is generated.
341 
342    Developer Notes:
343      This should support the `PCSetErrorIfFailure()` flag set to `PETSC_TRUE` to allow
344      the factorization to continue even after a zero pivot is found resulting in a NaN and hence
345      terminating `KSP` with a `KSP_DIVERGED_NANORINF` allowing
346      a nonlinear solver/ODE integrator to recover without stopping the program as currently happens.
347 
348      Perhaps should provide an option that allows generation of a valid preconditioner
349      even if a block is singular as the `PCJACOBI` does.
350 
351    Level: beginner
352 
353 .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCJACOBI`, `PCVPBJACOBI`, `PCBJACOBI`
354 M*/
355 
PCCreate_PBJacobi(PC pc)356 PETSC_EXTERN PetscErrorCode PCCreate_PBJacobi(PC pc)
357 {
358   PC_PBJacobi *jac;
359 
360   PetscFunctionBegin;
361   /*
362      Creates the private data structure for this preconditioner and
363      attach it to the PC object.
364   */
365   PetscCall(PetscNew(&jac));
366   pc->data = (void *)jac;
367 
368   /*
369      Initialize the pointers to vectors to ZERO; these will be used to store
370      diagonal entries of the matrix for fast preconditioner application.
371   */
372   jac->diag = NULL;
373 
374   /*
375       Set the pointers for the functions that are provided above.
376       Now when the user-level routines (such as PCApply(), PCDestroy(), etc.)
377       are called, they will automatically call these functions.  Note we
378       choose not to provide a couple of these functions since they are
379       not needed.
380   */
381   pc->ops->apply               = PCApply_PBJacobi;
382   pc->ops->applytranspose      = PCApplyTranspose_PBJacobi;
383   pc->ops->setup               = PCSetUp_PBJacobi;
384   pc->ops->destroy             = PCDestroy_PBJacobi;
385   pc->ops->setfromoptions      = NULL;
386   pc->ops->view                = PCView_PBJacobi;
387   pc->ops->applyrichardson     = NULL;
388   pc->ops->applysymmetricleft  = NULL;
389   pc->ops->applysymmetricright = NULL;
390   PetscFunctionReturn(PETSC_SUCCESS);
391 }
392