xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision 605a06ccce2b6060581a2b5350eea15cf005ca7e) !
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11  use, intrinsic :: ISO_C_binding
12  implicit none (type, external)
13  MatScalar, intent(in) :: a(0:*)
14  PetscScalar, intent(inout) :: x(0:*)
15  PetscScalar, intent(in) :: b(0:*)
16  PetscInt, intent(in) :: n
17  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
18
19  PetscInt :: i,j,jstart,jend
20  PetscInt :: idx,ax,jdx
21  PetscScalar :: s(0:3)
22
23  PETSC_AssertAlignx(16,a(1))
24  PETSC_AssertAlignx(16,x(1))
25  PETSC_AssertAlignx(16,b(1))
26  PETSC_AssertAlignx(16,ai(1))
27  PETSC_AssertAlignx(16,aj(1))
28  PETSC_AssertAlignx(16,adiag(1))
29
30  !
31  ! Forward Solve
32  !
33  x(0:3) = b(0:3)
34  idx  = 0
35  do i=1,n-1
36    jstart = ai(i)
37    jend   = adiag(i) - 1
38    ax     = 16*jstart
39    idx    = idx + 4
40    s(0:3) = b(idx+0:idx+3)
41    do j=jstart,jend
42      jdx = 4*aj(j)
43
44      s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3))
45      s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3))
46      s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3))
47      s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3))
48      ax = ax + 16
49    end do
50    x(idx+0:idx+3) = s(0:3)
51  end do
52
53  !
54  ! Backward solve the upper triangular
55  !
56  do i=n-1,0,-1
57    jstart = adiag(i) + 1
58    jend   = ai(i+1) - 1
59    ax     = 16*jstart
60    s(0:3) = x(idx+0:idx+3)
61    do j=jstart,jend
62      jdx   = 4*aj(j)
63      s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3))
64      s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3))
65      s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3))
66      s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3))
67      ax = ax + 16
68    end do
69    ax      = 16*adiag(i)
70    x(idx+0) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3)
71    x(idx+1) = a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3)
72    x(idx+2) = a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
73    x(idx+3) = a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
74    idx      = idx - 4
75  end do
76end subroutine FortranSolveBAIJ4Unroll
77
78!   version that does not call BLAS 2 operation for each row block
79!
80pure subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
81  use, intrinsic :: ISO_C_binding
82  implicit none
83  MatScalar, intent(in) :: a(0:*)
84  PetscScalar, intent(inout) :: x(0:*),w(0:*)
85  PetscScalar, intent(in) :: b(0:*)
86  PetscInt, intent(in) :: n
87  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
88
89  PetscInt :: ii,jj,i,j
90  PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn
91  PetscScalar :: s(0:3)
92
93  PETSC_AssertAlignx(16,a(1))
94  PETSC_AssertAlignx(16,w(1))
95  PETSC_AssertAlignx(16,x(1))
96  PETSC_AssertAlignx(16,b(1))
97  PETSC_AssertAlignx(16,ai(1))
98  PETSC_AssertAlignx(16,aj(1))
99  PETSC_AssertAlignx(16,adiag(1))
100  !
101  !     Forward Solve
102  !
103  x(0:3) = b(0:3)
104  idx  = 0
105  do i=1,n-1
106    !
107    ! Pack required part of vector into work array
108    !
109    kdx    = 0
110    jstart = ai(i)
111    jend   = adiag(i) - 1
112
113    if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
114
115    do j=jstart,jend
116      jdx       = 4*aj(j)
117      w(kdx:kdx+3) = x(jdx:jdx+3)
118      kdx       = kdx + 4
119    end do
120
121    ax     = 16*jstart
122    idx    = idx + 4
123    s(0:3) = b(idx:idx+3)
124    !
125    !    s = s - a(ax:)*w
126    !
127    nn = 4*(jend - jstart + 1) - 1
128    do ii=0,3
129      do jj=0,nn
130        s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
131      end do
132    end do
133
134    x(idx:idx+3) = s(0:3)
135  end do
136  !
137  ! Backward solve the upper triangular
138  !
139  do i=n-1,0,-1
140     jstart = adiag(i) + 1
141     jend   = ai(i+1) - 1
142     ax     = 16*jstart
143     s(0:3) = x(idx:idx+3)
144     !
145     !   Pack each chunk of vector needed
146     !
147     kdx = 0
148     if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
149
150     do j=jstart,jend
151       jdx = 4*aj(j)
152       w(kdx:kdx+3) = x(jdx:jdx+3)
153       kdx = kdx + 4
154     end do
155     nn = 4*(jend - jstart + 1) - 1
156     do ii=0,3
157       do jj=0,nn
158         s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
159       end do
160     end do
161
162     ax      = 16*adiag(i)
163     x(idx)  = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3)
164     x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3)
165     x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
166     x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
167     idx     = idx - 4
168  end do
169end subroutine FortranSolveBAIJ4
170