xref: /petsc/src/mat/impls/baij/seq/baijsolvtran4.c (revision 31d78bcd2b98084dc1368b20eb1129c8b9fb39fe)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
MatSolveTranspose_SeqBAIJ_4_inplace(Mat A,Vec bb,Vec xx)4 PetscErrorCode MatSolveTranspose_SeqBAIJ_4_inplace(Mat A, Vec bb, Vec xx)
5 {
6   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
7   IS                 iscol = a->col, isrow = a->row;
8   const PetscInt    *r, *c, *rout, *cout;
9   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
10   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
11   const MatScalar   *aa = a->a, *v;
12   PetscScalar        s1, s2, s3, s4, x1, x2, x3, x4, *x, *t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb, &b));
17   PetscCall(VecGetArray(xx, &x));
18   t = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow, &rout));
21   r = rout;
22   PetscCall(ISGetIndices(iscol, &cout));
23   c = cout;
24 
25   /* copy the b into temp work space according to permutation */
26   ii = 0;
27   for (i = 0; i < n; i++) {
28     ic        = 4 * c[i];
29     t[ii]     = b[ic];
30     t[ii + 1] = b[ic + 1];
31     t[ii + 2] = b[ic + 2];
32     t[ii + 3] = b[ic + 3];
33     ii += 4;
34   }
35 
36   /* forward solve the U^T */
37   idx = 0;
38   for (i = 0; i < n; i++) {
39     v = aa + 16 * diag[i];
40     /* multiply by the inverse of the block diagonal */
41     x1 = t[idx];
42     x2 = t[1 + idx];
43     x3 = t[2 + idx];
44     x4 = t[3 + idx];
45     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4;
46     s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4;
47     s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4;
48     s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4;
49     v += 16;
50 
51     vi = aj + diag[i] + 1;
52     nz = ai[i + 1] - diag[i] - 1;
53     while (nz--) {
54       oidx = 4 * (*vi++);
55       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4;
56       t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4;
57       t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4;
58       t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4;
59       v += 16;
60     }
61     t[idx]     = s1;
62     t[1 + idx] = s2;
63     t[2 + idx] = s3;
64     t[3 + idx] = s4;
65     idx += 4;
66   }
67   /* backward solve the L^T */
68   for (i = n - 1; i >= 0; i--) {
69     v   = aa + 16 * diag[i] - 16;
70     vi  = aj + diag[i] - 1;
71     nz  = diag[i] - ai[i];
72     idt = 4 * i;
73     s1  = t[idt];
74     s2  = t[1 + idt];
75     s3  = t[2 + idt];
76     s4  = t[3 + idt];
77     while (nz--) {
78       idx = 4 * (*vi--);
79       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4;
80       t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4;
81       t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4;
82       t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4;
83       v -= 16;
84     }
85   }
86 
87   /* copy t into x according to permutation */
88   ii = 0;
89   for (i = 0; i < n; i++) {
90     ir        = 4 * r[i];
91     x[ir]     = t[ii];
92     x[ir + 1] = t[ii + 1];
93     x[ir + 2] = t[ii + 2];
94     x[ir + 3] = t[ii + 3];
95     ii += 4;
96   }
97 
98   PetscCall(ISRestoreIndices(isrow, &rout));
99   PetscCall(ISRestoreIndices(iscol, &cout));
100   PetscCall(VecRestoreArrayRead(bb, &b));
101   PetscCall(VecRestoreArray(xx, &x));
102   PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n));
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105 
MatSolveTranspose_SeqBAIJ_4(Mat A,Vec bb,Vec xx)106 PetscErrorCode MatSolveTranspose_SeqBAIJ_4(Mat A, Vec bb, Vec xx)
107 {
108   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
109   IS                 iscol = a->col, isrow = a->row;
110   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
111   const PetscInt    *r, *c, *rout, *cout;
112   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
113   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
114   const MatScalar   *aa = a->a, *v;
115   PetscScalar        s1, s2, s3, s4, x1, x2, x3, x4, *x, *t;
116   const PetscScalar *b;
117 
118   PetscFunctionBegin;
119   PetscCall(VecGetArrayRead(bb, &b));
120   PetscCall(VecGetArray(xx, &x));
121   t = a->solve_work;
122 
123   PetscCall(ISGetIndices(isrow, &rout));
124   r = rout;
125   PetscCall(ISGetIndices(iscol, &cout));
126   c = cout;
127 
128   /* copy b into temp work space according to permutation */
129   for (i = 0; i < n; i++) {
130     ii        = bs * i;
131     ic        = bs * c[i];
132     t[ii]     = b[ic];
133     t[ii + 1] = b[ic + 1];
134     t[ii + 2] = b[ic + 2];
135     t[ii + 3] = b[ic + 3];
136   }
137 
138   /* forward solve the U^T */
139   idx = 0;
140   for (i = 0; i < n; i++) {
141     v = aa + bs2 * diag[i];
142     /* multiply by the inverse of the block diagonal */
143     x1 = t[idx];
144     x2 = t[1 + idx];
145     x3 = t[2 + idx];
146     x4 = t[3 + idx];
147     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4;
148     s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4;
149     s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4;
150     s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4;
151     v -= bs2;
152 
153     vi = aj + diag[i] - 1;
154     nz = diag[i] - diag[i + 1] - 1;
155     for (j = 0; j > -nz; j--) {
156       oidx = bs * vi[j];
157       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4;
158       t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4;
159       t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4;
160       t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4;
161       v -= bs2;
162     }
163     t[idx]     = s1;
164     t[1 + idx] = s2;
165     t[2 + idx] = s3;
166     t[3 + idx] = s4;
167     idx += bs;
168   }
169   /* backward solve the L^T */
170   for (i = n - 1; i >= 0; i--) {
171     v   = aa + bs2 * ai[i];
172     vi  = aj + ai[i];
173     nz  = ai[i + 1] - ai[i];
174     idt = bs * i;
175     s1  = t[idt];
176     s2  = t[1 + idt];
177     s3  = t[2 + idt];
178     s4  = t[3 + idt];
179     for (j = 0; j < nz; j++) {
180       idx = bs * vi[j];
181       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4;
182       t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4;
183       t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4;
184       t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4;
185       v += bs2;
186     }
187   }
188 
189   /* copy t into x according to permutation */
190   for (i = 0; i < n; i++) {
191     ii        = bs * i;
192     ir        = bs * r[i];
193     x[ir]     = t[ii];
194     x[ir + 1] = t[ii + 1];
195     x[ir + 2] = t[ii + 2];
196     x[ir + 3] = t[ii + 3];
197   }
198 
199   PetscCall(ISRestoreIndices(isrow, &rout));
200   PetscCall(ISRestoreIndices(iscol, &cout));
201   PetscCall(VecRestoreArrayRead(bb, &b));
202   PetscCall(VecRestoreArray(xx, &x));
203   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
204   PetscFunctionReturn(PETSC_SUCCESS);
205 }
206