xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision c87ba875e4007ad659b117ea274f03d5f4cd5ea7)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10      subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11      implicit none
12      MatScalar   a(0:*)
13      PetscScalar x(0:*)
14      PetscScalar b(0:*)
15      PetscInt    n
16      PetscInt    ai(0:*)
17      PetscInt    aj(0:*)
18      PetscInt    adiag(0:*)
19
20      PetscInt    i,j,jstart,jend
21      PetscInt    idx,ax,jdx
22      PetscScalar s1,s2,s3,s4
23      PetscScalar x1,x2,x3,x4
24!
25!     Forward Solve
26!
27      PETSC_AssertAlignx(16,a(1))
28      PETSC_AssertAlignx(16,x(1))
29      PETSC_AssertAlignx(16,b(1))
30      PETSC_AssertAlignx(16,ai(1))
31      PETSC_AssertAlignx(16,aj(1))
32      PETSC_AssertAlignx(16,adiag(1))
33
34         x(0) = b(0)
35         x(1) = b(1)
36         x(2) = b(2)
37         x(3) = b(3)
38         idx  = 0
39         do 20 i=1,n-1
40            jstart = ai(i)
41            jend   = adiag(i) - 1
42            ax    = 16*jstart
43            idx    = idx + 4
44            s1     = b(idx)
45            s2     = b(idx+1)
46            s3     = b(idx+2)
47            s4     = b(idx+3)
48            do 30 j=jstart,jend
49              jdx   = 4*aj(j)
50
51              x1    = x(jdx)
52              x2    = x(jdx+1)
53              x3    = x(jdx+2)
54              x4    = x(jdx+3)
55              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
56              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
57              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
58              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
59              ax = ax + 16
60 30         continue
61            x(idx)   = s1
62            x(idx+1) = s2
63            x(idx+2) = s3
64            x(idx+3) = s4
65 20      continue
66
67
68!
69!     Backward solve the upper triangular
70!
71         do 40 i=n-1,0,-1
72            jstart  = adiag(i) + 1
73            jend    = ai(i+1) - 1
74            ax     = 16*jstart
75            s1      = x(idx)
76            s2      = x(idx+1)
77            s3      = x(idx+2)
78            s4      = x(idx+3)
79            do 50 j=jstart,jend
80              jdx   = 4*aj(j)
81              x1    = x(jdx)
82              x2    = x(jdx+1)
83              x3    = x(jdx+2)
84              x4    = x(jdx+3)
85              s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
86              s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
87              s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
88              s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
89              ax = ax + 16
90 50         continue
91            ax      = 16*adiag(i)
92            x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
93            x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
94            x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
95            x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
96            idx      = idx - 4
97 40      continue
98      return
99      end
100
101!   version that does not call BLAS 2 operation for each row block
102!
103      subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
104      implicit none
105      MatScalar   a(0:*)
106      PetscScalar x(0:*),b(0:*),w(0:*)
107      PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
108      PetscInt  ii,jj,i,j
109
110      PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
111      PetscScalar s(0:3)
112
113!
114!     Forward Solve
115!
116
117      PETSC_AssertAlignx(16,a(1))
118      PETSC_AssertAlignx(16,w(1))
119      PETSC_AssertAlignx(16,x(1))
120      PETSC_AssertAlignx(16,b(1))
121      PETSC_AssertAlignx(16,ai(1))
122      PETSC_AssertAlignx(16,aj(1))
123      PETSC_AssertAlignx(16,adiag(1))
124
125      x(0) = b(0)
126      x(1) = b(1)
127      x(2) = b(2)
128      x(3) = b(3)
129      idx  = 0
130      do 20 i=1,n-1
131!
132!        Pack required part of vector into work array
133!
134         kdx    = 0
135         jstart = ai(i)
136         jend   = adiag(i) - 1
137         if (jend - jstart .ge. 500) then
138           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
139         endif
140         do 30 j=jstart,jend
141
142           jdx       = 4*aj(j)
143
144           w(kdx)    = x(jdx)
145           w(kdx+1)  = x(jdx+1)
146           w(kdx+2)  = x(jdx+2)
147           w(kdx+3)  = x(jdx+3)
148           kdx       = kdx + 4
149 30      continue
150
151         ax       = 16*jstart
152         idx      = idx + 4
153         s(0)     = b(idx)
154         s(1)     = b(idx+1)
155         s(2)     = b(idx+2)
156         s(3)     = b(idx+3)
157!
158!    s = s - a(ax:)*w
159!
160         nn = 4*(jend - jstart + 1) - 1
161         do 100, ii=0,3
162           do 110, jj=0,nn
163             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
164 110       continue
165 100     continue
166
167         x(idx)   = s(0)
168         x(idx+1) = s(1)
169         x(idx+2) = s(2)
170         x(idx+3) = s(3)
171 20   continue
172
173!
174!     Backward solve the upper triangular
175!
176      do 40 i=n-1,0,-1
177         jstart    = adiag(i) + 1
178         jend      = ai(i+1) - 1
179         ax        = 16*jstart
180         s(0)      = x(idx)
181         s(1)      = x(idx+1)
182         s(2)      = x(idx+2)
183         s(3)      = x(idx+3)
184!
185!   Pack each chunk of vector needed
186!
187         kdx = 0
188         if (jend - jstart .ge. 500) then
189           write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
190         endif
191         do 50 j=jstart,jend
192           jdx      = 4*aj(j)
193           w(kdx)   = x(jdx)
194           w(kdx+1) = x(jdx+1)
195           w(kdx+2) = x(jdx+2)
196           w(kdx+3) = x(jdx+3)
197           kdx      = kdx + 4
198 50      continue
199         nn = 4*(jend - jstart + 1) - 1
200         do 200, ii=0,3
201           do 210, jj=0,nn
202             s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
203 210       continue
204 200     continue
205
206         ax      = 16*adiag(i)
207         x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
208         x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
209         x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
210         x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
211         idx     = idx - 4
212 40   continue
213
214      return
215      end
216
217