xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision 7e1a0bbe36d2be40a00a95404ece00db4857f70d)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10pure subroutine FortranSolveBAIJ4Unroll(n, x, ai, aj, adiag, a, b)
11  use, intrinsic :: ISO_C_binding
12  implicit none(type, external)
13  MatScalar, intent(in) :: a(0:*)
14  PetscScalar, intent(inout) :: x(0:*)
15  PetscScalar, intent(in) :: b(0:*)
16  PetscInt, intent(in) :: n
17  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
18
19  PetscInt :: i, j, jstart, jend
20  PetscInt :: idx, ax, jdx
21  PetscScalar :: s(0:3)
22
23  PETSC_AssertAlignx(16, a(1))
24  PETSC_AssertAlignx(16, x(1))
25  PETSC_AssertAlignx(16, b(1))
26  PETSC_AssertAlignx(16, ai(1))
27  PETSC_AssertAlignx(16, aj(1))
28  PETSC_AssertAlignx(16, adiag(1))
29
30  !
31  ! Forward Solve
32  !
33  x(0:3) = b(0:3)
34  idx = 0
35  do i = 1, n - 1
36    jstart = ai(i)
37    jend = adiag(i) - 1
38    ax = 16*jstart
39    idx = idx + 4
40    s(0:3) = b(idx + 0:idx + 3)
41    do j = jstart, jend
42      jdx = 4*aj(j)
43
44      s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
45      s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
46      s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
47      s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
48      ax = ax + 16
49    end do
50    x(idx + 0:idx + 3) = s(0:3)
51  end do
52
53  !
54  ! Backward solve the upper triangular
55  !
56  do i = n - 1, 0, -1
57    jstart = adiag(i) + 1
58    jend = ai(i + 1) - 1
59    ax = 16*jstart
60    s(0:3) = x(idx + 0:idx + 3)
61    do j = jstart, jend
62      jdx = 4*aj(j)
63      s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
64      s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
65      s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
66      s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
67      ax = ax + 16
68    end do
69    ax = 16*adiag(i)
70    x(idx + 0) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
71    x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
72    x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
73    x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
74    idx = idx - 4
75  end do
76end subroutine FortranSolveBAIJ4Unroll
77
78!   version that does not call BLAS 2 operation for each row block
79!
80pure subroutine FortranSolveBAIJ4(n, x, ai, aj, adiag, a, b, w)
81  use, intrinsic :: ISO_C_binding
82  implicit none
83  MatScalar, intent(in) :: a(0:*)
84  PetscScalar, intent(inout) :: x(0:*), w(0:*)
85  PetscScalar, intent(in) :: b(0:*)
86  PetscInt, intent(in) :: n
87  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
88
89  PetscInt :: ii, jj, i, j
90  PetscInt :: jstart, jend, idx, ax, jdx, kdx, nn
91  PetscScalar :: s(0:3)
92
93  PETSC_AssertAlignx(16, a(1))
94  PETSC_AssertAlignx(16, w(1))
95  PETSC_AssertAlignx(16, x(1))
96  PETSC_AssertAlignx(16, b(1))
97  PETSC_AssertAlignx(16, ai(1))
98  PETSC_AssertAlignx(16, aj(1))
99  PETSC_AssertAlignx(16, adiag(1))
100  !
101  !     Forward Solve
102  !
103  x(0:3) = b(0:3)
104  idx = 0
105  do i = 1, n - 1
106    !
107    ! Pack required part of vector into work array
108    !
109    kdx = 0
110    jstart = ai(i)
111    jend = adiag(i) - 1
112
113    if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
114
115    do j = jstart, jend
116      jdx = 4*aj(j)
117      w(kdx:kdx + 3) = x(jdx:jdx + 3)
118      kdx = kdx + 4
119    end do
120
121    ax = 16*jstart
122    idx = idx + 4
123    s(0:3) = b(idx:idx + 3)
124    !
125    !    s = s - a(ax:)*w
126    !
127    nn = 4*(jend - jstart + 1) - 1
128    do ii = 0, 3
129      do jj = 0, nn
130        s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
131      end do
132    end do
133
134    x(idx:idx + 3) = s(0:3)
135  end do
136  !
137  ! Backward solve the upper triangular
138  !
139  do i = n - 1, 0, -1
140    jstart = adiag(i) + 1
141    jend = ai(i + 1) - 1
142    ax = 16*jstart
143    s(0:3) = x(idx:idx + 3)
144    !
145    !   Pack each chunk of vector needed
146    !
147    kdx = 0
148    if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
149
150    do j = jstart, jend
151      jdx = 4*aj(j)
152      w(kdx:kdx + 3) = x(jdx:jdx + 3)
153      kdx = kdx + 4
154    end do
155    nn = 4*(jend - jstart + 1) - 1
156    do ii = 0, 3
157      do jj = 0, nn
158        s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
159      end do
160    end do
161
162    ax = 16*adiag(i)
163    x(idx) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
164    x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
165    x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
166    x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
167    idx = idx - 4
168  end do
169end subroutine FortranSolveBAIJ4
170