xref: /petsc/src/binding/petsc4py/test/test_ksp.py (revision 0c5727118538cc82ab6050202f4496ec4e39cf67)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class BaseTestKSP:
11    KSP_TYPE = None
12    PC_TYPE = None
13
14    def setUp(self):
15        ksp = PETSc.KSP()
16        ksp.create(PETSc.COMM_SELF)
17        if self.KSP_TYPE:
18            ksp.setType(self.KSP_TYPE)
19        if self.PC_TYPE:
20            pc = ksp.getPC()
21            pc.setType(self.PC_TYPE)
22        self.ksp = ksp
23
24    def tearDown(self):
25        self.ksp = None
26        PETSc.garbage_cleanup()
27
28    def testGetSetType(self):
29        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
30        self.ksp.setType(self.KSP_TYPE)
31        self.assertEqual(self.ksp.getType(), self.KSP_TYPE)
32
33    def testTols(self):
34        tols = self.ksp.getTolerances()
35        self.ksp.setTolerances(*tols)
36        tnames = ('rtol', 'atol', 'divtol', 'max_it')
37        tolvals = [getattr(self.ksp, t) for t in tnames]
38        self.assertEqual(tuple(tols), tuple(tolvals))
39
40    def testProperties(self):
41        ksp = self.ksp
42        #
43        ksp.appctx = (1, 2, 3)
44        self.assertEqual(ksp.appctx, (1, 2, 3))
45        ksp.appctx = None
46        self.assertEqual(ksp.appctx, None)
47        #
48        side = ksp.pc_side
49        ksp.pc_side = side
50        self.assertEqual(ksp.pc_side, side)
51        #
52        nt = ksp.norm_type
53        ksp.norm_type = nt
54        self.assertEqual(ksp.norm_type, nt)
55        #
56        ksp.its = 1
57        self.assertEqual(ksp.its, 1)
58        ksp.its = 0
59        self.assertEqual(ksp.its, 0)
60        #
61        ksp.norm = 1
62        self.assertEqual(ksp.norm, 1)
63        ksp.norm = 0
64        self.assertEqual(ksp.norm, 0)
65        #
66        rh = ksp.history
67        self.assertTrue(len(rh) == 0)
68        #
69        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS
70        ksp.reason = reason
71        self.assertEqual(ksp.reason, reason)
72        self.assertTrue(ksp.is_converged)
73        self.assertFalse(ksp.is_diverged)
74        self.assertFalse(ksp.is_iterating)
75        reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
76        ksp.reason = reason
77        self.assertEqual(ksp.reason, reason)
78        self.assertFalse(ksp.is_converged)
79        self.assertTrue(ksp.is_diverged)
80        self.assertFalse(ksp.is_iterating)
81        reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING
82        ksp.reason = reason
83        self.assertEqual(ksp.reason, reason)
84        self.assertFalse(ksp.is_converged)
85        self.assertFalse(ksp.is_diverged)
86        self.assertTrue(ksp.is_iterating)
87
88    def testGetSetPC(self):
89        oldpc = self.ksp.getPC()
90        self.assertEqual(oldpc.getRefCount(), 2)
91        newpc = PETSc.PC()
92        newpc.create(self.ksp.getComm())
93        self.assertEqual(newpc.getRefCount(), 1)
94        self.ksp.setPC(newpc)
95        self.assertEqual(newpc.getRefCount(), 2)
96        self.assertEqual(oldpc.getRefCount(), 1)
97        oldpc.destroy()
98        self.assertFalse(bool(oldpc))
99        pc = self.ksp.getPC()
100        self.assertTrue(bool(pc))
101        self.assertEqual(pc, newpc)
102        self.assertEqual(pc.getRefCount(), 3)
103        newpc.destroy()
104        self.assertFalse(bool(newpc))
105        self.assertEqual(pc.getRefCount(), 2)
106
107    def testSolve(self, solve_only=False):
108        A = PETSc.Mat().create(PETSc.COMM_SELF)
109        A.setSizes([3, 3])
110        A.setType(PETSc.Mat.Type.SEQAIJ)
111        A.setPreallocationNNZ(1)
112        for i in range(3):
113            A.setValue(i, i, 0.9 / (i + 1))
114        A.assemble()
115        A.shift(1)
116        x, b = A.createVecs()
117        b.set(10)
118        x.setRandom()
119        self.ksp.setOperators(A)
120        if not solve_only:
121            self.ksp.setConvergenceHistory()
122        self.ksp.solve(b, x)
123        if not solve_only:
124            u = x.duplicate()
125            self.ksp.buildSolution(u)
126            self.ksp.buildResidual(u)
127            rh = self.ksp.getConvergenceHistory()
128            self.ksp.setConvergenceHistory(0)
129            rh = self.ksp.getConvergenceHistory()
130            self.assertEqual(len(rh), 0)
131            u.destroy()
132        del A, x, b
133
134    def testResetAndSolve(self):
135        self.ksp.reset()
136        self.testSolve()
137        self.ksp.reset()
138        self.testSolve()
139        self.ksp.reset()
140
141    def testSetMonitor(self):
142        reshist = {}
143
144        def monitor(ksp, its, rnorm):
145            if ksp.type in ['cg', 'stcg']:
146                reshist[its] = {'r': rnorm, 'o': ksp.getCGObjectiveValue()}
147            else:
148                reshist[its] = rnorm
149        refcnt = getrefcount(monitor)
150        self.ksp.setMonitor(monitor)
151        self.assertEqual(getrefcount(monitor), refcnt + 1)
152        self.testSolve(solve_only=True)
153        reshist = {}
154        self.ksp.monitorCancel()
155        self.assertEqual(getrefcount(monitor), refcnt)
156        self.testSolve(solve_only=True)
157        self.assertEqual(len(reshist), 0)
158        ## Monitor = PETSc.KSP.Monitor
159        ## self.ksp.setMonitor(Monitor())
160        ## self.ksp.setMonitor(Monitor.DEFAULT)
161        ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM)
162        ## self.ksp.setMonitor(Monitor.SOLUTION)
163
164    def testSetConvergenceTest(self):
165        def converged(ksp, its, rnorm):
166            if its > 10:
167                return True
168            return False
169
170        refcnt = getrefcount(converged)
171        self.ksp.setConvergenceTest(converged)
172        self.assertEqual(getrefcount(converged), refcnt + 1)
173        self.ksp.setConvergenceTest(None)
174        self.assertEqual(getrefcount(converged), refcnt)
175
176    def testAddConvergenceTest(self):
177        def converged(ksp, its, rnorm):
178            return True
179
180        refcnt = getrefcount(converged)
181        self.ksp.addConvergenceTest(converged, prepend=True)
182        self.assertEqual(getrefcount(converged), refcnt + 1)
183        self.testSolve()
184        self.ksp.setConvergenceTest(None)
185        self.assertEqual(getrefcount(converged), refcnt)
186        self.testSolve()
187        self.ksp.addConvergenceTest(converged, prepend=False)
188        self.assertEqual(getrefcount(converged), refcnt + 1)
189        self.testSolve()
190        self.ksp.setConvergenceTest(None)
191        self.assertEqual(getrefcount(converged), refcnt)
192
193    def testSetPreSolveTest(self):
194        check = {'val': 0}
195
196        def presolve(ksp, rhs, x):
197            check['val'] = 1
198
199        refcnt = getrefcount(presolve)
200        self.ksp.setPreSolve(presolve)
201        self.assertEqual(getrefcount(presolve), refcnt + 1)
202        self.testSolve()
203        self.assertEqual(check['val'], 1)
204        self.ksp.setPreSolve(None)
205        self.assertEqual(getrefcount(presolve), refcnt)
206
207    def testSetPostSolveTest(self):
208        check = {'val': 0}
209
210        def postsolve(ksp, rhs, x):
211            check['val'] = 1
212
213        refcnt = getrefcount(postsolve)
214        self.ksp.setPostSolve(postsolve)
215        self.assertEqual(getrefcount(postsolve), refcnt + 1)
216        self.testSolve()
217        self.assertEqual(check['val'], 1)
218        self.ksp.setPostSolve(None)
219        self.assertEqual(getrefcount(postsolve), refcnt)
220
221
222# --------------------------------------------------------------------
223
224
225class TestKSPPREONLY(BaseTestKSP, unittest.TestCase):
226    KSP_TYPE = PETSc.KSP.Type.PREONLY
227    PC_TYPE = PETSc.PC.Type.LU
228
229
230class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase):
231    KSP_TYPE = PETSc.KSP.Type.RICHARDSON
232
233
234class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase):
235    try:
236        KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV
237    except AttributeError:
238        KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV
239
240
241class TestKSPCG(BaseTestKSP, unittest.TestCase):
242    KSP_TYPE = PETSc.KSP.Type.CG
243
244
245class TestKSPCGNE(BaseTestKSP, unittest.TestCase):
246    KSP_TYPE = PETSc.KSP.Type.CGNE
247
248
249class TestKSPSTCG(BaseTestKSP, unittest.TestCase):
250    KSP_TYPE = PETSc.KSP.Type.STCG
251
252
253class TestKSPBCGS(BaseTestKSP, unittest.TestCase):
254    KSP_TYPE = PETSc.KSP.Type.BCGS
255
256
257class TestKSPBCGSL(BaseTestKSP, unittest.TestCase):
258    KSP_TYPE = PETSc.KSP.Type.BCGSL
259
260
261class TestKSPCGS(BaseTestKSP, unittest.TestCase):
262    KSP_TYPE = PETSc.KSP.Type.CGS
263
264
265class TestKSPQCG(BaseTestKSP, unittest.TestCase):
266    KSP_TYPE = PETSc.KSP.Type.QCG
267    PC_TYPE = PETSc.PC.Type.JACOBI
268
269
270class TestKSPBICG(BaseTestKSP, unittest.TestCase):
271    KSP_TYPE = PETSc.KSP.Type.BICG
272
273
274class TestKSPGMRES(BaseTestKSP, unittest.TestCase):
275    KSP_TYPE = PETSc.KSP.Type.GMRES
276
277
278class TestKSPFGMRES(BaseTestKSP, unittest.TestCase):
279    KSP_TYPE = PETSc.KSP.Type.FGMRES
280
281
282class TestKSPLSQR(BaseTestKSP, unittest.TestCase):
283    KSP_TYPE = PETSc.KSP.Type.LSQR
284
285
286# --------------------------------------------------------------------
287
288if PETSc.ScalarType().dtype.char in 'FDG':
289    del TestKSPSTCG
290
291# --------------------------------------------------------------------
292
293if __name__ == '__main__':
294    unittest.main()
295