1c96caaccSSatish Balay! 2c96caaccSSatish Balay! 3c96caaccSSatish Balay! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4c96caaccSSatish Balay! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5c96caaccSSatish Balay! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6c96caaccSSatish Balay! 7c96caaccSSatish Balay#include <petsc/finclude/petscsys.h> 8c96caaccSSatish Balay! 9c96caaccSSatish Balay 100ccf82acSMartin Diehlpure subroutine FortranSolveBAIJ4Unroll(n, x, ai, aj, adiag, a, b) 11*fe66ebccSMartin Diehl use, intrinsic :: ISO_C_binding 120ccf82acSMartin Diehl implicit none(type, external) 130ccf82acSMartin Diehl MatScalar, intent(in) :: a(0:*) 140ccf82acSMartin Diehl PetscScalar, intent(inout) :: x(0:*) 150ccf82acSMartin Diehl PetscScalar, intent(in) :: b(0:*) 160ccf82acSMartin Diehl PetscInt, intent(in) :: n 170ccf82acSMartin Diehl PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 18c96caaccSSatish Balay 190ccf82acSMartin Diehl PetscInt :: i, j, jstart, jend 200ccf82acSMartin Diehl PetscInt :: idx, ax, jdx 21ff45ff59SMartin Diehl PetscScalar :: s(0:3) 220ccf82acSMartin Diehl 23c96caaccSSatish Balay PETSC_AssertAlignx(16, a(1)) 24c96caaccSSatish Balay PETSC_AssertAlignx(16, x(1)) 25c96caaccSSatish Balay PETSC_AssertAlignx(16, b(1)) 26c96caaccSSatish Balay PETSC_AssertAlignx(16, ai(1)) 27c96caaccSSatish Balay PETSC_AssertAlignx(16, aj(1)) 28c96caaccSSatish Balay PETSC_AssertAlignx(16, adiag(1)) 29c96caaccSSatish Balay 300ccf82acSMartin Diehl ! 310ccf82acSMartin Diehl ! Forward Solve 320ccf82acSMartin Diehl ! 33d66e387eSMartin Diehl x(0:3) = b(0:3) 34c96caaccSSatish Balay idx = 0 350113e719SMartin Diehl do i = 1, n - 1 36c96caaccSSatish Balay jstart = ai(i) 37c96caaccSSatish Balay jend = adiag(i) - 1 38c96caaccSSatish Balay ax = 16*jstart 39c96caaccSSatish Balay idx = idx + 4 40ff45ff59SMartin Diehl s(0:3) = b(idx + 0:idx + 3) 410113e719SMartin Diehl do j = jstart, jend 42c96caaccSSatish Balay jdx = 4*aj(j) 43c96caaccSSatish Balay 44ff45ff59SMartin Diehl s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3)) 45ff45ff59SMartin Diehl s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3)) 46ff45ff59SMartin Diehl s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3)) 47ff45ff59SMartin Diehl s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3)) 48c96caaccSSatish Balay ax = ax + 16 490113e719SMartin Diehl end do 50ff45ff59SMartin Diehl x(idx + 0:idx + 3) = s(0:3) 510113e719SMartin Diehl end do 52c96caaccSSatish Balay 53c96caaccSSatish Balay ! 54c96caaccSSatish Balay ! Backward solve the upper triangular 55c96caaccSSatish Balay ! 560113e719SMartin Diehl do i = n - 1, 0, -1 57c96caaccSSatish Balay jstart = adiag(i) + 1 58c96caaccSSatish Balay jend = ai(i + 1) - 1 59c96caaccSSatish Balay ax = 16*jstart 60ff45ff59SMartin Diehl s(0:3) = x(idx + 0:idx + 3) 610113e719SMartin Diehl do j = jstart, jend 62c96caaccSSatish Balay jdx = 4*aj(j) 63ff45ff59SMartin Diehl s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3)) 64ff45ff59SMartin Diehl s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3)) 65ff45ff59SMartin Diehl s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3)) 66ff45ff59SMartin Diehl s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3)) 67c96caaccSSatish Balay ax = ax + 16 680113e719SMartin Diehl end do 69c96caaccSSatish Balay ax = 16*adiag(i) 70ff45ff59SMartin Diehl x(idx + 0) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3) 71ff45ff59SMartin Diehl x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3) 72ff45ff59SMartin Diehl x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3) 73ff45ff59SMartin Diehl x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3) 74c96caaccSSatish Balay idx = idx - 4 750113e719SMartin Diehl end do 760113e719SMartin Diehlend subroutine FortranSolveBAIJ4Unroll 77c96caaccSSatish Balay 78c96caaccSSatish Balay! version that does not call BLAS 2 operation for each row block 79c96caaccSSatish Balay! 80ff45ff59SMartin Diehlpure subroutine FortranSolveBAIJ4(n, x, ai, aj, adiag, a, b, w) 81*fe66ebccSMartin Diehl use, intrinsic :: ISO_C_binding 82c96caaccSSatish Balay implicit none 830ccf82acSMartin Diehl MatScalar, intent(in) :: a(0:*) 840ccf82acSMartin Diehl PetscScalar, intent(inout) :: x(0:*), w(0:*) 850ccf82acSMartin Diehl PetscScalar, intent(in) :: b(0:*) 860ccf82acSMartin Diehl PetscInt, intent(in) :: n 870ccf82acSMartin Diehl PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 88c96caaccSSatish Balay 890ccf82acSMartin Diehl PetscInt :: ii, jj, i, j 900ccf82acSMartin Diehl PetscInt :: jstart, jend, idx, ax, jdx, kdx, nn 910ccf82acSMartin Diehl PetscScalar :: s(0:3) 92c96caaccSSatish Balay 93c96caaccSSatish Balay PETSC_AssertAlignx(16, a(1)) 94c96caaccSSatish Balay PETSC_AssertAlignx(16, w(1)) 95c96caaccSSatish Balay PETSC_AssertAlignx(16, x(1)) 96c96caaccSSatish Balay PETSC_AssertAlignx(16, b(1)) 97c96caaccSSatish Balay PETSC_AssertAlignx(16, ai(1)) 98c96caaccSSatish Balay PETSC_AssertAlignx(16, aj(1)) 99c96caaccSSatish Balay PETSC_AssertAlignx(16, adiag(1)) 1000113e719SMartin Diehl ! 1010113e719SMartin Diehl ! Forward Solve 1020113e719SMartin Diehl ! 103d66e387eSMartin Diehl x(0:3) = b(0:3) 104c96caaccSSatish Balay idx = 0 1050113e719SMartin Diehl do i = 1, n - 1 106c96caaccSSatish Balay ! 107c96caaccSSatish Balay ! Pack required part of vector into work array 108c96caaccSSatish Balay ! 109c96caaccSSatish Balay kdx = 0 110c96caaccSSatish Balay jstart = ai(i) 111c96caaccSSatish Balay jend = adiag(i) - 1 112d66e387eSMartin Diehl 113ff45ff59SMartin Diehl if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 114d66e387eSMartin Diehl 1150113e719SMartin Diehl do j = jstart, jend 116c96caaccSSatish Balay jdx = 4*aj(j) 117d66e387eSMartin Diehl w(kdx:kdx + 3) = x(jdx:jdx + 3) 118c96caaccSSatish Balay kdx = kdx + 4 1190113e719SMartin Diehl end do 120c96caaccSSatish Balay 121c96caaccSSatish Balay ax = 16*jstart 122c96caaccSSatish Balay idx = idx + 4 123d66e387eSMartin Diehl s(0:3) = b(idx:idx + 3) 124c96caaccSSatish Balay ! 125c96caaccSSatish Balay ! s = s - a(ax:)*w 126c96caaccSSatish Balay ! 127c96caaccSSatish Balay nn = 4*(jend - jstart + 1) - 1 1280113e719SMartin Diehl do ii = 0, 3 1290113e719SMartin Diehl do jj = 0, nn 130c96caaccSSatish Balay s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj) 1310113e719SMartin Diehl end do 1320113e719SMartin Diehl end do 133c96caaccSSatish Balay 134d66e387eSMartin Diehl x(idx:idx + 3) = s(0:3) 1350113e719SMartin Diehl end do 136c96caaccSSatish Balay ! 137c96caaccSSatish Balay ! Backward solve the upper triangular 138c96caaccSSatish Balay ! 1390113e719SMartin Diehl do i = n - 1, 0, -1 140c96caaccSSatish Balay jstart = adiag(i) + 1 141c96caaccSSatish Balay jend = ai(i + 1) - 1 142c96caaccSSatish Balay ax = 16*jstart 143d66e387eSMartin Diehl s(0:3) = x(idx:idx + 3) 144c96caaccSSatish Balay ! 145c96caaccSSatish Balay ! Pack each chunk of vector needed 146c96caaccSSatish Balay ! 147c96caaccSSatish Balay kdx = 0 148ff45ff59SMartin Diehl if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 149d66e387eSMartin Diehl 1500113e719SMartin Diehl do j = jstart, jend 151c96caaccSSatish Balay jdx = 4*aj(j) 152d66e387eSMartin Diehl w(kdx:kdx + 3) = x(jdx:jdx + 3) 153c96caaccSSatish Balay kdx = kdx + 4 1540113e719SMartin Diehl end do 155c96caaccSSatish Balay nn = 4*(jend - jstart + 1) - 1 1560113e719SMartin Diehl do ii = 0, 3 1570113e719SMartin Diehl do jj = 0, nn 158c96caaccSSatish Balay s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj) 1590113e719SMartin Diehl end do 1600113e719SMartin Diehl end do 161c96caaccSSatish Balay 162c96caaccSSatish Balay ax = 16*adiag(i) 163d66e387eSMartin Diehl x(idx) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3) 164c96caaccSSatish Balay x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3) 165c96caaccSSatish Balay x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3) 166c96caaccSSatish Balay x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3) 167c96caaccSSatish Balay idx = idx - 4 1680113e719SMartin Diehl end do 1690113e719SMartin Diehlend subroutine FortranSolveBAIJ4 170