xref: /petsc/src/mat/impls/baij/seq/baijsolvtran2.c (revision 31d78bcd2b98084dc1368b20eb1129c8b9fb39fe)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
MatSolveTranspose_SeqBAIJ_2_inplace(Mat A,Vec bb,Vec xx)4 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A, Vec bb, Vec xx)
5 {
6   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
7   IS                 iscol = a->col, isrow = a->row;
8   const PetscInt    *r, *c, *rout, *cout;
9   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
10   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
11   const MatScalar   *aa = a->a, *v;
12   PetscScalar        s1, s2, x1, x2, *x, *t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb, &b));
17   PetscCall(VecGetArray(xx, &x));
18   t = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow, &rout));
21   r = rout;
22   PetscCall(ISGetIndices(iscol, &cout));
23   c = cout;
24 
25   /* copy the b into temp work space according to permutation */
26   ii = 0;
27   for (i = 0; i < n; i++) {
28     ic        = 2 * c[i];
29     t[ii]     = b[ic];
30     t[ii + 1] = b[ic + 1];
31     ii += 2;
32   }
33 
34   /* forward solve the U^T */
35   idx = 0;
36   for (i = 0; i < n; i++) {
37     v = aa + 4 * diag[i];
38     /* multiply by the inverse of the block diagonal */
39     x1 = t[idx];
40     x2 = t[1 + idx];
41     s1 = v[0] * x1 + v[1] * x2;
42     s2 = v[2] * x1 + v[3] * x2;
43     v += 4;
44 
45     vi = aj + diag[i] + 1;
46     nz = ai[i + 1] - diag[i] - 1;
47     while (nz--) {
48       oidx = 2 * (*vi++);
49       t[oidx] -= v[0] * s1 + v[1] * s2;
50       t[oidx + 1] -= v[2] * s1 + v[3] * s2;
51       v += 4;
52     }
53     t[idx]     = s1;
54     t[1 + idx] = s2;
55     idx += 2;
56   }
57   /* backward solve the L^T */
58   for (i = n - 1; i >= 0; i--) {
59     v   = aa + 4 * diag[i] - 4;
60     vi  = aj + diag[i] - 1;
61     nz  = diag[i] - ai[i];
62     idt = 2 * i;
63     s1  = t[idt];
64     s2  = t[1 + idt];
65     while (nz--) {
66       idx = 2 * (*vi--);
67       t[idx] -= v[0] * s1 + v[1] * s2;
68       t[idx + 1] -= v[2] * s1 + v[3] * s2;
69       v -= 4;
70     }
71   }
72 
73   /* copy t into x according to permutation */
74   ii = 0;
75   for (i = 0; i < n; i++) {
76     ir        = 2 * r[i];
77     x[ir]     = t[ii];
78     x[ir + 1] = t[ii + 1];
79     ii += 2;
80   }
81 
82   PetscCall(ISRestoreIndices(isrow, &rout));
83   PetscCall(ISRestoreIndices(iscol, &cout));
84   PetscCall(VecRestoreArrayRead(bb, &b));
85   PetscCall(VecRestoreArray(xx, &x));
86   PetscCall(PetscLogFlops(2.0 * 4 * (a->nz) - 2.0 * A->cmap->n));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
MatSolveTranspose_SeqBAIJ_2(Mat A,Vec bb,Vec xx)90 PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A, Vec bb, Vec xx)
91 {
92   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
93   IS                 iscol = a->col, isrow = a->row;
94   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
95   const PetscInt    *r, *c, *rout, *cout;
96   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
97   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
98   const MatScalar   *aa = a->a, *v;
99   PetscScalar        s1, s2, x1, x2, *x, *t;
100   const PetscScalar *b;
101 
102   PetscFunctionBegin;
103   PetscCall(VecGetArrayRead(bb, &b));
104   PetscCall(VecGetArray(xx, &x));
105   t = a->solve_work;
106 
107   PetscCall(ISGetIndices(isrow, &rout));
108   r = rout;
109   PetscCall(ISGetIndices(iscol, &cout));
110   c = cout;
111 
112   /* copy b into temp work space according to permutation */
113   for (i = 0; i < n; i++) {
114     ii        = bs * i;
115     ic        = bs * c[i];
116     t[ii]     = b[ic];
117     t[ii + 1] = b[ic + 1];
118   }
119 
120   /* forward solve the U^T */
121   idx = 0;
122   for (i = 0; i < n; i++) {
123     v = aa + bs2 * diag[i];
124     /* multiply by the inverse of the block diagonal */
125     x1 = t[idx];
126     x2 = t[1 + idx];
127     s1 = v[0] * x1 + v[1] * x2;
128     s2 = v[2] * x1 + v[3] * x2;
129     v -= bs2;
130 
131     vi = aj + diag[i] - 1;
132     nz = diag[i] - diag[i + 1] - 1;
133     for (j = 0; j > -nz; j--) {
134       oidx = bs * vi[j];
135       t[oidx] -= v[0] * s1 + v[1] * s2;
136       t[oidx + 1] -= v[2] * s1 + v[3] * s2;
137       v -= bs2;
138     }
139     t[idx]     = s1;
140     t[1 + idx] = s2;
141     idx += bs;
142   }
143   /* backward solve the L^T */
144   for (i = n - 1; i >= 0; i--) {
145     v   = aa + bs2 * ai[i];
146     vi  = aj + ai[i];
147     nz  = ai[i + 1] - ai[i];
148     idt = bs * i;
149     s1  = t[idt];
150     s2  = t[1 + idt];
151     for (j = 0; j < nz; j++) {
152       idx = bs * vi[j];
153       t[idx] -= v[0] * s1 + v[1] * s2;
154       t[idx + 1] -= v[2] * s1 + v[3] * s2;
155       v += bs2;
156     }
157   }
158 
159   /* copy t into x according to permutation */
160   for (i = 0; i < n; i++) {
161     ii        = bs * i;
162     ir        = bs * r[i];
163     x[ir]     = t[ii];
164     x[ir + 1] = t[ii + 1];
165   }
166 
167   PetscCall(ISRestoreIndices(isrow, &rout));
168   PetscCall(ISRestoreIndices(iscol, &cout));
169   PetscCall(VecRestoreArrayRead(bb, &b));
170   PetscCall(VecRestoreArray(xx, &x));
171   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174