xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision ccb4e88a40f0b86eaeca07ff64c64e4de2fae686)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10      subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11      implicit none
12      MatScalar   a(0:*)
13      PetscScalar x(0:*)
14      PetscScalar b(0:*)
15      PetscInt    n
16      PetscInt    ai(0:*)
17      PetscInt    aj(0:*)
18      PetscInt    adiag(0:*)
19
20      PetscInt    i,j,jstart,jend
21      PetscInt    idx,ax,jdx
22      PetscScalar s1,s2,s3,s4
23      PetscScalar x1,x2,x3,x4
24!
25!     Forward Solve
26!
27      PETSC_AssertAlignx(16,a(1))
28      PETSC_AssertAlignx(16,x(1))
29      PETSC_AssertAlignx(16,b(1))
30      PETSC_AssertAlignx(16,ai(1))
31      PETSC_AssertAlignx(16,aj(1))
32      PETSC_AssertAlignx(16,adiag(1))
33
34         x(0) = b(0)
35         x(1) = b(1)
36         x(2) = b(2)
37         x(3) = b(3)
38         idx  = 0
39         do 20 i=1,n-1
40            jstart = ai(i)
41            jend   = adiag(i) - 1
42            ax    = 16*jstart
43            idx    = idx + 4
44            s1     = b(idx)
45            s2     = b(idx+1)
46            s3     = b(idx+2)
47            s4     = b(idx+3)
48            do 30 j=jstart,jend
49              jdx   = 4*aj(j)
50
51              x1    = x(jdx)
52              x2    = x(jdx+1)
53              x3    = x(jdx+2)
54              x4    = x(jdx+3)
55              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
56              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
57              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
58              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
59              ax = ax + 16
60 30         continue
61            x(idx)   = s1
62            x(idx+1) = s2
63            x(idx+2) = s3
64            x(idx+3) = s4
65 20      continue
66
67!
68!     Backward solve the upper triangular
69!
70         do 40 i=n-1,0,-1
71            jstart  = adiag(i) + 1
72            jend    = ai(i+1) - 1
73            ax     = 16*jstart
74            s1      = x(idx)
75            s2      = x(idx+1)
76            s3      = x(idx+2)
77            s4      = x(idx+3)
78            do 50 j=jstart,jend
79              jdx   = 4*aj(j)
80              x1    = x(jdx)
81              x2    = x(jdx+1)
82              x3    = x(jdx+2)
83              x4    = x(jdx+3)
84              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
85              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
86              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
87              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
88              ax = ax + 16
89 50         continue
90            ax      = 16*adiag(i)
91            x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
92            x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
93            x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
94            x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
95            idx      = idx - 4
96 40      continue
97      return
98      end
99
100!   version that does not call BLAS 2 operation for each row block
101!
102      subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
103      implicit none
104      MatScalar   a(0:*)
105      PetscScalar x(0:*),b(0:*),w(0:*)
106      PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
107      PetscInt  ii,jj,i,j
108
109      PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
110      PetscScalar s(0:3)
111
112!
113!     Forward Solve
114!
115
116      PETSC_AssertAlignx(16,a(1))
117      PETSC_AssertAlignx(16,w(1))
118      PETSC_AssertAlignx(16,x(1))
119      PETSC_AssertAlignx(16,b(1))
120      PETSC_AssertAlignx(16,ai(1))
121      PETSC_AssertAlignx(16,aj(1))
122      PETSC_AssertAlignx(16,adiag(1))
123
124      x(0) = b(0)
125      x(1) = b(1)
126      x(2) = b(2)
127      x(3) = b(3)
128      idx  = 0
129      do 20 i=1,n-1
130!
131!        Pack required part of vector into work array
132!
133         kdx    = 0
134         jstart = ai(i)
135         jend   = adiag(i) - 1
136         if (jend - jstart .ge. 500) then
137           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
138         endif
139         do 30 j=jstart,jend
140
141           jdx       = 4*aj(j)
142
143           w(kdx)    = x(jdx)
144           w(kdx+1)  = x(jdx+1)
145           w(kdx+2)  = x(jdx+2)
146           w(kdx+3)  = x(jdx+3)
147           kdx       = kdx + 4
148 30      continue
149
150         ax       = 16*jstart
151         idx      = idx + 4
152         s(0)     = b(idx)
153         s(1)     = b(idx+1)
154         s(2)     = b(idx+2)
155         s(3)     = b(idx+3)
156!
157!    s = s - a(ax:)*w
158!
159         nn = 4*(jend - jstart + 1) - 1
160         do 100, ii=0,3
161           do 110, jj=0,nn
162             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
163 110       continue
164 100     continue
165
166         x(idx)   = s(0)
167         x(idx+1) = s(1)
168         x(idx+2) = s(2)
169         x(idx+3) = s(3)
170 20   continue
171
172!
173!     Backward solve the upper triangular
174!
175      do 40 i=n-1,0,-1
176         jstart    = adiag(i) + 1
177         jend      = ai(i+1) - 1
178         ax        = 16*jstart
179         s(0)      = x(idx)
180         s(1)      = x(idx+1)
181         s(2)      = x(idx+2)
182         s(3)      = x(idx+3)
183!
184!   Pack each chunk of vector needed
185!
186         kdx = 0
187         if (jend - jstart .ge. 500) then
188           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
189         endif
190         do 50 j=jstart,jend
191           jdx      = 4*aj(j)
192           w(kdx)   = x(jdx)
193           w(kdx+1) = x(jdx+1)
194           w(kdx+2) = x(jdx+2)
195           w(kdx+3) = x(jdx+3)
196           kdx      = kdx + 4
197 50      continue
198         nn = 4*(jend - jstart + 1) - 1
199         do 200, ii=0,3
200           do 210, jj=0,nn
201             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
202 210       continue
203 200     continue
204
205         ax      = 16*adiag(i)
206         x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
207         x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
208         x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
209         x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
210         idx     = idx - 4
211 40   continue
212
213      return
214      end
215
216