xref: /petsc/src/mat/impls/baij/seq/baijsolvtran3.c (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_3_inplace(Mat A, Vec bb, Vec xx) {
5   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
6   IS                 iscol = a->col, isrow = a->row;
7   const PetscInt    *r, *c, *rout, *cout;
8   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
9   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
10   const MatScalar   *aa = a->a, *v;
11   PetscScalar        s1, s2, s3, x1, x2, x3, *x, *t;
12   const PetscScalar *b;
13 
14   PetscFunctionBegin;
15   PetscCall(VecGetArrayRead(bb, &b));
16   PetscCall(VecGetArray(xx, &x));
17   t = a->solve_work;
18 
19   PetscCall(ISGetIndices(isrow, &rout));
20   r = rout;
21   PetscCall(ISGetIndices(iscol, &cout));
22   c = cout;
23 
24   /* copy the b into temp work space according to permutation */
25   ii = 0;
26   for (i = 0; i < n; i++) {
27     ic        = 3 * c[i];
28     t[ii]     = b[ic];
29     t[ii + 1] = b[ic + 1];
30     t[ii + 2] = b[ic + 2];
31     ii += 3;
32   }
33 
34   /* forward solve the U^T */
35   idx = 0;
36   for (i = 0; i < n; i++) {
37     v  = aa + 9 * diag[i];
38     /* multiply by the inverse of the block diagonal */
39     x1 = t[idx];
40     x2 = t[1 + idx];
41     x3 = t[2 + idx];
42     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3;
43     s2 = v[3] * x1 + v[4] * x2 + v[5] * x3;
44     s3 = v[6] * x1 + v[7] * x2 + v[8] * x3;
45     v += 9;
46 
47     vi = aj + diag[i] + 1;
48     nz = ai[i + 1] - diag[i] - 1;
49     while (nz--) {
50       oidx = 3 * (*vi++);
51       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
52       t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
53       t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
54       v += 9;
55     }
56     t[idx]     = s1;
57     t[1 + idx] = s2;
58     t[2 + idx] = s3;
59     idx += 3;
60   }
61   /* backward solve the L^T */
62   for (i = n - 1; i >= 0; i--) {
63     v   = aa + 9 * diag[i] - 9;
64     vi  = aj + diag[i] - 1;
65     nz  = diag[i] - ai[i];
66     idt = 3 * i;
67     s1  = t[idt];
68     s2  = t[1 + idt];
69     s3  = t[2 + idt];
70     while (nz--) {
71       idx = 3 * (*vi--);
72       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
73       t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
74       t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
75       v -= 9;
76     }
77   }
78 
79   /* copy t into x according to permutation */
80   ii = 0;
81   for (i = 0; i < n; i++) {
82     ir        = 3 * r[i];
83     x[ir]     = t[ii];
84     x[ir + 1] = t[ii + 1];
85     x[ir + 2] = t[ii + 2];
86     ii += 3;
87   }
88 
89   PetscCall(ISRestoreIndices(isrow, &rout));
90   PetscCall(ISRestoreIndices(iscol, &cout));
91   PetscCall(VecRestoreArrayRead(bb, &b));
92   PetscCall(VecRestoreArray(xx, &x));
93   PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n));
94   PetscFunctionReturn(0);
95 }
96 
97 PetscErrorCode MatSolveTranspose_SeqBAIJ_3(Mat A, Vec bb, Vec xx) {
98   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
99   IS                 iscol = a->col, isrow = a->row;
100   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
101   const PetscInt    *r, *c, *rout, *cout;
102   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
103   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
104   const MatScalar   *aa = a->a, *v;
105   PetscScalar        s1, s2, s3, x1, x2, x3, *x, *t;
106   const PetscScalar *b;
107 
108   PetscFunctionBegin;
109   PetscCall(VecGetArrayRead(bb, &b));
110   PetscCall(VecGetArray(xx, &x));
111   t = a->solve_work;
112 
113   PetscCall(ISGetIndices(isrow, &rout));
114   r = rout;
115   PetscCall(ISGetIndices(iscol, &cout));
116   c = cout;
117 
118   /* copy b into temp work space according to permutation */
119   for (i = 0; i < n; i++) {
120     ii        = bs * i;
121     ic        = bs * c[i];
122     t[ii]     = b[ic];
123     t[ii + 1] = b[ic + 1];
124     t[ii + 2] = b[ic + 2];
125   }
126 
127   /* forward solve the U^T */
128   idx = 0;
129   for (i = 0; i < n; i++) {
130     v  = aa + bs2 * diag[i];
131     /* multiply by the inverse of the block diagonal */
132     x1 = t[idx];
133     x2 = t[1 + idx];
134     x3 = t[2 + idx];
135     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3;
136     s2 = v[3] * x1 + v[4] * x2 + v[5] * x3;
137     s3 = v[6] * x1 + v[7] * x2 + v[8] * x3;
138     v -= bs2;
139 
140     vi = aj + diag[i] - 1;
141     nz = diag[i] - diag[i + 1] - 1;
142     for (j = 0; j > -nz; j--) {
143       oidx = bs * vi[j];
144       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
145       t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
146       t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
147       v -= bs2;
148     }
149     t[idx]     = s1;
150     t[1 + idx] = s2;
151     t[2 + idx] = s3;
152     idx += bs;
153   }
154   /* backward solve the L^T */
155   for (i = n - 1; i >= 0; i--) {
156     v   = aa + bs2 * ai[i];
157     vi  = aj + ai[i];
158     nz  = ai[i + 1] - ai[i];
159     idt = bs * i;
160     s1  = t[idt];
161     s2  = t[1 + idt];
162     s3  = t[2 + idt];
163     for (j = 0; j < nz; j++) {
164       idx = bs * vi[j];
165       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
166       t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
167       t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
168       v += bs2;
169     }
170   }
171 
172   /* copy t into x according to permutation */
173   for (i = 0; i < n; i++) {
174     ii        = bs * i;
175     ir        = bs * r[i];
176     x[ir]     = t[ii];
177     x[ir + 1] = t[ii + 1];
178     x[ir + 2] = t[ii + 2];
179   }
180 
181   PetscCall(ISRestoreIndices(isrow, &rout));
182   PetscCall(ISRestoreIndices(iscol, &cout));
183   PetscCall(VecRestoreArrayRead(bb, &b));
184   PetscCall(VecRestoreArray(xx, &x));
185   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
186   PetscFunctionReturn(0);
187 }
188