xref: /petsc/src/binding/petsc4py/test/test_ksp.py (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseTestKSP(object):
10
11    KSP_TYPE = None
12    PC_TYPE  = None
13
14    def setUp(self):
15        ksp = PETSc.KSP()
16        ksp.create(PETSc.COMM_SELF)
17        if self.KSP_TYPE:
18            ksp.setType(self.KSP_TYPE)
19        if self.PC_TYPE:
20            pc = ksp.getPC()
21            pc.setType(self.PC_TYPE)
22        self.ksp = ksp
23
24    def tearDown(self):
25        self.ksp = None
26
27    def testGetSetType(self):
28        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
29        self.ksp.setType(self.KSP_TYPE)
30        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
31
32    def testTols(self):
33        tols = self.ksp.getTolerances()
34        self.ksp.setTolerances(*tols)
35        tnames = ('rtol', 'atol', 'divtol', 'max_it')
36        tolvals = [getattr(self.ksp, t) for t in  tnames]
37        self.assertEqual(tuple(tols), tuple(tolvals))
38
39    def testProperties(self):
40        ksp = self.ksp
41        #
42        ksp.appctx = (1,2,3)
43        self.assertEqual(ksp.appctx, (1,2,3))
44        ksp.appctx = None
45        self.assertEqual(ksp.appctx, None)
46        #
47        side = ksp.pc_side
48        ksp.pc_side = side
49        self.assertEqual(ksp.pc_side, side)
50        #
51        nt = ksp.norm_type
52        ksp.norm_type = nt
53        self.assertEqual(ksp.norm_type, nt)
54        #
55        ksp.its = 1
56        self.assertEqual(ksp.its, 1)
57        ksp.its = 0
58        self.assertEqual(ksp.its, 0)
59        #
60        ksp.norm = 1
61        self.assertEqual(ksp.norm, 1)
62        ksp.norm = 0
63        self.assertEqual(ksp.norm, 0)
64        #
65        rh = ksp.history
66        self.assertTrue(len(rh)==0)
67        #
68        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS
69        ksp.reason = reason
70        self.assertEqual(ksp.reason, reason)
71        self.assertTrue(ksp.converged)
72        self.assertFalse(ksp.diverged)
73        self.assertFalse(ksp.iterating)
74        reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
75        ksp.reason = reason
76        self.assertEqual(ksp.reason, reason)
77        self.assertFalse(ksp.converged)
78        self.assertTrue(ksp.diverged)
79        self.assertFalse(ksp.iterating)
80        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING
81        ksp.reason = reason
82        self.assertEqual(ksp.reason, reason)
83        self.assertFalse(ksp.converged)
84        self.assertFalse(ksp.diverged)
85        self.assertTrue(ksp.iterating)
86
87    def testGetSetPC(self):
88        oldpc = self.ksp.getPC()
89        self.assertEqual(oldpc.getRefCount(), 2)
90        newpc = PETSc.PC()
91        newpc.create(self.ksp.getComm())
92        self.assertEqual(newpc.getRefCount(), 1)
93        self.ksp.setPC(newpc)
94        self.assertEqual(newpc.getRefCount(), 2)
95        self.assertEqual(oldpc.getRefCount(), 1)
96        oldpc.destroy()
97        self.assertFalse(bool(oldpc))
98        pc = self.ksp.getPC()
99        self.assertTrue(bool(pc))
100        self.assertEqual(pc, newpc)
101        self.assertEqual(pc.getRefCount(), 3)
102        newpc.destroy()
103        self.assertFalse(bool(newpc))
104        self.assertEqual(pc.getRefCount(), 2)
105
106    def testSolve(self):
107        A = PETSc.Mat().create(PETSc.COMM_SELF)
108        A.setSizes([3,3])
109        A.setType(PETSc.Mat.Type.SEQAIJ)
110        A.setPreallocationNNZ(1)
111        for i in range(3):
112            A.setValue(i, i, 0.9/(i+1))
113        A.assemble()
114        A.shift(1)
115        x, b = A.createVecs()
116        b.set(10)
117        x.setRandom()
118        self.ksp.setOperators(A)
119        self.ksp.setConvergenceHistory()
120        self.ksp.solve(b, x)
121        r = b.duplicate()
122        u = x.duplicate()
123        self.ksp.buildSolution(u)
124        self.ksp.buildResidual(u)
125        rh = self.ksp.getConvergenceHistory()
126        self.ksp.setConvergenceHistory(0)
127        rh = self.ksp.getConvergenceHistory()
128        self.assertEqual(len(rh), 0)
129        del A, x, b
130
131    def testResetAndSolve(self):
132        self.ksp.reset()
133        self.testSolve()
134        self.ksp.reset()
135        self.testSolve()
136        self.ksp.reset()
137
138    def testSetMonitor(self):
139        reshist = {}
140        def monitor(ksp, its, rnorm):
141            reshist[its] = rnorm
142        refcnt = getrefcount(monitor)
143        self.ksp.setMonitor(monitor)
144        self.assertEqual(getrefcount(monitor), refcnt + 1)
145        ## self.testSolve()
146        reshist = {}
147        self.ksp.monitorCancel()
148        self.assertEqual(getrefcount(monitor), refcnt)
149        self.testSolve()
150        self.assertEqual(len(reshist), 0)
151        ## Monitor = PETSc.KSP.Monitor
152        ## self.ksp.setMonitor(Monitor())
153        ## self.ksp.setMonitor(Monitor.DEFAULT)
154        ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM)
155        ## self.ksp.setMonitor(Monitor.SOLUTION)
156
157    def testSetConvergenceTest(self):
158        def converged(ksp, its, rnorm):
159            if its > 10: return True
160            return False
161        refcnt = getrefcount(converged)
162        self.ksp.setConvergenceTest(converged)
163        self.assertEqual(getrefcount(converged), refcnt + 1)
164        self.ksp.setConvergenceTest(None)
165        self.assertEqual(getrefcount(converged), refcnt)
166
167# --------------------------------------------------------------------
168
169class TestKSPPREONLY(BaseTestKSP, unittest.TestCase):
170    KSP_TYPE = PETSc.KSP.Type.PREONLY
171    PC_TYPE = PETSc.PC.Type.LU
172
173class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase):
174    KSP_TYPE = PETSc.KSP.Type.RICHARDSON
175
176class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase):
177    try:
178        KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV
179    except AttributeError:
180        KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV
181
182class TestKSPCG(BaseTestKSP, unittest.TestCase):
183    KSP_TYPE = PETSc.KSP.Type.CG
184
185class TestKSPCGNE(BaseTestKSP, unittest.TestCase):
186    KSP_TYPE = PETSc.KSP.Type.CGNE
187
188class TestKSPSTCG(BaseTestKSP, unittest.TestCase):
189    KSP_TYPE = PETSc.KSP.Type.STCG
190
191class TestKSPBCGS(BaseTestKSP, unittest.TestCase):
192    KSP_TYPE = PETSc.KSP.Type.BCGS
193
194class TestKSPBCGSL(BaseTestKSP, unittest.TestCase):
195    KSP_TYPE = PETSc.KSP.Type.BCGSL
196
197class TestKSPCGS(BaseTestKSP, unittest.TestCase):
198    KSP_TYPE = PETSc.KSP.Type.CGS
199
200class TestKSPQCG(BaseTestKSP, unittest.TestCase):
201    KSP_TYPE = PETSc.KSP.Type.QCG
202    PC_TYPE  = PETSc.PC.Type.JACOBI
203
204class TestKSPBICG(BaseTestKSP, unittest.TestCase):
205    KSP_TYPE = PETSc.KSP.Type.BICG
206
207class TestKSPGMRES(BaseTestKSP, unittest.TestCase):
208    KSP_TYPE = PETSc.KSP.Type.GMRES
209
210class TestKSPFGMRES(BaseTestKSP, unittest.TestCase):
211    KSP_TYPE = PETSc.KSP.Type.FGMRES
212
213# --------------------------------------------------------------------
214
215if PETSc.ScalarType().dtype.char in 'FDG':
216    del TestKSPSTCG
217
218# --------------------------------------------------------------------
219
220if __name__ == '__main__':
221    unittest.main()
222