xref: /petsc/src/mat/impls/baij/seq/baijsolvtran2.c (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A, Vec bb, Vec xx) {
5   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
6   IS                 iscol = a->col, isrow = a->row;
7   const PetscInt    *r, *c, *rout, *cout;
8   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
9   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
10   const MatScalar   *aa = a->a, *v;
11   PetscScalar        s1, s2, x1, x2, *x, *t;
12   const PetscScalar *b;
13 
14   PetscFunctionBegin;
15   PetscCall(VecGetArrayRead(bb, &b));
16   PetscCall(VecGetArray(xx, &x));
17   t = a->solve_work;
18 
19   PetscCall(ISGetIndices(isrow, &rout));
20   r = rout;
21   PetscCall(ISGetIndices(iscol, &cout));
22   c = cout;
23 
24   /* copy the b into temp work space according to permutation */
25   ii = 0;
26   for (i = 0; i < n; i++) {
27     ic        = 2 * c[i];
28     t[ii]     = b[ic];
29     t[ii + 1] = b[ic + 1];
30     ii += 2;
31   }
32 
33   /* forward solve the U^T */
34   idx = 0;
35   for (i = 0; i < n; i++) {
36     v  = aa + 4 * diag[i];
37     /* multiply by the inverse of the block diagonal */
38     x1 = t[idx];
39     x2 = t[1 + idx];
40     s1 = v[0] * x1 + v[1] * x2;
41     s2 = v[2] * x1 + v[3] * x2;
42     v += 4;
43 
44     vi = aj + diag[i] + 1;
45     nz = ai[i + 1] - diag[i] - 1;
46     while (nz--) {
47       oidx = 2 * (*vi++);
48       t[oidx] -= v[0] * s1 + v[1] * s2;
49       t[oidx + 1] -= v[2] * s1 + v[3] * s2;
50       v += 4;
51     }
52     t[idx]     = s1;
53     t[1 + idx] = s2;
54     idx += 2;
55   }
56   /* backward solve the L^T */
57   for (i = n - 1; i >= 0; i--) {
58     v   = aa + 4 * diag[i] - 4;
59     vi  = aj + diag[i] - 1;
60     nz  = diag[i] - ai[i];
61     idt = 2 * i;
62     s1  = t[idt];
63     s2  = t[1 + idt];
64     while (nz--) {
65       idx = 2 * (*vi--);
66       t[idx] -= v[0] * s1 + v[1] * s2;
67       t[idx + 1] -= v[2] * s1 + v[3] * s2;
68       v -= 4;
69     }
70   }
71 
72   /* copy t into x according to permutation */
73   ii = 0;
74   for (i = 0; i < n; i++) {
75     ir        = 2 * r[i];
76     x[ir]     = t[ii];
77     x[ir + 1] = t[ii + 1];
78     ii += 2;
79   }
80 
81   PetscCall(ISRestoreIndices(isrow, &rout));
82   PetscCall(ISRestoreIndices(iscol, &cout));
83   PetscCall(VecRestoreArrayRead(bb, &b));
84   PetscCall(VecRestoreArray(xx, &x));
85   PetscCall(PetscLogFlops(2.0 * 4 * (a->nz) - 2.0 * A->cmap->n));
86   PetscFunctionReturn(0);
87 }
88 
89 PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A, Vec bb, Vec xx) {
90   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
91   IS                 iscol = a->col, isrow = a->row;
92   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
93   const PetscInt    *r, *c, *rout, *cout;
94   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
95   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
96   const MatScalar   *aa = a->a, *v;
97   PetscScalar        s1, s2, x1, x2, *x, *t;
98   const PetscScalar *b;
99 
100   PetscFunctionBegin;
101   PetscCall(VecGetArrayRead(bb, &b));
102   PetscCall(VecGetArray(xx, &x));
103   t = a->solve_work;
104 
105   PetscCall(ISGetIndices(isrow, &rout));
106   r = rout;
107   PetscCall(ISGetIndices(iscol, &cout));
108   c = cout;
109 
110   /* copy b into temp work space according to permutation */
111   for (i = 0; i < n; i++) {
112     ii        = bs * i;
113     ic        = bs * c[i];
114     t[ii]     = b[ic];
115     t[ii + 1] = b[ic + 1];
116   }
117 
118   /* forward solve the U^T */
119   idx = 0;
120   for (i = 0; i < n; i++) {
121     v  = aa + bs2 * diag[i];
122     /* multiply by the inverse of the block diagonal */
123     x1 = t[idx];
124     x2 = t[1 + idx];
125     s1 = v[0] * x1 + v[1] * x2;
126     s2 = v[2] * x1 + v[3] * x2;
127     v -= bs2;
128 
129     vi = aj + diag[i] - 1;
130     nz = diag[i] - diag[i + 1] - 1;
131     for (j = 0; j > -nz; j--) {
132       oidx = bs * vi[j];
133       t[oidx] -= v[0] * s1 + v[1] * s2;
134       t[oidx + 1] -= v[2] * s1 + v[3] * s2;
135       v -= bs2;
136     }
137     t[idx]     = s1;
138     t[1 + idx] = s2;
139     idx += bs;
140   }
141   /* backward solve the L^T */
142   for (i = n - 1; i >= 0; i--) {
143     v   = aa + bs2 * ai[i];
144     vi  = aj + ai[i];
145     nz  = ai[i + 1] - ai[i];
146     idt = bs * i;
147     s1  = t[idt];
148     s2  = t[1 + idt];
149     for (j = 0; j < nz; j++) {
150       idx = bs * vi[j];
151       t[idx] -= v[0] * s1 + v[1] * s2;
152       t[idx + 1] -= v[2] * s1 + v[3] * s2;
153       v += bs2;
154     }
155   }
156 
157   /* copy t into x according to permutation */
158   for (i = 0; i < n; i++) {
159     ii        = bs * i;
160     ir        = bs * r[i];
161     x[ir]     = t[ii];
162     x[ir + 1] = t[ii + 1];
163   }
164 
165   PetscCall(ISRestoreIndices(isrow, &rout));
166   PetscCall(ISRestoreIndices(iscol, &cout));
167   PetscCall(VecRestoreArrayRead(bb, &b));
168   PetscCall(VecRestoreArray(xx, &x));
169   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
170   PetscFunctionReturn(0);
171 }
172