xref: /petsc/src/mat/impls/baij/seq/baijsolv.c (revision a02648fdf9ec0d41d7b5ca02cb70ddcfa0e65728)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
MatSolve_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx)4 PetscErrorCode MatSolve_SeqBAIJ_N_inplace(Mat A, Vec bb, Vec xx)
5 {
6   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
7   IS                 iscol = a->col, isrow = a->row;
8   const PetscInt    *r, *c, *rout, *cout;
9   const PetscInt     n = a->mbs, *ai = a->i, *aj = a->j, *vi;
10   PetscInt           i, nz;
11   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
12   const MatScalar   *aa = a->a, *v;
13   PetscScalar       *x, *s, *t, *ls;
14   const PetscScalar *b;
15 
16   PetscFunctionBegin;
17   PetscCheck(bs > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Expected bs %" PetscInt_FMT " > 0", bs);
18   PetscCall(VecGetArrayRead(bb, &b));
19   PetscCall(VecGetArray(xx, &x));
20   t = a->solve_work;
21 
22   PetscCall(ISGetIndices(isrow, &rout));
23   r = rout;
24   PetscCall(ISGetIndices(iscol, &cout));
25   c = cout + (n - 1);
26 
27   /* forward solve the lower triangular */
28   PetscCall(PetscArraycpy(t, b + bs * (*r++), bs));
29   for (i = 1; i < n; i++) {
30     v  = aa + bs2 * ai[i];
31     vi = aj + ai[i];
32     nz = a->diag[i] - ai[i];
33     s  = t + bs * i;
34     PetscCall(PetscArraycpy(s, b + bs * (*r++), bs));
35     while (nz--) {
36       PetscKernel_v_gets_v_minus_A_times_w(bs, s, v, t + bs * (*vi++));
37       v += bs2;
38     }
39   }
40   /* backward solve the upper triangular */
41   ls = a->solve_work + A->cmap->n;
42   for (i = n - 1; i >= 0; i--) {
43     v  = aa + bs2 * (a->diag[i] + 1);
44     vi = aj + a->diag[i] + 1;
45     nz = ai[i + 1] - a->diag[i] - 1;
46     PetscCall(PetscArraycpy(ls, t + i * bs, bs));
47     while (nz--) {
48       PetscKernel_v_gets_v_minus_A_times_w(bs, ls, v, t + bs * (*vi++));
49       v += bs2;
50     }
51     PetscKernel_w_gets_A_times_v(bs, ls, aa + bs2 * a->diag[i], t + i * bs);
52     PetscCall(PetscArraycpy(x + bs * (*c--), t + i * bs, bs));
53   }
54 
55   PetscCall(ISRestoreIndices(isrow, &rout));
56   PetscCall(ISRestoreIndices(iscol, &cout));
57   PetscCall(VecRestoreArrayRead(bb, &b));
58   PetscCall(VecRestoreArray(xx, &x));
59   PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
60   PetscFunctionReturn(PETSC_SUCCESS);
61 }
62 
MatSolve_SeqBAIJ_7_inplace(Mat A,Vec bb,Vec xx)63 PetscErrorCode MatSolve_SeqBAIJ_7_inplace(Mat A, Vec bb, Vec xx)
64 {
65   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
66   IS                 iscol = a->col, isrow = a->row;
67   const PetscInt    *r, *c, *ai = a->i, *aj = a->j;
68   const PetscInt    *rout, *cout, *diag = a->diag, *vi, n = a->mbs;
69   PetscInt           i, nz, idx, idt, idc;
70   const MatScalar   *aa = a->a, *v;
71   PetscScalar        s1, s2, s3, s4, s5, s6, s7, x1, x2, x3, x4, x5, x6, x7, *x, *t;
72   const PetscScalar *b;
73 
74   PetscFunctionBegin;
75   PetscCall(VecGetArrayRead(bb, &b));
76   PetscCall(VecGetArray(xx, &x));
77   t = a->solve_work;
78 
79   PetscCall(ISGetIndices(isrow, &rout));
80   r = rout;
81   PetscCall(ISGetIndices(iscol, &cout));
82   c = cout + (n - 1);
83 
84   /* forward solve the lower triangular */
85   idx  = 7 * (*r++);
86   t[0] = b[idx];
87   t[1] = b[1 + idx];
88   t[2] = b[2 + idx];
89   t[3] = b[3 + idx];
90   t[4] = b[4 + idx];
91   t[5] = b[5 + idx];
92   t[6] = b[6 + idx];
93 
94   for (i = 1; i < n; i++) {
95     v   = aa + 49 * ai[i];
96     vi  = aj + ai[i];
97     nz  = diag[i] - ai[i];
98     idx = 7 * (*r++);
99     s1  = b[idx];
100     s2  = b[1 + idx];
101     s3  = b[2 + idx];
102     s4  = b[3 + idx];
103     s5  = b[4 + idx];
104     s6  = b[5 + idx];
105     s7  = b[6 + idx];
106     while (nz--) {
107       idx = 7 * (*vi++);
108       x1  = t[idx];
109       x2  = t[1 + idx];
110       x3  = t[2 + idx];
111       x4  = t[3 + idx];
112       x5  = t[4 + idx];
113       x6  = t[5 + idx];
114       x7  = t[6 + idx];
115       s1 -= v[0] * x1 + v[7] * x2 + v[14] * x3 + v[21] * x4 + v[28] * x5 + v[35] * x6 + v[42] * x7;
116       s2 -= v[1] * x1 + v[8] * x2 + v[15] * x3 + v[22] * x4 + v[29] * x5 + v[36] * x6 + v[43] * x7;
117       s3 -= v[2] * x1 + v[9] * x2 + v[16] * x3 + v[23] * x4 + v[30] * x5 + v[37] * x6 + v[44] * x7;
118       s4 -= v[3] * x1 + v[10] * x2 + v[17] * x3 + v[24] * x4 + v[31] * x5 + v[38] * x6 + v[45] * x7;
119       s5 -= v[4] * x1 + v[11] * x2 + v[18] * x3 + v[25] * x4 + v[32] * x5 + v[39] * x6 + v[46] * x7;
120       s6 -= v[5] * x1 + v[12] * x2 + v[19] * x3 + v[26] * x4 + v[33] * x5 + v[40] * x6 + v[47] * x7;
121       s7 -= v[6] * x1 + v[13] * x2 + v[20] * x3 + v[27] * x4 + v[34] * x5 + v[41] * x6 + v[48] * x7;
122       v += 49;
123     }
124     idx        = 7 * i;
125     t[idx]     = s1;
126     t[1 + idx] = s2;
127     t[2 + idx] = s3;
128     t[3 + idx] = s4;
129     t[4 + idx] = s5;
130     t[5 + idx] = s6;
131     t[6 + idx] = s7;
132   }
133   /* backward solve the upper triangular */
134   for (i = n - 1; i >= 0; i--) {
135     v   = aa + 49 * diag[i] + 49;
136     vi  = aj + diag[i] + 1;
137     nz  = ai[i + 1] - diag[i] - 1;
138     idt = 7 * i;
139     s1  = t[idt];
140     s2  = t[1 + idt];
141     s3  = t[2 + idt];
142     s4  = t[3 + idt];
143     s5  = t[4 + idt];
144     s6  = t[5 + idt];
145     s7  = t[6 + idt];
146     while (nz--) {
147       idx = 7 * (*vi++);
148       x1  = t[idx];
149       x2  = t[1 + idx];
150       x3  = t[2 + idx];
151       x4  = t[3 + idx];
152       x5  = t[4 + idx];
153       x6  = t[5 + idx];
154       x7  = t[6 + idx];
155       s1 -= v[0] * x1 + v[7] * x2 + v[14] * x3 + v[21] * x4 + v[28] * x5 + v[35] * x6 + v[42] * x7;
156       s2 -= v[1] * x1 + v[8] * x2 + v[15] * x3 + v[22] * x4 + v[29] * x5 + v[36] * x6 + v[43] * x7;
157       s3 -= v[2] * x1 + v[9] * x2 + v[16] * x3 + v[23] * x4 + v[30] * x5 + v[37] * x6 + v[44] * x7;
158       s4 -= v[3] * x1 + v[10] * x2 + v[17] * x3 + v[24] * x4 + v[31] * x5 + v[38] * x6 + v[45] * x7;
159       s5 -= v[4] * x1 + v[11] * x2 + v[18] * x3 + v[25] * x4 + v[32] * x5 + v[39] * x6 + v[46] * x7;
160       s6 -= v[5] * x1 + v[12] * x2 + v[19] * x3 + v[26] * x4 + v[33] * x5 + v[40] * x6 + v[47] * x7;
161       s7 -= v[6] * x1 + v[13] * x2 + v[20] * x3 + v[27] * x4 + v[34] * x5 + v[41] * x6 + v[48] * x7;
162       v += 49;
163     }
164     idc    = 7 * (*c--);
165     v      = aa + 49 * diag[i];
166     x[idc] = t[idt] = v[0] * s1 + v[7] * s2 + v[14] * s3 + v[21] * s4 + v[28] * s5 + v[35] * s6 + v[42] * s7;
167     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[8] * s2 + v[15] * s3 + v[22] * s4 + v[29] * s5 + v[36] * s6 + v[43] * s7;
168     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[9] * s2 + v[16] * s3 + v[23] * s4 + v[30] * s5 + v[37] * s6 + v[44] * s7;
169     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[10] * s2 + v[17] * s3 + v[24] * s4 + v[31] * s5 + v[38] * s6 + v[45] * s7;
170     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[11] * s2 + v[18] * s3 + v[25] * s4 + v[32] * s5 + v[39] * s6 + v[46] * s7;
171     x[5 + idc] = t[5 + idt] = v[5] * s1 + v[12] * s2 + v[19] * s3 + v[26] * s4 + v[33] * s5 + v[40] * s6 + v[47] * s7;
172     x[6 + idc] = t[6 + idt] = v[6] * s1 + v[13] * s2 + v[20] * s3 + v[27] * s4 + v[34] * s5 + v[41] * s6 + v[48] * s7;
173   }
174 
175   PetscCall(ISRestoreIndices(isrow, &rout));
176   PetscCall(ISRestoreIndices(iscol, &cout));
177   PetscCall(VecRestoreArrayRead(bb, &b));
178   PetscCall(VecRestoreArray(xx, &x));
179   PetscCall(PetscLogFlops(2.0 * 49 * (a->nz) - 7.0 * A->cmap->n));
180   PetscFunctionReturn(PETSC_SUCCESS);
181 }
182 
MatSolve_SeqBAIJ_7(Mat A,Vec bb,Vec xx)183 PetscErrorCode MatSolve_SeqBAIJ_7(Mat A, Vec bb, Vec xx)
184 {
185   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
186   IS                 iscol = a->col, isrow = a->row;
187   const PetscInt    *r, *c, *ai = a->i, *aj = a->j, *adiag = a->diag;
188   const PetscInt     n = a->mbs, *rout, *cout, *vi;
189   PetscInt           i, nz, idx, idt, idc, m;
190   const MatScalar   *aa = a->a, *v;
191   PetscScalar        s1, s2, s3, s4, s5, s6, s7, x1, x2, x3, x4, x5, x6, x7, *x, *t;
192   const PetscScalar *b;
193 
194   PetscFunctionBegin;
195   PetscCall(VecGetArrayRead(bb, &b));
196   PetscCall(VecGetArray(xx, &x));
197   t = a->solve_work;
198 
199   PetscCall(ISGetIndices(isrow, &rout));
200   r = rout;
201   PetscCall(ISGetIndices(iscol, &cout));
202   c = cout;
203 
204   /* forward solve the lower triangular */
205   idx  = 7 * r[0];
206   t[0] = b[idx];
207   t[1] = b[1 + idx];
208   t[2] = b[2 + idx];
209   t[3] = b[3 + idx];
210   t[4] = b[4 + idx];
211   t[5] = b[5 + idx];
212   t[6] = b[6 + idx];
213 
214   for (i = 1; i < n; i++) {
215     v   = aa + 49 * ai[i];
216     vi  = aj + ai[i];
217     nz  = ai[i + 1] - ai[i];
218     idx = 7 * r[i];
219     s1  = b[idx];
220     s2  = b[1 + idx];
221     s3  = b[2 + idx];
222     s4  = b[3 + idx];
223     s5  = b[4 + idx];
224     s6  = b[5 + idx];
225     s7  = b[6 + idx];
226     for (m = 0; m < nz; m++) {
227       idx = 7 * vi[m];
228       x1  = t[idx];
229       x2  = t[1 + idx];
230       x3  = t[2 + idx];
231       x4  = t[3 + idx];
232       x5  = t[4 + idx];
233       x6  = t[5 + idx];
234       x7  = t[6 + idx];
235       s1 -= v[0] * x1 + v[7] * x2 + v[14] * x3 + v[21] * x4 + v[28] * x5 + v[35] * x6 + v[42] * x7;
236       s2 -= v[1] * x1 + v[8] * x2 + v[15] * x3 + v[22] * x4 + v[29] * x5 + v[36] * x6 + v[43] * x7;
237       s3 -= v[2] * x1 + v[9] * x2 + v[16] * x3 + v[23] * x4 + v[30] * x5 + v[37] * x6 + v[44] * x7;
238       s4 -= v[3] * x1 + v[10] * x2 + v[17] * x3 + v[24] * x4 + v[31] * x5 + v[38] * x6 + v[45] * x7;
239       s5 -= v[4] * x1 + v[11] * x2 + v[18] * x3 + v[25] * x4 + v[32] * x5 + v[39] * x6 + v[46] * x7;
240       s6 -= v[5] * x1 + v[12] * x2 + v[19] * x3 + v[26] * x4 + v[33] * x5 + v[40] * x6 + v[47] * x7;
241       s7 -= v[6] * x1 + v[13] * x2 + v[20] * x3 + v[27] * x4 + v[34] * x5 + v[41] * x6 + v[48] * x7;
242       v += 49;
243     }
244     idx        = 7 * i;
245     t[idx]     = s1;
246     t[1 + idx] = s2;
247     t[2 + idx] = s3;
248     t[3 + idx] = s4;
249     t[4 + idx] = s5;
250     t[5 + idx] = s6;
251     t[6 + idx] = s7;
252   }
253   /* backward solve the upper triangular */
254   for (i = n - 1; i >= 0; i--) {
255     v   = aa + 49 * (adiag[i + 1] + 1);
256     vi  = aj + adiag[i + 1] + 1;
257     nz  = adiag[i] - adiag[i + 1] - 1;
258     idt = 7 * i;
259     s1  = t[idt];
260     s2  = t[1 + idt];
261     s3  = t[2 + idt];
262     s4  = t[3 + idt];
263     s5  = t[4 + idt];
264     s6  = t[5 + idt];
265     s7  = t[6 + idt];
266     for (m = 0; m < nz; m++) {
267       idx = 7 * vi[m];
268       x1  = t[idx];
269       x2  = t[1 + idx];
270       x3  = t[2 + idx];
271       x4  = t[3 + idx];
272       x5  = t[4 + idx];
273       x6  = t[5 + idx];
274       x7  = t[6 + idx];
275       s1 -= v[0] * x1 + v[7] * x2 + v[14] * x3 + v[21] * x4 + v[28] * x5 + v[35] * x6 + v[42] * x7;
276       s2 -= v[1] * x1 + v[8] * x2 + v[15] * x3 + v[22] * x4 + v[29] * x5 + v[36] * x6 + v[43] * x7;
277       s3 -= v[2] * x1 + v[9] * x2 + v[16] * x3 + v[23] * x4 + v[30] * x5 + v[37] * x6 + v[44] * x7;
278       s4 -= v[3] * x1 + v[10] * x2 + v[17] * x3 + v[24] * x4 + v[31] * x5 + v[38] * x6 + v[45] * x7;
279       s5 -= v[4] * x1 + v[11] * x2 + v[18] * x3 + v[25] * x4 + v[32] * x5 + v[39] * x6 + v[46] * x7;
280       s6 -= v[5] * x1 + v[12] * x2 + v[19] * x3 + v[26] * x4 + v[33] * x5 + v[40] * x6 + v[47] * x7;
281       s7 -= v[6] * x1 + v[13] * x2 + v[20] * x3 + v[27] * x4 + v[34] * x5 + v[41] * x6 + v[48] * x7;
282       v += 49;
283     }
284     idc    = 7 * c[i];
285     x[idc] = t[idt] = v[0] * s1 + v[7] * s2 + v[14] * s3 + v[21] * s4 + v[28] * s5 + v[35] * s6 + v[42] * s7;
286     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[8] * s2 + v[15] * s3 + v[22] * s4 + v[29] * s5 + v[36] * s6 + v[43] * s7;
287     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[9] * s2 + v[16] * s3 + v[23] * s4 + v[30] * s5 + v[37] * s6 + v[44] * s7;
288     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[10] * s2 + v[17] * s3 + v[24] * s4 + v[31] * s5 + v[38] * s6 + v[45] * s7;
289     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[11] * s2 + v[18] * s3 + v[25] * s4 + v[32] * s5 + v[39] * s6 + v[46] * s7;
290     x[5 + idc] = t[5 + idt] = v[5] * s1 + v[12] * s2 + v[19] * s3 + v[26] * s4 + v[33] * s5 + v[40] * s6 + v[47] * s7;
291     x[6 + idc] = t[6 + idt] = v[6] * s1 + v[13] * s2 + v[20] * s3 + v[27] * s4 + v[34] * s5 + v[41] * s6 + v[48] * s7;
292   }
293 
294   PetscCall(ISRestoreIndices(isrow, &rout));
295   PetscCall(ISRestoreIndices(iscol, &cout));
296   PetscCall(VecRestoreArrayRead(bb, &b));
297   PetscCall(VecRestoreArray(xx, &x));
298   PetscCall(PetscLogFlops(2.0 * 49 * (a->nz) - 7.0 * A->cmap->n));
299   PetscFunctionReturn(PETSC_SUCCESS);
300 }
301 
MatSolve_SeqBAIJ_6_inplace(Mat A,Vec bb,Vec xx)302 PetscErrorCode MatSolve_SeqBAIJ_6_inplace(Mat A, Vec bb, Vec xx)
303 {
304   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
305   IS                 iscol = a->col, isrow = a->row;
306   const PetscInt    *r, *c, *rout, *cout;
307   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
308   PetscInt           i, nz, idx, idt, idc;
309   const MatScalar   *aa = a->a, *v;
310   PetscScalar       *x, s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *t;
311   const PetscScalar *b;
312 
313   PetscFunctionBegin;
314   PetscCall(VecGetArrayRead(bb, &b));
315   PetscCall(VecGetArray(xx, &x));
316   t = a->solve_work;
317 
318   PetscCall(ISGetIndices(isrow, &rout));
319   r = rout;
320   PetscCall(ISGetIndices(iscol, &cout));
321   c = cout + (n - 1);
322 
323   /* forward solve the lower triangular */
324   idx  = 6 * (*r++);
325   t[0] = b[idx];
326   t[1] = b[1 + idx];
327   t[2] = b[2 + idx];
328   t[3] = b[3 + idx];
329   t[4] = b[4 + idx];
330   t[5] = b[5 + idx];
331   for (i = 1; i < n; i++) {
332     v   = aa + 36 * ai[i];
333     vi  = aj + ai[i];
334     nz  = diag[i] - ai[i];
335     idx = 6 * (*r++);
336     s1  = b[idx];
337     s2  = b[1 + idx];
338     s3  = b[2 + idx];
339     s4  = b[3 + idx];
340     s5  = b[4 + idx];
341     s6  = b[5 + idx];
342     while (nz--) {
343       idx = 6 * (*vi++);
344       x1  = t[idx];
345       x2  = t[1 + idx];
346       x3  = t[2 + idx];
347       x4  = t[3 + idx];
348       x5  = t[4 + idx];
349       x6  = t[5 + idx];
350       s1 -= v[0] * x1 + v[6] * x2 + v[12] * x3 + v[18] * x4 + v[24] * x5 + v[30] * x6;
351       s2 -= v[1] * x1 + v[7] * x2 + v[13] * x3 + v[19] * x4 + v[25] * x5 + v[31] * x6;
352       s3 -= v[2] * x1 + v[8] * x2 + v[14] * x3 + v[20] * x4 + v[26] * x5 + v[32] * x6;
353       s4 -= v[3] * x1 + v[9] * x2 + v[15] * x3 + v[21] * x4 + v[27] * x5 + v[33] * x6;
354       s5 -= v[4] * x1 + v[10] * x2 + v[16] * x3 + v[22] * x4 + v[28] * x5 + v[34] * x6;
355       s6 -= v[5] * x1 + v[11] * x2 + v[17] * x3 + v[23] * x4 + v[29] * x5 + v[35] * x6;
356       v += 36;
357     }
358     idx        = 6 * i;
359     t[idx]     = s1;
360     t[1 + idx] = s2;
361     t[2 + idx] = s3;
362     t[3 + idx] = s4;
363     t[4 + idx] = s5;
364     t[5 + idx] = s6;
365   }
366   /* backward solve the upper triangular */
367   for (i = n - 1; i >= 0; i--) {
368     v   = aa + 36 * diag[i] + 36;
369     vi  = aj + diag[i] + 1;
370     nz  = ai[i + 1] - diag[i] - 1;
371     idt = 6 * i;
372     s1  = t[idt];
373     s2  = t[1 + idt];
374     s3  = t[2 + idt];
375     s4  = t[3 + idt];
376     s5  = t[4 + idt];
377     s6  = t[5 + idt];
378     while (nz--) {
379       idx = 6 * (*vi++);
380       x1  = t[idx];
381       x2  = t[1 + idx];
382       x3  = t[2 + idx];
383       x4  = t[3 + idx];
384       x5  = t[4 + idx];
385       x6  = t[5 + idx];
386       s1 -= v[0] * x1 + v[6] * x2 + v[12] * x3 + v[18] * x4 + v[24] * x5 + v[30] * x6;
387       s2 -= v[1] * x1 + v[7] * x2 + v[13] * x3 + v[19] * x4 + v[25] * x5 + v[31] * x6;
388       s3 -= v[2] * x1 + v[8] * x2 + v[14] * x3 + v[20] * x4 + v[26] * x5 + v[32] * x6;
389       s4 -= v[3] * x1 + v[9] * x2 + v[15] * x3 + v[21] * x4 + v[27] * x5 + v[33] * x6;
390       s5 -= v[4] * x1 + v[10] * x2 + v[16] * x3 + v[22] * x4 + v[28] * x5 + v[34] * x6;
391       s6 -= v[5] * x1 + v[11] * x2 + v[17] * x3 + v[23] * x4 + v[29] * x5 + v[35] * x6;
392       v += 36;
393     }
394     idc    = 6 * (*c--);
395     v      = aa + 36 * diag[i];
396     x[idc] = t[idt] = v[0] * s1 + v[6] * s2 + v[12] * s3 + v[18] * s4 + v[24] * s5 + v[30] * s6;
397     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[7] * s2 + v[13] * s3 + v[19] * s4 + v[25] * s5 + v[31] * s6;
398     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[8] * s2 + v[14] * s3 + v[20] * s4 + v[26] * s5 + v[32] * s6;
399     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[9] * s2 + v[15] * s3 + v[21] * s4 + v[27] * s5 + v[33] * s6;
400     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[10] * s2 + v[16] * s3 + v[22] * s4 + v[28] * s5 + v[34] * s6;
401     x[5 + idc] = t[5 + idt] = v[5] * s1 + v[11] * s2 + v[17] * s3 + v[23] * s4 + v[29] * s5 + v[35] * s6;
402   }
403 
404   PetscCall(ISRestoreIndices(isrow, &rout));
405   PetscCall(ISRestoreIndices(iscol, &cout));
406   PetscCall(VecRestoreArrayRead(bb, &b));
407   PetscCall(VecRestoreArray(xx, &x));
408   PetscCall(PetscLogFlops(2.0 * 36 * (a->nz) - 6.0 * A->cmap->n));
409   PetscFunctionReturn(PETSC_SUCCESS);
410 }
411 
MatSolve_SeqBAIJ_6(Mat A,Vec bb,Vec xx)412 PetscErrorCode MatSolve_SeqBAIJ_6(Mat A, Vec bb, Vec xx)
413 {
414   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
415   IS                 iscol = a->col, isrow = a->row;
416   const PetscInt    *r, *c, *rout, *cout;
417   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
418   PetscInt           i, nz, idx, idt, idc, m;
419   const MatScalar   *aa = a->a, *v;
420   PetscScalar       *x, s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *t;
421   const PetscScalar *b;
422 
423   PetscFunctionBegin;
424   PetscCall(VecGetArrayRead(bb, &b));
425   PetscCall(VecGetArray(xx, &x));
426   t = a->solve_work;
427 
428   PetscCall(ISGetIndices(isrow, &rout));
429   r = rout;
430   PetscCall(ISGetIndices(iscol, &cout));
431   c = cout;
432 
433   /* forward solve the lower triangular */
434   idx  = 6 * r[0];
435   t[0] = b[idx];
436   t[1] = b[1 + idx];
437   t[2] = b[2 + idx];
438   t[3] = b[3 + idx];
439   t[4] = b[4 + idx];
440   t[5] = b[5 + idx];
441   for (i = 1; i < n; i++) {
442     v   = aa + 36 * ai[i];
443     vi  = aj + ai[i];
444     nz  = ai[i + 1] - ai[i];
445     idx = 6 * r[i];
446     s1  = b[idx];
447     s2  = b[1 + idx];
448     s3  = b[2 + idx];
449     s4  = b[3 + idx];
450     s5  = b[4 + idx];
451     s6  = b[5 + idx];
452     for (m = 0; m < nz; m++) {
453       idx = 6 * vi[m];
454       x1  = t[idx];
455       x2  = t[1 + idx];
456       x3  = t[2 + idx];
457       x4  = t[3 + idx];
458       x5  = t[4 + idx];
459       x6  = t[5 + idx];
460       s1 -= v[0] * x1 + v[6] * x2 + v[12] * x3 + v[18] * x4 + v[24] * x5 + v[30] * x6;
461       s2 -= v[1] * x1 + v[7] * x2 + v[13] * x3 + v[19] * x4 + v[25] * x5 + v[31] * x6;
462       s3 -= v[2] * x1 + v[8] * x2 + v[14] * x3 + v[20] * x4 + v[26] * x5 + v[32] * x6;
463       s4 -= v[3] * x1 + v[9] * x2 + v[15] * x3 + v[21] * x4 + v[27] * x5 + v[33] * x6;
464       s5 -= v[4] * x1 + v[10] * x2 + v[16] * x3 + v[22] * x4 + v[28] * x5 + v[34] * x6;
465       s6 -= v[5] * x1 + v[11] * x2 + v[17] * x3 + v[23] * x4 + v[29] * x5 + v[35] * x6;
466       v += 36;
467     }
468     idx        = 6 * i;
469     t[idx]     = s1;
470     t[1 + idx] = s2;
471     t[2 + idx] = s3;
472     t[3 + idx] = s4;
473     t[4 + idx] = s5;
474     t[5 + idx] = s6;
475   }
476   /* backward solve the upper triangular */
477   for (i = n - 1; i >= 0; i--) {
478     v   = aa + 36 * (adiag[i + 1] + 1);
479     vi  = aj + adiag[i + 1] + 1;
480     nz  = adiag[i] - adiag[i + 1] - 1;
481     idt = 6 * i;
482     s1  = t[idt];
483     s2  = t[1 + idt];
484     s3  = t[2 + idt];
485     s4  = t[3 + idt];
486     s5  = t[4 + idt];
487     s6  = t[5 + idt];
488     for (m = 0; m < nz; m++) {
489       idx = 6 * vi[m];
490       x1  = t[idx];
491       x2  = t[1 + idx];
492       x3  = t[2 + idx];
493       x4  = t[3 + idx];
494       x5  = t[4 + idx];
495       x6  = t[5 + idx];
496       s1 -= v[0] * x1 + v[6] * x2 + v[12] * x3 + v[18] * x4 + v[24] * x5 + v[30] * x6;
497       s2 -= v[1] * x1 + v[7] * x2 + v[13] * x3 + v[19] * x4 + v[25] * x5 + v[31] * x6;
498       s3 -= v[2] * x1 + v[8] * x2 + v[14] * x3 + v[20] * x4 + v[26] * x5 + v[32] * x6;
499       s4 -= v[3] * x1 + v[9] * x2 + v[15] * x3 + v[21] * x4 + v[27] * x5 + v[33] * x6;
500       s5 -= v[4] * x1 + v[10] * x2 + v[16] * x3 + v[22] * x4 + v[28] * x5 + v[34] * x6;
501       s6 -= v[5] * x1 + v[11] * x2 + v[17] * x3 + v[23] * x4 + v[29] * x5 + v[35] * x6;
502       v += 36;
503     }
504     idc    = 6 * c[i];
505     x[idc] = t[idt] = v[0] * s1 + v[6] * s2 + v[12] * s3 + v[18] * s4 + v[24] * s5 + v[30] * s6;
506     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[7] * s2 + v[13] * s3 + v[19] * s4 + v[25] * s5 + v[31] * s6;
507     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[8] * s2 + v[14] * s3 + v[20] * s4 + v[26] * s5 + v[32] * s6;
508     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[9] * s2 + v[15] * s3 + v[21] * s4 + v[27] * s5 + v[33] * s6;
509     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[10] * s2 + v[16] * s3 + v[22] * s4 + v[28] * s5 + v[34] * s6;
510     x[5 + idc] = t[5 + idt] = v[5] * s1 + v[11] * s2 + v[17] * s3 + v[23] * s4 + v[29] * s5 + v[35] * s6;
511   }
512 
513   PetscCall(ISRestoreIndices(isrow, &rout));
514   PetscCall(ISRestoreIndices(iscol, &cout));
515   PetscCall(VecRestoreArrayRead(bb, &b));
516   PetscCall(VecRestoreArray(xx, &x));
517   PetscCall(PetscLogFlops(2.0 * 36 * (a->nz) - 6.0 * A->cmap->n));
518   PetscFunctionReturn(PETSC_SUCCESS);
519 }
520 
MatSolve_SeqBAIJ_5_inplace(Mat A,Vec bb,Vec xx)521 PetscErrorCode MatSolve_SeqBAIJ_5_inplace(Mat A, Vec bb, Vec xx)
522 {
523   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
524   IS                 iscol = a->col, isrow = a->row;
525   const PetscInt    *r, *c, *rout, *cout, *diag = a->diag;
526   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j;
527   PetscInt           i, nz, idx, idt, idc;
528   const MatScalar   *aa = a->a, *v;
529   PetscScalar       *x, s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *t;
530   const PetscScalar *b;
531 
532   PetscFunctionBegin;
533   PetscCall(VecGetArrayRead(bb, &b));
534   PetscCall(VecGetArray(xx, &x));
535   t = a->solve_work;
536 
537   PetscCall(ISGetIndices(isrow, &rout));
538   r = rout;
539   PetscCall(ISGetIndices(iscol, &cout));
540   c = cout + (n - 1);
541 
542   /* forward solve the lower triangular */
543   idx  = 5 * (*r++);
544   t[0] = b[idx];
545   t[1] = b[1 + idx];
546   t[2] = b[2 + idx];
547   t[3] = b[3 + idx];
548   t[4] = b[4 + idx];
549   for (i = 1; i < n; i++) {
550     v   = aa + 25 * ai[i];
551     vi  = aj + ai[i];
552     nz  = diag[i] - ai[i];
553     idx = 5 * (*r++);
554     s1  = b[idx];
555     s2  = b[1 + idx];
556     s3  = b[2 + idx];
557     s4  = b[3 + idx];
558     s5  = b[4 + idx];
559     while (nz--) {
560       idx = 5 * (*vi++);
561       x1  = t[idx];
562       x2  = t[1 + idx];
563       x3  = t[2 + idx];
564       x4  = t[3 + idx];
565       x5  = t[4 + idx];
566       s1 -= v[0] * x1 + v[5] * x2 + v[10] * x3 + v[15] * x4 + v[20] * x5;
567       s2 -= v[1] * x1 + v[6] * x2 + v[11] * x3 + v[16] * x4 + v[21] * x5;
568       s3 -= v[2] * x1 + v[7] * x2 + v[12] * x3 + v[17] * x4 + v[22] * x5;
569       s4 -= v[3] * x1 + v[8] * x2 + v[13] * x3 + v[18] * x4 + v[23] * x5;
570       s5 -= v[4] * x1 + v[9] * x2 + v[14] * x3 + v[19] * x4 + v[24] * x5;
571       v += 25;
572     }
573     idx        = 5 * i;
574     t[idx]     = s1;
575     t[1 + idx] = s2;
576     t[2 + idx] = s3;
577     t[3 + idx] = s4;
578     t[4 + idx] = s5;
579   }
580   /* backward solve the upper triangular */
581   for (i = n - 1; i >= 0; i--) {
582     v   = aa + 25 * diag[i] + 25;
583     vi  = aj + diag[i] + 1;
584     nz  = ai[i + 1] - diag[i] - 1;
585     idt = 5 * i;
586     s1  = t[idt];
587     s2  = t[1 + idt];
588     s3  = t[2 + idt];
589     s4  = t[3 + idt];
590     s5  = t[4 + idt];
591     while (nz--) {
592       idx = 5 * (*vi++);
593       x1  = t[idx];
594       x2  = t[1 + idx];
595       x3  = t[2 + idx];
596       x4  = t[3 + idx];
597       x5  = t[4 + idx];
598       s1 -= v[0] * x1 + v[5] * x2 + v[10] * x3 + v[15] * x4 + v[20] * x5;
599       s2 -= v[1] * x1 + v[6] * x2 + v[11] * x3 + v[16] * x4 + v[21] * x5;
600       s3 -= v[2] * x1 + v[7] * x2 + v[12] * x3 + v[17] * x4 + v[22] * x5;
601       s4 -= v[3] * x1 + v[8] * x2 + v[13] * x3 + v[18] * x4 + v[23] * x5;
602       s5 -= v[4] * x1 + v[9] * x2 + v[14] * x3 + v[19] * x4 + v[24] * x5;
603       v += 25;
604     }
605     idc    = 5 * (*c--);
606     v      = aa + 25 * diag[i];
607     x[idc] = t[idt] = v[0] * s1 + v[5] * s2 + v[10] * s3 + v[15] * s4 + v[20] * s5;
608     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[6] * s2 + v[11] * s3 + v[16] * s4 + v[21] * s5;
609     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[7] * s2 + v[12] * s3 + v[17] * s4 + v[22] * s5;
610     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[8] * s2 + v[13] * s3 + v[18] * s4 + v[23] * s5;
611     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[9] * s2 + v[14] * s3 + v[19] * s4 + v[24] * s5;
612   }
613 
614   PetscCall(ISRestoreIndices(isrow, &rout));
615   PetscCall(ISRestoreIndices(iscol, &cout));
616   PetscCall(VecRestoreArrayRead(bb, &b));
617   PetscCall(VecRestoreArray(xx, &x));
618   PetscCall(PetscLogFlops(2.0 * 25 * (a->nz) - 5.0 * A->cmap->n));
619   PetscFunctionReturn(PETSC_SUCCESS);
620 }
621 
MatSolve_SeqBAIJ_5(Mat A,Vec bb,Vec xx)622 PetscErrorCode MatSolve_SeqBAIJ_5(Mat A, Vec bb, Vec xx)
623 {
624   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
625   IS                 iscol = a->col, isrow = a->row;
626   const PetscInt    *r, *c, *rout, *cout;
627   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
628   PetscInt           i, nz, idx, idt, idc, m;
629   const MatScalar   *aa = a->a, *v;
630   PetscScalar       *x, s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *t;
631   const PetscScalar *b;
632 
633   PetscFunctionBegin;
634   PetscCall(VecGetArrayRead(bb, &b));
635   PetscCall(VecGetArray(xx, &x));
636   t = a->solve_work;
637 
638   PetscCall(ISGetIndices(isrow, &rout));
639   r = rout;
640   PetscCall(ISGetIndices(iscol, &cout));
641   c = cout;
642 
643   /* forward solve the lower triangular */
644   idx  = 5 * r[0];
645   t[0] = b[idx];
646   t[1] = b[1 + idx];
647   t[2] = b[2 + idx];
648   t[3] = b[3 + idx];
649   t[4] = b[4 + idx];
650   for (i = 1; i < n; i++) {
651     v   = aa + 25 * ai[i];
652     vi  = aj + ai[i];
653     nz  = ai[i + 1] - ai[i];
654     idx = 5 * r[i];
655     s1  = b[idx];
656     s2  = b[1 + idx];
657     s3  = b[2 + idx];
658     s4  = b[3 + idx];
659     s5  = b[4 + idx];
660     for (m = 0; m < nz; m++) {
661       idx = 5 * vi[m];
662       x1  = t[idx];
663       x2  = t[1 + idx];
664       x3  = t[2 + idx];
665       x4  = t[3 + idx];
666       x5  = t[4 + idx];
667       s1 -= v[0] * x1 + v[5] * x2 + v[10] * x3 + v[15] * x4 + v[20] * x5;
668       s2 -= v[1] * x1 + v[6] * x2 + v[11] * x3 + v[16] * x4 + v[21] * x5;
669       s3 -= v[2] * x1 + v[7] * x2 + v[12] * x3 + v[17] * x4 + v[22] * x5;
670       s4 -= v[3] * x1 + v[8] * x2 + v[13] * x3 + v[18] * x4 + v[23] * x5;
671       s5 -= v[4] * x1 + v[9] * x2 + v[14] * x3 + v[19] * x4 + v[24] * x5;
672       v += 25;
673     }
674     idx        = 5 * i;
675     t[idx]     = s1;
676     t[1 + idx] = s2;
677     t[2 + idx] = s3;
678     t[3 + idx] = s4;
679     t[4 + idx] = s5;
680   }
681   /* backward solve the upper triangular */
682   for (i = n - 1; i >= 0; i--) {
683     v   = aa + 25 * (adiag[i + 1] + 1);
684     vi  = aj + adiag[i + 1] + 1;
685     nz  = adiag[i] - adiag[i + 1] - 1;
686     idt = 5 * i;
687     s1  = t[idt];
688     s2  = t[1 + idt];
689     s3  = t[2 + idt];
690     s4  = t[3 + idt];
691     s5  = t[4 + idt];
692     for (m = 0; m < nz; m++) {
693       idx = 5 * vi[m];
694       x1  = t[idx];
695       x2  = t[1 + idx];
696       x3  = t[2 + idx];
697       x4  = t[3 + idx];
698       x5  = t[4 + idx];
699       s1 -= v[0] * x1 + v[5] * x2 + v[10] * x3 + v[15] * x4 + v[20] * x5;
700       s2 -= v[1] * x1 + v[6] * x2 + v[11] * x3 + v[16] * x4 + v[21] * x5;
701       s3 -= v[2] * x1 + v[7] * x2 + v[12] * x3 + v[17] * x4 + v[22] * x5;
702       s4 -= v[3] * x1 + v[8] * x2 + v[13] * x3 + v[18] * x4 + v[23] * x5;
703       s5 -= v[4] * x1 + v[9] * x2 + v[14] * x3 + v[19] * x4 + v[24] * x5;
704       v += 25;
705     }
706     idc    = 5 * c[i];
707     x[idc] = t[idt] = v[0] * s1 + v[5] * s2 + v[10] * s3 + v[15] * s4 + v[20] * s5;
708     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[6] * s2 + v[11] * s3 + v[16] * s4 + v[21] * s5;
709     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[7] * s2 + v[12] * s3 + v[17] * s4 + v[22] * s5;
710     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[8] * s2 + v[13] * s3 + v[18] * s4 + v[23] * s5;
711     x[4 + idc] = t[4 + idt] = v[4] * s1 + v[9] * s2 + v[14] * s3 + v[19] * s4 + v[24] * s5;
712   }
713 
714   PetscCall(ISRestoreIndices(isrow, &rout));
715   PetscCall(ISRestoreIndices(iscol, &cout));
716   PetscCall(VecRestoreArrayRead(bb, &b));
717   PetscCall(VecRestoreArray(xx, &x));
718   PetscCall(PetscLogFlops(2.0 * 25 * (a->nz) - 5.0 * A->cmap->n));
719   PetscFunctionReturn(PETSC_SUCCESS);
720 }
721 
MatSolve_SeqBAIJ_4_inplace(Mat A,Vec bb,Vec xx)722 PetscErrorCode MatSolve_SeqBAIJ_4_inplace(Mat A, Vec bb, Vec xx)
723 {
724   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
725   IS                 iscol = a->col, isrow = a->row;
726   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j;
727   PetscInt           i, nz, idx, idt, idc;
728   const PetscInt    *r, *c, *diag = a->diag, *rout, *cout;
729   const MatScalar   *aa = a->a, *v;
730   PetscScalar       *x, s1, s2, s3, s4, x1, x2, x3, x4, *t;
731   const PetscScalar *b;
732 
733   PetscFunctionBegin;
734   PetscCall(VecGetArrayRead(bb, &b));
735   PetscCall(VecGetArray(xx, &x));
736   t = a->solve_work;
737 
738   PetscCall(ISGetIndices(isrow, &rout));
739   r = rout;
740   PetscCall(ISGetIndices(iscol, &cout));
741   c = cout + (n - 1);
742 
743   /* forward solve the lower triangular */
744   idx  = 4 * (*r++);
745   t[0] = b[idx];
746   t[1] = b[1 + idx];
747   t[2] = b[2 + idx];
748   t[3] = b[3 + idx];
749   for (i = 1; i < n; i++) {
750     v   = aa + 16 * ai[i];
751     vi  = aj + ai[i];
752     nz  = diag[i] - ai[i];
753     idx = 4 * (*r++);
754     s1  = b[idx];
755     s2  = b[1 + idx];
756     s3  = b[2 + idx];
757     s4  = b[3 + idx];
758     while (nz--) {
759       idx = 4 * (*vi++);
760       x1  = t[idx];
761       x2  = t[1 + idx];
762       x3  = t[2 + idx];
763       x4  = t[3 + idx];
764       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
765       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
766       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
767       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
768       v += 16;
769     }
770     idx        = 4 * i;
771     t[idx]     = s1;
772     t[1 + idx] = s2;
773     t[2 + idx] = s3;
774     t[3 + idx] = s4;
775   }
776   /* backward solve the upper triangular */
777   for (i = n - 1; i >= 0; i--) {
778     v   = aa + 16 * diag[i] + 16;
779     vi  = aj + diag[i] + 1;
780     nz  = ai[i + 1] - diag[i] - 1;
781     idt = 4 * i;
782     s1  = t[idt];
783     s2  = t[1 + idt];
784     s3  = t[2 + idt];
785     s4  = t[3 + idt];
786     while (nz--) {
787       idx = 4 * (*vi++);
788       x1  = t[idx];
789       x2  = t[1 + idx];
790       x3  = t[2 + idx];
791       x4  = t[3 + idx];
792       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
793       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
794       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
795       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
796       v += 16;
797     }
798     idc    = 4 * (*c--);
799     v      = aa + 16 * diag[i];
800     x[idc] = t[idt] = v[0] * s1 + v[4] * s2 + v[8] * s3 + v[12] * s4;
801     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[5] * s2 + v[9] * s3 + v[13] * s4;
802     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[6] * s2 + v[10] * s3 + v[14] * s4;
803     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[7] * s2 + v[11] * s3 + v[15] * s4;
804   }
805 
806   PetscCall(ISRestoreIndices(isrow, &rout));
807   PetscCall(ISRestoreIndices(iscol, &cout));
808   PetscCall(VecRestoreArrayRead(bb, &b));
809   PetscCall(VecRestoreArray(xx, &x));
810   PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n));
811   PetscFunctionReturn(PETSC_SUCCESS);
812 }
813 
MatSolve_SeqBAIJ_4(Mat A,Vec bb,Vec xx)814 PetscErrorCode MatSolve_SeqBAIJ_4(Mat A, Vec bb, Vec xx)
815 {
816   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
817   IS                 iscol = a->col, isrow = a->row;
818   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
819   PetscInt           i, nz, idx, idt, idc, m;
820   const PetscInt    *r, *c, *rout, *cout;
821   const MatScalar   *aa = a->a, *v;
822   PetscScalar       *x, s1, s2, s3, s4, x1, x2, x3, x4, *t;
823   const PetscScalar *b;
824 
825   PetscFunctionBegin;
826   PetscCall(VecGetArrayRead(bb, &b));
827   PetscCall(VecGetArray(xx, &x));
828   t = a->solve_work;
829 
830   PetscCall(ISGetIndices(isrow, &rout));
831   r = rout;
832   PetscCall(ISGetIndices(iscol, &cout));
833   c = cout;
834 
835   /* forward solve the lower triangular */
836   idx  = 4 * r[0];
837   t[0] = b[idx];
838   t[1] = b[1 + idx];
839   t[2] = b[2 + idx];
840   t[3] = b[3 + idx];
841   for (i = 1; i < n; i++) {
842     v   = aa + 16 * ai[i];
843     vi  = aj + ai[i];
844     nz  = ai[i + 1] - ai[i];
845     idx = 4 * r[i];
846     s1  = b[idx];
847     s2  = b[1 + idx];
848     s3  = b[2 + idx];
849     s4  = b[3 + idx];
850     for (m = 0; m < nz; m++) {
851       idx = 4 * vi[m];
852       x1  = t[idx];
853       x2  = t[1 + idx];
854       x3  = t[2 + idx];
855       x4  = t[3 + idx];
856       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
857       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
858       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
859       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
860       v += 16;
861     }
862     idx        = 4 * i;
863     t[idx]     = s1;
864     t[1 + idx] = s2;
865     t[2 + idx] = s3;
866     t[3 + idx] = s4;
867   }
868   /* backward solve the upper triangular */
869   for (i = n - 1; i >= 0; i--) {
870     v   = aa + 16 * (adiag[i + 1] + 1);
871     vi  = aj + adiag[i + 1] + 1;
872     nz  = adiag[i] - adiag[i + 1] - 1;
873     idt = 4 * i;
874     s1  = t[idt];
875     s2  = t[1 + idt];
876     s3  = t[2 + idt];
877     s4  = t[3 + idt];
878     for (m = 0; m < nz; m++) {
879       idx = 4 * vi[m];
880       x1  = t[idx];
881       x2  = t[1 + idx];
882       x3  = t[2 + idx];
883       x4  = t[3 + idx];
884       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
885       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
886       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
887       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
888       v += 16;
889     }
890     idc    = 4 * c[i];
891     x[idc] = t[idt] = v[0] * s1 + v[4] * s2 + v[8] * s3 + v[12] * s4;
892     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[5] * s2 + v[9] * s3 + v[13] * s4;
893     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[6] * s2 + v[10] * s3 + v[14] * s4;
894     x[3 + idc] = t[3 + idt] = v[3] * s1 + v[7] * s2 + v[11] * s3 + v[15] * s4;
895   }
896 
897   PetscCall(ISRestoreIndices(isrow, &rout));
898   PetscCall(ISRestoreIndices(iscol, &cout));
899   PetscCall(VecRestoreArrayRead(bb, &b));
900   PetscCall(VecRestoreArray(xx, &x));
901   PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n));
902   PetscFunctionReturn(PETSC_SUCCESS);
903 }
904 
MatSolve_SeqBAIJ_3_inplace(Mat A,Vec bb,Vec xx)905 PetscErrorCode MatSolve_SeqBAIJ_3_inplace(Mat A, Vec bb, Vec xx)
906 {
907   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
908   IS                 iscol = a->col, isrow = a->row;
909   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j;
910   PetscInt           i, nz, idx, idt, idc;
911   const PetscInt    *r, *c, *diag = a->diag, *rout, *cout;
912   const MatScalar   *aa = a->a, *v;
913   PetscScalar       *x, s1, s2, s3, x1, x2, x3, *t;
914   const PetscScalar *b;
915 
916   PetscFunctionBegin;
917   PetscCall(VecGetArrayRead(bb, &b));
918   PetscCall(VecGetArray(xx, &x));
919   t = a->solve_work;
920 
921   PetscCall(ISGetIndices(isrow, &rout));
922   r = rout;
923   PetscCall(ISGetIndices(iscol, &cout));
924   c = cout + (n - 1);
925 
926   /* forward solve the lower triangular */
927   idx  = 3 * (*r++);
928   t[0] = b[idx];
929   t[1] = b[1 + idx];
930   t[2] = b[2 + idx];
931   for (i = 1; i < n; i++) {
932     v   = aa + 9 * ai[i];
933     vi  = aj + ai[i];
934     nz  = diag[i] - ai[i];
935     idx = 3 * (*r++);
936     s1  = b[idx];
937     s2  = b[1 + idx];
938     s3  = b[2 + idx];
939     while (nz--) {
940       idx = 3 * (*vi++);
941       x1  = t[idx];
942       x2  = t[1 + idx];
943       x3  = t[2 + idx];
944       s1 -= v[0] * x1 + v[3] * x2 + v[6] * x3;
945       s2 -= v[1] * x1 + v[4] * x2 + v[7] * x3;
946       s3 -= v[2] * x1 + v[5] * x2 + v[8] * x3;
947       v += 9;
948     }
949     idx        = 3 * i;
950     t[idx]     = s1;
951     t[1 + idx] = s2;
952     t[2 + idx] = s3;
953   }
954   /* backward solve the upper triangular */
955   for (i = n - 1; i >= 0; i--) {
956     v   = aa + 9 * diag[i] + 9;
957     vi  = aj + diag[i] + 1;
958     nz  = ai[i + 1] - diag[i] - 1;
959     idt = 3 * i;
960     s1  = t[idt];
961     s2  = t[1 + idt];
962     s3  = t[2 + idt];
963     while (nz--) {
964       idx = 3 * (*vi++);
965       x1  = t[idx];
966       x2  = t[1 + idx];
967       x3  = t[2 + idx];
968       s1 -= v[0] * x1 + v[3] * x2 + v[6] * x3;
969       s2 -= v[1] * x1 + v[4] * x2 + v[7] * x3;
970       s3 -= v[2] * x1 + v[5] * x2 + v[8] * x3;
971       v += 9;
972     }
973     idc    = 3 * (*c--);
974     v      = aa + 9 * diag[i];
975     x[idc] = t[idt] = v[0] * s1 + v[3] * s2 + v[6] * s3;
976     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[4] * s2 + v[7] * s3;
977     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[5] * s2 + v[8] * s3;
978   }
979   PetscCall(ISRestoreIndices(isrow, &rout));
980   PetscCall(ISRestoreIndices(iscol, &cout));
981   PetscCall(VecRestoreArrayRead(bb, &b));
982   PetscCall(VecRestoreArray(xx, &x));
983   PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n));
984   PetscFunctionReturn(PETSC_SUCCESS);
985 }
986 
MatSolve_SeqBAIJ_3(Mat A,Vec bb,Vec xx)987 PetscErrorCode MatSolve_SeqBAIJ_3(Mat A, Vec bb, Vec xx)
988 {
989   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
990   IS                 iscol = a->col, isrow = a->row;
991   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
992   PetscInt           i, nz, idx, idt, idc, m;
993   const PetscInt    *r, *c, *rout, *cout;
994   const MatScalar   *aa = a->a, *v;
995   PetscScalar       *x, s1, s2, s3, x1, x2, x3, *t;
996   const PetscScalar *b;
997 
998   PetscFunctionBegin;
999   PetscCall(VecGetArrayRead(bb, &b));
1000   PetscCall(VecGetArray(xx, &x));
1001   t = a->solve_work;
1002 
1003   PetscCall(ISGetIndices(isrow, &rout));
1004   r = rout;
1005   PetscCall(ISGetIndices(iscol, &cout));
1006   c = cout;
1007 
1008   /* forward solve the lower triangular */
1009   idx  = 3 * r[0];
1010   t[0] = b[idx];
1011   t[1] = b[1 + idx];
1012   t[2] = b[2 + idx];
1013   for (i = 1; i < n; i++) {
1014     v   = aa + 9 * ai[i];
1015     vi  = aj + ai[i];
1016     nz  = ai[i + 1] - ai[i];
1017     idx = 3 * r[i];
1018     s1  = b[idx];
1019     s2  = b[1 + idx];
1020     s3  = b[2 + idx];
1021     for (m = 0; m < nz; m++) {
1022       idx = 3 * vi[m];
1023       x1  = t[idx];
1024       x2  = t[1 + idx];
1025       x3  = t[2 + idx];
1026       s1 -= v[0] * x1 + v[3] * x2 + v[6] * x3;
1027       s2 -= v[1] * x1 + v[4] * x2 + v[7] * x3;
1028       s3 -= v[2] * x1 + v[5] * x2 + v[8] * x3;
1029       v += 9;
1030     }
1031     idx        = 3 * i;
1032     t[idx]     = s1;
1033     t[1 + idx] = s2;
1034     t[2 + idx] = s3;
1035   }
1036   /* backward solve the upper triangular */
1037   for (i = n - 1; i >= 0; i--) {
1038     v   = aa + 9 * (adiag[i + 1] + 1);
1039     vi  = aj + adiag[i + 1] + 1;
1040     nz  = adiag[i] - adiag[i + 1] - 1;
1041     idt = 3 * i;
1042     s1  = t[idt];
1043     s2  = t[1 + idt];
1044     s3  = t[2 + idt];
1045     for (m = 0; m < nz; m++) {
1046       idx = 3 * vi[m];
1047       x1  = t[idx];
1048       x2  = t[1 + idx];
1049       x3  = t[2 + idx];
1050       s1 -= v[0] * x1 + v[3] * x2 + v[6] * x3;
1051       s2 -= v[1] * x1 + v[4] * x2 + v[7] * x3;
1052       s3 -= v[2] * x1 + v[5] * x2 + v[8] * x3;
1053       v += 9;
1054     }
1055     idc    = 3 * c[i];
1056     x[idc] = t[idt] = v[0] * s1 + v[3] * s2 + v[6] * s3;
1057     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[4] * s2 + v[7] * s3;
1058     x[2 + idc] = t[2 + idt] = v[2] * s1 + v[5] * s2 + v[8] * s3;
1059   }
1060   PetscCall(ISRestoreIndices(isrow, &rout));
1061   PetscCall(ISRestoreIndices(iscol, &cout));
1062   PetscCall(VecRestoreArrayRead(bb, &b));
1063   PetscCall(VecRestoreArray(xx, &x));
1064   PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n));
1065   PetscFunctionReturn(PETSC_SUCCESS);
1066 }
1067 
MatSolve_SeqBAIJ_2_inplace(Mat A,Vec bb,Vec xx)1068 PetscErrorCode MatSolve_SeqBAIJ_2_inplace(Mat A, Vec bb, Vec xx)
1069 {
1070   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
1071   IS                 iscol = a->col, isrow = a->row;
1072   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j;
1073   PetscInt           i, nz, idx, idt, idc;
1074   const PetscInt    *r, *c, *diag = a->diag, *rout, *cout;
1075   const MatScalar   *aa = a->a, *v;
1076   PetscScalar       *x, s1, s2, x1, x2, *t;
1077   const PetscScalar *b;
1078 
1079   PetscFunctionBegin;
1080   PetscCall(VecGetArrayRead(bb, &b));
1081   PetscCall(VecGetArray(xx, &x));
1082   t = a->solve_work;
1083 
1084   PetscCall(ISGetIndices(isrow, &rout));
1085   r = rout;
1086   PetscCall(ISGetIndices(iscol, &cout));
1087   c = cout + (n - 1);
1088 
1089   /* forward solve the lower triangular */
1090   idx  = 2 * (*r++);
1091   t[0] = b[idx];
1092   t[1] = b[1 + idx];
1093   for (i = 1; i < n; i++) {
1094     v   = aa + 4 * ai[i];
1095     vi  = aj + ai[i];
1096     nz  = diag[i] - ai[i];
1097     idx = 2 * (*r++);
1098     s1  = b[idx];
1099     s2  = b[1 + idx];
1100     while (nz--) {
1101       idx = 2 * (*vi++);
1102       x1  = t[idx];
1103       x2  = t[1 + idx];
1104       s1 -= v[0] * x1 + v[2] * x2;
1105       s2 -= v[1] * x1 + v[3] * x2;
1106       v += 4;
1107     }
1108     idx        = 2 * i;
1109     t[idx]     = s1;
1110     t[1 + idx] = s2;
1111   }
1112   /* backward solve the upper triangular */
1113   for (i = n - 1; i >= 0; i--) {
1114     v   = aa + 4 * diag[i] + 4;
1115     vi  = aj + diag[i] + 1;
1116     nz  = ai[i + 1] - diag[i] - 1;
1117     idt = 2 * i;
1118     s1  = t[idt];
1119     s2  = t[1 + idt];
1120     while (nz--) {
1121       idx = 2 * (*vi++);
1122       x1  = t[idx];
1123       x2  = t[1 + idx];
1124       s1 -= v[0] * x1 + v[2] * x2;
1125       s2 -= v[1] * x1 + v[3] * x2;
1126       v += 4;
1127     }
1128     idc    = 2 * (*c--);
1129     v      = aa + 4 * diag[i];
1130     x[idc] = t[idt] = v[0] * s1 + v[2] * s2;
1131     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[3] * s2;
1132   }
1133   PetscCall(ISRestoreIndices(isrow, &rout));
1134   PetscCall(ISRestoreIndices(iscol, &cout));
1135   PetscCall(VecRestoreArrayRead(bb, &b));
1136   PetscCall(VecRestoreArray(xx, &x));
1137   PetscCall(PetscLogFlops(2.0 * 4 * (a->nz) - 2.0 * A->cmap->n));
1138   PetscFunctionReturn(PETSC_SUCCESS);
1139 }
1140 
MatSolve_SeqBAIJ_2(Mat A,Vec bb,Vec xx)1141 PetscErrorCode MatSolve_SeqBAIJ_2(Mat A, Vec bb, Vec xx)
1142 {
1143   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
1144   IS                 iscol = a->col, isrow = a->row;
1145   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
1146   PetscInt           i, nz, idx, jdx, idt, idc, m;
1147   const PetscInt    *r, *c, *rout, *cout;
1148   const MatScalar   *aa = a->a, *v;
1149   PetscScalar       *x, s1, s2, x1, x2, *t;
1150   const PetscScalar *b;
1151 
1152   PetscFunctionBegin;
1153   PetscCall(VecGetArrayRead(bb, &b));
1154   PetscCall(VecGetArray(xx, &x));
1155   t = a->solve_work;
1156 
1157   PetscCall(ISGetIndices(isrow, &rout));
1158   r = rout;
1159   PetscCall(ISGetIndices(iscol, &cout));
1160   c = cout;
1161 
1162   /* forward solve the lower triangular */
1163   idx  = 2 * r[0];
1164   t[0] = b[idx];
1165   t[1] = b[1 + idx];
1166   for (i = 1; i < n; i++) {
1167     v   = aa + 4 * ai[i];
1168     vi  = aj + ai[i];
1169     nz  = ai[i + 1] - ai[i];
1170     idx = 2 * r[i];
1171     s1  = b[idx];
1172     s2  = b[1 + idx];
1173     for (m = 0; m < nz; m++) {
1174       jdx = 2 * vi[m];
1175       x1  = t[jdx];
1176       x2  = t[1 + jdx];
1177       s1 -= v[0] * x1 + v[2] * x2;
1178       s2 -= v[1] * x1 + v[3] * x2;
1179       v += 4;
1180     }
1181     idx        = 2 * i;
1182     t[idx]     = s1;
1183     t[1 + idx] = s2;
1184   }
1185   /* backward solve the upper triangular */
1186   for (i = n - 1; i >= 0; i--) {
1187     v   = aa + 4 * (adiag[i + 1] + 1);
1188     vi  = aj + adiag[i + 1] + 1;
1189     nz  = adiag[i] - adiag[i + 1] - 1;
1190     idt = 2 * i;
1191     s1  = t[idt];
1192     s2  = t[1 + idt];
1193     for (m = 0; m < nz; m++) {
1194       idx = 2 * vi[m];
1195       x1  = t[idx];
1196       x2  = t[1 + idx];
1197       s1 -= v[0] * x1 + v[2] * x2;
1198       s2 -= v[1] * x1 + v[3] * x2;
1199       v += 4;
1200     }
1201     idc    = 2 * c[i];
1202     x[idc] = t[idt] = v[0] * s1 + v[2] * s2;
1203     x[1 + idc] = t[1 + idt] = v[1] * s1 + v[3] * s2;
1204   }
1205   PetscCall(ISRestoreIndices(isrow, &rout));
1206   PetscCall(ISRestoreIndices(iscol, &cout));
1207   PetscCall(VecRestoreArrayRead(bb, &b));
1208   PetscCall(VecRestoreArray(xx, &x));
1209   PetscCall(PetscLogFlops(2.0 * 4 * (a->nz) - 2.0 * A->cmap->n));
1210   PetscFunctionReturn(PETSC_SUCCESS);
1211 }
1212 
MatSolve_SeqBAIJ_1_inplace(Mat A,Vec bb,Vec xx)1213 PetscErrorCode MatSolve_SeqBAIJ_1_inplace(Mat A, Vec bb, Vec xx)
1214 {
1215   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
1216   IS                 iscol = a->col, isrow = a->row;
1217   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j;
1218   PetscInt           i, nz;
1219   const PetscInt    *r, *c, *diag = a->diag, *rout, *cout;
1220   const MatScalar   *aa = a->a, *v;
1221   PetscScalar       *x, s1, *t;
1222   const PetscScalar *b;
1223 
1224   PetscFunctionBegin;
1225   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
1226 
1227   PetscCall(VecGetArrayRead(bb, &b));
1228   PetscCall(VecGetArray(xx, &x));
1229   t = a->solve_work;
1230 
1231   PetscCall(ISGetIndices(isrow, &rout));
1232   r = rout;
1233   PetscCall(ISGetIndices(iscol, &cout));
1234   c = cout + (n - 1);
1235 
1236   /* forward solve the lower triangular */
1237   t[0] = b[*r++];
1238   for (i = 1; i < n; i++) {
1239     v  = aa + ai[i];
1240     vi = aj + ai[i];
1241     nz = diag[i] - ai[i];
1242     s1 = b[*r++];
1243     while (nz--) s1 -= (*v++) * t[*vi++];
1244     t[i] = s1;
1245   }
1246   /* backward solve the upper triangular */
1247   for (i = n - 1; i >= 0; i--) {
1248     v  = aa + diag[i] + 1;
1249     vi = aj + diag[i] + 1;
1250     nz = ai[i + 1] - diag[i] - 1;
1251     s1 = t[i];
1252     while (nz--) s1 -= (*v++) * t[*vi++];
1253     x[*c--] = t[i] = aa[diag[i]] * s1;
1254   }
1255 
1256   PetscCall(ISRestoreIndices(isrow, &rout));
1257   PetscCall(ISRestoreIndices(iscol, &cout));
1258   PetscCall(VecRestoreArrayRead(bb, &b));
1259   PetscCall(VecRestoreArray(xx, &x));
1260   PetscCall(PetscLogFlops(2.0 * 1 * (a->nz) - A->cmap->n));
1261   PetscFunctionReturn(PETSC_SUCCESS);
1262 }
1263 
MatSolve_SeqBAIJ_1(Mat A,Vec bb,Vec xx)1264 PetscErrorCode MatSolve_SeqBAIJ_1(Mat A, Vec bb, Vec xx)
1265 {
1266   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
1267   IS                 iscol = a->col, isrow = a->row;
1268   PetscInt           i, n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag, nz;
1269   const PetscInt    *rout, *cout, *r, *c;
1270   PetscScalar       *x, *tmp, sum;
1271   const PetscScalar *b;
1272   const MatScalar   *aa = a->a, *v;
1273 
1274   PetscFunctionBegin;
1275   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
1276 
1277   PetscCall(VecGetArrayRead(bb, &b));
1278   PetscCall(VecGetArray(xx, &x));
1279   tmp = a->solve_work;
1280 
1281   PetscCall(ISGetIndices(isrow, &rout));
1282   r = rout;
1283   PetscCall(ISGetIndices(iscol, &cout));
1284   c = cout;
1285 
1286   /* forward solve the lower triangular */
1287   tmp[0] = b[r[0]];
1288   v      = aa;
1289   vi     = aj;
1290   for (i = 1; i < n; i++) {
1291     nz  = ai[i + 1] - ai[i];
1292     sum = b[r[i]];
1293     PetscSparseDenseMinusDot(sum, tmp, v, vi, nz);
1294     tmp[i] = sum;
1295     v += nz;
1296     vi += nz;
1297   }
1298 
1299   /* backward solve the upper triangular */
1300   for (i = n - 1; i >= 0; i--) {
1301     v   = aa + adiag[i + 1] + 1;
1302     vi  = aj + adiag[i + 1] + 1;
1303     nz  = adiag[i] - adiag[i + 1] - 1;
1304     sum = tmp[i];
1305     PetscSparseDenseMinusDot(sum, tmp, v, vi, nz);
1306     x[c[i]] = tmp[i] = sum * v[nz]; /* v[nz] = aa[adiag[i]] */
1307   }
1308 
1309   PetscCall(ISRestoreIndices(isrow, &rout));
1310   PetscCall(ISRestoreIndices(iscol, &cout));
1311   PetscCall(VecRestoreArrayRead(bb, &b));
1312   PetscCall(VecRestoreArray(xx, &x));
1313   PetscCall(PetscLogFlops(2.0 * a->nz - A->cmap->n));
1314   PetscFunctionReturn(PETSC_SUCCESS);
1315 }
1316