xref: /petsc/src/binding/petsc4py/test/test_ksp.py (revision fbf9dbe564678ed6eff1806adbc4c4f01b9743f4)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseTestKSP(object):
10
11    KSP_TYPE = None
12    PC_TYPE  = None
13
14    def setUp(self):
15        ksp = PETSc.KSP()
16        ksp.create(PETSc.COMM_SELF)
17        if self.KSP_TYPE:
18            ksp.setType(self.KSP_TYPE)
19        if self.PC_TYPE:
20            pc = ksp.getPC()
21            pc.setType(self.PC_TYPE)
22        self.ksp = ksp
23
24    def tearDown(self):
25        self.ksp = None
26        PETSc.garbage_cleanup()
27
28    def testGetSetType(self):
29        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
30        self.ksp.setType(self.KSP_TYPE)
31        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
32
33    def testTols(self):
34        tols = self.ksp.getTolerances()
35        self.ksp.setTolerances(*tols)
36        tnames = ('rtol', 'atol', 'divtol', 'max_it')
37        tolvals = [getattr(self.ksp, t) for t in  tnames]
38        self.assertEqual(tuple(tols), tuple(tolvals))
39
40    def testProperties(self):
41        ksp = self.ksp
42        #
43        ksp.appctx = (1,2,3)
44        self.assertEqual(ksp.appctx, (1,2,3))
45        ksp.appctx = None
46        self.assertEqual(ksp.appctx, None)
47        #
48        side = ksp.pc_side
49        ksp.pc_side = side
50        self.assertEqual(ksp.pc_side, side)
51        #
52        nt = ksp.norm_type
53        ksp.norm_type = nt
54        self.assertEqual(ksp.norm_type, nt)
55        #
56        ksp.its = 1
57        self.assertEqual(ksp.its, 1)
58        ksp.its = 0
59        self.assertEqual(ksp.its, 0)
60        #
61        ksp.norm = 1
62        self.assertEqual(ksp.norm, 1)
63        ksp.norm = 0
64        self.assertEqual(ksp.norm, 0)
65        #
66        rh = ksp.history
67        self.assertTrue(len(rh)==0)
68        #
69        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS
70        ksp.reason = reason
71        self.assertEqual(ksp.reason, reason)
72        self.assertTrue(ksp.is_converged)
73        self.assertFalse(ksp.is_diverged)
74        self.assertFalse(ksp.is_iterating)
75        reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
76        ksp.reason = reason
77        self.assertEqual(ksp.reason, reason)
78        self.assertFalse(ksp.is_converged)
79        self.assertTrue(ksp.is_diverged)
80        self.assertFalse(ksp.is_iterating)
81        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING
82        ksp.reason = reason
83        self.assertEqual(ksp.reason, reason)
84        self.assertFalse(ksp.is_converged)
85        self.assertFalse(ksp.is_diverged)
86        self.assertTrue(ksp.is_iterating)
87
88    def testGetSetPC(self):
89        oldpc = self.ksp.getPC()
90        self.assertEqual(oldpc.getRefCount(), 2)
91        newpc = PETSc.PC()
92        newpc.create(self.ksp.getComm())
93        self.assertEqual(newpc.getRefCount(), 1)
94        self.ksp.setPC(newpc)
95        self.assertEqual(newpc.getRefCount(), 2)
96        self.assertEqual(oldpc.getRefCount(), 1)
97        oldpc.destroy()
98        self.assertFalse(bool(oldpc))
99        pc = self.ksp.getPC()
100        self.assertTrue(bool(pc))
101        self.assertEqual(pc, newpc)
102        self.assertEqual(pc.getRefCount(), 3)
103        newpc.destroy()
104        self.assertFalse(bool(newpc))
105        self.assertEqual(pc.getRefCount(), 2)
106
107    def testSolve(self):
108        A = PETSc.Mat().create(PETSc.COMM_SELF)
109        A.setSizes([3,3])
110        A.setType(PETSc.Mat.Type.SEQAIJ)
111        A.setPreallocationNNZ(1)
112        for i in range(3):
113            A.setValue(i, i, 0.9/(i+1))
114        A.assemble()
115        A.shift(1)
116        x, b = A.createVecs()
117        b.set(10)
118        x.setRandom()
119        self.ksp.setOperators(A)
120        self.ksp.setConvergenceHistory()
121        self.ksp.solve(b, x)
122        r = b.duplicate()
123        u = x.duplicate()
124        self.ksp.buildSolution(u)
125        self.ksp.buildResidual(u)
126        rh = self.ksp.getConvergenceHistory()
127        self.ksp.setConvergenceHistory(0)
128        rh = self.ksp.getConvergenceHistory()
129        self.assertEqual(len(rh), 0)
130        del A, x, b
131
132    def testResetAndSolve(self):
133        self.ksp.reset()
134        self.testSolve()
135        self.ksp.reset()
136        self.testSolve()
137        self.ksp.reset()
138
139    def testSetMonitor(self):
140        reshist = {}
141        def monitor(ksp, its, rnorm):
142            reshist[its] = rnorm
143        refcnt = getrefcount(monitor)
144        self.ksp.setMonitor(monitor)
145        self.assertEqual(getrefcount(monitor), refcnt + 1)
146        ## self.testSolve()
147        reshist = {}
148        self.ksp.monitorCancel()
149        self.assertEqual(getrefcount(monitor), refcnt)
150        self.testSolve()
151        self.assertEqual(len(reshist), 0)
152        ## Monitor = PETSc.KSP.Monitor
153        ## self.ksp.setMonitor(Monitor())
154        ## self.ksp.setMonitor(Monitor.DEFAULT)
155        ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM)
156        ## self.ksp.setMonitor(Monitor.SOLUTION)
157
158    def testSetConvergenceTest(self):
159        def converged(ksp, its, rnorm):
160            if its > 10: return True
161            return False
162        refcnt = getrefcount(converged)
163        self.ksp.setConvergenceTest(converged)
164        self.assertEqual(getrefcount(converged), refcnt + 1)
165        self.ksp.setConvergenceTest(None)
166        self.assertEqual(getrefcount(converged), refcnt)
167
168# --------------------------------------------------------------------
169
170class TestKSPPREONLY(BaseTestKSP, unittest.TestCase):
171    KSP_TYPE = PETSc.KSP.Type.PREONLY
172    PC_TYPE = PETSc.PC.Type.LU
173
174class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase):
175    KSP_TYPE = PETSc.KSP.Type.RICHARDSON
176
177class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase):
178    try:
179        KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV
180    except AttributeError:
181        KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV
182
183class TestKSPCG(BaseTestKSP, unittest.TestCase):
184    KSP_TYPE = PETSc.KSP.Type.CG
185
186class TestKSPCGNE(BaseTestKSP, unittest.TestCase):
187    KSP_TYPE = PETSc.KSP.Type.CGNE
188
189class TestKSPSTCG(BaseTestKSP, unittest.TestCase):
190    KSP_TYPE = PETSc.KSP.Type.STCG
191
192class TestKSPBCGS(BaseTestKSP, unittest.TestCase):
193    KSP_TYPE = PETSc.KSP.Type.BCGS
194
195class TestKSPBCGSL(BaseTestKSP, unittest.TestCase):
196    KSP_TYPE = PETSc.KSP.Type.BCGSL
197
198class TestKSPCGS(BaseTestKSP, unittest.TestCase):
199    KSP_TYPE = PETSc.KSP.Type.CGS
200
201class TestKSPQCG(BaseTestKSP, unittest.TestCase):
202    KSP_TYPE = PETSc.KSP.Type.QCG
203    PC_TYPE  = PETSc.PC.Type.JACOBI
204
205class TestKSPBICG(BaseTestKSP, unittest.TestCase):
206    KSP_TYPE = PETSc.KSP.Type.BICG
207
208class TestKSPGMRES(BaseTestKSP, unittest.TestCase):
209    KSP_TYPE = PETSc.KSP.Type.GMRES
210
211class TestKSPFGMRES(BaseTestKSP, unittest.TestCase):
212    KSP_TYPE = PETSc.KSP.Type.FGMRES
213
214# --------------------------------------------------------------------
215
216if PETSc.ScalarType().dtype.char in 'FDG':
217    del TestKSPSTCG
218
219# --------------------------------------------------------------------
220
221if __name__ == '__main__':
222    unittest.main()
223