xref: /petsc/src/mat/impls/baij/seq/baijfact13.c (revision d9acb416d05abeed0a33bde3a81aeb2ea0364f6a)
1 /*
2     Factorization code for BAIJ format.
3 */
4 #include <../src/mat/impls/baij/seq/baij.h>
5 #include <petsc/private/kernels/blockinvert.h>
6 
7 /*
8       Version for when blocks are 3 by 3
9 */
10 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_inplace(Mat C, Mat A, const MatFactorInfo *info)
11 {
12   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
13   IS              isrow = b->row, isicol = b->icol;
14   const PetscInt *r, *ic;
15   PetscInt        i, j, n = a->mbs, *bi = b->i, *bj = b->j;
16   PetscInt       *ajtmpold, *ajtmp, nz, row, *ai = a->i, *aj = a->j;
17   PetscInt       *diag_offset = b->diag, idx, *pj;
18   MatScalar      *pv, *v, *rtmp, *pc, *w, *x;
19   MatScalar       p1, p2, p3, p4, m1, m2, m3, m4, m5, m6, m7, m8, m9, x1, x2, x3, x4;
20   MatScalar       p5, p6, p7, p8, p9, x5, x6, x7, x8, x9;
21   MatScalar      *ba = b->a, *aa = a->a;
22   PetscReal       shift = info->shiftamount;
23   PetscBool       allowzeropivot, zeropivotdetected;
24 
25   PetscFunctionBegin;
26   PetscCall(ISGetIndices(isrow, &r));
27   PetscCall(ISGetIndices(isicol, &ic));
28   PetscCall(PetscMalloc1(9 * (n + 1), &rtmp));
29   allowzeropivot = PetscNot(A->erroriffailure);
30 
31   for (i = 0; i < n; i++) {
32     nz    = bi[i + 1] - bi[i];
33     ajtmp = bj + bi[i];
34     for (j = 0; j < nz; j++) {
35       x    = rtmp + 9 * ajtmp[j];
36       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = 0.0;
37     }
38     /* load in initial (unfactored row) */
39     idx      = r[i];
40     nz       = ai[idx + 1] - ai[idx];
41     ajtmpold = aj + ai[idx];
42     v        = aa + 9 * ai[idx];
43     for (j = 0; j < nz; j++) {
44       x    = rtmp + 9 * ic[ajtmpold[j]];
45       x[0] = v[0];
46       x[1] = v[1];
47       x[2] = v[2];
48       x[3] = v[3];
49       x[4] = v[4];
50       x[5] = v[5];
51       x[6] = v[6];
52       x[7] = v[7];
53       x[8] = v[8];
54       v += 9;
55     }
56     row = *ajtmp++;
57     while (row < i) {
58       pc = rtmp + 9 * row;
59       p1 = pc[0];
60       p2 = pc[1];
61       p3 = pc[2];
62       p4 = pc[3];
63       p5 = pc[4];
64       p6 = pc[5];
65       p7 = pc[6];
66       p8 = pc[7];
67       p9 = pc[8];
68       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0) {
69         pv    = ba + 9 * diag_offset[row];
70         pj    = bj + diag_offset[row] + 1;
71         x1    = pv[0];
72         x2    = pv[1];
73         x3    = pv[2];
74         x4    = pv[3];
75         x5    = pv[4];
76         x6    = pv[5];
77         x7    = pv[6];
78         x8    = pv[7];
79         x9    = pv[8];
80         pc[0] = m1 = p1 * x1 + p4 * x2 + p7 * x3;
81         pc[1] = m2 = p2 * x1 + p5 * x2 + p8 * x3;
82         pc[2] = m3 = p3 * x1 + p6 * x2 + p9 * x3;
83 
84         pc[3] = m4 = p1 * x4 + p4 * x5 + p7 * x6;
85         pc[4] = m5 = p2 * x4 + p5 * x5 + p8 * x6;
86         pc[5] = m6 = p3 * x4 + p6 * x5 + p9 * x6;
87 
88         pc[6] = m7 = p1 * x7 + p4 * x8 + p7 * x9;
89         pc[7] = m8 = p2 * x7 + p5 * x8 + p8 * x9;
90         pc[8] = m9 = p3 * x7 + p6 * x8 + p9 * x9;
91         nz         = bi[row + 1] - diag_offset[row] - 1;
92         pv += 9;
93         for (j = 0; j < nz; j++) {
94           x1 = pv[0];
95           x2 = pv[1];
96           x3 = pv[2];
97           x4 = pv[3];
98           x5 = pv[4];
99           x6 = pv[5];
100           x7 = pv[6];
101           x8 = pv[7];
102           x9 = pv[8];
103           x  = rtmp + 9 * pj[j];
104           x[0] -= m1 * x1 + m4 * x2 + m7 * x3;
105           x[1] -= m2 * x1 + m5 * x2 + m8 * x3;
106           x[2] -= m3 * x1 + m6 * x2 + m9 * x3;
107 
108           x[3] -= m1 * x4 + m4 * x5 + m7 * x6;
109           x[4] -= m2 * x4 + m5 * x5 + m8 * x6;
110           x[5] -= m3 * x4 + m6 * x5 + m9 * x6;
111 
112           x[6] -= m1 * x7 + m4 * x8 + m7 * x9;
113           x[7] -= m2 * x7 + m5 * x8 + m8 * x9;
114           x[8] -= m3 * x7 + m6 * x8 + m9 * x9;
115           pv += 9;
116         }
117         PetscCall(PetscLogFlops(54.0 * nz + 36.0));
118       }
119       row = *ajtmp++;
120     }
121     /* finished row so stick it into b->a */
122     pv = ba + 9 * bi[i];
123     pj = bj + bi[i];
124     nz = bi[i + 1] - bi[i];
125     for (j = 0; j < nz; j++) {
126       x     = rtmp + 9 * pj[j];
127       pv[0] = x[0];
128       pv[1] = x[1];
129       pv[2] = x[2];
130       pv[3] = x[3];
131       pv[4] = x[4];
132       pv[5] = x[5];
133       pv[6] = x[6];
134       pv[7] = x[7];
135       pv[8] = x[8];
136       pv += 9;
137     }
138     /* invert diagonal block */
139     w = ba + 9 * diag_offset[i];
140     PetscCall(PetscKernel_A_gets_inverse_A_3(w, shift, allowzeropivot, &zeropivotdetected));
141     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
142   }
143 
144   PetscCall(PetscFree(rtmp));
145   PetscCall(ISRestoreIndices(isicol, &ic));
146   PetscCall(ISRestoreIndices(isrow, &r));
147 
148   C->ops->solve          = MatSolve_SeqBAIJ_3_inplace;
149   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_inplace;
150   C->assembled           = PETSC_TRUE;
151 
152   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * b->mbs)); /* from inverting diagonal blocks */
153   PetscFunctionReturn(PETSC_SUCCESS);
154 }
155 
156 /* MatLUFactorNumeric_SeqBAIJ_3 -
157      copied from MatLUFactorNumeric_SeqBAIJ_N_inplace() and manually re-implemented
158        PetscKernel_A_gets_A_times_B()
159        PetscKernel_A_gets_A_minus_B_times_C()
160        PetscKernel_A_gets_inverse_A()
161 */
162 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3(Mat B, Mat A, const MatFactorInfo *info)
163 {
164   Mat             C = B;
165   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
166   IS              isrow = b->row, isicol = b->icol;
167   const PetscInt *r, *ic;
168   PetscInt        i, j, k, nz, nzL, row;
169   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
170   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
171   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
172   PetscInt        flg;
173   PetscReal       shift = info->shiftamount;
174   PetscBool       allowzeropivot, zeropivotdetected;
175 
176   PetscFunctionBegin;
177   PetscCall(ISGetIndices(isrow, &r));
178   PetscCall(ISGetIndices(isicol, &ic));
179   allowzeropivot = PetscNot(A->erroriffailure);
180 
181   /* generate work space needed by the factorization */
182   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
183   PetscCall(PetscArrayzero(rtmp, bs2 * n));
184 
185   for (i = 0; i < n; i++) {
186     /* zero rtmp */
187     /* L part */
188     nz    = bi[i + 1] - bi[i];
189     bjtmp = bj + bi[i];
190     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
191 
192     /* U part */
193     nz    = bdiag[i] - bdiag[i + 1];
194     bjtmp = bj + bdiag[i + 1] + 1;
195     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
196 
197     /* load in initial (unfactored row) */
198     nz    = ai[r[i] + 1] - ai[r[i]];
199     ajtmp = aj + ai[r[i]];
200     v     = aa + bs2 * ai[r[i]];
201     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ic[ajtmp[j]], v + bs2 * j, bs2));
202 
203     /* elimination */
204     bjtmp = bj + bi[i];
205     nzL   = bi[i + 1] - bi[i];
206     for (k = 0; k < nzL; k++) {
207       row = bjtmp[k];
208       pc  = rtmp + bs2 * row;
209       for (flg = 0, j = 0; j < bs2; j++) {
210         if (pc[j] != 0.0) {
211           flg = 1;
212           break;
213         }
214       }
215       if (flg) {
216         pv = b->a + bs2 * bdiag[row];
217         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
218         PetscCall(PetscKernel_A_gets_A_times_B_3(pc, pv, mwork));
219 
220         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
221         pv = b->a + bs2 * (bdiag[row + 1] + 1);
222         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries in U(row,:) excluding diag */
223         for (j = 0; j < nz; j++) {
224           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
225           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
226           v = rtmp + bs2 * pj[j];
227           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_3(v, pc, pv));
228           pv += bs2;
229         }
230         PetscCall(PetscLogFlops(54.0 * nz + 45)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
231       }
232     }
233 
234     /* finished row so stick it into b->a */
235     /* L part */
236     pv = b->a + bs2 * bi[i];
237     pj = b->j + bi[i];
238     nz = bi[i + 1] - bi[i];
239     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
240 
241     /* Mark diagonal and invert diagonal for simpler triangular solves */
242     pv = b->a + bs2 * bdiag[i];
243     pj = b->j + bdiag[i];
244     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
245     PetscCall(PetscKernel_A_gets_inverse_A_3(pv, shift, allowzeropivot, &zeropivotdetected));
246     if (zeropivotdetected) B->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
247 
248     /* U part */
249     pj = b->j + bdiag[i + 1] + 1;
250     pv = b->a + bs2 * (bdiag[i + 1] + 1);
251     nz = bdiag[i] - bdiag[i + 1] - 1;
252     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
253   }
254 
255   PetscCall(PetscFree2(rtmp, mwork));
256   PetscCall(ISRestoreIndices(isicol, &ic));
257   PetscCall(ISRestoreIndices(isrow, &r));
258 
259   C->ops->solve          = MatSolve_SeqBAIJ_3;
260   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3;
261   C->assembled           = PETSC_TRUE;
262 
263   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * n)); /* from inverting diagonal blocks */
264   PetscFunctionReturn(PETSC_SUCCESS);
265 }
266 
267 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering_inplace(Mat C, Mat A, const MatFactorInfo *info)
268 {
269   Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
270   PetscInt     i, j, n = a->mbs, *bi = b->i, *bj = b->j;
271   PetscInt    *ajtmpold, *ajtmp, nz, row;
272   PetscInt    *diag_offset = b->diag, *ai = a->i, *aj = a->j, *pj;
273   MatScalar   *pv, *v, *rtmp, *pc, *w, *x;
274   MatScalar    p1, p2, p3, p4, m1, m2, m3, m4, m5, m6, m7, m8, m9, x1, x2, x3, x4;
275   MatScalar    p5, p6, p7, p8, p9, x5, x6, x7, x8, x9;
276   MatScalar   *ba = b->a, *aa = a->a;
277   PetscReal    shift = info->shiftamount;
278   PetscBool    allowzeropivot, zeropivotdetected;
279 
280   PetscFunctionBegin;
281   PetscCall(PetscMalloc1(9 * (n + 1), &rtmp));
282   allowzeropivot = PetscNot(A->erroriffailure);
283 
284   for (i = 0; i < n; i++) {
285     nz    = bi[i + 1] - bi[i];
286     ajtmp = bj + bi[i];
287     for (j = 0; j < nz; j++) {
288       x    = rtmp + 9 * ajtmp[j];
289       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = 0.0;
290     }
291     /* load in initial (unfactored row) */
292     nz       = ai[i + 1] - ai[i];
293     ajtmpold = aj + ai[i];
294     v        = aa + 9 * ai[i];
295     for (j = 0; j < nz; j++) {
296       x    = rtmp + 9 * ajtmpold[j];
297       x[0] = v[0];
298       x[1] = v[1];
299       x[2] = v[2];
300       x[3] = v[3];
301       x[4] = v[4];
302       x[5] = v[5];
303       x[6] = v[6];
304       x[7] = v[7];
305       x[8] = v[8];
306       v += 9;
307     }
308     row = *ajtmp++;
309     while (row < i) {
310       pc = rtmp + 9 * row;
311       p1 = pc[0];
312       p2 = pc[1];
313       p3 = pc[2];
314       p4 = pc[3];
315       p5 = pc[4];
316       p6 = pc[5];
317       p7 = pc[6];
318       p8 = pc[7];
319       p9 = pc[8];
320       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0) {
321         pv    = ba + 9 * diag_offset[row];
322         pj    = bj + diag_offset[row] + 1;
323         x1    = pv[0];
324         x2    = pv[1];
325         x3    = pv[2];
326         x4    = pv[3];
327         x5    = pv[4];
328         x6    = pv[5];
329         x7    = pv[6];
330         x8    = pv[7];
331         x9    = pv[8];
332         pc[0] = m1 = p1 * x1 + p4 * x2 + p7 * x3;
333         pc[1] = m2 = p2 * x1 + p5 * x2 + p8 * x3;
334         pc[2] = m3 = p3 * x1 + p6 * x2 + p9 * x3;
335 
336         pc[3] = m4 = p1 * x4 + p4 * x5 + p7 * x6;
337         pc[4] = m5 = p2 * x4 + p5 * x5 + p8 * x6;
338         pc[5] = m6 = p3 * x4 + p6 * x5 + p9 * x6;
339 
340         pc[6] = m7 = p1 * x7 + p4 * x8 + p7 * x9;
341         pc[7] = m8 = p2 * x7 + p5 * x8 + p8 * x9;
342         pc[8] = m9 = p3 * x7 + p6 * x8 + p9 * x9;
343 
344         nz = bi[row + 1] - diag_offset[row] - 1;
345         pv += 9;
346         for (j = 0; j < nz; j++) {
347           x1 = pv[0];
348           x2 = pv[1];
349           x3 = pv[2];
350           x4 = pv[3];
351           x5 = pv[4];
352           x6 = pv[5];
353           x7 = pv[6];
354           x8 = pv[7];
355           x9 = pv[8];
356           x  = rtmp + 9 * pj[j];
357           x[0] -= m1 * x1 + m4 * x2 + m7 * x3;
358           x[1] -= m2 * x1 + m5 * x2 + m8 * x3;
359           x[2] -= m3 * x1 + m6 * x2 + m9 * x3;
360 
361           x[3] -= m1 * x4 + m4 * x5 + m7 * x6;
362           x[4] -= m2 * x4 + m5 * x5 + m8 * x6;
363           x[5] -= m3 * x4 + m6 * x5 + m9 * x6;
364 
365           x[6] -= m1 * x7 + m4 * x8 + m7 * x9;
366           x[7] -= m2 * x7 + m5 * x8 + m8 * x9;
367           x[8] -= m3 * x7 + m6 * x8 + m9 * x9;
368           pv += 9;
369         }
370         PetscCall(PetscLogFlops(54.0 * nz + 36.0));
371       }
372       row = *ajtmp++;
373     }
374     /* finished row so stick it into b->a */
375     pv = ba + 9 * bi[i];
376     pj = bj + bi[i];
377     nz = bi[i + 1] - bi[i];
378     for (j = 0; j < nz; j++) {
379       x     = rtmp + 9 * pj[j];
380       pv[0] = x[0];
381       pv[1] = x[1];
382       pv[2] = x[2];
383       pv[3] = x[3];
384       pv[4] = x[4];
385       pv[5] = x[5];
386       pv[6] = x[6];
387       pv[7] = x[7];
388       pv[8] = x[8];
389       pv += 9;
390     }
391     /* invert diagonal block */
392     w = ba + 9 * diag_offset[i];
393     PetscCall(PetscKernel_A_gets_inverse_A_3(w, shift, allowzeropivot, &zeropivotdetected));
394     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
395   }
396 
397   PetscCall(PetscFree(rtmp));
398 
399   C->ops->solve          = MatSolve_SeqBAIJ_3_NaturalOrdering_inplace;
400   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_NaturalOrdering_inplace;
401   C->assembled           = PETSC_TRUE;
402 
403   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * b->mbs)); /* from inverting diagonal blocks */
404   PetscFunctionReturn(PETSC_SUCCESS);
405 }
406 
407 /*
408   MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering -
409     copied from MatLUFactorNumeric_SeqBAIJ_2_NaturalOrdering_inplace()
410 */
411 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
412 {
413   Mat             C = B;
414   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
415   PetscInt        i, j, k, nz, nzL, row;
416   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
417   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
418   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
419   PetscInt        flg;
420   PetscReal       shift = info->shiftamount;
421   PetscBool       allowzeropivot, zeropivotdetected;
422 
423   PetscFunctionBegin;
424   allowzeropivot = PetscNot(A->erroriffailure);
425 
426   /* generate work space needed by the factorization */
427   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
428   PetscCall(PetscArrayzero(rtmp, bs2 * n));
429 
430   for (i = 0; i < n; i++) {
431     /* zero rtmp */
432     /* L part */
433     nz    = bi[i + 1] - bi[i];
434     bjtmp = bj + bi[i];
435     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
436 
437     /* U part */
438     nz    = bdiag[i] - bdiag[i + 1];
439     bjtmp = bj + bdiag[i + 1] + 1;
440     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
441 
442     /* load in initial (unfactored row) */
443     nz    = ai[i + 1] - ai[i];
444     ajtmp = aj + ai[i];
445     v     = aa + bs2 * ai[i];
446     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
447 
448     /* elimination */
449     bjtmp = bj + bi[i];
450     nzL   = bi[i + 1] - bi[i];
451     for (k = 0; k < nzL; k++) {
452       row = bjtmp[k];
453       pc  = rtmp + bs2 * row;
454       for (flg = 0, j = 0; j < bs2; j++) {
455         if (pc[j] != 0.0) {
456           flg = 1;
457           break;
458         }
459       }
460       if (flg) {
461         pv = b->a + bs2 * bdiag[row];
462         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
463         PetscCall(PetscKernel_A_gets_A_times_B_3(pc, pv, mwork));
464 
465         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
466         pv = b->a + bs2 * (bdiag[row + 1] + 1);
467         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries in U(row,:) excluding diag */
468         for (j = 0; j < nz; j++) {
469           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
470           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
471           v = rtmp + bs2 * pj[j];
472           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_3(v, pc, pv));
473           pv += bs2;
474         }
475         PetscCall(PetscLogFlops(54.0 * nz + 45)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
476       }
477     }
478 
479     /* finished row so stick it into b->a */
480     /* L part */
481     pv = b->a + bs2 * bi[i];
482     pj = b->j + bi[i];
483     nz = bi[i + 1] - bi[i];
484     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
485 
486     /* Mark diagonal and invert diagonal for simpler triangular solves */
487     pv = b->a + bs2 * bdiag[i];
488     pj = b->j + bdiag[i];
489     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
490     PetscCall(PetscKernel_A_gets_inverse_A_3(pv, shift, allowzeropivot, &zeropivotdetected));
491     if (zeropivotdetected) B->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
492 
493     /* U part */
494     pv = b->a + bs2 * (bdiag[i + 1] + 1);
495     pj = b->j + bdiag[i + 1] + 1;
496     nz = bdiag[i] - bdiag[i + 1] - 1;
497     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
498   }
499   PetscCall(PetscFree2(rtmp, mwork));
500 
501   C->ops->solve          = MatSolve_SeqBAIJ_3_NaturalOrdering;
502   C->ops->forwardsolve   = MatForwardSolve_SeqBAIJ_3_NaturalOrdering;
503   C->ops->backwardsolve  = MatBackwardSolve_SeqBAIJ_3_NaturalOrdering;
504   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_NaturalOrdering;
505   C->assembled           = PETSC_TRUE;
506 
507   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * n)); /* from inverting diagonal blocks */
508   PetscFunctionReturn(PETSC_SUCCESS);
509 }
510