1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10 subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b) 11 implicit none 12 MatScalar a(0:*) 13 PetscScalar x(0:*) 14 PetscScalar b(0:*) 15 PetscInt n 16 PetscInt ai(0:*) 17 PetscInt aj(0:*) 18 PetscInt adiag(0:*) 19 20 PetscInt i,j,jstart,jend 21 PetscInt idx,ax,jdx 22 PetscScalar s1,s2,s3,s4 23 PetscScalar x1,x2,x3,x4 24! 25! Forward Solve 26! 27 PETSC_AssertAlignx(16,a(1)) 28 PETSC_AssertAlignx(16,x(1)) 29 PETSC_AssertAlignx(16,b(1)) 30 PETSC_AssertAlignx(16,ai(1)) 31 PETSC_AssertAlignx(16,aj(1)) 32 PETSC_AssertAlignx(16,adiag(1)) 33 34 x(0) = b(0) 35 x(1) = b(1) 36 x(2) = b(2) 37 x(3) = b(3) 38 idx = 0 39 do 20 i=1,n-1 40 jstart = ai(i) 41 jend = adiag(i) - 1 42 ax = 16*jstart 43 idx = idx + 4 44 s1 = b(idx) 45 s2 = b(idx+1) 46 s3 = b(idx+2) 47 s4 = b(idx+3) 48 do 30 j=jstart,jend 49 jdx = 4*aj(j) 50 51 x1 = x(jdx) 52 x2 = x(jdx+1) 53 x3 = x(jdx+2) 54 x4 = x(jdx+3) 55 s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4) 56 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 57 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 58 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 59 ax = ax + 16 60 30 continue 61 x(idx) = s1 62 x(idx+1) = s2 63 x(idx+2) = s3 64 x(idx+3) = s4 65 20 continue 66 67 68! 69! Backward solve the upper triangular 70! 71 do 40 i=n-1,0,-1 72 jstart = adiag(i) + 1 73 jend = ai(i+1) - 1 74 ax = 16*jstart 75 s1 = x(idx) 76 s2 = x(idx+1) 77 s3 = x(idx+2) 78 s4 = x(idx+3) 79 do 50 j=jstart,jend 80 jdx = 4*aj(j) 81 x1 = x(jdx) 82 x2 = x(jdx+1) 83 x3 = x(jdx+2) 84 x4 = x(jdx+3) 85 s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4) 86 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 87 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 88 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 89 ax = ax + 16 90 50 continue 91 ax = 16*adiag(i) 92 x(idx) = a(ax)*s1 +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4 93 x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4 94 x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4 95 x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4 96 idx = idx - 4 97 40 continue 98 return 99 end 100 101! version that does not call BLAS 2 operation for each row block 102! 103 subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w) 104 implicit none 105 MatScalar a(0:*) 106 PetscScalar x(0:*),b(0:*),w(0:*) 107 PetscInt n,ai(0:*),aj(0:*),adiag(0:*) 108 PetscInt ii,jj,i,j 109 110 PetscInt jstart,jend,idx,ax,jdx,kdx,nn 111 PetscScalar s(0:3) 112 113! 114! Forward Solve 115! 116 117 PETSC_AssertAlignx(16,a(1)) 118 PETSC_AssertAlignx(16,w(1)) 119 PETSC_AssertAlignx(16,x(1)) 120 PETSC_AssertAlignx(16,b(1)) 121 PETSC_AssertAlignx(16,ai(1)) 122 PETSC_AssertAlignx(16,aj(1)) 123 PETSC_AssertAlignx(16,adiag(1)) 124 125 x(0) = b(0) 126 x(1) = b(1) 127 x(2) = b(2) 128 x(3) = b(3) 129 idx = 0 130 do 20 i=1,n-1 131! 132! Pack required part of vector into work array 133! 134 kdx = 0 135 jstart = ai(i) 136 jend = adiag(i) - 1 137 if (jend - jstart .ge. 500) then 138 write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 139 endif 140 do 30 j=jstart,jend 141 142 jdx = 4*aj(j) 143 144 w(kdx) = x(jdx) 145 w(kdx+1) = x(jdx+1) 146 w(kdx+2) = x(jdx+2) 147 w(kdx+3) = x(jdx+3) 148 kdx = kdx + 4 149 30 continue 150 151 ax = 16*jstart 152 idx = idx + 4 153 s(0) = b(idx) 154 s(1) = b(idx+1) 155 s(2) = b(idx+2) 156 s(3) = b(idx+3) 157! 158! s = s - a(ax:)*w 159! 160 nn = 4*(jend - jstart + 1) - 1 161 do 100, ii=0,3 162 do 110, jj=0,nn 163 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 164 110 continue 165 100 continue 166 167 x(idx) = s(0) 168 x(idx+1) = s(1) 169 x(idx+2) = s(2) 170 x(idx+3) = s(3) 171 20 continue 172 173! 174! Backward solve the upper triangular 175! 176 do 40 i=n-1,0,-1 177 jstart = adiag(i) + 1 178 jend = ai(i+1) - 1 179 ax = 16*jstart 180 s(0) = x(idx) 181 s(1) = x(idx+1) 182 s(2) = x(idx+2) 183 s(3) = x(idx+3) 184! 185! Pack each chunk of vector needed 186! 187 kdx = 0 188 if (jend - jstart .ge. 500) then 189 write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 190 endif 191 do 50 j=jstart,jend 192 jdx = 4*aj(j) 193 w(kdx) = x(jdx) 194 w(kdx+1) = x(jdx+1) 195 w(kdx+2) = x(jdx+2) 196 w(kdx+3) = x(jdx+3) 197 kdx = kdx + 4 198 50 continue 199 nn = 4*(jend - jstart + 1) - 1 200 do 200, ii=0,3 201 do 210, jj=0,nn 202 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 203 210 continue 204 200 continue 205 206 ax = 16*adiag(i) 207 x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3) 208 x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3) 209 x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 210 x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 211 idx = idx - 4 212 40 continue 213 214 return 215 end 216 217