xref: /petsc/src/binding/petsc4py/test/test_ksp.py (revision ffeef943c8ee50edff320d8a3135bb0c94853e4c)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class BaseTestKSP:
11    KSP_TYPE = None
12    PC_TYPE = None
13
14    def setUp(self):
15        ksp = PETSc.KSP()
16        ksp.create(PETSc.COMM_SELF)
17        if self.KSP_TYPE:
18            ksp.setType(self.KSP_TYPE)
19        if self.PC_TYPE:
20            pc = ksp.getPC()
21            pc.setType(self.PC_TYPE)
22        self.ksp = ksp
23
24    def tearDown(self):
25        self.ksp = None
26        PETSc.garbage_cleanup()
27
28    def testGetSetType(self):
29        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
30        self.ksp.setType(self.KSP_TYPE)
31        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
32
33    def testTols(self):
34        tols = self.ksp.getTolerances()
35        self.ksp.setTolerances(*tols)
36        tnames = ('rtol', 'atol', 'divtol', 'max_it')
37        tolvals = [getattr(self.ksp, t) for t in tnames]
38        self.assertEqual(tuple(tols), tuple(tolvals))
39
40    def testProperties(self):
41        ksp = self.ksp
42        #
43        ksp.appctx = (1, 2, 3)
44        self.assertEqual(ksp.appctx, (1, 2, 3))
45        ksp.appctx = None
46        self.assertEqual(ksp.appctx, None)
47        #
48        side = ksp.pc_side
49        ksp.pc_side = side
50        self.assertEqual(ksp.pc_side, side)
51        #
52        nt = ksp.norm_type
53        ksp.norm_type = nt
54        self.assertEqual(ksp.norm_type, nt)
55        #
56        ksp.its = 1
57        self.assertEqual(ksp.its, 1)
58        ksp.its = 0
59        self.assertEqual(ksp.its, 0)
60        #
61        ksp.norm = 1
62        self.assertEqual(ksp.norm, 1)
63        ksp.norm = 0
64        self.assertEqual(ksp.norm, 0)
65        #
66        rh = ksp.history
67        self.assertTrue(len(rh) == 0)
68        #
69        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS
70        ksp.reason = reason
71        self.assertEqual(ksp.reason, reason)
72        self.assertTrue(ksp.is_converged)
73        self.assertFalse(ksp.is_diverged)
74        self.assertFalse(ksp.is_iterating)
75        reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
76        ksp.reason = reason
77        self.assertEqual(ksp.reason, reason)
78        self.assertFalse(ksp.is_converged)
79        self.assertTrue(ksp.is_diverged)
80        self.assertFalse(ksp.is_iterating)
81        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING
82        ksp.reason = reason
83        self.assertEqual(ksp.reason, reason)
84        self.assertFalse(ksp.is_converged)
85        self.assertFalse(ksp.is_diverged)
86        self.assertTrue(ksp.is_iterating)
87
88    def testGetSetPC(self):
89        oldpc = self.ksp.getPC()
90        self.assertEqual(oldpc.getRefCount(), 2)
91        newpc = PETSc.PC()
92        newpc.create(self.ksp.getComm())
93        self.assertEqual(newpc.getRefCount(), 1)
94        self.ksp.setPC(newpc)
95        self.assertEqual(newpc.getRefCount(), 2)
96        self.assertEqual(oldpc.getRefCount(), 1)
97        oldpc.destroy()
98        self.assertFalse(bool(oldpc))
99        pc = self.ksp.getPC()
100        self.assertTrue(bool(pc))
101        self.assertEqual(pc, newpc)
102        self.assertEqual(pc.getRefCount(), 3)
103        newpc.destroy()
104        self.assertFalse(bool(newpc))
105        self.assertEqual(pc.getRefCount(), 2)
106
107    def testSolve(self):
108        A = PETSc.Mat().create(PETSc.COMM_SELF)
109        A.setSizes([3, 3])
110        A.setType(PETSc.Mat.Type.SEQAIJ)
111        A.setPreallocationNNZ(1)
112        for i in range(3):
113            A.setValue(i, i, 0.9 / (i + 1))
114        A.assemble()
115        A.shift(1)
116        x, b = A.createVecs()
117        b.set(10)
118        x.setRandom()
119        self.ksp.setOperators(A)
120        self.ksp.setConvergenceHistory()
121        self.ksp.solve(b, x)
122        u = x.duplicate()
123        self.ksp.buildSolution(u)
124        self.ksp.buildResidual(u)
125        rh = self.ksp.getConvergenceHistory()
126        self.ksp.setConvergenceHistory(0)
127        rh = self.ksp.getConvergenceHistory()
128        self.assertEqual(len(rh), 0)
129        del A, x, b
130
131    def testResetAndSolve(self):
132        self.ksp.reset()
133        self.testSolve()
134        self.ksp.reset()
135        self.testSolve()
136        self.ksp.reset()
137
138    def testSetMonitor(self):
139        reshist = {}
140
141        def monitor(ksp, its, rnorm):
142            reshist[its] = rnorm
143
144        refcnt = getrefcount(monitor)
145        self.ksp.setMonitor(monitor)
146        self.assertEqual(getrefcount(monitor), refcnt + 1)
147        ## self.testSolve()
148        reshist = {}
149        self.ksp.monitorCancel()
150        self.assertEqual(getrefcount(monitor), refcnt)
151        self.testSolve()
152        self.assertEqual(len(reshist), 0)
153        ## Monitor = PETSc.KSP.Monitor
154        ## self.ksp.setMonitor(Monitor())
155        ## self.ksp.setMonitor(Monitor.DEFAULT)
156        ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM)
157        ## self.ksp.setMonitor(Monitor.SOLUTION)
158
159    def testSetConvergenceTest(self):
160        def converged(ksp, its, rnorm):
161            if its > 10:
162                return True
163            return False
164
165        refcnt = getrefcount(converged)
166        self.ksp.setConvergenceTest(converged)
167        self.assertEqual(getrefcount(converged), refcnt + 1)
168        self.ksp.setConvergenceTest(None)
169        self.assertEqual(getrefcount(converged), refcnt)
170
171    def testAddConvergenceTest(self):
172        def converged(ksp, its, rnorm):
173            return True
174
175        refcnt = getrefcount(converged)
176        self.ksp.addConvergenceTest(converged, prepend=True)
177        self.assertEqual(getrefcount(converged), refcnt + 1)
178        self.testSolve()
179        self.ksp.setConvergenceTest(None)
180        self.assertEqual(getrefcount(converged), refcnt)
181        self.testSolve()
182        self.ksp.addConvergenceTest(converged, prepend=False)
183        self.assertEqual(getrefcount(converged), refcnt + 1)
184        self.testSolve()
185        self.ksp.setConvergenceTest(None)
186        self.assertEqual(getrefcount(converged), refcnt)
187
188    def testSetPreSolveTest(self):
189        check = {'val': 0}
190
191        def presolve(ksp, rhs, x):
192            check['val'] = 1
193
194        refcnt = getrefcount(presolve)
195        self.ksp.setPreSolve(presolve)
196        self.assertEqual(getrefcount(presolve), refcnt + 1)
197        self.testSolve()
198        self.assertEqual(check['val'], 1)
199        self.ksp.setPreSolve(None)
200        self.assertEqual(getrefcount(presolve), refcnt)
201
202    def testSetPostSolveTest(self):
203        check = {'val': 0}
204
205        def postsolve(ksp, rhs, x):
206            check['val'] = 1
207
208        refcnt = getrefcount(postsolve)
209        self.ksp.setPostSolve(postsolve)
210        self.assertEqual(getrefcount(postsolve), refcnt + 1)
211        self.testSolve()
212        self.assertEqual(check['val'], 1)
213        self.ksp.setPostSolve(None)
214        self.assertEqual(getrefcount(postsolve), refcnt)
215
216
217# --------------------------------------------------------------------
218
219
220class TestKSPPREONLY(BaseTestKSP, unittest.TestCase):
221    KSP_TYPE = PETSc.KSP.Type.PREONLY
222    PC_TYPE = PETSc.PC.Type.LU
223
224
225class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase):
226    KSP_TYPE = PETSc.KSP.Type.RICHARDSON
227
228
229class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase):
230    try:
231        KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV
232    except AttributeError:
233        KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV
234
235
236class TestKSPCG(BaseTestKSP, unittest.TestCase):
237    KSP_TYPE = PETSc.KSP.Type.CG
238
239
240class TestKSPCGNE(BaseTestKSP, unittest.TestCase):
241    KSP_TYPE = PETSc.KSP.Type.CGNE
242
243
244class TestKSPSTCG(BaseTestKSP, unittest.TestCase):
245    KSP_TYPE = PETSc.KSP.Type.STCG
246
247
248class TestKSPBCGS(BaseTestKSP, unittest.TestCase):
249    KSP_TYPE = PETSc.KSP.Type.BCGS
250
251
252class TestKSPBCGSL(BaseTestKSP, unittest.TestCase):
253    KSP_TYPE = PETSc.KSP.Type.BCGSL
254
255
256class TestKSPCGS(BaseTestKSP, unittest.TestCase):
257    KSP_TYPE = PETSc.KSP.Type.CGS
258
259
260class TestKSPQCG(BaseTestKSP, unittest.TestCase):
261    KSP_TYPE = PETSc.KSP.Type.QCG
262    PC_TYPE = PETSc.PC.Type.JACOBI
263
264
265class TestKSPBICG(BaseTestKSP, unittest.TestCase):
266    KSP_TYPE = PETSc.KSP.Type.BICG
267
268
269class TestKSPGMRES(BaseTestKSP, unittest.TestCase):
270    KSP_TYPE = PETSc.KSP.Type.GMRES
271
272
273class TestKSPFGMRES(BaseTestKSP, unittest.TestCase):
274    KSP_TYPE = PETSc.KSP.Type.FGMRES
275
276
277class TestKSPLSQR(BaseTestKSP, unittest.TestCase):
278    KSP_TYPE = PETSc.KSP.Type.LSQR
279
280
281# --------------------------------------------------------------------
282
283if PETSc.ScalarType().dtype.char in 'FDG':
284    del TestKSPSTCG
285
286# --------------------------------------------------------------------
287
288if __name__ == '__main__':
289    unittest.main()
290