xref: /petsc/src/binding/petsc4py/test/test_ksp_py.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from test_ksp import BaseTestKSP
6
7# --------------------------------------------------------------------
8
9
10class MyKSP:
11    def __init__(self):
12        pass
13
14    def create(self, ksp):
15        self.work = []
16
17    def destroy(self, ksp):
18        for v in self.work:
19            v.destroy()
20
21    def setUp(self, ksp):
22        self.work[:] = ksp.getWorkVecs(right=2, left=None)
23
24    def reset(self, ksp):
25        for v in self.work:
26            v.destroy()
27        del self.work[:]
28
29    def loop(self, ksp, r):
30        its = ksp.getIterationNumber()
31        rnorm = r.norm()
32        ksp.setResidualNorm(rnorm)
33        ksp.logConvergenceHistory(rnorm)
34        ksp.monitor(its, rnorm)
35        reason = ksp.callConvergenceTest(its, rnorm)
36        if not reason:
37            ksp.setIterationNumber(its + 1)
38        else:
39            ksp.setConvergedReason(reason)
40        return reason
41
42
43class MyRichardson(MyKSP):
44    def solve(self, ksp, b, x):
45        A, B = ksp.getOperators()
46        P = ksp.getPC()
47        r, z = self.work
48        #
49        A.mult(x, r)
50        r.aypx(-1, b)
51        P.apply(r, z)
52        x.axpy(1, z)
53        while not self.loop(ksp, z):
54            A.mult(x, r)
55            r.aypx(-1, b)
56            P.apply(r, z)
57            x.axpy(1, z)
58
59
60class MyCG(MyKSP):
61    def setUp(self, ksp):
62        super().setUp(ksp)
63        d = self.work[0].duplicate()
64        q = d.duplicate()
65        self.work += [d, q]
66
67    def solve(self, ksp, b, x):
68        A, B = ksp.getOperators()
69        # P = ksp.getPC()
70        r, z, d, q = self.work
71        #
72        A.mult(x, r)
73        r.aypx(-1, b)
74        r.copy(d)
75        delta_0 = r.dot(r)
76        delta = delta_0
77        while not self.loop(ksp, r):
78            A.mult(d, q)
79            alpha = delta / d.dot(q)
80            x.axpy(+alpha, d)
81            r.axpy(-alpha, q)
82            delta_old = delta
83            delta = r.dot(r)
84            beta = delta / delta_old
85            d.aypx(beta, r)
86
87
88# --------------------------------------------------------------------
89
90
91class BaseTestKSPPYTHON(BaseTestKSP):
92    KSP_TYPE = PETSc.KSP.Type.PYTHON
93    ContextClass = None
94
95    def setUp(self):
96        super().setUp()
97        ctx = self.ContextClass()
98        self.ksp.setPythonContext(ctx)
99
100    def testGetType(self):
101        ctx = self.ksp.getPythonContext()
102        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
103        self.assertTrue(self.ksp.getPythonType() == pytype)
104
105    def tearDown(self):
106        self.ksp.destroy()
107        PETSc.garbage_cleanup()
108
109
110class TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase):
111    PC_TYPE = PETSc.PC.Type.JACOBI
112    ContextClass = MyRichardson
113
114
115class TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase):
116    PC_TYPE = PETSc.PC.Type.NONE
117    ContextClass = MyCG
118
119
120# --------------------------------------------------------------------
121
122if __name__ == '__main__':
123    unittest.main()
124