xref: /petsc/src/mat/impls/baij/seq/baijsolvnat4.c (revision 174dc0c8cee294b82b85e4dd3b331b29396264fc)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 /*
5       Special case where the matrix was ILU(0) factored in the natural
6    ordering. This eliminates the need for the column and row permutation.
7 */
8 PetscErrorCode MatSolve_SeqBAIJ_4_NaturalOrdering_inplace(Mat A, Vec bb, Vec xx)
9 {
10   Mat_SeqBAIJ       *a  = (Mat_SeqBAIJ *)A->data;
11   PetscInt           n  = a->mbs;
12   const PetscInt    *ai = a->i, *aj = a->j;
13   const PetscInt    *diag = a->diag;
14   const MatScalar   *aa   = a->a;
15   PetscScalar       *x;
16   const PetscScalar *b;
17 
18   PetscFunctionBegin;
19   PetscCall(VecGetArrayRead(bb, &b));
20   PetscCall(VecGetArray(xx, &x));
21 
22 #if defined(PETSC_USE_FORTRAN_KERNEL_SOLVEBAIJ)
23   {
24     static PetscScalar w[2000]; /* very BAD need to fix */
25     fortransolvebaij4_(&n, x, ai, aj, diag, aa, b, w);
26   }
27 #elif defined(PETSC_USE_FORTRAN_KERNEL_SOLVEBAIJUNROLL)
28   fortransolvebaij4unroll_(&n, x, ai, aj, diag, aa, b);
29 #else
30   {
31     PetscScalar      s1, s2, s3, s4, x1, x2, x3, x4;
32     const MatScalar *v;
33     PetscInt         jdx, idt, idx, nz, i, ai16;
34     const PetscInt  *vi;
35 
36     /* forward solve the lower triangular */
37     idx  = 0;
38     x[0] = b[0];
39     x[1] = b[1];
40     x[2] = b[2];
41     x[3] = b[3];
42     for (i = 1; i < n; i++) {
43       v  = aa + 16 * ai[i];
44       vi = aj + ai[i];
45       nz = diag[i] - ai[i];
46       idx += 4;
47       s1 = b[idx];
48       s2 = b[1 + idx];
49       s3 = b[2 + idx];
50       s4 = b[3 + idx];
51       while (nz--) {
52         jdx = 4 * (*vi++);
53         x1  = x[jdx];
54         x2  = x[1 + jdx];
55         x3  = x[2 + jdx];
56         x4  = x[3 + jdx];
57         s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
58         s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
59         s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
60         s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
61         v += 16;
62       }
63       x[idx]     = s1;
64       x[1 + idx] = s2;
65       x[2 + idx] = s3;
66       x[3 + idx] = s4;
67     }
68     /* backward solve the upper triangular */
69     idt = 4 * (n - 1);
70     for (i = n - 1; i >= 0; i--) {
71       ai16 = 16 * diag[i];
72       v    = aa + ai16 + 16;
73       vi   = aj + diag[i] + 1;
74       nz   = ai[i + 1] - diag[i] - 1;
75       s1   = x[idt];
76       s2   = x[1 + idt];
77       s3   = x[2 + idt];
78       s4   = x[3 + idt];
79       while (nz--) {
80         idx = 4 * (*vi++);
81         x1  = x[idx];
82         x2  = x[1 + idx];
83         x3  = x[2 + idx];
84         x4  = x[3 + idx];
85         s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
86         s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
87         s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
88         s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
89         v += 16;
90       }
91       v          = aa + ai16;
92       x[idt]     = v[0] * s1 + v[4] * s2 + v[8] * s3 + v[12] * s4;
93       x[1 + idt] = v[1] * s1 + v[5] * s2 + v[9] * s3 + v[13] * s4;
94       x[2 + idt] = v[2] * s1 + v[6] * s2 + v[10] * s3 + v[14] * s4;
95       x[3 + idt] = v[3] * s1 + v[7] * s2 + v[11] * s3 + v[15] * s4;
96       idt -= 4;
97     }
98   }
99 #endif
100 
101   PetscCall(VecRestoreArrayRead(bb, &b));
102   PetscCall(VecRestoreArray(xx, &x));
103   PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n));
104   PetscFunctionReturn(PETSC_SUCCESS);
105 }
106 
107 PetscErrorCode MatSolve_SeqBAIJ_4_NaturalOrdering(Mat A, Vec bb, Vec xx)
108 {
109   Mat_SeqBAIJ       *a = (Mat_SeqBAIJ *)A->data;
110   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *adiag = a->diag;
111   PetscInt           i, k, nz, idx, jdx, idt;
112   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
113   const MatScalar   *aa = a->a, *v;
114   PetscScalar       *x;
115   const PetscScalar *b;
116   PetscScalar        s1, s2, s3, s4, x1, x2, x3, x4;
117 
118   PetscFunctionBegin;
119   PetscCall(VecGetArrayRead(bb, &b));
120   PetscCall(VecGetArray(xx, &x));
121   /* forward solve the lower triangular */
122   idx  = 0;
123   x[0] = b[idx];
124   x[1] = b[1 + idx];
125   x[2] = b[2 + idx];
126   x[3] = b[3 + idx];
127   for (i = 1; i < n; i++) {
128     v   = aa + bs2 * ai[i];
129     vi  = aj + ai[i];
130     nz  = ai[i + 1] - ai[i];
131     idx = bs * i;
132     s1  = b[idx];
133     s2  = b[1 + idx];
134     s3  = b[2 + idx];
135     s4  = b[3 + idx];
136     for (k = 0; k < nz; k++) {
137       jdx = bs * vi[k];
138       x1  = x[jdx];
139       x2  = x[1 + jdx];
140       x3  = x[2 + jdx];
141       x4  = x[3 + jdx];
142       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
143       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
144       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
145       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
146 
147       v += bs2;
148     }
149 
150     x[idx]     = s1;
151     x[1 + idx] = s2;
152     x[2 + idx] = s3;
153     x[3 + idx] = s4;
154   }
155 
156   /* backward solve the upper triangular */
157   for (i = n - 1; i >= 0; i--) {
158     v   = aa + bs2 * (adiag[i + 1] + 1);
159     vi  = aj + adiag[i + 1] + 1;
160     nz  = adiag[i] - adiag[i + 1] - 1;
161     idt = bs * i;
162     s1  = x[idt];
163     s2  = x[1 + idt];
164     s3  = x[2 + idt];
165     s4  = x[3 + idt];
166 
167     for (k = 0; k < nz; k++) {
168       idx = bs * vi[k];
169       x1  = x[idx];
170       x2  = x[1 + idx];
171       x3  = x[2 + idx];
172       x4  = x[3 + idx];
173       s1 -= v[0] * x1 + v[4] * x2 + v[8] * x3 + v[12] * x4;
174       s2 -= v[1] * x1 + v[5] * x2 + v[9] * x3 + v[13] * x4;
175       s3 -= v[2] * x1 + v[6] * x2 + v[10] * x3 + v[14] * x4;
176       s4 -= v[3] * x1 + v[7] * x2 + v[11] * x3 + v[15] * x4;
177 
178       v += bs2;
179     }
180     /* x = inv_diagonal*x */
181     x[idt]     = v[0] * s1 + v[4] * s2 + v[8] * s3 + v[12] * s4;
182     x[1 + idt] = v[1] * s1 + v[5] * s2 + v[9] * s3 + v[13] * s4;
183     x[2 + idt] = v[2] * s1 + v[6] * s2 + v[10] * s3 + v[14] * s4;
184     x[3 + idt] = v[3] * s1 + v[7] * s2 + v[11] * s3 + v[15] * s4;
185   }
186 
187   PetscCall(VecRestoreArrayRead(bb, &b));
188   PetscCall(VecRestoreArray(xx, &x));
189   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
190   PetscFunctionReturn(PETSC_SUCCESS);
191 }
192