xref: /petsc/src/binding/petsc4py/test/test_ksp_py.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
15808f684SSatish Balay# --------------------------------------------------------------------
25808f684SSatish Balay
35808f684SSatish Balayfrom petsc4py import PETSc
45808f684SSatish Balayimport unittest
5*6f336411SStefano Zampinifrom test_ksp import BaseTestKSP
65808f684SSatish Balay
75808f684SSatish Balay# --------------------------------------------------------------------
85808f684SSatish Balay
95808f684SSatish Balay
10*6f336411SStefano Zampiniclass MyKSP:
115808f684SSatish Balay    def __init__(self):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def create(self, ksp):
155808f684SSatish Balay        self.work = []
165808f684SSatish Balay
175808f684SSatish Balay    def destroy(self, ksp):
185808f684SSatish Balay        for v in self.work:
195808f684SSatish Balay            v.destroy()
205808f684SSatish Balay
215808f684SSatish Balay    def setUp(self, ksp):
225808f684SSatish Balay        self.work[:] = ksp.getWorkVecs(right=2, left=None)
235808f684SSatish Balay
245808f684SSatish Balay    def reset(self, ksp):
255808f684SSatish Balay        for v in self.work:
265808f684SSatish Balay            v.destroy()
275808f684SSatish Balay        del self.work[:]
285808f684SSatish Balay
295808f684SSatish Balay    def loop(self, ksp, r):
305808f684SSatish Balay        its = ksp.getIterationNumber()
315808f684SSatish Balay        rnorm = r.norm()
325808f684SSatish Balay        ksp.setResidualNorm(rnorm)
335808f684SSatish Balay        ksp.logConvergenceHistory(rnorm)
345808f684SSatish Balay        ksp.monitor(its, rnorm)
355808f684SSatish Balay        reason = ksp.callConvergenceTest(its, rnorm)
365808f684SSatish Balay        if not reason:
375808f684SSatish Balay            ksp.setIterationNumber(its + 1)
385808f684SSatish Balay        else:
395808f684SSatish Balay            ksp.setConvergedReason(reason)
405808f684SSatish Balay        return reason
415808f684SSatish Balay
425808f684SSatish Balay
43*6f336411SStefano Zampiniclass MyRichardson(MyKSP):
445808f684SSatish Balay    def solve(self, ksp, b, x):
455808f684SSatish Balay        A, B = ksp.getOperators()
465808f684SSatish Balay        P = ksp.getPC()
475808f684SSatish Balay        r, z = self.work
485808f684SSatish Balay        #
495808f684SSatish Balay        A.mult(x, r)
505808f684SSatish Balay        r.aypx(-1, b)
515808f684SSatish Balay        P.apply(r, z)
525808f684SSatish Balay        x.axpy(1, z)
535808f684SSatish Balay        while not self.loop(ksp, z):
545808f684SSatish Balay            A.mult(x, r)
555808f684SSatish Balay            r.aypx(-1, b)
565808f684SSatish Balay            P.apply(r, z)
575808f684SSatish Balay            x.axpy(1, z)
585808f684SSatish Balay
595808f684SSatish Balay
60*6f336411SStefano Zampiniclass MyCG(MyKSP):
615808f684SSatish Balay    def setUp(self, ksp):
62*6f336411SStefano Zampini        super().setUp(ksp)
635808f684SSatish Balay        d = self.work[0].duplicate()
645808f684SSatish Balay        q = d.duplicate()
655808f684SSatish Balay        self.work += [d, q]
665808f684SSatish Balay
675808f684SSatish Balay    def solve(self, ksp, b, x):
685808f684SSatish Balay        A, B = ksp.getOperators()
69*6f336411SStefano Zampini        # P = ksp.getPC()
705808f684SSatish Balay        r, z, d, q = self.work
715808f684SSatish Balay        #
725808f684SSatish Balay        A.mult(x, r)
735808f684SSatish Balay        r.aypx(-1, b)
745808f684SSatish Balay        r.copy(d)
755808f684SSatish Balay        delta_0 = r.dot(r)
765808f684SSatish Balay        delta = delta_0
775808f684SSatish Balay        while not self.loop(ksp, r):
785808f684SSatish Balay            A.mult(d, q)
795808f684SSatish Balay            alpha = delta / d.dot(q)
805808f684SSatish Balay            x.axpy(+alpha, d)
815808f684SSatish Balay            r.axpy(-alpha, q)
825808f684SSatish Balay            delta_old = delta
835808f684SSatish Balay            delta = r.dot(r)
845808f684SSatish Balay            beta = delta / delta_old
855808f684SSatish Balay            d.aypx(beta, r)
865808f684SSatish Balay
87*6f336411SStefano Zampini
885808f684SSatish Balay# --------------------------------------------------------------------
895808f684SSatish Balay
905808f684SSatish Balay
915808f684SSatish Balayclass BaseTestKSPPYTHON(BaseTestKSP):
925808f684SSatish Balay    KSP_TYPE = PETSc.KSP.Type.PYTHON
935808f684SSatish Balay    ContextClass = None
945808f684SSatish Balay
955808f684SSatish Balay    def setUp(self):
96*6f336411SStefano Zampini        super().setUp()
975808f684SSatish Balay        ctx = self.ContextClass()
985808f684SSatish Balay        self.ksp.setPythonContext(ctx)
995808f684SSatish Balay
100ebead697SStefano Zampini    def testGetType(self):
101ebead697SStefano Zampini        ctx = self.ksp.getPythonContext()
102*6f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
103ebead697SStefano Zampini        self.assertTrue(self.ksp.getPythonType() == pytype)
104ebead697SStefano Zampini
10562e5d2d2SJDBetteridge    def tearDown(self):
10662e5d2d2SJDBetteridge        self.ksp.destroy()
10762e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
10862e5d2d2SJDBetteridge
109*6f336411SStefano Zampini
1105808f684SSatish Balayclass TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase):
1115808f684SSatish Balay    PC_TYPE = PETSc.PC.Type.JACOBI
1125808f684SSatish Balay    ContextClass = MyRichardson
1135808f684SSatish Balay
114*6f336411SStefano Zampini
1155808f684SSatish Balayclass TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase):
1165808f684SSatish Balay    PC_TYPE = PETSc.PC.Type.NONE
1175808f684SSatish Balay    ContextClass = MyCG
1185808f684SSatish Balay
119*6f336411SStefano Zampini
1205808f684SSatish Balay# --------------------------------------------------------------------
1215808f684SSatish Balay
1225808f684SSatish Balayif __name__ == '__main__':
1235808f684SSatish Balay    unittest.main()
124