xref: /petsc/src/binding/petsc4py/test/test_ksp_py.py (revision bef158480efac06de457f7a665168877ab3c2fd7)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class MyKSP(object):
10
11    def __init__(self):
12        pass
13
14    def create(self, ksp):
15        self.work = []
16
17    def destroy(self, ksp):
18        for v in self.work:
19            v.destroy()
20
21    def setUp(self, ksp):
22        self.work[:] = ksp.getWorkVecs(right=2, left=None)
23
24    def reset(self, ksp):
25        for v in self.work:
26            v.destroy()
27        del self.work[:]
28
29    def loop(self, ksp, r):
30        its = ksp.getIterationNumber()
31        rnorm = r.norm()
32        ksp.setResidualNorm(rnorm)
33        ksp.logConvergenceHistory(rnorm)
34        ksp.monitor(its, rnorm)
35        reason = ksp.callConvergenceTest(its, rnorm)
36        if not reason:
37            ksp.setIterationNumber(its+1)
38        else:
39            ksp.setConvergedReason(reason)
40        return reason
41
42class MyRichardson(MyKSP):
43
44    def solve(self, ksp, b, x):
45        A, B = ksp.getOperators()
46        P = ksp.getPC()
47        r, z = self.work
48        #
49        A.mult(x, r)
50        r.aypx(-1, b)
51        P.apply(r, z)
52        x.axpy(1, z)
53        while not self.loop(ksp, z):
54            A.mult(x, r)
55            r.aypx(-1, b)
56            P.apply(r, z)
57            x.axpy(1, z)
58
59class MyCG(MyKSP):
60
61    def setUp(self, ksp):
62        super(MyCG, self).setUp(ksp)
63        d = self.work[0].duplicate()
64        q = d.duplicate()
65        self.work += [d, q]
66
67    def solve(self, ksp, b, x):
68        A, B = ksp.getOperators()
69        P = ksp.getPC()
70        r, z, d, q = self.work
71        #
72        A.mult(x, r)
73        r.aypx(-1, b)
74        r.copy(d)
75        delta_0 = r.dot(r)
76        delta = delta_0
77        while not self.loop(ksp, r):
78            A.mult(d, q)
79            alpha = delta / d.dot(q)
80            x.axpy(+alpha, d)
81            r.axpy(-alpha, q)
82            delta_old = delta
83            delta = r.dot(r)
84            beta = delta / delta_old
85            d.aypx(beta, r)
86
87# --------------------------------------------------------------------
88
89from test_ksp import BaseTestKSP
90
91class BaseTestKSPPYTHON(BaseTestKSP):
92
93    KSP_TYPE = PETSc.KSP.Type.PYTHON
94    ContextClass = None
95
96    def setUp(self):
97        super(BaseTestKSPPYTHON, self).setUp()
98        ctx = self.ContextClass()
99        self.ksp.setPythonContext(ctx)
100
101class TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase):
102    PC_TYPE  = PETSc.PC.Type.JACOBI
103    ContextClass = MyRichardson
104
105class TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase):
106    PC_TYPE  = PETSc.PC.Type.NONE
107    ContextClass = MyCG
108
109# --------------------------------------------------------------------
110
111if __name__ == '__main__':
112    unittest.main()
113