xref: /petsc/src/ksp/pc/impls/pbjacobi/pbjacobi.c (revision d9acb416d05abeed0a33bde3a81aeb2ea0364f6a) !
1 #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
2 
3 static PetscErrorCode PCApply_PBJacobi(PC pc, Vec x, Vec y)
4 {
5   PC_PBJacobi       *jac = (PC_PBJacobi *)pc->data;
6   PetscInt           i, ib, jb;
7   const PetscInt     m    = jac->mbs;
8   const PetscInt     bs   = jac->bs;
9   const MatScalar   *diag = jac->diag;
10   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
11   const PetscScalar *xx;
12 
13   PetscFunctionBegin;
14   PetscCall(VecGetArrayRead(x, &xx));
15   PetscCall(VecGetArray(y, &yy));
16   switch (bs) {
17   case 1:
18     for (i = 0; i < m; i++) yy[i] = diag[i] * xx[i];
19     break;
20   case 2:
21     for (i = 0; i < m; i++) {
22       x0            = xx[2 * i];
23       x1            = xx[2 * i + 1];
24       yy[2 * i]     = diag[0] * x0 + diag[2] * x1;
25       yy[2 * i + 1] = diag[1] * x0 + diag[3] * x1;
26       diag += 4;
27     }
28     break;
29   case 3:
30     for (i = 0; i < m; i++) {
31       x0 = xx[3 * i];
32       x1 = xx[3 * i + 1];
33       x2 = xx[3 * i + 2];
34 
35       yy[3 * i]     = diag[0] * x0 + diag[3] * x1 + diag[6] * x2;
36       yy[3 * i + 1] = diag[1] * x0 + diag[4] * x1 + diag[7] * x2;
37       yy[3 * i + 2] = diag[2] * x0 + diag[5] * x1 + diag[8] * x2;
38       diag += 9;
39     }
40     break;
41   case 4:
42     for (i = 0; i < m; i++) {
43       x0 = xx[4 * i];
44       x1 = xx[4 * i + 1];
45       x2 = xx[4 * i + 2];
46       x3 = xx[4 * i + 3];
47 
48       yy[4 * i]     = diag[0] * x0 + diag[4] * x1 + diag[8] * x2 + diag[12] * x3;
49       yy[4 * i + 1] = diag[1] * x0 + diag[5] * x1 + diag[9] * x2 + diag[13] * x3;
50       yy[4 * i + 2] = diag[2] * x0 + diag[6] * x1 + diag[10] * x2 + diag[14] * x3;
51       yy[4 * i + 3] = diag[3] * x0 + diag[7] * x1 + diag[11] * x2 + diag[15] * x3;
52       diag += 16;
53     }
54     break;
55   case 5:
56     for (i = 0; i < m; i++) {
57       x0 = xx[5 * i];
58       x1 = xx[5 * i + 1];
59       x2 = xx[5 * i + 2];
60       x3 = xx[5 * i + 3];
61       x4 = xx[5 * i + 4];
62 
63       yy[5 * i]     = diag[0] * x0 + diag[5] * x1 + diag[10] * x2 + diag[15] * x3 + diag[20] * x4;
64       yy[5 * i + 1] = diag[1] * x0 + diag[6] * x1 + diag[11] * x2 + diag[16] * x3 + diag[21] * x4;
65       yy[5 * i + 2] = diag[2] * x0 + diag[7] * x1 + diag[12] * x2 + diag[17] * x3 + diag[22] * x4;
66       yy[5 * i + 3] = diag[3] * x0 + diag[8] * x1 + diag[13] * x2 + diag[18] * x3 + diag[23] * x4;
67       yy[5 * i + 4] = diag[4] * x0 + diag[9] * x1 + diag[14] * x2 + diag[19] * x3 + diag[24] * x4;
68       diag += 25;
69     }
70     break;
71   case 6:
72     for (i = 0; i < m; i++) {
73       x0 = xx[6 * i];
74       x1 = xx[6 * i + 1];
75       x2 = xx[6 * i + 2];
76       x3 = xx[6 * i + 3];
77       x4 = xx[6 * i + 4];
78       x5 = xx[6 * i + 5];
79 
80       yy[6 * i]     = diag[0] * x0 + diag[6] * x1 + diag[12] * x2 + diag[18] * x3 + diag[24] * x4 + diag[30] * x5;
81       yy[6 * i + 1] = diag[1] * x0 + diag[7] * x1 + diag[13] * x2 + diag[19] * x3 + diag[25] * x4 + diag[31] * x5;
82       yy[6 * i + 2] = diag[2] * x0 + diag[8] * x1 + diag[14] * x2 + diag[20] * x3 + diag[26] * x4 + diag[32] * x5;
83       yy[6 * i + 3] = diag[3] * x0 + diag[9] * x1 + diag[15] * x2 + diag[21] * x3 + diag[27] * x4 + diag[33] * x5;
84       yy[6 * i + 4] = diag[4] * x0 + diag[10] * x1 + diag[16] * x2 + diag[22] * x3 + diag[28] * x4 + diag[34] * x5;
85       yy[6 * i + 5] = diag[5] * x0 + diag[11] * x1 + diag[17] * x2 + diag[23] * x3 + diag[29] * x4 + diag[35] * x5;
86       diag += 36;
87     }
88     break;
89   case 7:
90     for (i = 0; i < m; i++) {
91       x0 = xx[7 * i];
92       x1 = xx[7 * i + 1];
93       x2 = xx[7 * i + 2];
94       x3 = xx[7 * i + 3];
95       x4 = xx[7 * i + 4];
96       x5 = xx[7 * i + 5];
97       x6 = xx[7 * i + 6];
98 
99       yy[7 * i]     = diag[0] * x0 + diag[7] * x1 + diag[14] * x2 + diag[21] * x3 + diag[28] * x4 + diag[35] * x5 + diag[42] * x6;
100       yy[7 * i + 1] = diag[1] * x0 + diag[8] * x1 + diag[15] * x2 + diag[22] * x3 + diag[29] * x4 + diag[36] * x5 + diag[43] * x6;
101       yy[7 * i + 2] = diag[2] * x0 + diag[9] * x1 + diag[16] * x2 + diag[23] * x3 + diag[30] * x4 + diag[37] * x5 + diag[44] * x6;
102       yy[7 * i + 3] = diag[3] * x0 + diag[10] * x1 + diag[17] * x2 + diag[24] * x3 + diag[31] * x4 + diag[38] * x5 + diag[45] * x6;
103       yy[7 * i + 4] = diag[4] * x0 + diag[11] * x1 + diag[18] * x2 + diag[25] * x3 + diag[32] * x4 + diag[39] * x5 + diag[46] * x6;
104       yy[7 * i + 5] = diag[5] * x0 + diag[12] * x1 + diag[19] * x2 + diag[26] * x3 + diag[33] * x4 + diag[40] * x5 + diag[47] * x6;
105       yy[7 * i + 6] = diag[6] * x0 + diag[13] * x1 + diag[20] * x2 + diag[27] * x3 + diag[34] * x4 + diag[41] * x5 + diag[48] * x6;
106       diag += 49;
107     }
108     break;
109   default:
110     for (i = 0; i < m; i++) {
111       for (ib = 0; ib < bs; ib++) {
112         PetscScalar rowsum = 0;
113         for (jb = 0; jb < bs; jb++) rowsum += diag[ib + jb * bs] * xx[bs * i + jb];
114         yy[bs * i + ib] = rowsum;
115       }
116       diag += bs * bs;
117     }
118   }
119   PetscCall(VecRestoreArrayRead(x, &xx));
120   PetscCall(VecRestoreArray(y, &yy));
121   PetscCall(PetscLogFlops((2.0 * bs * bs - bs) * m)); /* 2*bs2 - bs */
122   PetscFunctionReturn(PETSC_SUCCESS);
123 }
124 
125 static PetscErrorCode PCApplyTranspose_PBJacobi(PC pc, Vec x, Vec y)
126 {
127   PC_PBJacobi       *jac = (PC_PBJacobi *)pc->data;
128   PetscInt           i, ib, jb;
129   const PetscInt     m    = jac->mbs;
130   const PetscInt     bs   = jac->bs;
131   const MatScalar   *diag = jac->diag;
132   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
133   const PetscScalar *xx;
134 
135   PetscFunctionBegin;
136   PetscCall(VecGetArrayRead(x, &xx));
137   PetscCall(VecGetArray(y, &yy));
138   switch (bs) {
139   case 1:
140     for (i = 0; i < m; i++) yy[i] = diag[i] * xx[i];
141     break;
142   case 2:
143     for (i = 0; i < m; i++) {
144       x0            = xx[2 * i];
145       x1            = xx[2 * i + 1];
146       yy[2 * i]     = diag[0] * x0 + diag[1] * x1;
147       yy[2 * i + 1] = diag[2] * x0 + diag[3] * x1;
148       diag += 4;
149     }
150     break;
151   case 3:
152     for (i = 0; i < m; i++) {
153       x0 = xx[3 * i];
154       x1 = xx[3 * i + 1];
155       x2 = xx[3 * i + 2];
156 
157       yy[3 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2;
158       yy[3 * i + 1] = diag[3] * x0 + diag[4] * x1 + diag[5] * x2;
159       yy[3 * i + 2] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2;
160       diag += 9;
161     }
162     break;
163   case 4:
164     for (i = 0; i < m; i++) {
165       x0 = xx[4 * i];
166       x1 = xx[4 * i + 1];
167       x2 = xx[4 * i + 2];
168       x3 = xx[4 * i + 3];
169 
170       yy[4 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3;
171       yy[4 * i + 1] = diag[4] * x0 + diag[5] * x1 + diag[6] * x2 + diag[7] * x3;
172       yy[4 * i + 2] = diag[8] * x0 + diag[9] * x1 + diag[10] * x2 + diag[11] * x3;
173       yy[4 * i + 3] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3;
174       diag += 16;
175     }
176     break;
177   case 5:
178     for (i = 0; i < m; i++) {
179       x0 = xx[5 * i];
180       x1 = xx[5 * i + 1];
181       x2 = xx[5 * i + 2];
182       x3 = xx[5 * i + 3];
183       x4 = xx[5 * i + 4];
184 
185       yy[5 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4;
186       yy[5 * i + 1] = diag[5] * x0 + diag[6] * x1 + diag[7] * x2 + diag[8] * x3 + diag[9] * x4;
187       yy[5 * i + 2] = diag[10] * x0 + diag[11] * x1 + diag[12] * x2 + diag[13] * x3 + diag[14] * x4;
188       yy[5 * i + 3] = diag[15] * x0 + diag[16] * x1 + diag[17] * x2 + diag[18] * x3 + diag[19] * x4;
189       yy[5 * i + 4] = diag[20] * x0 + diag[21] * x1 + diag[22] * x2 + diag[23] * x3 + diag[24] * x4;
190       diag += 25;
191     }
192     break;
193   case 6:
194     for (i = 0; i < m; i++) {
195       x0 = xx[6 * i];
196       x1 = xx[6 * i + 1];
197       x2 = xx[6 * i + 2];
198       x3 = xx[6 * i + 3];
199       x4 = xx[6 * i + 4];
200       x5 = xx[6 * i + 5];
201 
202       yy[6 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5;
203       yy[6 * i + 1] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2 + diag[9] * x3 + diag[10] * x4 + diag[11] * x5;
204       yy[6 * i + 2] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3 + diag[16] * x4 + diag[17] * x5;
205       yy[6 * i + 3] = diag[18] * x0 + diag[19] * x1 + diag[20] * x2 + diag[21] * x3 + diag[22] * x4 + diag[23] * x5;
206       yy[6 * i + 4] = diag[24] * x0 + diag[25] * x1 + diag[26] * x2 + diag[27] * x3 + diag[28] * x4 + diag[29] * x5;
207       yy[6 * i + 5] = diag[30] * x0 + diag[31] * x1 + diag[32] * x2 + diag[33] * x3 + diag[34] * x4 + diag[35] * x5;
208       diag += 36;
209     }
210     break;
211   case 7:
212     for (i = 0; i < m; i++) {
213       x0 = xx[7 * i];
214       x1 = xx[7 * i + 1];
215       x2 = xx[7 * i + 2];
216       x3 = xx[7 * i + 3];
217       x4 = xx[7 * i + 4];
218       x5 = xx[7 * i + 5];
219       x6 = xx[7 * i + 6];
220 
221       yy[7 * i]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5 + diag[6] * x6;
222       yy[7 * i + 1] = diag[7] * x0 + diag[8] * x1 + diag[9] * x2 + diag[10] * x3 + diag[11] * x4 + diag[12] * x5 + diag[13] * x6;
223       yy[7 * i + 2] = diag[14] * x0 + diag[15] * x1 + diag[16] * x2 + diag[17] * x3 + diag[18] * x4 + diag[19] * x5 + diag[20] * x6;
224       yy[7 * i + 3] = diag[21] * x0 + diag[22] * x1 + diag[23] * x2 + diag[24] * x3 + diag[25] * x4 + diag[26] * x5 + diag[27] * x6;
225       yy[7 * i + 4] = diag[28] * x0 + diag[29] * x1 + diag[30] * x2 + diag[31] * x3 + diag[32] * x4 + diag[33] * x5 + diag[34] * x6;
226       yy[7 * i + 5] = diag[35] * x0 + diag[36] * x1 + diag[37] * x2 + diag[38] * x3 + diag[39] * x4 + diag[40] * x5 + diag[41] * x6;
227       yy[7 * i + 6] = diag[42] * x0 + diag[43] * x1 + diag[44] * x2 + diag[45] * x3 + diag[46] * x4 + diag[47] * x5 + diag[48] * x6;
228       diag += 49;
229     }
230     break;
231   default:
232     for (i = 0; i < m; i++) {
233       for (ib = 0; ib < bs; ib++) {
234         PetscScalar rowsum = 0;
235         for (jb = 0; jb < bs; jb++) rowsum += diag[ib * bs + jb] * xx[bs * i + jb];
236         yy[bs * i + ib] = rowsum;
237       }
238       diag += bs * bs;
239     }
240   }
241   PetscCall(VecRestoreArrayRead(x, &xx));
242   PetscCall(VecRestoreArray(y, &yy));
243   PetscCall(PetscLogFlops((2.0 * bs * bs - bs) * m));
244   PetscFunctionReturn(PETSC_SUCCESS);
245 }
246 
247 PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Host(PC pc)
248 {
249   PC_PBJacobi   *jac = (PC_PBJacobi *)pc->data;
250   Mat            A   = pc->pmat;
251   MatFactorError err;
252   PetscInt       nlocal;
253 
254   PetscFunctionBegin;
255   PetscCall(MatInvertBlockDiagonal(A, &jac->diag));
256   PetscCall(MatFactorGetError(A, &err));
257   if (err) pc->failedreason = (PCFailedReason)err;
258 
259   PetscCall(MatGetBlockSize(A, &jac->bs));
260   PetscCall(MatGetLocalSize(A, &nlocal, NULL));
261   jac->mbs = nlocal / jac->bs;
262   PetscFunctionReturn(PETSC_SUCCESS);
263 }
264 
265 static PetscErrorCode PCSetUp_PBJacobi(PC pc)
266 {
267   PetscFunctionBegin;
268   /* In PCCreate_PBJacobi() pmat might have not been set, so we wait to the last minute to do the dispatch */
269 #if defined(PETSC_HAVE_CUDA)
270   PetscBool isCuda;
271   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
272 #endif
273 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
274   PetscBool isKok;
275   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
276 #endif
277 
278 #if defined(PETSC_HAVE_CUDA)
279   if (isCuda) PetscCall(PCSetUp_PBJacobi_CUDA(pc));
280   else
281 #endif
282 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
283     if (isKok)
284     PetscCall(PCSetUp_PBJacobi_Kokkos(pc));
285   else
286 #endif
287   {
288     PetscCall(PCSetUp_PBJacobi_Host(pc));
289   }
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 PetscErrorCode PCDestroy_PBJacobi(PC pc)
294 {
295   PetscFunctionBegin;
296   /*
297       Free the private data structure that was hanging off the PC
298   */
299   PetscCall(PetscFree(pc->data));
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 static PetscErrorCode PCView_PBJacobi(PC pc, PetscViewer viewer)
304 {
305   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
306   PetscBool    iascii;
307 
308   PetscFunctionBegin;
309   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
310   if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "  point-block size %" PetscInt_FMT "\n", jac->bs));
311   PetscFunctionReturn(PETSC_SUCCESS);
312 }
313 
314 /*MC
315      PCPBJACOBI - Point block Jacobi preconditioner
316 
317    Notes:
318     See `PCJACOBI` for diagonal Jacobi, `PCVPBJACOBI` for variable-size point block, and `PCBJACOBI` for large size blocks
319 
320    This works for `MATAIJ` and `MATBAIJ` matrices and uses the blocksize provided to the matrix
321 
322    Uses dense LU factorization with partial pivoting to invert the blocks; if a zero pivot
323    is detected a PETSc error is generated.
324 
325    Developer Notes:
326      This should support the `PCSetErrorIfFailure()` flag set to `PETSC_TRUE` to allow
327      the factorization to continue even after a zero pivot is found resulting in a Nan and hence
328      terminating `KSP` with a `KSP_DIVERGED_NANORINF` allowing
329      a nonlinear solver/ODE integrator to recover without stopping the program as currently happens.
330 
331      Perhaps should provide an option that allows generation of a valid preconditioner
332      even if a block is singular as the `PCJACOBI` does.
333 
334    Level: beginner
335 
336 .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCJACOBI`, `PCVPBJACOBI`, `PCBJACOBI`
337 M*/
338 
339 PETSC_EXTERN PetscErrorCode PCCreate_PBJacobi(PC pc)
340 {
341   PC_PBJacobi *jac;
342 
343   PetscFunctionBegin;
344   /*
345      Creates the private data structure for this preconditioner and
346      attach it to the PC object.
347   */
348   PetscCall(PetscNew(&jac));
349   pc->data = (void *)jac;
350 
351   /*
352      Initialize the pointers to vectors to ZERO; these will be used to store
353      diagonal entries of the matrix for fast preconditioner application.
354   */
355   jac->diag = NULL;
356 
357   /*
358       Set the pointers for the functions that are provided above.
359       Now when the user-level routines (such as PCApply(), PCDestroy(), etc.)
360       are called, they will automatically call these functions.  Note we
361       choose not to provide a couple of these functions since they are
362       not needed.
363   */
364   pc->ops->apply               = PCApply_PBJacobi;
365   pc->ops->applytranspose      = PCApplyTranspose_PBJacobi;
366   pc->ops->setup               = PCSetUp_PBJacobi;
367   pc->ops->destroy             = PCDestroy_PBJacobi;
368   pc->ops->setfromoptions      = NULL;
369   pc->ops->view                = PCView_PBJacobi;
370   pc->ops->applyrichardson     = NULL;
371   pc->ops->applysymmetricleft  = NULL;
372   pc->ops->applysymmetricright = NULL;
373   PetscFunctionReturn(PETSC_SUCCESS);
374 }
375