xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision f13dfd9ea68e0ddeee984e65c377a1819eab8a8a)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10      subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11      implicit none
12      MatScalar   a(0:*)
13      PetscScalar x(0:*)
14      PetscScalar b(0:*)
15      PetscInt    n
16      PetscInt    ai(0:*)
17      PetscInt    aj(0:*)
18      PetscInt    adiag(0:*)
19
20      PetscInt    i,j,jstart,jend
21      PetscInt    idx,ax,jdx
22      PetscScalar s1,s2,s3,s4
23      PetscScalar x1,x2,x3,x4
24!
25!     Forward Solve
26!
27      PETSC_AssertAlignx(16,a(1))
28      PETSC_AssertAlignx(16,x(1))
29      PETSC_AssertAlignx(16,b(1))
30      PETSC_AssertAlignx(16,ai(1))
31      PETSC_AssertAlignx(16,aj(1))
32      PETSC_AssertAlignx(16,adiag(1))
33
34         x(0) = b(0)
35         x(1) = b(1)
36         x(2) = b(2)
37         x(3) = b(3)
38         idx  = 0
39         do 20 i=1,n-1
40            jstart = ai(i)
41            jend   = adiag(i) - 1
42            ax    = 16*jstart
43            idx    = idx + 4
44            s1     = b(idx)
45            s2     = b(idx+1)
46            s3     = b(idx+2)
47            s4     = b(idx+3)
48            do 30 j=jstart,jend
49              jdx   = 4*aj(j)
50
51              x1    = x(jdx)
52              x2    = x(jdx+1)
53              x3    = x(jdx+2)
54              x4    = x(jdx+3)
55              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
56              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
57              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
58              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
59              ax = ax + 16
60 30         continue
61            x(idx)   = s1
62            x(idx+1) = s2
63            x(idx+2) = s3
64            x(idx+3) = s4
65 20      continue
66
67!
68!     Backward solve the upper triangular
69!
70         do 40 i=n-1,0,-1
71            jstart  = adiag(i) + 1
72            jend    = ai(i+1) - 1
73            ax     = 16*jstart
74            s1      = x(idx)
75            s2      = x(idx+1)
76            s3      = x(idx+2)
77            s4      = x(idx+3)
78            do 50 j=jstart,jend
79              jdx   = 4*aj(j)
80              x1    = x(jdx)
81              x2    = x(jdx+1)
82              x3    = x(jdx+2)
83              x4    = x(jdx+3)
84              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
85              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
86              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
87              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
88              ax = ax + 16
89 50         continue
90            ax      = 16*adiag(i)
91            x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
92            x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
93            x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
94            x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
95            idx      = idx - 4
96 40      continue
97      end
98
99!   version that does not call BLAS 2 operation for each row block
100!
101      subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
102      implicit none
103      MatScalar   a(0:*)
104      PetscScalar x(0:*),b(0:*),w(0:*)
105      PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
106      PetscInt  ii,jj,i,j
107
108      PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
109      PetscScalar s(0:3)
110
111!
112!     Forward Solve
113!
114
115      PETSC_AssertAlignx(16,a(1))
116      PETSC_AssertAlignx(16,w(1))
117      PETSC_AssertAlignx(16,x(1))
118      PETSC_AssertAlignx(16,b(1))
119      PETSC_AssertAlignx(16,ai(1))
120      PETSC_AssertAlignx(16,aj(1))
121      PETSC_AssertAlignx(16,adiag(1))
122
123      x(0) = b(0)
124      x(1) = b(1)
125      x(2) = b(2)
126      x(3) = b(3)
127      idx  = 0
128      do 20 i=1,n-1
129!
130!        Pack required part of vector into work array
131!
132         kdx    = 0
133         jstart = ai(i)
134         jend   = adiag(i) - 1
135         if (jend - jstart .ge. 500) then
136           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
137         endif
138         do 30 j=jstart,jend
139
140           jdx       = 4*aj(j)
141
142           w(kdx)    = x(jdx)
143           w(kdx+1)  = x(jdx+1)
144           w(kdx+2)  = x(jdx+2)
145           w(kdx+3)  = x(jdx+3)
146           kdx       = kdx + 4
147 30      continue
148
149         ax       = 16*jstart
150         idx      = idx + 4
151         s(0)     = b(idx)
152         s(1)     = b(idx+1)
153         s(2)     = b(idx+2)
154         s(3)     = b(idx+3)
155!
156!    s = s - a(ax:)*w
157!
158         nn = 4*(jend - jstart + 1) - 1
159         do 100, ii=0,3
160           do 110, jj=0,nn
161             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
162 110       continue
163 100     continue
164
165         x(idx)   = s(0)
166         x(idx+1) = s(1)
167         x(idx+2) = s(2)
168         x(idx+3) = s(3)
169 20   continue
170
171!
172!     Backward solve the upper triangular
173!
174      do 40 i=n-1,0,-1
175         jstart    = adiag(i) + 1
176         jend      = ai(i+1) - 1
177         ax        = 16*jstart
178         s(0)      = x(idx)
179         s(1)      = x(idx+1)
180         s(2)      = x(idx+2)
181         s(3)      = x(idx+3)
182!
183!   Pack each chunk of vector needed
184!
185         kdx = 0
186         if (jend - jstart .ge. 500) then
187           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
188         endif
189         do 50 j=jstart,jend
190           jdx      = 4*aj(j)
191           w(kdx)   = x(jdx)
192           w(kdx+1) = x(jdx+1)
193           w(kdx+2) = x(jdx+2)
194           w(kdx+3) = x(jdx+3)
195           kdx      = kdx + 4
196 50      continue
197         nn = 4*(jend - jstart + 1) - 1
198         do 200, ii=0,3
199           do 210, jj=0,nn
200             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
201 210       continue
202 200     continue
203
204         ax      = 16*adiag(i)
205         x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
206         x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
207         x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
208         x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
209         idx     = idx - 4
210 40   continue
211
212      end
213