xref: /petsc/src/ksp/pc/impls/vpbjacobi/vpbjacobi.c (revision bd89dbf26d8a5efecb980364933175da61864cd7)
1 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
2 #include <petsc/private/matimpl.h>
3 
PCApply_VPBJacobi(PC pc,Vec x,Vec y)4 static PetscErrorCode PCApply_VPBJacobi(PC pc, Vec x, Vec y)
5 {
6   PC_VPBJacobi      *jac = (PC_VPBJacobi *)pc->data;
7   PetscInt           i, ncnt = 0;
8   const MatScalar   *diag = jac->diag;
9   PetscInt           ib, jb, bs;
10   const PetscScalar *xx;
11   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
12   PetscInt           nblocks;
13   const PetscInt    *bsizes;
14 
15   PetscFunctionBegin;
16   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
17   PetscCall(VecGetArrayRead(x, &xx));
18   PetscCall(VecGetArray(y, &yy));
19   for (i = 0; i < nblocks; i++) {
20     bs = bsizes[i];
21     switch (bs) {
22     case 1:
23       yy[ncnt] = *diag * xx[ncnt];
24       break;
25     case 2:
26       x0           = xx[ncnt];
27       x1           = xx[ncnt + 1];
28       yy[ncnt]     = diag[0] * x0 + diag[2] * x1;
29       yy[ncnt + 1] = diag[1] * x0 + diag[3] * x1;
30       break;
31     case 3:
32       x0           = xx[ncnt];
33       x1           = xx[ncnt + 1];
34       x2           = xx[ncnt + 2];
35       yy[ncnt]     = diag[0] * x0 + diag[3] * x1 + diag[6] * x2;
36       yy[ncnt + 1] = diag[1] * x0 + diag[4] * x1 + diag[7] * x2;
37       yy[ncnt + 2] = diag[2] * x0 + diag[5] * x1 + diag[8] * x2;
38       break;
39     case 4:
40       x0           = xx[ncnt];
41       x1           = xx[ncnt + 1];
42       x2           = xx[ncnt + 2];
43       x3           = xx[ncnt + 3];
44       yy[ncnt]     = diag[0] * x0 + diag[4] * x1 + diag[8] * x2 + diag[12] * x3;
45       yy[ncnt + 1] = diag[1] * x0 + diag[5] * x1 + diag[9] * x2 + diag[13] * x3;
46       yy[ncnt + 2] = diag[2] * x0 + diag[6] * x1 + diag[10] * x2 + diag[14] * x3;
47       yy[ncnt + 3] = diag[3] * x0 + diag[7] * x1 + diag[11] * x2 + diag[15] * x3;
48       break;
49     case 5:
50       x0           = xx[ncnt];
51       x1           = xx[ncnt + 1];
52       x2           = xx[ncnt + 2];
53       x3           = xx[ncnt + 3];
54       x4           = xx[ncnt + 4];
55       yy[ncnt]     = diag[0] * x0 + diag[5] * x1 + diag[10] * x2 + diag[15] * x3 + diag[20] * x4;
56       yy[ncnt + 1] = diag[1] * x0 + diag[6] * x1 + diag[11] * x2 + diag[16] * x3 + diag[21] * x4;
57       yy[ncnt + 2] = diag[2] * x0 + diag[7] * x1 + diag[12] * x2 + diag[17] * x3 + diag[22] * x4;
58       yy[ncnt + 3] = diag[3] * x0 + diag[8] * x1 + diag[13] * x2 + diag[18] * x3 + diag[23] * x4;
59       yy[ncnt + 4] = diag[4] * x0 + diag[9] * x1 + diag[14] * x2 + diag[19] * x3 + diag[24] * x4;
60       break;
61     case 6:
62       x0           = xx[ncnt];
63       x1           = xx[ncnt + 1];
64       x2           = xx[ncnt + 2];
65       x3           = xx[ncnt + 3];
66       x4           = xx[ncnt + 4];
67       x5           = xx[ncnt + 5];
68       yy[ncnt]     = diag[0] * x0 + diag[6] * x1 + diag[12] * x2 + diag[18] * x3 + diag[24] * x4 + diag[30] * x5;
69       yy[ncnt + 1] = diag[1] * x0 + diag[7] * x1 + diag[13] * x2 + diag[19] * x3 + diag[25] * x4 + diag[31] * x5;
70       yy[ncnt + 2] = diag[2] * x0 + diag[8] * x1 + diag[14] * x2 + diag[20] * x3 + diag[26] * x4 + diag[32] * x5;
71       yy[ncnt + 3] = diag[3] * x0 + diag[9] * x1 + diag[15] * x2 + diag[21] * x3 + diag[27] * x4 + diag[33] * x5;
72       yy[ncnt + 4] = diag[4] * x0 + diag[10] * x1 + diag[16] * x2 + diag[22] * x3 + diag[28] * x4 + diag[34] * x5;
73       yy[ncnt + 5] = diag[5] * x0 + diag[11] * x1 + diag[17] * x2 + diag[23] * x3 + diag[29] * x4 + diag[35] * x5;
74       break;
75     case 7:
76       x0           = xx[ncnt];
77       x1           = xx[ncnt + 1];
78       x2           = xx[ncnt + 2];
79       x3           = xx[ncnt + 3];
80       x4           = xx[ncnt + 4];
81       x5           = xx[ncnt + 5];
82       x6           = xx[ncnt + 6];
83       yy[ncnt]     = diag[0] * x0 + diag[7] * x1 + diag[14] * x2 + diag[21] * x3 + diag[28] * x4 + diag[35] * x5 + diag[42] * x6;
84       yy[ncnt + 1] = diag[1] * x0 + diag[8] * x1 + diag[15] * x2 + diag[22] * x3 + diag[29] * x4 + diag[36] * x5 + diag[43] * x6;
85       yy[ncnt + 2] = diag[2] * x0 + diag[9] * x1 + diag[16] * x2 + diag[23] * x3 + diag[30] * x4 + diag[37] * x5 + diag[44] * x6;
86       yy[ncnt + 3] = diag[3] * x0 + diag[10] * x1 + diag[17] * x2 + diag[24] * x3 + diag[31] * x4 + diag[38] * x5 + diag[45] * x6;
87       yy[ncnt + 4] = diag[4] * x0 + diag[11] * x1 + diag[18] * x2 + diag[25] * x3 + diag[32] * x4 + diag[39] * x5 + diag[46] * x6;
88       yy[ncnt + 5] = diag[5] * x0 + diag[12] * x1 + diag[19] * x2 + diag[26] * x3 + diag[33] * x4 + diag[40] * x5 + diag[47] * x6;
89       yy[ncnt + 6] = diag[6] * x0 + diag[13] * x1 + diag[20] * x2 + diag[27] * x3 + diag[34] * x4 + diag[41] * x5 + diag[48] * x6;
90       break;
91     default:
92       for (ib = 0; ib < bs; ib++) {
93         PetscScalar rowsum = 0;
94         for (jb = 0; jb < bs; jb++) rowsum += diag[ib + jb * bs] * xx[ncnt + jb];
95         yy[ncnt + ib] = rowsum;
96       }
97     }
98     ncnt += bsizes[i];
99     diag += bsizes[i] * bsizes[i];
100   }
101   PetscCall(VecRestoreArrayRead(x, &xx));
102   PetscCall(VecRestoreArray(y, &yy));
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105 
PCApplyTranspose_VPBJacobi(PC pc,Vec x,Vec y)106 static PetscErrorCode PCApplyTranspose_VPBJacobi(PC pc, Vec x, Vec y)
107 {
108   PC_VPBJacobi      *jac = (PC_VPBJacobi *)pc->data;
109   PetscInt           i, ncnt = 0;
110   const MatScalar   *diag = jac->diag;
111   PetscInt           ib, jb, bs;
112   const PetscScalar *xx;
113   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
114   PetscInt           nblocks;
115   const PetscInt    *bsizes;
116 
117   PetscFunctionBegin;
118   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
119   PetscCall(VecGetArrayRead(x, &xx));
120   PetscCall(VecGetArray(y, &yy));
121   for (i = 0; i < nblocks; i++) {
122     bs = bsizes[i];
123     switch (bs) {
124     case 1:
125       yy[ncnt] = *diag * xx[ncnt];
126       break;
127     case 2:
128       x0           = xx[ncnt];
129       x1           = xx[ncnt + 1];
130       yy[ncnt]     = diag[0] * x0 + diag[1] * x1;
131       yy[ncnt + 1] = diag[2] * x0 + diag[3] * x1;
132       break;
133     case 3:
134       x0           = xx[ncnt];
135       x1           = xx[ncnt + 1];
136       x2           = xx[ncnt + 2];
137       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2;
138       yy[ncnt + 1] = diag[3] * x0 + diag[4] * x1 + diag[5] * x2;
139       yy[ncnt + 2] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2;
140       break;
141     case 4:
142       x0           = xx[ncnt];
143       x1           = xx[ncnt + 1];
144       x2           = xx[ncnt + 2];
145       x3           = xx[ncnt + 3];
146       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3;
147       yy[ncnt + 1] = diag[4] * x0 + diag[5] * x1 + diag[6] * x2 + diag[7] * x3;
148       yy[ncnt + 2] = diag[8] * x0 + diag[9] * x1 + diag[10] * x2 + diag[11] * x3;
149       yy[ncnt + 3] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3;
150       break;
151     case 5:
152       x0           = xx[ncnt];
153       x1           = xx[ncnt + 1];
154       x2           = xx[ncnt + 2];
155       x3           = xx[ncnt + 3];
156       x4           = xx[ncnt + 4];
157       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4;
158       yy[ncnt + 1] = diag[5] * x0 + diag[6] * x1 + diag[7] * x2 + diag[8] * x3 + diag[9] * x4;
159       yy[ncnt + 2] = diag[10] * x0 + diag[11] * x1 + diag[12] * x2 + diag[13] * x3 + diag[14] * x4;
160       yy[ncnt + 3] = diag[15] * x0 + diag[16] * x1 + diag[17] * x2 + diag[18] * x3 + diag[19] * x4;
161       yy[ncnt + 4] = diag[20] * x0 + diag[21] * x1 + diag[22] * x2 + diag[23] * x3 + diag[24] * x4;
162       break;
163     case 6:
164       x0           = xx[ncnt];
165       x1           = xx[ncnt + 1];
166       x2           = xx[ncnt + 2];
167       x3           = xx[ncnt + 3];
168       x4           = xx[ncnt + 4];
169       x5           = xx[ncnt + 5];
170       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5;
171       yy[ncnt + 1] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2 + diag[9] * x3 + diag[10] * x4 + diag[11] * x5;
172       yy[ncnt + 2] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3 + diag[16] * x4 + diag[17] * x5;
173       yy[ncnt + 3] = diag[18] * x0 + diag[19] * x1 + diag[20] * x2 + diag[21] * x3 + diag[22] * x4 + diag[23] * x5;
174       yy[ncnt + 4] = diag[24] * x0 + diag[25] * x1 + diag[26] * x2 + diag[27] * x3 + diag[28] * x4 + diag[29] * x5;
175       yy[ncnt + 5] = diag[30] * x0 + diag[31] * x1 + diag[32] * x2 + diag[33] * x3 + diag[34] * x4 + diag[35] * x5;
176       break;
177     case 7:
178       x0           = xx[ncnt];
179       x1           = xx[ncnt + 1];
180       x2           = xx[ncnt + 2];
181       x3           = xx[ncnt + 3];
182       x4           = xx[ncnt + 4];
183       x5           = xx[ncnt + 5];
184       x6           = xx[ncnt + 6];
185       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5 + diag[6] * x6;
186       yy[ncnt + 1] = diag[7] * x0 + diag[8] * x1 + diag[9] * x2 + diag[10] * x3 + diag[11] * x4 + diag[12] * x5 + diag[13] * x6;
187       yy[ncnt + 2] = diag[14] * x0 + diag[15] * x1 + diag[16] * x2 + diag[17] * x3 + diag[18] * x4 + diag[19] * x5 + diag[20] * x6;
188       yy[ncnt + 3] = diag[21] * x0 + diag[22] * x1 + diag[23] * x2 + diag[24] * x3 + diag[25] * x4 + diag[26] * x5 + diag[27] * x6;
189       yy[ncnt + 4] = diag[28] * x0 + diag[29] * x1 + diag[30] * x2 + diag[31] * x3 + diag[32] * x4 + diag[33] * x5 + diag[34] * x6;
190       yy[ncnt + 5] = diag[35] * x0 + diag[36] * x1 + diag[37] * x2 + diag[38] * x3 + diag[39] * x4 + diag[40] * x5 + diag[41] * x6;
191       yy[ncnt + 6] = diag[42] * x0 + diag[43] * x1 + diag[44] * x2 + diag[45] * x3 + diag[46] * x4 + diag[47] * x5 + diag[48] * x6;
192       break;
193     default:
194       for (ib = 0; ib < bs; ib++) {
195         PetscScalar rowsum = 0;
196         for (jb = 0; jb < bs; jb++) rowsum += diag[ib * bs + jb] * xx[ncnt + jb];
197         yy[ncnt + ib] = rowsum;
198       }
199     }
200     ncnt += bsizes[i];
201     diag += bsizes[i] * bsizes[i];
202   }
203   PetscCall(VecRestoreArrayRead(x, &xx));
204   PetscCall(VecRestoreArray(y, &yy));
205   PetscFunctionReturn(PETSC_SUCCESS);
206 }
207 
PCSetUp_VPBJacobi_Host(PC pc,Mat diagVPB)208 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Host(PC pc, Mat diagVPB)
209 {
210   PC_VPBJacobi   *jac = (PC_VPBJacobi *)pc->data;
211   Mat             A   = diagVPB ? diagVPB : pc->pmat;
212   MatFactorError  err;
213   PetscInt        i, nsize = 0, nlocal;
214   PetscInt        nblocks;
215   const PetscInt *bsizes;
216 
217   PetscFunctionBegin;
218   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
219   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
220   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
221   if (!jac->diag) {
222     PetscInt max_bs = -1, min_bs = PETSC_INT_MAX;
223     for (i = 0; i < nblocks; i++) {
224       min_bs = PetscMin(min_bs, bsizes[i]);
225       max_bs = PetscMax(max_bs, bsizes[i]);
226       nsize += bsizes[i] * bsizes[i];
227     }
228     PetscCall(PetscMalloc1(nsize, &jac->diag));
229     jac->nblocks = nblocks;
230     jac->min_bs  = min_bs;
231     jac->max_bs  = max_bs;
232   }
233   PetscCall(MatInvertVariableBlockDiagonal(A, nblocks, bsizes, jac->diag));
234   PetscCall(MatFactorGetError(A, &err));
235   if (err) pc->failedreason = (PCFailedReason)err;
236   pc->ops->apply          = PCApply_VPBJacobi;
237   pc->ops->applytranspose = PCApplyTranspose_VPBJacobi;
238   PetscFunctionReturn(PETSC_SUCCESS);
239 }
240 
PCSetUp_VPBJacobi(PC pc)241 static PetscErrorCode PCSetUp_VPBJacobi(PC pc)
242 {
243   PetscBool flg;
244   Mat       diagVPB = NULL;
245 
246   PetscFunctionBegin;
247   // In PCCreate_VPBJacobi() pmat might have not been set, so we wait to the last minute to do the dispatch
248 
249   // pmat (e.g., MatCEED from libCEED) might have its own method to provide a matrix (diagVPB)
250   // made of the diagonal blocks. So we check both pmat and diagVPB.
251   PetscCall(MatHasOperation(pc->pmat, MATOP_GET_VBLOCK_DIAGONAL, &flg));
252   if (flg) PetscUseTypeMethod(pc->pmat, getvblockdiagonal, &diagVPB); // diagVPB's reference count is increased upon return
253 
254 #if defined(PETSC_HAVE_CUDA)
255   PetscBool isCuda;
256   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
257   if (!isCuda && diagVPB) PetscCall(PetscObjectTypeCompareAny((PetscObject)diagVPB, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
258 #endif
259 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
260   PetscBool isKok;
261   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
262   if (!isKok && diagVPB) PetscCall(PetscObjectTypeCompareAny((PetscObject)diagVPB, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
263 #endif
264 
265 #if defined(PETSC_HAVE_CUDA)
266   if (isCuda) PetscCall(PCSetUp_VPBJacobi_CUDA(pc, diagVPB));
267   else
268 #endif
269 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
270     if (isKok)
271     PetscCall(PCSetUp_VPBJacobi_Kokkos(pc, diagVPB));
272   else
273 #endif
274   {
275     PetscCall(PCSetUp_VPBJacobi_Host(pc, diagVPB));
276   }
277   PetscCall(MatDestroy(&diagVPB)); // since we don't need it anymore, we don't need to stash it in PC_VPBJacobi
278   PetscFunctionReturn(PETSC_SUCCESS);
279 }
280 
PCView_VPBJacobi(PC pc,PetscViewer viewer)281 static PetscErrorCode PCView_VPBJacobi(PC pc, PetscViewer viewer)
282 {
283   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
284   PetscBool     isascii;
285 
286   PetscFunctionBegin;
287   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
288   if (isascii) {
289     PetscCall(PetscViewerASCIIPrintf(viewer, "  number of blocks: %" PetscInt_FMT "\n", jac->nblocks));
290     PetscCall(PetscViewerASCIIPrintf(viewer, "  block sizes: min=%" PetscInt_FMT " max=%" PetscInt_FMT "\n", jac->min_bs, jac->max_bs));
291   }
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
PCDestroy_VPBJacobi(PC pc)295 PETSC_INTERN PetscErrorCode PCDestroy_VPBJacobi(PC pc)
296 {
297   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
298 
299   PetscFunctionBegin;
300   /*
301       Free the private data structure that was hanging off the PC
302   */
303   PetscCall(PetscFree(jac->diag));
304   PetscCall(MatDestroy(&jac->diagVPB));
305   PetscCall(PetscFree(pc->data));
306   PetscFunctionReturn(PETSC_SUCCESS);
307 }
308 
309 /*MC
310      PCVPBJACOBI - Variable size point block Jacobi preconditioner
311 
312    Level: beginner
313 
314    Notes:
315      See `PCJACOBI` for point Jacobi preconditioning, `PCPBJACOBI` for fixed point block size, and `PCBJACOBI` for large size blocks
316 
317      This works for `MATAIJ` matrices
318 
319      Uses dense LU factorization with partial pivoting to invert the blocks; if a zero pivot
320      is detected a PETSc error is generated.
321 
322      One must call `MatSetVariableBlockSizes()` to use this preconditioner
323 
324    Developer Notes:
325      This should support the `PCSetErrorIfFailure()` flag set to `PETSC_TRUE` to allow
326      the factorization to continue even after a zero pivot is found resulting in a NaN and hence
327      terminating `KSP` with a `KSP_DIVERGED_NANORINF` allowing
328      a nonlinear solver/ODE integrator to recover without stopping the program as currently happens.
329 
330      Perhaps should provide an option that allows generation of a valid preconditioner
331      even if a block is singular as the `PCJACOBI` does.
332 
333 .seealso: [](ch_ksp), `MatSetVariableBlockSizes()`, `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCJACOBI`, `PCPBJACOBI`, `PCBJACOBI`
334 M*/
335 
PCCreate_VPBJacobi(PC pc)336 PETSC_EXTERN PetscErrorCode PCCreate_VPBJacobi(PC pc)
337 {
338   PC_VPBJacobi *jac;
339 
340   PetscFunctionBegin;
341   /*
342      Creates the private data structure for this preconditioner and
343      attach it to the PC object.
344   */
345   PetscCall(PetscNew(&jac));
346   pc->data = (void *)jac;
347 
348   /*
349      Initialize the pointers to vectors to ZERO; these will be used to store
350      diagonal entries of the matrix for fast preconditioner application.
351   */
352   jac->diag = NULL;
353 
354   /*
355       Set the pointers for the functions that are provided above.
356       Now when the user-level routines (such as PCApply(), PCDestroy(), etc.)
357       are called, they will automatically call these functions.  Note we
358       choose not to provide a couple of these functions since they are
359       not needed.
360   */
361   pc->ops->apply               = PCApply_VPBJacobi;
362   pc->ops->applytranspose      = NULL;
363   pc->ops->setup               = PCSetUp_VPBJacobi;
364   pc->ops->destroy             = PCDestroy_VPBJacobi;
365   pc->ops->setfromoptions      = NULL;
366   pc->ops->view                = PCView_VPBJacobi;
367   pc->ops->applyrichardson     = NULL;
368   pc->ops->applysymmetricleft  = NULL;
369   pc->ops->applysymmetricright = NULL;
370   PetscFunctionReturn(PETSC_SUCCESS);
371 }
372