1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3
MatSolveTranspose_SeqBAIJ_1(Mat A,Vec bb,Vec xx)4 PetscErrorCode MatSolveTranspose_SeqBAIJ_1(Mat A, Vec bb, Vec xx)
5 {
6 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
7 IS iscol = a->col, isrow = a->row;
8 const PetscInt *rout, *cout, *r, *c, *adiag = a->diag, *ai = a->i, *aj = a->j, *vi;
9 PetscInt i, n = a->mbs, j;
10 PetscInt nz;
11 PetscScalar *x, *tmp, s1;
12 const MatScalar *aa = a->a, *v;
13 const PetscScalar *b;
14
15 PetscFunctionBegin;
16 PetscCall(VecGetArrayRead(bb, &b));
17 PetscCall(VecGetArray(xx, &x));
18 tmp = a->solve_work;
19
20 PetscCall(ISGetIndices(isrow, &rout));
21 r = rout;
22 PetscCall(ISGetIndices(iscol, &cout));
23 c = cout;
24
25 /* copy the b into temp work space according to permutation */
26 for (i = 0; i < n; i++) tmp[i] = b[c[i]];
27
28 /* forward solve the U^T */
29 for (i = 0; i < n; i++) {
30 v = aa + adiag[i + 1] + 1;
31 vi = aj + adiag[i + 1] + 1;
32 nz = adiag[i] - adiag[i + 1] - 1;
33 s1 = tmp[i];
34 s1 *= v[nz]; /* multiply by inverse of diagonal entry */
35 for (j = 0; j < nz; j++) tmp[vi[j]] -= s1 * v[j];
36 tmp[i] = s1;
37 }
38
39 /* backward solve the L^T */
40 for (i = n - 1; i >= 0; i--) {
41 v = aa + ai[i];
42 vi = aj + ai[i];
43 nz = ai[i + 1] - ai[i];
44 s1 = tmp[i];
45 for (j = 0; j < nz; j++) tmp[vi[j]] -= s1 * v[j];
46 }
47
48 /* copy tmp into x according to permutation */
49 for (i = 0; i < n; i++) x[r[i]] = tmp[i];
50
51 PetscCall(ISRestoreIndices(isrow, &rout));
52 PetscCall(ISRestoreIndices(iscol, &cout));
53 PetscCall(VecRestoreArrayRead(bb, &b));
54 PetscCall(VecRestoreArray(xx, &x));
55
56 PetscCall(PetscLogFlops(2.0 * a->nz - A->cmap->n));
57 PetscFunctionReturn(PETSC_SUCCESS);
58 }
59
MatSolveTranspose_SeqBAIJ_1_inplace(Mat A,Vec bb,Vec xx)60 PetscErrorCode MatSolveTranspose_SeqBAIJ_1_inplace(Mat A, Vec bb, Vec xx)
61 {
62 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
63 IS iscol = a->col, isrow = a->row;
64 const PetscInt *r, *c, *rout, *cout;
65 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
66 PetscInt i, nz;
67 const MatScalar *aa = a->a, *v;
68 PetscScalar s1, *x, *t;
69 const PetscScalar *b;
70
71 PetscFunctionBegin;
72 PetscCall(VecGetArrayRead(bb, &b));
73 PetscCall(VecGetArray(xx, &x));
74 t = a->solve_work;
75
76 PetscCall(ISGetIndices(isrow, &rout));
77 r = rout;
78 PetscCall(ISGetIndices(iscol, &cout));
79 c = cout;
80
81 /* copy the b into temp work space according to permutation */
82 for (i = 0; i < n; i++) t[i] = b[c[i]];
83
84 /* forward solve the U^T */
85 for (i = 0; i < n; i++) {
86 v = aa + diag[i];
87 /* multiply by the inverse of the block diagonal */
88 s1 = (*v++) * t[i];
89 vi = aj + diag[i] + 1;
90 nz = ai[i + 1] - diag[i] - 1;
91 while (nz--) t[*vi++] -= (*v++) * s1;
92 t[i] = s1;
93 }
94 /* backward solve the L^T */
95 for (i = n - 1; i >= 0; i--) {
96 v = aa + diag[i] - 1;
97 vi = aj + diag[i] - 1;
98 nz = diag[i] - ai[i];
99 s1 = t[i];
100 while (nz--) t[*vi--] -= (*v--) * s1;
101 }
102
103 /* copy t into x according to permutation */
104 for (i = 0; i < n; i++) x[r[i]] = t[i];
105
106 PetscCall(ISRestoreIndices(isrow, &rout));
107 PetscCall(ISRestoreIndices(iscol, &cout));
108 PetscCall(VecRestoreArrayRead(bb, &b));
109 PetscCall(VecRestoreArray(xx, &x));
110 PetscCall(PetscLogFlops(2.0 * (a->nz) - A->cmap->n));
111 PetscFunctionReturn(PETSC_SUCCESS);
112 }
113