1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b) 11 use, intrinsic :: ISO_C_binding 12 implicit none (type, external) 13 MatScalar, intent(in) :: a(0:*) 14 PetscScalar, intent(inout) :: x(0:*) 15 PetscScalar, intent(in) :: b(0:*) 16 PetscInt, intent(in) :: n 17 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 18 19 PetscInt :: i,j,jstart,jend 20 PetscInt :: idx,ax,jdx 21 PetscScalar :: s(0:3) 22 23 PETSC_AssertAlignx(16,a(1)) 24 PETSC_AssertAlignx(16,x(1)) 25 PETSC_AssertAlignx(16,b(1)) 26 PETSC_AssertAlignx(16,ai(1)) 27 PETSC_AssertAlignx(16,aj(1)) 28 PETSC_AssertAlignx(16,adiag(1)) 29 30 ! 31 ! Forward Solve 32 ! 33 x(0:3) = b(0:3) 34 idx = 0 35 do i=1,n-1 36 jstart = ai(i) 37 jend = adiag(i) - 1 38 ax = 16*jstart 39 idx = idx + 4 40 s(0:3) = b(idx+0:idx+3) 41 do j=jstart,jend 42 jdx = 4*aj(j) 43 44 s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3)) 45 s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3)) 46 s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3)) 47 s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3)) 48 ax = ax + 16 49 end do 50 x(idx+0:idx+3) = s(0:3) 51 end do 52 53 ! 54 ! Backward solve the upper triangular 55 ! 56 do i=n-1,0,-1 57 jstart = adiag(i) + 1 58 jend = ai(i+1) - 1 59 ax = 16*jstart 60 s(0:3) = x(idx+0:idx+3) 61 do j=jstart,jend 62 jdx = 4*aj(j) 63 s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3)) 64 s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3)) 65 s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3)) 66 s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3)) 67 ax = ax + 16 68 end do 69 ax = 16*adiag(i) 70 x(idx+0) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3) 71 x(idx+1) = a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3) 72 x(idx+2) = a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 73 x(idx+3) = a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 74 idx = idx - 4 75 end do 76end subroutine FortranSolveBAIJ4Unroll 77 78! version that does not call BLAS 2 operation for each row block 79! 80pure subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w) 81 use, intrinsic :: ISO_C_binding 82 implicit none 83 MatScalar, intent(in) :: a(0:*) 84 PetscScalar, intent(inout) :: x(0:*),w(0:*) 85 PetscScalar, intent(in) :: b(0:*) 86 PetscInt, intent(in) :: n 87 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 88 89 PetscInt :: ii,jj,i,j 90 PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn 91 PetscScalar :: s(0:3) 92 93 PETSC_AssertAlignx(16,a(1)) 94 PETSC_AssertAlignx(16,w(1)) 95 PETSC_AssertAlignx(16,x(1)) 96 PETSC_AssertAlignx(16,b(1)) 97 PETSC_AssertAlignx(16,ai(1)) 98 PETSC_AssertAlignx(16,aj(1)) 99 PETSC_AssertAlignx(16,adiag(1)) 100 ! 101 ! Forward Solve 102 ! 103 x(0:3) = b(0:3) 104 idx = 0 105 do i=1,n-1 106 ! 107 ! Pack required part of vector into work array 108 ! 109 kdx = 0 110 jstart = ai(i) 111 jend = adiag(i) - 1 112 113 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 114 115 do j=jstart,jend 116 jdx = 4*aj(j) 117 w(kdx:kdx+3) = x(jdx:jdx+3) 118 kdx = kdx + 4 119 end do 120 121 ax = 16*jstart 122 idx = idx + 4 123 s(0:3) = b(idx:idx+3) 124 ! 125 ! s = s - a(ax:)*w 126 ! 127 nn = 4*(jend - jstart + 1) - 1 128 do ii=0,3 129 do jj=0,nn 130 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 131 end do 132 end do 133 134 x(idx:idx+3) = s(0:3) 135 end do 136 ! 137 ! Backward solve the upper triangular 138 ! 139 do i=n-1,0,-1 140 jstart = adiag(i) + 1 141 jend = ai(i+1) - 1 142 ax = 16*jstart 143 s(0:3) = x(idx:idx+3) 144 ! 145 ! Pack each chunk of vector needed 146 ! 147 kdx = 0 148 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 149 150 do j=jstart,jend 151 jdx = 4*aj(j) 152 w(kdx:kdx+3) = x(jdx:jdx+3) 153 kdx = kdx + 4 154 end do 155 nn = 4*(jend - jstart + 1) - 1 156 do ii=0,3 157 do jj=0,nn 158 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 159 end do 160 end do 161 162 ax = 16*adiag(i) 163 x(idx) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3) 164 x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3) 165 x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 166 x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 167 idx = idx - 4 168 end do 169end subroutine FortranSolveBAIJ4 170