xref: /petsc/src/mat/impls/baij/seq/baijsolvtran5.c (revision 31d78bcd2b98084dc1368b20eb1129c8b9fb39fe)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
MatSolveTranspose_SeqBAIJ_5_inplace(Mat A,Vec bb,Vec xx)4 PetscErrorCode MatSolveTranspose_SeqBAIJ_5_inplace(Mat A, Vec bb, Vec xx)
5 {
6   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
7   IS                 iscol = a->col, isrow = a->row;
8   const PetscInt    *r, *c, *rout, *cout;
9   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
10   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
11   const MatScalar   *aa = a->a, *v;
12   PetscScalar        s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *x, *t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb, &b));
17   PetscCall(VecGetArray(xx, &x));
18   t = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow, &rout));
21   r = rout;
22   PetscCall(ISGetIndices(iscol, &cout));
23   c = cout;
24 
25   /* copy the b into temp work space according to permutation */
26   ii = 0;
27   for (i = 0; i < n; i++) {
28     ic        = 5 * c[i];
29     t[ii]     = b[ic];
30     t[ii + 1] = b[ic + 1];
31     t[ii + 2] = b[ic + 2];
32     t[ii + 3] = b[ic + 3];
33     t[ii + 4] = b[ic + 4];
34     ii += 5;
35   }
36 
37   /* forward solve the U^T */
38   idx = 0;
39   for (i = 0; i < n; i++) {
40     v = aa + 25 * diag[i];
41     /* multiply by the inverse of the block diagonal */
42     x1 = t[idx];
43     x2 = t[1 + idx];
44     x3 = t[2 + idx];
45     x4 = t[3 + idx];
46     x5 = t[4 + idx];
47     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5;
48     s2 = v[5] * x1 + v[6] * x2 + v[7] * x3 + v[8] * x4 + v[9] * x5;
49     s3 = v[10] * x1 + v[11] * x2 + v[12] * x3 + v[13] * x4 + v[14] * x5;
50     s4 = v[15] * x1 + v[16] * x2 + v[17] * x3 + v[18] * x4 + v[19] * x5;
51     s5 = v[20] * x1 + v[21] * x2 + v[22] * x3 + v[23] * x4 + v[24] * x5;
52     v += 25;
53 
54     vi = aj + diag[i] + 1;
55     nz = ai[i + 1] - diag[i] - 1;
56     while (nz--) {
57       oidx = 5 * (*vi++);
58       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5;
59       t[oidx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5;
60       t[oidx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5;
61       t[oidx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5;
62       t[oidx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5;
63       v += 25;
64     }
65     t[idx]     = s1;
66     t[1 + idx] = s2;
67     t[2 + idx] = s3;
68     t[3 + idx] = s4;
69     t[4 + idx] = s5;
70     idx += 5;
71   }
72   /* backward solve the L^T */
73   for (i = n - 1; i >= 0; i--) {
74     v   = aa + 25 * diag[i] - 25;
75     vi  = aj + diag[i] - 1;
76     nz  = diag[i] - ai[i];
77     idt = 5 * i;
78     s1  = t[idt];
79     s2  = t[1 + idt];
80     s3  = t[2 + idt];
81     s4  = t[3 + idt];
82     s5  = t[4 + idt];
83     while (nz--) {
84       idx = 5 * (*vi--);
85       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5;
86       t[idx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5;
87       t[idx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5;
88       t[idx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5;
89       t[idx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5;
90       v -= 25;
91     }
92   }
93 
94   /* copy t into x according to permutation */
95   ii = 0;
96   for (i = 0; i < n; i++) {
97     ir        = 5 * r[i];
98     x[ir]     = t[ii];
99     x[ir + 1] = t[ii + 1];
100     x[ir + 2] = t[ii + 2];
101     x[ir + 3] = t[ii + 3];
102     x[ir + 4] = t[ii + 4];
103     ii += 5;
104   }
105 
106   PetscCall(ISRestoreIndices(isrow, &rout));
107   PetscCall(ISRestoreIndices(iscol, &cout));
108   PetscCall(VecRestoreArrayRead(bb, &b));
109   PetscCall(VecRestoreArray(xx, &x));
110   PetscCall(PetscLogFlops(2.0 * 25 * (a->nz) - 5.0 * A->cmap->n));
111   PetscFunctionReturn(PETSC_SUCCESS);
112 }
113 
MatSolveTranspose_SeqBAIJ_5(Mat A,Vec bb,Vec xx)114 PetscErrorCode MatSolveTranspose_SeqBAIJ_5(Mat A, Vec bb, Vec xx)
115 {
116   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
117   IS                 iscol = a->col, isrow = a->row;
118   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
119   const PetscInt    *r, *c, *rout, *cout;
120   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
121   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
122   const MatScalar   *aa = a->a, *v;
123   PetscScalar        s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *x, *t;
124   const PetscScalar *b;
125 
126   PetscFunctionBegin;
127   PetscCall(VecGetArrayRead(bb, &b));
128   PetscCall(VecGetArray(xx, &x));
129   t = a->solve_work;
130 
131   PetscCall(ISGetIndices(isrow, &rout));
132   r = rout;
133   PetscCall(ISGetIndices(iscol, &cout));
134   c = cout;
135 
136   /* copy b into temp work space according to permutation */
137   for (i = 0; i < n; i++) {
138     ii        = bs * i;
139     ic        = bs * c[i];
140     t[ii]     = b[ic];
141     t[ii + 1] = b[ic + 1];
142     t[ii + 2] = b[ic + 2];
143     t[ii + 3] = b[ic + 3];
144     t[ii + 4] = b[ic + 4];
145   }
146 
147   /* forward solve the U^T */
148   idx = 0;
149   for (i = 0; i < n; i++) {
150     v = aa + bs2 * diag[i];
151     /* multiply by the inverse of the block diagonal */
152     x1 = t[idx];
153     x2 = t[1 + idx];
154     x3 = t[2 + idx];
155     x4 = t[3 + idx];
156     x5 = t[4 + idx];
157     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5;
158     s2 = v[5] * x1 + v[6] * x2 + v[7] * x3 + v[8] * x4 + v[9] * x5;
159     s3 = v[10] * x1 + v[11] * x2 + v[12] * x3 + v[13] * x4 + v[14] * x5;
160     s4 = v[15] * x1 + v[16] * x2 + v[17] * x3 + v[18] * x4 + v[19] * x5;
161     s5 = v[20] * x1 + v[21] * x2 + v[22] * x3 + v[23] * x4 + v[24] * x5;
162     v -= bs2;
163 
164     vi = aj + diag[i] - 1;
165     nz = diag[i] - diag[i + 1] - 1;
166     for (j = 0; j > -nz; j--) {
167       oidx = bs * vi[j];
168       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5;
169       t[oidx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5;
170       t[oidx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5;
171       t[oidx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5;
172       t[oidx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5;
173       v -= bs2;
174     }
175     t[idx]     = s1;
176     t[1 + idx] = s2;
177     t[2 + idx] = s3;
178     t[3 + idx] = s4;
179     t[4 + idx] = s5;
180     idx += bs;
181   }
182   /* backward solve the L^T */
183   for (i = n - 1; i >= 0; i--) {
184     v   = aa + bs2 * ai[i];
185     vi  = aj + ai[i];
186     nz  = ai[i + 1] - ai[i];
187     idt = bs * i;
188     s1  = t[idt];
189     s2  = t[1 + idt];
190     s3  = t[2 + idt];
191     s4  = t[3 + idt];
192     s5  = t[4 + idt];
193     for (j = 0; j < nz; j++) {
194       idx = bs * vi[j];
195       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5;
196       t[idx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5;
197       t[idx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5;
198       t[idx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5;
199       t[idx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5;
200       v += bs2;
201     }
202   }
203 
204   /* copy t into x according to permutation */
205   for (i = 0; i < n; i++) {
206     ii        = bs * i;
207     ir        = bs * r[i];
208     x[ir]     = t[ii];
209     x[ir + 1] = t[ii + 1];
210     x[ir + 2] = t[ii + 2];
211     x[ir + 3] = t[ii + 3];
212     x[ir + 4] = t[ii + 4];
213   }
214 
215   PetscCall(ISRestoreIndices(isrow, &rout));
216   PetscCall(ISRestoreIndices(iscol, &cout));
217   PetscCall(VecRestoreArrayRead(bb, &b));
218   PetscCall(VecRestoreArray(xx, &x));
219   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
220   PetscFunctionReturn(PETSC_SUCCESS);
221 }
222