xref: /petsc/src/mat/impls/baij/seq/baijfact9.c (revision e8c0849ab8fe171bed529bea27238c9b402db591)
1 /*
2     Factorization code for BAIJ format.
3 */
4 #include <../src/mat/impls/baij/seq/baij.h>
5 #include <petsc/private/kernels/blockinvert.h>
6 
7 /*
8       Version for when blocks are 5 by 5
9 */
MatILUFactorNumeric_SeqBAIJ_5_inplace(Mat C,Mat A,const MatFactorInfo * info)10 PetscErrorCode MatILUFactorNumeric_SeqBAIJ_5_inplace(Mat C, Mat A, const MatFactorInfo *info)
11 {
12   Mat_SeqBAIJ     *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
13   IS               isrow = b->row, isicol = b->icol;
14   const PetscInt  *r, *ic;
15   PetscInt        *bi = b->i, *bj = b->j, *ajtmpold, *ajtmp;
16   PetscInt         i, j, n = a->mbs, nz, row, idx, ipvt[5];
17   const PetscInt  *diag_offset;
18   PetscInt        *ai = a->i, *aj = a->j, *pj;
19   MatScalar       *w, *pv, *rtmp, *x, *pc;
20   const MatScalar *v, *aa = a->a;
21   MatScalar        p1, p2, p3, p4, m1, m2, m3, m4, m5, m6, m7, m8, m9, x1, x2, x3, x4;
22   MatScalar        p5, p6, p7, p8, p9, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16;
23   MatScalar        x17, x18, x19, x20, x21, x22, x23, x24, x25, p10, p11, p12, p13, p14;
24   MatScalar        p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, m10, m11, m12;
25   MatScalar        m13, m14, m15, m16, m17, m18, m19, m20, m21, m22, m23, m24, m25;
26   MatScalar       *ba    = b->a, work[25];
27   PetscReal        shift = info->shiftamount;
28   PetscBool        allowzeropivot, zeropivotdetected;
29 
30   PetscFunctionBegin;
31   /* Since A is C and C is labeled as a factored matrix we need to lie to MatGetDiagonalMarkers_SeqBAIJ() to get it to compute the diagonals */
32   A->factortype = MAT_FACTOR_NONE;
33   PetscCall(MatGetDiagonalMarkers_SeqBAIJ(A, &diag_offset, NULL));
34   A->factortype  = MAT_FACTOR_ILU;
35   allowzeropivot = PetscNot(A->erroriffailure);
36   PetscCall(ISGetIndices(isrow, &r));
37   PetscCall(ISGetIndices(isicol, &ic));
38   PetscCall(PetscMalloc1(25 * (n + 1), &rtmp));
39 
40 #define PETSC_USE_MEMZERO 1
41 #define PETSC_USE_MEMCPY  1
42 
43   for (i = 0; i < n; i++) {
44     nz    = bi[i + 1] - bi[i];
45     ajtmp = bj + bi[i];
46     for (j = 0; j < nz; j++) {
47 #if defined(PETSC_USE_MEMZERO)
48       PetscCall(PetscArrayzero(rtmp + 25 * ajtmp[j], 25));
49 #else
50       x    = rtmp + 25 * ajtmp[j];
51       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = x[9] = 0.0;
52       x[10] = x[11] = x[12] = x[13] = x[14] = x[15] = x[16] = x[17] = 0.0;
53       x[18] = x[19] = x[20] = x[21] = x[22] = x[23] = x[24] = 0.0;
54 #endif
55     }
56     /* load in initial (unfactored row) */
57     idx      = r[i];
58     nz       = ai[idx + 1] - ai[idx];
59     ajtmpold = aj + ai[idx];
60     v        = aa + 25 * ai[idx];
61     for (j = 0; j < nz; j++) {
62 #if defined(PETSC_USE_MEMCPY)
63       PetscCall(PetscArraycpy(rtmp + 25 * ic[ajtmpold[j]], v, 25));
64 #else
65       x     = rtmp + 25 * ic[ajtmpold[j]];
66       x[0]  = v[0];
67       x[1]  = v[1];
68       x[2]  = v[2];
69       x[3]  = v[3];
70       x[4]  = v[4];
71       x[5]  = v[5];
72       x[6]  = v[6];
73       x[7]  = v[7];
74       x[8]  = v[8];
75       x[9]  = v[9];
76       x[10] = v[10];
77       x[11] = v[11];
78       x[12] = v[12];
79       x[13] = v[13];
80       x[14] = v[14];
81       x[15] = v[15];
82       x[16] = v[16];
83       x[17] = v[17];
84       x[18] = v[18];
85       x[19] = v[19];
86       x[20] = v[20];
87       x[21] = v[21];
88       x[22] = v[22];
89       x[23] = v[23];
90       x[24] = v[24];
91 #endif
92       v += 25;
93     }
94     row = *ajtmp++;
95     while (row < i) {
96       pc  = rtmp + 25 * row;
97       p1  = pc[0];
98       p2  = pc[1];
99       p3  = pc[2];
100       p4  = pc[3];
101       p5  = pc[4];
102       p6  = pc[5];
103       p7  = pc[6];
104       p8  = pc[7];
105       p9  = pc[8];
106       p10 = pc[9];
107       p11 = pc[10];
108       p12 = pc[11];
109       p13 = pc[12];
110       p14 = pc[13];
111       p15 = pc[14];
112       p16 = pc[15];
113       p17 = pc[16];
114       p18 = pc[17];
115       p19 = pc[18];
116       p20 = pc[19];
117       p21 = pc[20];
118       p22 = pc[21];
119       p23 = pc[22];
120       p24 = pc[23];
121       p25 = pc[24];
122       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0 || p10 != 0.0 || p11 != 0.0 || p12 != 0.0 || p13 != 0.0 || p14 != 0.0 || p15 != 0.0 || p16 != 0.0 || p17 != 0.0 || p18 != 0.0 || p19 != 0.0 || p20 != 0.0 || p21 != 0.0 || p22 != 0.0 || p23 != 0.0 || p24 != 0.0 || p25 != 0.0) {
123         pv    = ba + 25 * diag_offset[row];
124         pj    = bj + diag_offset[row] + 1;
125         x1    = pv[0];
126         x2    = pv[1];
127         x3    = pv[2];
128         x4    = pv[3];
129         x5    = pv[4];
130         x6    = pv[5];
131         x7    = pv[6];
132         x8    = pv[7];
133         x9    = pv[8];
134         x10   = pv[9];
135         x11   = pv[10];
136         x12   = pv[11];
137         x13   = pv[12];
138         x14   = pv[13];
139         x15   = pv[14];
140         x16   = pv[15];
141         x17   = pv[16];
142         x18   = pv[17];
143         x19   = pv[18];
144         x20   = pv[19];
145         x21   = pv[20];
146         x22   = pv[21];
147         x23   = pv[22];
148         x24   = pv[23];
149         x25   = pv[24];
150         pc[0] = m1 = p1 * x1 + p6 * x2 + p11 * x3 + p16 * x4 + p21 * x5;
151         pc[1] = m2 = p2 * x1 + p7 * x2 + p12 * x3 + p17 * x4 + p22 * x5;
152         pc[2] = m3 = p3 * x1 + p8 * x2 + p13 * x3 + p18 * x4 + p23 * x5;
153         pc[3] = m4 = p4 * x1 + p9 * x2 + p14 * x3 + p19 * x4 + p24 * x5;
154         pc[4] = m5 = p5 * x1 + p10 * x2 + p15 * x3 + p20 * x4 + p25 * x5;
155 
156         pc[5] = m6 = p1 * x6 + p6 * x7 + p11 * x8 + p16 * x9 + p21 * x10;
157         pc[6] = m7 = p2 * x6 + p7 * x7 + p12 * x8 + p17 * x9 + p22 * x10;
158         pc[7] = m8 = p3 * x6 + p8 * x7 + p13 * x8 + p18 * x9 + p23 * x10;
159         pc[8] = m9 = p4 * x6 + p9 * x7 + p14 * x8 + p19 * x9 + p24 * x10;
160         pc[9] = m10 = p5 * x6 + p10 * x7 + p15 * x8 + p20 * x9 + p25 * x10;
161 
162         pc[10] = m11 = p1 * x11 + p6 * x12 + p11 * x13 + p16 * x14 + p21 * x15;
163         pc[11] = m12 = p2 * x11 + p7 * x12 + p12 * x13 + p17 * x14 + p22 * x15;
164         pc[12] = m13 = p3 * x11 + p8 * x12 + p13 * x13 + p18 * x14 + p23 * x15;
165         pc[13] = m14 = p4 * x11 + p9 * x12 + p14 * x13 + p19 * x14 + p24 * x15;
166         pc[14] = m15 = p5 * x11 + p10 * x12 + p15 * x13 + p20 * x14 + p25 * x15;
167 
168         pc[15] = m16 = p1 * x16 + p6 * x17 + p11 * x18 + p16 * x19 + p21 * x20;
169         pc[16] = m17 = p2 * x16 + p7 * x17 + p12 * x18 + p17 * x19 + p22 * x20;
170         pc[17] = m18 = p3 * x16 + p8 * x17 + p13 * x18 + p18 * x19 + p23 * x20;
171         pc[18] = m19 = p4 * x16 + p9 * x17 + p14 * x18 + p19 * x19 + p24 * x20;
172         pc[19] = m20 = p5 * x16 + p10 * x17 + p15 * x18 + p20 * x19 + p25 * x20;
173 
174         pc[20] = m21 = p1 * x21 + p6 * x22 + p11 * x23 + p16 * x24 + p21 * x25;
175         pc[21] = m22 = p2 * x21 + p7 * x22 + p12 * x23 + p17 * x24 + p22 * x25;
176         pc[22] = m23 = p3 * x21 + p8 * x22 + p13 * x23 + p18 * x24 + p23 * x25;
177         pc[23] = m24 = p4 * x21 + p9 * x22 + p14 * x23 + p19 * x24 + p24 * x25;
178         pc[24] = m25 = p5 * x21 + p10 * x22 + p15 * x23 + p20 * x24 + p25 * x25;
179 
180         nz = bi[row + 1] - diag_offset[row] - 1;
181         pv += 25;
182         for (j = 0; j < nz; j++) {
183           x1  = pv[0];
184           x2  = pv[1];
185           x3  = pv[2];
186           x4  = pv[3];
187           x5  = pv[4];
188           x6  = pv[5];
189           x7  = pv[6];
190           x8  = pv[7];
191           x9  = pv[8];
192           x10 = pv[9];
193           x11 = pv[10];
194           x12 = pv[11];
195           x13 = pv[12];
196           x14 = pv[13];
197           x15 = pv[14];
198           x16 = pv[15];
199           x17 = pv[16];
200           x18 = pv[17];
201           x19 = pv[18];
202           x20 = pv[19];
203           x21 = pv[20];
204           x22 = pv[21];
205           x23 = pv[22];
206           x24 = pv[23];
207           x25 = pv[24];
208           x   = rtmp + 25 * pj[j];
209           x[0] -= m1 * x1 + m6 * x2 + m11 * x3 + m16 * x4 + m21 * x5;
210           x[1] -= m2 * x1 + m7 * x2 + m12 * x3 + m17 * x4 + m22 * x5;
211           x[2] -= m3 * x1 + m8 * x2 + m13 * x3 + m18 * x4 + m23 * x5;
212           x[3] -= m4 * x1 + m9 * x2 + m14 * x3 + m19 * x4 + m24 * x5;
213           x[4] -= m5 * x1 + m10 * x2 + m15 * x3 + m20 * x4 + m25 * x5;
214 
215           x[5] -= m1 * x6 + m6 * x7 + m11 * x8 + m16 * x9 + m21 * x10;
216           x[6] -= m2 * x6 + m7 * x7 + m12 * x8 + m17 * x9 + m22 * x10;
217           x[7] -= m3 * x6 + m8 * x7 + m13 * x8 + m18 * x9 + m23 * x10;
218           x[8] -= m4 * x6 + m9 * x7 + m14 * x8 + m19 * x9 + m24 * x10;
219           x[9] -= m5 * x6 + m10 * x7 + m15 * x8 + m20 * x9 + m25 * x10;
220 
221           x[10] -= m1 * x11 + m6 * x12 + m11 * x13 + m16 * x14 + m21 * x15;
222           x[11] -= m2 * x11 + m7 * x12 + m12 * x13 + m17 * x14 + m22 * x15;
223           x[12] -= m3 * x11 + m8 * x12 + m13 * x13 + m18 * x14 + m23 * x15;
224           x[13] -= m4 * x11 + m9 * x12 + m14 * x13 + m19 * x14 + m24 * x15;
225           x[14] -= m5 * x11 + m10 * x12 + m15 * x13 + m20 * x14 + m25 * x15;
226 
227           x[15] -= m1 * x16 + m6 * x17 + m11 * x18 + m16 * x19 + m21 * x20;
228           x[16] -= m2 * x16 + m7 * x17 + m12 * x18 + m17 * x19 + m22 * x20;
229           x[17] -= m3 * x16 + m8 * x17 + m13 * x18 + m18 * x19 + m23 * x20;
230           x[18] -= m4 * x16 + m9 * x17 + m14 * x18 + m19 * x19 + m24 * x20;
231           x[19] -= m5 * x16 + m10 * x17 + m15 * x18 + m20 * x19 + m25 * x20;
232 
233           x[20] -= m1 * x21 + m6 * x22 + m11 * x23 + m16 * x24 + m21 * x25;
234           x[21] -= m2 * x21 + m7 * x22 + m12 * x23 + m17 * x24 + m22 * x25;
235           x[22] -= m3 * x21 + m8 * x22 + m13 * x23 + m18 * x24 + m23 * x25;
236           x[23] -= m4 * x21 + m9 * x22 + m14 * x23 + m19 * x24 + m24 * x25;
237           x[24] -= m5 * x21 + m10 * x22 + m15 * x23 + m20 * x24 + m25 * x25;
238 
239           pv += 25;
240         }
241         PetscCall(PetscLogFlops(250.0 * nz + 225.0));
242       }
243       row = *ajtmp++;
244     }
245     /* finished row so stick it into b->a */
246     pv = ba + 25 * bi[i];
247     pj = bj + bi[i];
248     nz = bi[i + 1] - bi[i];
249     for (j = 0; j < nz; j++) {
250 #if defined(PETSC_USE_MEMCPY)
251       PetscCall(PetscArraycpy(pv, rtmp + 25 * pj[j], 25));
252 #else
253       x      = rtmp + 25 * pj[j];
254       pv[0]  = x[0];
255       pv[1]  = x[1];
256       pv[2]  = x[2];
257       pv[3]  = x[3];
258       pv[4]  = x[4];
259       pv[5]  = x[5];
260       pv[6]  = x[6];
261       pv[7]  = x[7];
262       pv[8]  = x[8];
263       pv[9]  = x[9];
264       pv[10] = x[10];
265       pv[11] = x[11];
266       pv[12] = x[12];
267       pv[13] = x[13];
268       pv[14] = x[14];
269       pv[15] = x[15];
270       pv[16] = x[16];
271       pv[17] = x[17];
272       pv[18] = x[18];
273       pv[19] = x[19];
274       pv[20] = x[20];
275       pv[21] = x[21];
276       pv[22] = x[22];
277       pv[23] = x[23];
278       pv[24] = x[24];
279 #endif
280       pv += 25;
281     }
282     /* invert diagonal block */
283     w = ba + 25 * diag_offset[i];
284     PetscCall(PetscKernel_A_gets_inverse_A_5(w, ipvt, work, shift, allowzeropivot, &zeropivotdetected));
285     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
286   }
287 
288   PetscCall(PetscFree(rtmp));
289   PetscCall(ISRestoreIndices(isicol, &ic));
290   PetscCall(ISRestoreIndices(isrow, &r));
291 
292   C->ops->solve          = MatSolve_SeqBAIJ_5_inplace;
293   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_5_inplace;
294   C->assembled           = PETSC_TRUE;
295 
296   PetscCall(PetscLogFlops(1.333333333333 * 5 * 5 * 5 * b->mbs)); /* from inverting diagonal blocks */
297   PetscFunctionReturn(PETSC_SUCCESS);
298 }
299 
300 /* MatLUFactorNumeric_SeqBAIJ_5 -
301      copied from MatLUFactorNumeric_SeqBAIJ_N_inplace() and manually re-implemented
302        PetscKernel_A_gets_A_times_B()
303        PetscKernel_A_gets_A_minus_B_times_C()
304        PetscKernel_A_gets_inverse_A()
305 */
306 
MatLUFactorNumeric_SeqBAIJ_5(Mat B,Mat A,const MatFactorInfo * info)307 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_5(Mat B, Mat A, const MatFactorInfo *info)
308 {
309   Mat             C = B;
310   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
311   IS              isrow = b->row, isicol = b->icol;
312   const PetscInt *r, *ic;
313   PetscInt        i, j, k, nz, nzL, row;
314   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
315   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
316   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a, work[25];
317   PetscInt        flg, ipvt[5];
318   PetscReal       shift = info->shiftamount;
319   PetscBool       allowzeropivot, zeropivotdetected;
320 
321   PetscFunctionBegin;
322   allowzeropivot = PetscNot(A->erroriffailure);
323   PetscCall(ISGetIndices(isrow, &r));
324   PetscCall(ISGetIndices(isicol, &ic));
325 
326   /* generate work space needed by the factorization */
327   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
328   PetscCall(PetscArrayzero(rtmp, bs2 * n));
329 
330   for (i = 0; i < n; i++) {
331     /* zero rtmp */
332     /* L part */
333     nz    = bi[i + 1] - bi[i];
334     bjtmp = bj + bi[i];
335     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
336 
337     /* U part */
338     nz    = bdiag[i] - bdiag[i + 1];
339     bjtmp = bj + bdiag[i + 1] + 1;
340     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
341 
342     /* load in initial (unfactored row) */
343     nz    = ai[r[i] + 1] - ai[r[i]];
344     ajtmp = aj + ai[r[i]];
345     v     = aa + bs2 * ai[r[i]];
346     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ic[ajtmp[j]], v + bs2 * j, bs2));
347 
348     /* elimination */
349     bjtmp = bj + bi[i];
350     nzL   = bi[i + 1] - bi[i];
351     for (k = 0; k < nzL; k++) {
352       row = bjtmp[k];
353       pc  = rtmp + bs2 * row;
354       for (flg = 0, j = 0; j < bs2; j++) {
355         if (pc[j] != 0.0) {
356           flg = 1;
357           break;
358         }
359       }
360       if (flg) {
361         pv = b->a + bs2 * bdiag[row];
362         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
363         PetscCall(PetscKernel_A_gets_A_times_B_5(pc, pv, mwork));
364 
365         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
366         pv = b->a + bs2 * (bdiag[row + 1] + 1);
367         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
368         for (j = 0; j < nz; j++) {
369           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
370           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
371           v = rtmp + bs2 * pj[j];
372           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_5(v, pc, pv));
373           pv += bs2;
374         }
375         PetscCall(PetscLogFlops(250.0 * nz + 225)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
376       }
377     }
378 
379     /* finished row so stick it into b->a */
380     /* L part */
381     pv = b->a + bs2 * bi[i];
382     pj = b->j + bi[i];
383     nz = bi[i + 1] - bi[i];
384     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
385 
386     /* Mark diagonal and invert diagonal for simpler triangular solves */
387     pv = b->a + bs2 * bdiag[i];
388     pj = b->j + bdiag[i];
389     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
390     PetscCall(PetscKernel_A_gets_inverse_A_5(pv, ipvt, work, shift, allowzeropivot, &zeropivotdetected));
391     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
392 
393     /* U part */
394     pv = b->a + bs2 * (bdiag[i + 1] + 1);
395     pj = b->j + bdiag[i + 1] + 1;
396     nz = bdiag[i] - bdiag[i + 1] - 1;
397     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
398   }
399 
400   PetscCall(PetscFree2(rtmp, mwork));
401   PetscCall(ISRestoreIndices(isicol, &ic));
402   PetscCall(ISRestoreIndices(isrow, &r));
403 
404   C->ops->solve          = MatSolve_SeqBAIJ_5;
405   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_5;
406   C->assembled           = PETSC_TRUE;
407 
408   PetscCall(PetscLogFlops(1.333333333333 * 5 * 5 * 5 * n)); /* from inverting diagonal blocks */
409   PetscFunctionReturn(PETSC_SUCCESS);
410 }
411 
412 /*
413       Version for when blocks are 5 by 5 Using natural ordering
414 */
MatILUFactorNumeric_SeqBAIJ_5_NaturalOrdering_inplace(Mat C,Mat A,const MatFactorInfo * info)415 PetscErrorCode MatILUFactorNumeric_SeqBAIJ_5_NaturalOrdering_inplace(Mat C, Mat A, const MatFactorInfo *info)
416 {
417   Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
418   PetscInt     i, j, n = a->mbs, *bi = b->i, *bj = b->j, ipvt[5];
419   PetscInt    *ajtmpold, *ajtmp, nz, row;
420   PetscInt    *diag_offset = b->diag, *ai = a->i, *aj = a->j, *pj;
421   MatScalar   *pv, *v, *rtmp, *pc, *w, *x;
422   MatScalar    x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
423   MatScalar    x16, x17, x18, x19, x20, x21, x22, x23, x24, x25;
424   MatScalar    p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15;
425   MatScalar    p16, p17, p18, p19, p20, p21, p22, p23, p24, p25;
426   MatScalar    m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15;
427   MatScalar    m16, m17, m18, m19, m20, m21, m22, m23, m24, m25;
428   MatScalar   *ba = b->a, *aa = a->a, work[25];
429   PetscReal    shift = info->shiftamount;
430   PetscBool    allowzeropivot, zeropivotdetected;
431 
432   PetscFunctionBegin;
433   allowzeropivot = PetscNot(A->erroriffailure);
434   PetscCall(PetscMalloc1(25 * (n + 1), &rtmp));
435   for (i = 0; i < n; i++) {
436     nz    = bi[i + 1] - bi[i];
437     ajtmp = bj + bi[i];
438     for (j = 0; j < nz; j++) {
439       x    = rtmp + 25 * ajtmp[j];
440       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = x[9] = 0.0;
441       x[10] = x[11] = x[12] = x[13] = x[14] = x[15] = 0.0;
442       x[16] = x[17] = x[18] = x[19] = x[20] = x[21] = x[22] = x[23] = x[24] = 0.0;
443     }
444     /* load in initial (unfactored row) */
445     nz       = ai[i + 1] - ai[i];
446     ajtmpold = aj + ai[i];
447     v        = aa + 25 * ai[i];
448     for (j = 0; j < nz; j++) {
449       x     = rtmp + 25 * ajtmpold[j];
450       x[0]  = v[0];
451       x[1]  = v[1];
452       x[2]  = v[2];
453       x[3]  = v[3];
454       x[4]  = v[4];
455       x[5]  = v[5];
456       x[6]  = v[6];
457       x[7]  = v[7];
458       x[8]  = v[8];
459       x[9]  = v[9];
460       x[10] = v[10];
461       x[11] = v[11];
462       x[12] = v[12];
463       x[13] = v[13];
464       x[14] = v[14];
465       x[15] = v[15];
466       x[16] = v[16];
467       x[17] = v[17];
468       x[18] = v[18];
469       x[19] = v[19];
470       x[20] = v[20];
471       x[21] = v[21];
472       x[22] = v[22];
473       x[23] = v[23];
474       x[24] = v[24];
475       v += 25;
476     }
477     row = *ajtmp++;
478     while (row < i) {
479       pc  = rtmp + 25 * row;
480       p1  = pc[0];
481       p2  = pc[1];
482       p3  = pc[2];
483       p4  = pc[3];
484       p5  = pc[4];
485       p6  = pc[5];
486       p7  = pc[6];
487       p8  = pc[7];
488       p9  = pc[8];
489       p10 = pc[9];
490       p11 = pc[10];
491       p12 = pc[11];
492       p13 = pc[12];
493       p14 = pc[13];
494       p15 = pc[14];
495       p16 = pc[15];
496       p17 = pc[16];
497       p18 = pc[17];
498       p19 = pc[18];
499       p20 = pc[19];
500       p21 = pc[20];
501       p22 = pc[21];
502       p23 = pc[22];
503       p24 = pc[23];
504       p25 = pc[24];
505       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0 || p10 != 0.0 || p11 != 0.0 || p12 != 0.0 || p13 != 0.0 || p14 != 0.0 || p15 != 0.0 || p16 != 0.0 || p17 != 0.0 || p18 != 0.0 || p19 != 0.0 || p20 != 0.0 || p21 != 0.0 || p22 != 0.0 || p23 != 0.0 || p24 != 0.0 || p25 != 0.0) {
506         pv    = ba + 25 * diag_offset[row];
507         pj    = bj + diag_offset[row] + 1;
508         x1    = pv[0];
509         x2    = pv[1];
510         x3    = pv[2];
511         x4    = pv[3];
512         x5    = pv[4];
513         x6    = pv[5];
514         x7    = pv[6];
515         x8    = pv[7];
516         x9    = pv[8];
517         x10   = pv[9];
518         x11   = pv[10];
519         x12   = pv[11];
520         x13   = pv[12];
521         x14   = pv[13];
522         x15   = pv[14];
523         x16   = pv[15];
524         x17   = pv[16];
525         x18   = pv[17];
526         x19   = pv[18];
527         x20   = pv[19];
528         x21   = pv[20];
529         x22   = pv[21];
530         x23   = pv[22];
531         x24   = pv[23];
532         x25   = pv[24];
533         pc[0] = m1 = p1 * x1 + p6 * x2 + p11 * x3 + p16 * x4 + p21 * x5;
534         pc[1] = m2 = p2 * x1 + p7 * x2 + p12 * x3 + p17 * x4 + p22 * x5;
535         pc[2] = m3 = p3 * x1 + p8 * x2 + p13 * x3 + p18 * x4 + p23 * x5;
536         pc[3] = m4 = p4 * x1 + p9 * x2 + p14 * x3 + p19 * x4 + p24 * x5;
537         pc[4] = m5 = p5 * x1 + p10 * x2 + p15 * x3 + p20 * x4 + p25 * x5;
538 
539         pc[5] = m6 = p1 * x6 + p6 * x7 + p11 * x8 + p16 * x9 + p21 * x10;
540         pc[6] = m7 = p2 * x6 + p7 * x7 + p12 * x8 + p17 * x9 + p22 * x10;
541         pc[7] = m8 = p3 * x6 + p8 * x7 + p13 * x8 + p18 * x9 + p23 * x10;
542         pc[8] = m9 = p4 * x6 + p9 * x7 + p14 * x8 + p19 * x9 + p24 * x10;
543         pc[9] = m10 = p5 * x6 + p10 * x7 + p15 * x8 + p20 * x9 + p25 * x10;
544 
545         pc[10] = m11 = p1 * x11 + p6 * x12 + p11 * x13 + p16 * x14 + p21 * x15;
546         pc[11] = m12 = p2 * x11 + p7 * x12 + p12 * x13 + p17 * x14 + p22 * x15;
547         pc[12] = m13 = p3 * x11 + p8 * x12 + p13 * x13 + p18 * x14 + p23 * x15;
548         pc[13] = m14 = p4 * x11 + p9 * x12 + p14 * x13 + p19 * x14 + p24 * x15;
549         pc[14] = m15 = p5 * x11 + p10 * x12 + p15 * x13 + p20 * x14 + p25 * x15;
550 
551         pc[15] = m16 = p1 * x16 + p6 * x17 + p11 * x18 + p16 * x19 + p21 * x20;
552         pc[16] = m17 = p2 * x16 + p7 * x17 + p12 * x18 + p17 * x19 + p22 * x20;
553         pc[17] = m18 = p3 * x16 + p8 * x17 + p13 * x18 + p18 * x19 + p23 * x20;
554         pc[18] = m19 = p4 * x16 + p9 * x17 + p14 * x18 + p19 * x19 + p24 * x20;
555         pc[19] = m20 = p5 * x16 + p10 * x17 + p15 * x18 + p20 * x19 + p25 * x20;
556 
557         pc[20] = m21 = p1 * x21 + p6 * x22 + p11 * x23 + p16 * x24 + p21 * x25;
558         pc[21] = m22 = p2 * x21 + p7 * x22 + p12 * x23 + p17 * x24 + p22 * x25;
559         pc[22] = m23 = p3 * x21 + p8 * x22 + p13 * x23 + p18 * x24 + p23 * x25;
560         pc[23] = m24 = p4 * x21 + p9 * x22 + p14 * x23 + p19 * x24 + p24 * x25;
561         pc[24] = m25 = p5 * x21 + p10 * x22 + p15 * x23 + p20 * x24 + p25 * x25;
562 
563         nz = bi[row + 1] - diag_offset[row] - 1;
564         pv += 25;
565         for (j = 0; j < nz; j++) {
566           x1  = pv[0];
567           x2  = pv[1];
568           x3  = pv[2];
569           x4  = pv[3];
570           x5  = pv[4];
571           x6  = pv[5];
572           x7  = pv[6];
573           x8  = pv[7];
574           x9  = pv[8];
575           x10 = pv[9];
576           x11 = pv[10];
577           x12 = pv[11];
578           x13 = pv[12];
579           x14 = pv[13];
580           x15 = pv[14];
581           x16 = pv[15];
582           x17 = pv[16];
583           x18 = pv[17];
584           x19 = pv[18];
585           x20 = pv[19];
586           x21 = pv[20];
587           x22 = pv[21];
588           x23 = pv[22];
589           x24 = pv[23];
590           x25 = pv[24];
591           x   = rtmp + 25 * pj[j];
592           x[0] -= m1 * x1 + m6 * x2 + m11 * x3 + m16 * x4 + m21 * x5;
593           x[1] -= m2 * x1 + m7 * x2 + m12 * x3 + m17 * x4 + m22 * x5;
594           x[2] -= m3 * x1 + m8 * x2 + m13 * x3 + m18 * x4 + m23 * x5;
595           x[3] -= m4 * x1 + m9 * x2 + m14 * x3 + m19 * x4 + m24 * x5;
596           x[4] -= m5 * x1 + m10 * x2 + m15 * x3 + m20 * x4 + m25 * x5;
597 
598           x[5] -= m1 * x6 + m6 * x7 + m11 * x8 + m16 * x9 + m21 * x10;
599           x[6] -= m2 * x6 + m7 * x7 + m12 * x8 + m17 * x9 + m22 * x10;
600           x[7] -= m3 * x6 + m8 * x7 + m13 * x8 + m18 * x9 + m23 * x10;
601           x[8] -= m4 * x6 + m9 * x7 + m14 * x8 + m19 * x9 + m24 * x10;
602           x[9] -= m5 * x6 + m10 * x7 + m15 * x8 + m20 * x9 + m25 * x10;
603 
604           x[10] -= m1 * x11 + m6 * x12 + m11 * x13 + m16 * x14 + m21 * x15;
605           x[11] -= m2 * x11 + m7 * x12 + m12 * x13 + m17 * x14 + m22 * x15;
606           x[12] -= m3 * x11 + m8 * x12 + m13 * x13 + m18 * x14 + m23 * x15;
607           x[13] -= m4 * x11 + m9 * x12 + m14 * x13 + m19 * x14 + m24 * x15;
608           x[14] -= m5 * x11 + m10 * x12 + m15 * x13 + m20 * x14 + m25 * x15;
609 
610           x[15] -= m1 * x16 + m6 * x17 + m11 * x18 + m16 * x19 + m21 * x20;
611           x[16] -= m2 * x16 + m7 * x17 + m12 * x18 + m17 * x19 + m22 * x20;
612           x[17] -= m3 * x16 + m8 * x17 + m13 * x18 + m18 * x19 + m23 * x20;
613           x[18] -= m4 * x16 + m9 * x17 + m14 * x18 + m19 * x19 + m24 * x20;
614           x[19] -= m5 * x16 + m10 * x17 + m15 * x18 + m20 * x19 + m25 * x20;
615 
616           x[20] -= m1 * x21 + m6 * x22 + m11 * x23 + m16 * x24 + m21 * x25;
617           x[21] -= m2 * x21 + m7 * x22 + m12 * x23 + m17 * x24 + m22 * x25;
618           x[22] -= m3 * x21 + m8 * x22 + m13 * x23 + m18 * x24 + m23 * x25;
619           x[23] -= m4 * x21 + m9 * x22 + m14 * x23 + m19 * x24 + m24 * x25;
620           x[24] -= m5 * x21 + m10 * x22 + m15 * x23 + m20 * x24 + m25 * x25;
621           pv += 25;
622         }
623         PetscCall(PetscLogFlops(250.0 * nz + 225.0));
624       }
625       row = *ajtmp++;
626     }
627     /* finished row so stick it into b->a */
628     pv = ba + 25 * bi[i];
629     pj = bj + bi[i];
630     nz = bi[i + 1] - bi[i];
631     for (j = 0; j < nz; j++) {
632       x      = rtmp + 25 * pj[j];
633       pv[0]  = x[0];
634       pv[1]  = x[1];
635       pv[2]  = x[2];
636       pv[3]  = x[3];
637       pv[4]  = x[4];
638       pv[5]  = x[5];
639       pv[6]  = x[6];
640       pv[7]  = x[7];
641       pv[8]  = x[8];
642       pv[9]  = x[9];
643       pv[10] = x[10];
644       pv[11] = x[11];
645       pv[12] = x[12];
646       pv[13] = x[13];
647       pv[14] = x[14];
648       pv[15] = x[15];
649       pv[16] = x[16];
650       pv[17] = x[17];
651       pv[18] = x[18];
652       pv[19] = x[19];
653       pv[20] = x[20];
654       pv[21] = x[21];
655       pv[22] = x[22];
656       pv[23] = x[23];
657       pv[24] = x[24];
658       pv += 25;
659     }
660     /* invert diagonal block */
661     w = ba + 25 * diag_offset[i];
662     PetscCall(PetscKernel_A_gets_inverse_A_5(w, ipvt, work, shift, allowzeropivot, &zeropivotdetected));
663     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
664   }
665 
666   PetscCall(PetscFree(rtmp));
667 
668   C->ops->solve          = MatSolve_SeqBAIJ_5_NaturalOrdering_inplace;
669   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_5_NaturalOrdering_inplace;
670   C->assembled           = PETSC_TRUE;
671 
672   PetscCall(PetscLogFlops(1.333333333333 * 5 * 5 * 5 * b->mbs)); /* from inverting diagonal blocks */
673   PetscFunctionReturn(PETSC_SUCCESS);
674 }
675 
MatLUFactorNumeric_SeqBAIJ_5_NaturalOrdering(Mat B,Mat A,const MatFactorInfo * info)676 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_5_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
677 {
678   Mat             C = B;
679   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
680   PetscInt        i, j, k, nz, nzL, row;
681   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
682   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
683   MatScalar      *rtmp, *pc, *mwork, *v, *vv, *pv, *aa = a->a, work[25];
684   PetscInt        flg, ipvt[5];
685   PetscReal       shift = info->shiftamount;
686   PetscBool       allowzeropivot, zeropivotdetected;
687 
688   PetscFunctionBegin;
689   allowzeropivot = PetscNot(A->erroriffailure);
690 
691   /* generate work space needed by the factorization */
692   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
693   PetscCall(PetscArrayzero(rtmp, bs2 * n));
694 
695   for (i = 0; i < n; i++) {
696     /* zero rtmp */
697     /* L part */
698     nz    = bi[i + 1] - bi[i];
699     bjtmp = bj + bi[i];
700     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
701 
702     /* U part */
703     nz    = bdiag[i] - bdiag[i + 1];
704     bjtmp = bj + bdiag[i + 1] + 1;
705     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
706 
707     /* load in initial (unfactored row) */
708     nz    = ai[i + 1] - ai[i];
709     ajtmp = aj + ai[i];
710     v     = aa + bs2 * ai[i];
711     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
712 
713     /* elimination */
714     bjtmp = bj + bi[i];
715     nzL   = bi[i + 1] - bi[i];
716     for (k = 0; k < nzL; k++) {
717       row = bjtmp[k];
718       pc  = rtmp + bs2 * row;
719       for (flg = 0, j = 0; j < bs2; j++) {
720         if (pc[j] != 0.0) {
721           flg = 1;
722           break;
723         }
724       }
725       if (flg) {
726         pv = b->a + bs2 * bdiag[row];
727         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
728         PetscCall(PetscKernel_A_gets_A_times_B_5(pc, pv, mwork));
729 
730         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
731         pv = b->a + bs2 * (bdiag[row + 1] + 1);
732         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
733         for (j = 0; j < nz; j++) {
734           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
735           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
736           vv = rtmp + bs2 * pj[j];
737           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_5(vv, pc, pv));
738           pv += bs2;
739         }
740         PetscCall(PetscLogFlops(250.0 * nz + 225)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
741       }
742     }
743 
744     /* finished row so stick it into b->a */
745     /* L part */
746     pv = b->a + bs2 * bi[i];
747     pj = b->j + bi[i];
748     nz = bi[i + 1] - bi[i];
749     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
750 
751     /* Mark diagonal and invert diagonal for simpler triangular solves */
752     pv = b->a + bs2 * bdiag[i];
753     pj = b->j + bdiag[i];
754     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
755     PetscCall(PetscKernel_A_gets_inverse_A_5(pv, ipvt, work, shift, allowzeropivot, &zeropivotdetected));
756     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
757 
758     /* U part */
759     pv = b->a + bs2 * (bdiag[i + 1] + 1);
760     pj = b->j + bdiag[i + 1] + 1;
761     nz = bdiag[i] - bdiag[i + 1] - 1;
762     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
763   }
764   PetscCall(PetscFree2(rtmp, mwork));
765 
766   C->ops->solve          = MatSolve_SeqBAIJ_5_NaturalOrdering;
767   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_5_NaturalOrdering;
768   C->assembled           = PETSC_TRUE;
769 
770   PetscCall(PetscLogFlops(1.333333333333 * 5 * 5 * 5 * n)); /* from inverting diagonal blocks */
771   PetscFunctionReturn(PETSC_SUCCESS);
772 }
773 
774 /*
775    Version for when blocks are 9 by 9
776  */
777 #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
778   #include <immintrin.h>
MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B,Mat A,const MatFactorInfo * info)779 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
780 {
781   Mat             C = B;
782   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
783   PetscInt        i, j, k, nz, nzL, row;
784   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
785   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
786   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
787   PetscInt        flg;
788   PetscReal       shift = info->shiftamount;
789   PetscBool       allowzeropivot, zeropivotdetected;
790 
791   PetscFunctionBegin;
792   allowzeropivot = PetscNot(A->erroriffailure);
793 
794   /* generate work space needed by the factorization */
795   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
796   PetscCall(PetscArrayzero(rtmp, bs2 * n));
797 
798   for (i = 0; i < n; i++) {
799     /* zero rtmp */
800     /* L part */
801     nz    = bi[i + 1] - bi[i];
802     bjtmp = bj + bi[i];
803     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
804 
805     /* U part */
806     nz    = bdiag[i] - bdiag[i + 1];
807     bjtmp = bj + bdiag[i + 1] + 1;
808     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
809 
810     /* load in initial (unfactored row) */
811     nz    = ai[i + 1] - ai[i];
812     ajtmp = aj + ai[i];
813     v     = aa + bs2 * ai[i];
814     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
815 
816     /* elimination */
817     bjtmp = bj + bi[i];
818     nzL   = bi[i + 1] - bi[i];
819     for (k = 0; k < nzL; k++) {
820       row = bjtmp[k];
821       pc  = rtmp + bs2 * row;
822       for (flg = 0, j = 0; j < bs2; j++) {
823         if (pc[j] != 0.0) {
824           flg = 1;
825           break;
826         }
827       }
828       if (flg) {
829         pv = b->a + bs2 * bdiag[row];
830         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
831         PetscCall(PetscKernel_A_gets_A_times_B_9(pc, pv, mwork));
832 
833         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
834         pv = b->a + bs2 * (bdiag[row + 1] + 1);
835         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
836         for (j = 0; j < nz; j++) {
837           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
838           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
839           v = rtmp + bs2 * pj[j];
840           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_9(v, pc, pv + 81 * j));
841           /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
842         }
843         PetscCall(PetscLogFlops(1458 * nz + 1377)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
844       }
845     }
846 
847     /* finished row so stick it into b->a */
848     /* L part */
849     pv = b->a + bs2 * bi[i];
850     pj = b->j + bi[i];
851     nz = bi[i + 1] - bi[i];
852     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
853 
854     /* Mark diagonal and invert diagonal for simpler triangular solves */
855     pv = b->a + bs2 * bdiag[i];
856     pj = b->j + bdiag[i];
857     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
858     PetscCall(PetscKernel_A_gets_inverse_A_9(pv, shift, allowzeropivot, &zeropivotdetected));
859     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
860 
861     /* U part */
862     pv = b->a + bs2 * (bdiag[i + 1] + 1);
863     pj = b->j + bdiag[i + 1] + 1;
864     nz = bdiag[i] - bdiag[i + 1] - 1;
865     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
866   }
867   PetscCall(PetscFree2(rtmp, mwork));
868 
869   C->ops->solve          = MatSolve_SeqBAIJ_9_NaturalOrdering;
870   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
871   C->assembled           = PETSC_TRUE;
872 
873   PetscCall(PetscLogFlops(1.333333333333 * 9 * 9 * 9 * n)); /* from inverting diagonal blocks */
874   PetscFunctionReturn(PETSC_SUCCESS);
875 }
876 
MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A,Vec bb,Vec xx)877 PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A, Vec bb, Vec xx)
878 {
879   Mat_SeqBAIJ       *a  = (Mat_SeqBAIJ *)A->data;
880   const PetscInt    *ai = a->i, *aj = a->j, *adiag = a->diag, *vi;
881   PetscInt           i, k, n                       = a->mbs;
882   PetscInt           nz, bs = A->rmap->bs, bs2 = a->bs2;
883   const MatScalar   *aa = a->a, *v;
884   PetscScalar       *x, *s, *t, *ls;
885   const PetscScalar *b;
886   __m256d            a0, a1, a2, a3, a4, a5, w0, w1, w2, w3, s0, s1, s2, v0, v1, v2, v3;
887 
888   PetscFunctionBegin;
889   PetscCall(VecGetArrayRead(bb, &b));
890   PetscCall(VecGetArray(xx, &x));
891   t = a->solve_work;
892 
893   /* forward solve the lower triangular */
894   PetscCall(PetscArraycpy(t, b, bs)); /* copy 1st block of b to t */
895 
896   for (i = 1; i < n; i++) {
897     v  = aa + bs2 * ai[i];
898     vi = aj + ai[i];
899     nz = ai[i + 1] - ai[i];
900     s  = t + bs * i;
901     PetscCall(PetscArraycpy(s, b + bs * i, bs)); /* copy i_th block of b to t */
902 
903     __m256d s0, s1, s2;
904     s0 = _mm256_loadu_pd(s + 0);
905     s1 = _mm256_loadu_pd(s + 4);
906     s2 = _mm256_maskload_pd(s + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
907 
908     for (k = 0; k < nz; k++) {
909       w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
910       a0 = _mm256_loadu_pd(&v[0]);
911       s0 = _mm256_fnmadd_pd(a0, w0, s0);
912       a1 = _mm256_loadu_pd(&v[4]);
913       s1 = _mm256_fnmadd_pd(a1, w0, s1);
914       a2 = _mm256_loadu_pd(&v[8]);
915       s2 = _mm256_fnmadd_pd(a2, w0, s2);
916 
917       w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
918       a3 = _mm256_loadu_pd(&v[9]);
919       s0 = _mm256_fnmadd_pd(a3, w1, s0);
920       a4 = _mm256_loadu_pd(&v[13]);
921       s1 = _mm256_fnmadd_pd(a4, w1, s1);
922       a5 = _mm256_loadu_pd(&v[17]);
923       s2 = _mm256_fnmadd_pd(a5, w1, s2);
924 
925       w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
926       a0 = _mm256_loadu_pd(&v[18]);
927       s0 = _mm256_fnmadd_pd(a0, w2, s0);
928       a1 = _mm256_loadu_pd(&v[22]);
929       s1 = _mm256_fnmadd_pd(a1, w2, s1);
930       a2 = _mm256_loadu_pd(&v[26]);
931       s2 = _mm256_fnmadd_pd(a2, w2, s2);
932 
933       w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
934       a3 = _mm256_loadu_pd(&v[27]);
935       s0 = _mm256_fnmadd_pd(a3, w3, s0);
936       a4 = _mm256_loadu_pd(&v[31]);
937       s1 = _mm256_fnmadd_pd(a4, w3, s1);
938       a5 = _mm256_loadu_pd(&v[35]);
939       s2 = _mm256_fnmadd_pd(a5, w3, s2);
940 
941       w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
942       a0 = _mm256_loadu_pd(&v[36]);
943       s0 = _mm256_fnmadd_pd(a0, w0, s0);
944       a1 = _mm256_loadu_pd(&v[40]);
945       s1 = _mm256_fnmadd_pd(a1, w0, s1);
946       a2 = _mm256_loadu_pd(&v[44]);
947       s2 = _mm256_fnmadd_pd(a2, w0, s2);
948 
949       w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
950       a3 = _mm256_loadu_pd(&v[45]);
951       s0 = _mm256_fnmadd_pd(a3, w1, s0);
952       a4 = _mm256_loadu_pd(&v[49]);
953       s1 = _mm256_fnmadd_pd(a4, w1, s1);
954       a5 = _mm256_loadu_pd(&v[53]);
955       s2 = _mm256_fnmadd_pd(a5, w1, s2);
956 
957       w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
958       a0 = _mm256_loadu_pd(&v[54]);
959       s0 = _mm256_fnmadd_pd(a0, w2, s0);
960       a1 = _mm256_loadu_pd(&v[58]);
961       s1 = _mm256_fnmadd_pd(a1, w2, s1);
962       a2 = _mm256_loadu_pd(&v[62]);
963       s2 = _mm256_fnmadd_pd(a2, w2, s2);
964 
965       w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
966       a3 = _mm256_loadu_pd(&v[63]);
967       s0 = _mm256_fnmadd_pd(a3, w3, s0);
968       a4 = _mm256_loadu_pd(&v[67]);
969       s1 = _mm256_fnmadd_pd(a4, w3, s1);
970       a5 = _mm256_loadu_pd(&v[71]);
971       s2 = _mm256_fnmadd_pd(a5, w3, s2);
972 
973       w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
974       a0 = _mm256_loadu_pd(&v[72]);
975       s0 = _mm256_fnmadd_pd(a0, w0, s0);
976       a1 = _mm256_loadu_pd(&v[76]);
977       s1 = _mm256_fnmadd_pd(a1, w0, s1);
978       a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
979       s2 = _mm256_fnmadd_pd(a2, w0, s2);
980       v += bs2;
981     }
982     _mm256_storeu_pd(&s[0], s0);
983     _mm256_storeu_pd(&s[4], s1);
984     _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
985   }
986 
987   /* backward solve the upper triangular */
988   ls = a->solve_work + A->cmap->n;
989   for (i = n - 1; i >= 0; i--) {
990     v  = aa + bs2 * (adiag[i + 1] + 1);
991     vi = aj + adiag[i + 1] + 1;
992     nz = adiag[i] - adiag[i + 1] - 1;
993     PetscCall(PetscArraycpy(ls, t + i * bs, bs));
994 
995     s0 = _mm256_loadu_pd(ls + 0);
996     s1 = _mm256_loadu_pd(ls + 4);
997     s2 = _mm256_maskload_pd(ls + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
998 
999     for (k = 0; k < nz; k++) {
1000       w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
1001       a0 = _mm256_loadu_pd(&v[0]);
1002       s0 = _mm256_fnmadd_pd(a0, w0, s0);
1003       a1 = _mm256_loadu_pd(&v[4]);
1004       s1 = _mm256_fnmadd_pd(a1, w0, s1);
1005       a2 = _mm256_loadu_pd(&v[8]);
1006       s2 = _mm256_fnmadd_pd(a2, w0, s2);
1007 
1008       /* v += 9; */
1009       w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
1010       a3 = _mm256_loadu_pd(&v[9]);
1011       s0 = _mm256_fnmadd_pd(a3, w1, s0);
1012       a4 = _mm256_loadu_pd(&v[13]);
1013       s1 = _mm256_fnmadd_pd(a4, w1, s1);
1014       a5 = _mm256_loadu_pd(&v[17]);
1015       s2 = _mm256_fnmadd_pd(a5, w1, s2);
1016 
1017       /* v += 9; */
1018       w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
1019       a0 = _mm256_loadu_pd(&v[18]);
1020       s0 = _mm256_fnmadd_pd(a0, w2, s0);
1021       a1 = _mm256_loadu_pd(&v[22]);
1022       s1 = _mm256_fnmadd_pd(a1, w2, s1);
1023       a2 = _mm256_loadu_pd(&v[26]);
1024       s2 = _mm256_fnmadd_pd(a2, w2, s2);
1025 
1026       /* v += 9; */
1027       w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
1028       a3 = _mm256_loadu_pd(&v[27]);
1029       s0 = _mm256_fnmadd_pd(a3, w3, s0);
1030       a4 = _mm256_loadu_pd(&v[31]);
1031       s1 = _mm256_fnmadd_pd(a4, w3, s1);
1032       a5 = _mm256_loadu_pd(&v[35]);
1033       s2 = _mm256_fnmadd_pd(a5, w3, s2);
1034 
1035       /* v += 9; */
1036       w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
1037       a0 = _mm256_loadu_pd(&v[36]);
1038       s0 = _mm256_fnmadd_pd(a0, w0, s0);
1039       a1 = _mm256_loadu_pd(&v[40]);
1040       s1 = _mm256_fnmadd_pd(a1, w0, s1);
1041       a2 = _mm256_loadu_pd(&v[44]);
1042       s2 = _mm256_fnmadd_pd(a2, w0, s2);
1043 
1044       /* v += 9; */
1045       w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
1046       a3 = _mm256_loadu_pd(&v[45]);
1047       s0 = _mm256_fnmadd_pd(a3, w1, s0);
1048       a4 = _mm256_loadu_pd(&v[49]);
1049       s1 = _mm256_fnmadd_pd(a4, w1, s1);
1050       a5 = _mm256_loadu_pd(&v[53]);
1051       s2 = _mm256_fnmadd_pd(a5, w1, s2);
1052 
1053       /* v += 9; */
1054       w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
1055       a0 = _mm256_loadu_pd(&v[54]);
1056       s0 = _mm256_fnmadd_pd(a0, w2, s0);
1057       a1 = _mm256_loadu_pd(&v[58]);
1058       s1 = _mm256_fnmadd_pd(a1, w2, s1);
1059       a2 = _mm256_loadu_pd(&v[62]);
1060       s2 = _mm256_fnmadd_pd(a2, w2, s2);
1061 
1062       /* v += 9; */
1063       w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
1064       a3 = _mm256_loadu_pd(&v[63]);
1065       s0 = _mm256_fnmadd_pd(a3, w3, s0);
1066       a4 = _mm256_loadu_pd(&v[67]);
1067       s1 = _mm256_fnmadd_pd(a4, w3, s1);
1068       a5 = _mm256_loadu_pd(&v[71]);
1069       s2 = _mm256_fnmadd_pd(a5, w3, s2);
1070 
1071       /* v += 9; */
1072       w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
1073       a0 = _mm256_loadu_pd(&v[72]);
1074       s0 = _mm256_fnmadd_pd(a0, w0, s0);
1075       a1 = _mm256_loadu_pd(&v[76]);
1076       s1 = _mm256_fnmadd_pd(a1, w0, s1);
1077       a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
1078       s2 = _mm256_fnmadd_pd(a2, w0, s2);
1079       v += bs2;
1080     }
1081 
1082     _mm256_storeu_pd(&ls[0], s0);
1083     _mm256_storeu_pd(&ls[4], s1);
1084     _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
1085 
1086     w0 = _mm256_setzero_pd();
1087     w1 = _mm256_setzero_pd();
1088     w2 = _mm256_setzero_pd();
1089 
1090     /* first row */
1091     v0 = _mm256_set1_pd(ls[0]);
1092     a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[0]);
1093     w0 = _mm256_fmadd_pd(a0, v0, w0);
1094     a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[4]);
1095     w1 = _mm256_fmadd_pd(a1, v0, w1);
1096     a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[8]);
1097     w2 = _mm256_fmadd_pd(a2, v0, w2);
1098 
1099     /* second row */
1100     v1 = _mm256_set1_pd(ls[1]);
1101     a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[9]);
1102     w0 = _mm256_fmadd_pd(a3, v1, w0);
1103     a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[13]);
1104     w1 = _mm256_fmadd_pd(a4, v1, w1);
1105     a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[17]);
1106     w2 = _mm256_fmadd_pd(a5, v1, w2);
1107 
1108     /* third row */
1109     v2 = _mm256_set1_pd(ls[2]);
1110     a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[18]);
1111     w0 = _mm256_fmadd_pd(a0, v2, w0);
1112     a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[22]);
1113     w1 = _mm256_fmadd_pd(a1, v2, w1);
1114     a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[26]);
1115     w2 = _mm256_fmadd_pd(a2, v2, w2);
1116 
1117     /* fourth row */
1118     v3 = _mm256_set1_pd(ls[3]);
1119     a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[27]);
1120     w0 = _mm256_fmadd_pd(a3, v3, w0);
1121     a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[31]);
1122     w1 = _mm256_fmadd_pd(a4, v3, w1);
1123     a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[35]);
1124     w2 = _mm256_fmadd_pd(a5, v3, w2);
1125 
1126     /* fifth row */
1127     v0 = _mm256_set1_pd(ls[4]);
1128     a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[36]);
1129     w0 = _mm256_fmadd_pd(a0, v0, w0);
1130     a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[40]);
1131     w1 = _mm256_fmadd_pd(a1, v0, w1);
1132     a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[44]);
1133     w2 = _mm256_fmadd_pd(a2, v0, w2);
1134 
1135     /* sixth row */
1136     v1 = _mm256_set1_pd(ls[5]);
1137     a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[45]);
1138     w0 = _mm256_fmadd_pd(a3, v1, w0);
1139     a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[49]);
1140     w1 = _mm256_fmadd_pd(a4, v1, w1);
1141     a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[53]);
1142     w2 = _mm256_fmadd_pd(a5, v1, w2);
1143 
1144     /* seventh row */
1145     v2 = _mm256_set1_pd(ls[6]);
1146     a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[54]);
1147     w0 = _mm256_fmadd_pd(a0, v2, w0);
1148     a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[58]);
1149     w1 = _mm256_fmadd_pd(a1, v2, w1);
1150     a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[62]);
1151     w2 = _mm256_fmadd_pd(a2, v2, w2);
1152 
1153     /* eighth row */
1154     v3 = _mm256_set1_pd(ls[7]);
1155     a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[63]);
1156     w0 = _mm256_fmadd_pd(a3, v3, w0);
1157     a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[67]);
1158     w1 = _mm256_fmadd_pd(a4, v3, w1);
1159     a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[71]);
1160     w2 = _mm256_fmadd_pd(a5, v3, w2);
1161 
1162     /* ninth row */
1163     v0 = _mm256_set1_pd(ls[8]);
1164     a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[72]);
1165     w0 = _mm256_fmadd_pd(a3, v0, w0);
1166     a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[76]);
1167     w1 = _mm256_fmadd_pd(a4, v0, w1);
1168     a2 = _mm256_maskload_pd(&(aa + bs2 * adiag[i])[80], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
1169     w2 = _mm256_fmadd_pd(a2, v0, w2);
1170 
1171     _mm256_storeu_pd(&(t + i * bs)[0], w0);
1172     _mm256_storeu_pd(&(t + i * bs)[4], w1);
1173     _mm256_maskstore_pd(&(t + i * bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), w2);
1174 
1175     PetscCall(PetscArraycpy(x + i * bs, t + i * bs, bs));
1176   }
1177 
1178   PetscCall(VecRestoreArrayRead(bb, &b));
1179   PetscCall(VecRestoreArray(xx, &x));
1180   PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
1181   PetscFunctionReturn(PETSC_SUCCESS);
1182 }
1183 #endif
1184