# --------------------------------------------------------------------

from petsc4py import PETSc
import unittest
from sys import getrefcount

# --------------------------------------------------------------------


class BaseTestKSP:
    KSP_TYPE = None
    PC_TYPE = None

    def setUp(self):
        ksp = PETSc.KSP()
        ksp.create(PETSc.COMM_SELF)
        if self.KSP_TYPE:
            ksp.setType(self.KSP_TYPE)
        if self.PC_TYPE:
            pc = ksp.getPC()
            pc.setType(self.PC_TYPE)
        self.ksp = ksp

    def tearDown(self):
        self.ksp = None
        PETSc.garbage_cleanup()

    def testGetSetType(self):
        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
        self.ksp.setType(self.KSP_TYPE)
        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)

    def testTols(self):
        tols = self.ksp.getTolerances()
        self.ksp.setTolerances(*tols)
        tnames = ('rtol', 'atol', 'divtol', 'max_it')
        tolvals = [getattr(self.ksp, t) for t in tnames]
        self.assertEqual(tuple(tols), tuple(tolvals))

    def testProperties(self):
        ksp = self.ksp
        #
        ksp.appctx = (1, 2, 3)
        self.assertEqual(ksp.appctx, (1, 2, 3))
        ksp.appctx = None
        self.assertEqual(ksp.appctx, None)
        #
        side = ksp.pc_side
        ksp.pc_side = side
        self.assertEqual(ksp.pc_side, side)
        #
        nt = ksp.norm_type
        ksp.norm_type = nt
        self.assertEqual(ksp.norm_type, nt)
        #
        ksp.its = 1
        self.assertEqual(ksp.its, 1)
        ksp.its = 0
        self.assertEqual(ksp.its, 0)
        #
        ksp.norm = 1
        self.assertEqual(ksp.norm, 1)
        ksp.norm = 0
        self.assertEqual(ksp.norm, 0)
        #
        rh = ksp.history
        self.assertTrue(len(rh) == 0)
        #
        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS
        ksp.reason = reason
        self.assertEqual(ksp.reason, reason)
        self.assertTrue(ksp.is_converged)
        self.assertFalse(ksp.is_diverged)
        self.assertFalse(ksp.is_iterating)
        reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
        ksp.reason = reason
        self.assertEqual(ksp.reason, reason)
        self.assertFalse(ksp.is_converged)
        self.assertTrue(ksp.is_diverged)
        self.assertFalse(ksp.is_iterating)
        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING
        ksp.reason = reason
        self.assertEqual(ksp.reason, reason)
        self.assertFalse(ksp.is_converged)
        self.assertFalse(ksp.is_diverged)
        self.assertTrue(ksp.is_iterating)

    def testGetSetPC(self):
        oldpc = self.ksp.getPC()
        self.assertEqual(oldpc.getRefCount(), 2)
        newpc = PETSc.PC()
        newpc.create(self.ksp.getComm())
        self.assertEqual(newpc.getRefCount(), 1)
        self.ksp.setPC(newpc)
        self.assertEqual(newpc.getRefCount(), 2)
        self.assertEqual(oldpc.getRefCount(), 1)
        oldpc.destroy()
        self.assertFalse(bool(oldpc))
        pc = self.ksp.getPC()
        self.assertTrue(bool(pc))
        self.assertEqual(pc, newpc)
        self.assertEqual(pc.getRefCount(), 3)
        newpc.destroy()
        self.assertFalse(bool(newpc))
        self.assertEqual(pc.getRefCount(), 2)

    def testSolve(self, solve_only=False):
        A = PETSc.Mat().create(PETSc.COMM_SELF)
        A.setSizes([3, 3])
        A.setType(PETSc.Mat.Type.SEQAIJ)
        A.setPreallocationNNZ(1)
        for i in range(3):
            A.setValue(i, i, 0.9 / (i + 1))
        A.assemble()
        A.shift(1)
        x, b = A.createVecs()
        b.set(10)
        x.setRandom()
        self.ksp.setOperators(A)
        if not solve_only:
            self.ksp.setConvergenceHistory()
        self.ksp.solve(b, x)
        if not solve_only:
            u = x.duplicate()
            self.ksp.buildSolution(u)
            self.ksp.buildResidual(u)
            rh = self.ksp.getConvergenceHistory()
            self.ksp.setConvergenceHistory(0)
            rh = self.ksp.getConvergenceHistory()
            self.assertEqual(len(rh), 0)
            u.destroy()
        del A, x, b

    def testResetAndSolve(self):
        self.ksp.reset()
        self.testSolve()
        self.ksp.reset()
        self.testSolve()
        self.ksp.reset()

    def testSetMonitor(self):
        reshist = {}

        def monitor(ksp, its, rnorm):
            if ksp.type in ['cg', 'stcg']:
                reshist[its] = {'r': rnorm, 'o': ksp.getCGObjectiveValue()}
            else:
                reshist[its] = rnorm
        refcnt = getrefcount(monitor)
        self.ksp.setMonitor(monitor)
        self.assertEqual(getrefcount(monitor), refcnt + 1)
        self.testSolve(solve_only=True)
        reshist = {}
        self.ksp.monitorCancel()
        self.assertEqual(getrefcount(monitor), refcnt)
        self.testSolve(solve_only=True)
        self.assertEqual(len(reshist), 0)
        ## Monitor = PETSc.KSP.Monitor
        ## self.ksp.setMonitor(Monitor())
        ## self.ksp.setMonitor(Monitor.DEFAULT)
        ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM)
        ## self.ksp.setMonitor(Monitor.SOLUTION)

    def testSetConvergenceTest(self):
        def converged(ksp, its, rnorm):
            if its > 10:
                return True
            return False

        refcnt = getrefcount(converged)
        self.ksp.setConvergenceTest(converged)
        self.assertEqual(getrefcount(converged), refcnt + 1)
        self.ksp.setConvergenceTest(None)
        self.assertEqual(getrefcount(converged), refcnt)

    def testAddConvergenceTest(self):
        def converged(ksp, its, rnorm):
            return True

        refcnt = getrefcount(converged)
        self.ksp.addConvergenceTest(converged, prepend=True)
        self.assertEqual(getrefcount(converged), refcnt + 1)
        self.testSolve()
        self.ksp.setConvergenceTest(None)
        self.assertEqual(getrefcount(converged), refcnt)
        self.testSolve()
        self.ksp.addConvergenceTest(converged, prepend=False)
        self.assertEqual(getrefcount(converged), refcnt + 1)
        self.testSolve()
        self.ksp.setConvergenceTest(None)
        self.assertEqual(getrefcount(converged), refcnt)

    def testSetPreSolveTest(self):
        check = {'val': 0}

        def presolve(ksp, rhs, x):
            check['val'] = 1

        refcnt = getrefcount(presolve)
        self.ksp.setPreSolve(presolve)
        self.assertEqual(getrefcount(presolve), refcnt + 1)
        self.testSolve()
        self.assertEqual(check['val'], 1)
        self.ksp.setPreSolve(None)
        self.assertEqual(getrefcount(presolve), refcnt)

    def testSetPostSolveTest(self):
        check = {'val': 0}

        def postsolve(ksp, rhs, x):
            check['val'] = 1

        refcnt = getrefcount(postsolve)
        self.ksp.setPostSolve(postsolve)
        self.assertEqual(getrefcount(postsolve), refcnt + 1)
        self.testSolve()
        self.assertEqual(check['val'], 1)
        self.ksp.setPostSolve(None)
        self.assertEqual(getrefcount(postsolve), refcnt)


# --------------------------------------------------------------------


class TestKSPPREONLY(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.PREONLY
    PC_TYPE = PETSc.PC.Type.LU


class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.RICHARDSON


class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase):
    try:
        KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV
    except AttributeError:
        KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV


class TestKSPCG(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.CG


class TestKSPCGNE(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.CGNE


class TestKSPSTCG(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.STCG


class TestKSPBCGS(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.BCGS


class TestKSPBCGSL(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.BCGSL


class TestKSPCGS(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.CGS


class TestKSPQCG(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.QCG
    PC_TYPE = PETSc.PC.Type.JACOBI


class TestKSPBICG(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.BICG


class TestKSPGMRES(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.GMRES


class TestKSPFGMRES(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.FGMRES


class TestKSPLSQR(BaseTestKSP, unittest.TestCase):
    KSP_TYPE = PETSc.KSP.Type.LSQR


# --------------------------------------------------------------------

if PETSc.ScalarType().dtype.char in 'FDG':
    del TestKSPSTCG

# --------------------------------------------------------------------

if __name__ == '__main__':
    unittest.main()
