1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10pure subroutine FortranSolveBAIJ4Unroll(n, x, ai, aj, adiag, a, b) 11 use, intrinsic :: ISO_C_binding 12 implicit none(type, external) 13 MatScalar, intent(in) :: a(0:*) 14 PetscScalar, intent(inout) :: x(0:*) 15 PetscScalar, intent(in) :: b(0:*) 16 PetscInt, intent(in) :: n 17 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 18 19 PetscInt :: i, j, jstart, jend 20 PetscInt :: idx, ax, jdx 21 PetscScalar :: s(0:3) 22 23 PETSC_AssertAlignx(16, a(1)) 24 PETSC_AssertAlignx(16, x(1)) 25 PETSC_AssertAlignx(16, b(1)) 26 PETSC_AssertAlignx(16, ai(1)) 27 PETSC_AssertAlignx(16, aj(1)) 28 PETSC_AssertAlignx(16, adiag(1)) 29 30 ! 31 ! Forward Solve 32 ! 33 x(0:3) = b(0:3) 34 idx = 0 35 do i = 1, n - 1 36 jstart = ai(i) 37 jend = adiag(i) - 1 38 ax = 16*jstart 39 idx = idx + 4 40 s(0:3) = b(idx + 0:idx + 3) 41 do j = jstart, jend 42 jdx = 4*aj(j) 43 44 s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3)) 45 s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3)) 46 s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3)) 47 s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3)) 48 ax = ax + 16 49 end do 50 x(idx + 0:idx + 3) = s(0:3) 51 end do 52 53 ! 54 ! Backward solve the upper triangular 55 ! 56 do i = n - 1, 0, -1 57 jstart = adiag(i) + 1 58 jend = ai(i + 1) - 1 59 ax = 16*jstart 60 s(0:3) = x(idx + 0:idx + 3) 61 do j = jstart, jend 62 jdx = 4*aj(j) 63 s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3)) 64 s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3)) 65 s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3)) 66 s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3)) 67 ax = ax + 16 68 end do 69 ax = 16*adiag(i) 70 x(idx + 0) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3) 71 x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3) 72 x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3) 73 x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3) 74 idx = idx - 4 75 end do 76end subroutine FortranSolveBAIJ4Unroll 77 78! version that does not call BLAS 2 operation for each row block 79! 80pure subroutine FortranSolveBAIJ4(n, x, ai, aj, adiag, a, b, w) 81 use, intrinsic :: ISO_C_binding 82 implicit none 83 MatScalar, intent(in) :: a(0:*) 84 PetscScalar, intent(inout) :: x(0:*), w(0:*) 85 PetscScalar, intent(in) :: b(0:*) 86 PetscInt, intent(in) :: n 87 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 88 89 PetscInt :: ii, jj, i, j 90 PetscInt :: jstart, jend, idx, ax, jdx, kdx, nn 91 PetscScalar :: s(0:3) 92 93 PETSC_AssertAlignx(16, a(1)) 94 PETSC_AssertAlignx(16, w(1)) 95 PETSC_AssertAlignx(16, x(1)) 96 PETSC_AssertAlignx(16, b(1)) 97 PETSC_AssertAlignx(16, ai(1)) 98 PETSC_AssertAlignx(16, aj(1)) 99 PETSC_AssertAlignx(16, adiag(1)) 100 ! 101 ! Forward Solve 102 ! 103 x(0:3) = b(0:3) 104 idx = 0 105 do i = 1, n - 1 106 ! 107 ! Pack required part of vector into work array 108 ! 109 kdx = 0 110 jstart = ai(i) 111 jend = adiag(i) - 1 112 113 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 114 115 do j = jstart, jend 116 jdx = 4*aj(j) 117 w(kdx:kdx + 3) = x(jdx:jdx + 3) 118 kdx = kdx + 4 119 end do 120 121 ax = 16*jstart 122 idx = idx + 4 123 s(0:3) = b(idx:idx + 3) 124 ! 125 ! s = s - a(ax:)*w 126 ! 127 nn = 4*(jend - jstart + 1) - 1 128 do ii = 0, 3 129 do jj = 0, nn 130 s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj) 131 end do 132 end do 133 134 x(idx:idx + 3) = s(0:3) 135 end do 136 ! 137 ! Backward solve the upper triangular 138 ! 139 do i = n - 1, 0, -1 140 jstart = adiag(i) + 1 141 jend = ai(i + 1) - 1 142 ax = 16*jstart 143 s(0:3) = x(idx:idx + 3) 144 ! 145 ! Pack each chunk of vector needed 146 ! 147 kdx = 0 148 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 149 150 do j = jstart, jend 151 jdx = 4*aj(j) 152 w(kdx:kdx + 3) = x(jdx:jdx + 3) 153 kdx = kdx + 4 154 end do 155 nn = 4*(jend - jstart + 1) - 1 156 do ii = 0, 3 157 do jj = 0, nn 158 s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj) 159 end do 160 end do 161 162 ax = 16*adiag(i) 163 x(idx) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3) 164 x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3) 165 x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3) 166 x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3) 167 idx = idx - 4 168 end do 169end subroutine FortranSolveBAIJ4 170