! Time-dependent advection-reaction PDE in 1d. Demonstrates IMEX methods
!
!   u_t + a1*u_x = -k1*u + k2*v + s1
!   v_t + a2*v_x = k1*u - k2*v + s2
!   0 < x < 1
!   a1 = 1, k1 = 10^6, s1 = 0,
!   a2 = 0, k2 = 2*k1, s2 = 1
!
!   Initial conditions:
!   u(x,0) = 1 + s2*x
!   v(x,0) = k0/k1*u(x,0) + s1/k1
!
!   Upstream boundary conditions:
!   u(0,t) = 1-sin(12*t)^4
!

program main
#include <petsc/finclude/petscts.h>
#include <petsc/finclude/petscdmda.h>
  use petscts
  implicit none

!
!  Create an application context to contain data needed by the
!  application-provided call-back routines, FormJacobian() and
!  FormFunction(). We use a double precision array with six
!  entries, two for each problem parameter a, k, s.
!
  PetscReal user(6)
  integer user_a, user_k, user_s
  parameter(user_a=0, user_k=2, user_s=4)

  TS ts
  SNES snes
  SNESLineSearch linesearch
  Vec X
  Mat J
  PetscInt mx
  PetscErrorCode ierr
  DM da
  PetscReal ftime, dt
  PetscReal one, pone
  PetscInt im11, i2
  PetscBool flg

  im11 = 11
  i2 = 2
  one = 1.0
  pone = one/10

  PetscCallA(PetscInitialize(ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!  Create distributed array (DMDA) to manage parallel grid and vectors
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(DMDACreate1d(PETSC_COMM_WORLD, DM_BOUNDARY_NONE, im11, i2, i2, PETSC_NULL_INTEGER, da, ierr))
  PetscCallA(DMSetFromOptions(da, ierr))
  PetscCallA(DMSetUp(da, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!    Extract global vectors from DMDA
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(DMCreateGlobalVector(da, X, ierr))

! Initialize user application context
! Use zero-based indexing for command line parameters to match ex22.c
  user(user_a + 1) = 1.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-a0', user(user_a + 1), flg, ierr))
  user(user_a + 2) = 0.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-a1', user(user_a + 2), flg, ierr))
  user(user_k + 1) = 1000000.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-k0', user(user_k + 1), flg, ierr))
  user(user_k + 2) = 2*user(user_k + 1)
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-k1', user(user_k + 2), flg, ierr))
  user(user_s + 1) = 0.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-s0', user(user_s + 1), flg, ierr))
  user(user_s + 2) = 1.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-s1', user(user_s + 2), flg, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!    Create timestepping solver context
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSCreate(PETSC_COMM_WORLD, ts, ierr))
  PetscCallA(TSSetDM(ts, da, ierr))
  PetscCallA(TSSetType(ts, TSARKIMEX, ierr))
  PetscCallA(TSSetRHSFunction(ts, PETSC_NULL_VEC, FormRHSFunction, user, ierr))
  PetscCallA(TSSetIFunction(ts, PETSC_NULL_VEC, FormIFunction, user, ierr))
  PetscCallA(DMSetMatType(da, MATAIJ, ierr))
  PetscCallA(DMCreateMatrix(da, J, ierr))
  PetscCallA(TSSetIJacobian(ts, J, J, FormIJacobian, user, ierr))

  PetscCallA(TSGetSNES(ts, snes, ierr))
  PetscCallA(SNESGetLineSearch(snes, linesearch, ierr))
  PetscCallA(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC, ierr))

  ftime = 1.0
  PetscCallA(TSSetMaxTime(ts, ftime, ierr))
  PetscCallA(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_STEPOVER, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!  Set initial conditions
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(FormInitialSolution(ts, X, user, ierr))
  PetscCallA(TSSetSolution(ts, X, ierr))
  PetscCallA(VecGetSize(X, mx, ierr))
!  Advective CFL, I don't know why it needs so much safety factor.
  dt = pone*max(user(user_a + 1), user(user_a + 2))/mx
  PetscCallA(TSSetTimeStep(ts, dt, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!   Set runtime options
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSSetFromOptions(ts, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!  Solve nonlinear system
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSSolve(ts, X, ierr))

! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
!  Free work space.
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(MatDestroy(J, ierr))
  PetscCallA(VecDestroy(X, ierr))
  PetscCallA(TSDestroy(ts, ierr))
  PetscCallA(DMDestroy(da, ierr))
  PetscCallA(PetscFinalize(ierr))
contains

! Small helper to extract the layout, result uses 1-based indexing.
  subroutine GetLayout(da, mx, xs, xe, gxs, gxe, ierr)
    use petscdm
    implicit none

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscErrorCode ierr
    PetscInt xm, gxm
    PetscCall(DMDAGetInfo(da, PETSC_NULL_INTEGER, mx, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_ENUM, PETSC_NULL_ENUM, PETSC_NULL_ENUM, PETSC_NULL_ENUM, ierr))
    PetscCall(DMDAGetCorners(da, xs, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, xm, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, ierr))
    PetscCall(DMDAGetGhostCorners(da, gxs, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, gxm, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, ierr))
    xs = xs + 1
    gxs = gxs + 1
    xe = xs + xm - 1
    gxe = gxs + gxm - 1
  end subroutine

  subroutine FormIFunctionLocal(mx, xs, xe, gxs, gxe, x, xdot, f, a, k, s, ierr)
    implicit none
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar x(2, xs:xe)
    PetscScalar xdot(2, xs:xe)
    PetscScalar f(2, xs:xe)
    PetscReal a(2), k(2), s(2)
    PetscErrorCode ierr
    PetscInt i
    do 10 i = xs, xe
      f(1, i) = xdot(1, i) + k(1)*x(1, i) - k(2)*x(2, i) - s(1)
      f(2, i) = xdot(2, i) - k(1)*x(1, i) + k(2)*x(2, i) - s(2)
10    continue
      end subroutine

      subroutine FormIFunction(ts, t, X, Xdot, F, user, ierr)
        use petscts
        implicit none

        TS ts
        PetscReal t
        Vec X, Xdot, F
        PetscReal user(6)
        PetscErrorCode ierr
        integer user_a, user_k, user_s
        parameter(user_a=1, user_k=3, user_s=5)

        DM da
        PetscInt mx, xs, xe, gxs, gxe
        PetscScalar, pointer :: xx(:), xxdot(:), ff(:)

        PetscCall(TSGetDM(ts, da, ierr))
        PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

! Get access to vector data
        PetscCall(VecGetArrayRead(X, xx, ierr))
        PetscCall(VecGetArrayRead(Xdot, xxdot, ierr))
        PetscCall(VecGetArray(F, ff, ierr))

        PetscCall(FormIFunctionLocal(mx, xs, xe, gxs, gxe, xx, xxdot, ff, user(user_a), user(user_k), user(user_s), ierr))

        PetscCall(VecRestoreArrayRead(X, xx, ierr))
        PetscCall(VecRestoreArrayRead(Xdot, xxdot, ierr))
        PetscCall(VecRestoreArray(F, ff, ierr))
      end subroutine

      subroutine FormRHSFunctionLocal(mx, xs, xe, gxs, gxe, t, x, f, a, k, s, ierr)
        implicit none
        PetscInt mx, xs, xe, gxs, gxe
        PetscReal t
        PetscScalar x(2, gxs:gxe), f(2, xs:xe)
        PetscReal a(2), k(2), s(2)
        PetscErrorCode ierr
        PetscInt i, j
        PetscReal hx, u0t(2)
        PetscReal one, two, three, four, six, twelve
        PetscReal half, third, twothird, sixth
        PetscReal twelfth

        one = 1.0
        two = 2.0
        three = 3.0
        four = 4.0
        six = 6.0
        twelve = 12.0
        hx = one/mx
!     The Fortran standard only allows positive base for power functions; Nag compiler fails on this
        u0t(1) = one - abs(sin(twelve*t))**four
        u0t(2) = 0.0
        half = one/two
        third = one/three
        twothird = two/three
        sixth = one/six
        twelfth = one/twelve
        do 20 i = xs, xe
          do 10 j = 1, 2
            if (i == 1) then
              f(j, i) = a(j)/hx*(third*u0t(j) + half*x(j, i) - x(j, i + 1)  &
    &              + sixth*x(j, i + 2))
            else if (i == 2) then
              f(j, i) = a(j)/hx*(-twelfth*u0t(j) + twothird*x(j, i - 1)    &
    &              - twothird*x(j, i + 1) + twelfth*x(j, i + 2))
            else if (i == mx - 1) then
              f(j, i) = a(j)/hx*(-sixth*x(j, i - 2) + x(j, i - 1)             &
    &         - half*x(j, i) - third*x(j, i + 1))
            else if (i == mx) then
              f(j, i) = a(j)/hx*(-x(j, i) + x(j, i - 1))
            else
              f(j, i) = a(j)/hx*(-twelfth*x(j, i - 2)                      &
    &              + twothird*x(j, i - 1)                                 &
    &              - twothird*x(j, i + 1) + twelfth*x(j, i + 2))
            end if
10          continue
20          continue
            end subroutine

            subroutine FormRHSFunction(ts, t, X, F, user, ierr)
              use petscts
              implicit none

              TS ts
              PetscReal t
              Vec X, F
              PetscReal user(6)
              PetscErrorCode ierr
              integer user_a, user_k, user_s
              parameter(user_a=1, user_k=3, user_s=5)
              DM da
              Vec Xloc
              PetscInt mx, xs, xe, gxs, gxe
              PetscScalar, pointer :: xx(:), ff(:)

              PetscCall(TSGetDM(ts, da, ierr))
              PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

!     Scatter ghost points to local vector,using the 2-step process
!        DMGlobalToLocalBegin(),DMGlobalToLocalEnd().
!     By placing code between these two statements, computations can be
!     done while messages are in transition.
              PetscCall(DMGetLocalVector(da, Xloc, ierr))
              PetscCall(DMGlobalToLocalBegin(da, X, INSERT_VALUES, Xloc, ierr))
              PetscCall(DMGlobalToLocalEnd(da, X, INSERT_VALUES, Xloc, ierr))

! Get access to vector data
              PetscCall(VecGetArrayRead(Xloc, xx, ierr))
              PetscCall(VecGetArray(F, ff, ierr))

              PetscCall(FormRHSFunctionLocal(mx, xs, xe, gxs, gxe, t, xx, ff, user(user_a), user(user_k), user(user_s), ierr))

              PetscCall(VecRestoreArrayRead(Xloc, xx, ierr))
              PetscCall(VecRestoreArray(F, ff, ierr))
              PetscCall(DMRestoreLocalVector(da, Xloc, ierr))
            end subroutine

! ---------------------------------------------------------------------
!
!  IJacobian - Compute IJacobian = dF/dU + shift*dF/dUdot
!
            subroutine FormIJacobian(ts, t, X, Xdot, shift, J, Jpre, user, ierr)
              use petscts
              implicit none

              TS ts
              PetscReal t, shift
              Vec X, Xdot
              Mat J, Jpre
              PetscReal user(6)
              PetscErrorCode ierr
              integer user_a, user_k, user_s
              parameter(user_a=0, user_k=2, user_s=4)

              DM da
              PetscInt mx, xs, xe, gxs, gxe
              PetscInt i, i1, row, col
              PetscReal k1, k2
              PetscScalar val(4)

              PetscCall(TSGetDM(ts, da, ierr))
              PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

              i1 = 1
              k1 = user(user_k + 1)
              k2 = user(user_k + 2)
              do 10 i = xs, xe
                row = i - gxs
                col = i - gxs
                val(1) = shift + k1
                val(2) = -k2
                val(3) = -k1
                val(4) = shift + k2
                PetscCall(MatSetValuesBlockedLocal(Jpre, i1, [row], i1, [col], val, INSERT_VALUES, ierr))
10              continue
                PetscCall(MatAssemblyBegin(Jpre, MAT_FINAL_ASSEMBLY, ierr))
                PetscCall(MatAssemblyEnd(Jpre, MAT_FINAL_ASSEMBLY, ierr))
                if (J /= Jpre) then
                  PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY, ierr))
                  PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY, ierr))
                end if
                end subroutine

                subroutine FormInitialSolutionLocal(mx, xs, xe, gxs, gxe, x, a, k, s, ierr)
                  implicit none
                  PetscInt mx, xs, xe, gxs, gxe
                  PetscScalar x(2, xs:xe)
                  PetscReal a(2), k(2), s(2)
                  PetscErrorCode ierr

                  PetscInt i
                  PetscReal one, hx, r, ik
                  one = 1.0
                  hx = one/mx
                  do 10 i = xs, xe
                    r = i*hx
                    if (k(2) /= 0.0) then
                      ik = one/k(2)
                    else
                      ik = one
                    end if
                    x(1, i) = one + s(2)*r
                    x(2, i) = k(1)*ik*x(1, i) + s(2)*ik
10                  continue
                    end subroutine

                    subroutine FormInitialSolution(ts, X, user, ierr)
                      use petscts
                      implicit none

                      TS ts
                      Vec X
                      PetscReal user(6)
                      PetscErrorCode ierr
                      integer user_a, user_k, user_s
                      parameter(user_a=1, user_k=3, user_s=5)

                      DM da
                      PetscInt mx, xs, xe, gxs, gxe
                      PetscScalar, pointer :: xx(:)

                      PetscCall(TSGetDM(ts, da, ierr))
                      PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

! Get access to vector data
                      PetscCall(VecGetArray(X, xx, ierr))

                      PetscCall(FormInitialSolutionLocal(mx, xs, xe, gxs, gxe, xx, user(user_a), user(user_k), user(user_s), ierr))

                      PetscCall(VecRestoreArray(X, xx, ierr))
                    end subroutine
                    end program
!/*TEST
!
!    test:
!      args: -da_grid_x 200 -ts_arkimex_type 4
!      requires: !single
!      output_file: output/empty.out
!
!TEST*/
