xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 7d766218e3bc62b9d20359dfd3a8f5ff55ad6b2a)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
13        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
14        f.assemble()
15
16
17class Jacobian:
18    def __call__(self, snes, x, J, P):
19        P[0, 0] = (2.0 * x[0] + x[1]).item()
20        P[0, 1] = (x[0]).item()
21        P[1, 0] = (x[1]).item()
22        P[1, 1] = (x[0] + 2.0 * x[1]).item()
23        P.assemble()
24        if J != P:
25            J.assemble()
26
27
28# --------------------------------------------------------------------
29
30
31class BaseTestSNES:
32    SNES_TYPE = None
33
34    def setUp(self):
35        snes = PETSc.SNES()
36        snes.create(PETSc.COMM_SELF)
37        if self.SNES_TYPE:
38            snes.setType(self.SNES_TYPE)
39        self.snes = snes
40
41    def tearDown(self):
42        self.snes = None
43        PETSc.garbage_cleanup()
44
45    def testGetSetType(self):
46        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
47        self.snes.setType(self.SNES_TYPE)
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49
50    def testTols(self):
51        tols = self.snes.getTolerances()
52        self.snes.setTolerances(*tols)
53        tnames = ('rtol', 'atol', 'stol', 'max_it')
54        tolvals = [getattr(self.snes, t) for t in tnames]
55        self.assertEqual(tuple(tols), tuple(tolvals))
56        dtol = self.snes.getDivergenceTolerance()
57        self.assertTrue(dtol > 0)
58        self.snes.setDivergenceTolerance(-1)
59        dtol = self.snes.getDivergenceTolerance()
60        self.assertEqual(dtol, -1)
61        self.snes.setDivergenceTolerance(PETSc.DEFAULT)
62        self.assertEqual(dtol, -1)
63
64    def testProperties(self):
65        snes = self.snes
66        #
67        snes.appctx = (1, 2, 3)
68        self.assertEqual(snes.appctx, (1, 2, 3))
69        snes.appctx = None
70        self.assertEqual(snes.appctx, None)
71        #
72        snes.its = 1
73        self.assertEqual(snes.its, 1)
74        snes.its = 0
75        self.assertEqual(snes.its, 0)
76        #
77        snes.norm = 1
78        self.assertEqual(snes.norm, 1)
79        snes.norm = 0
80        self.assertEqual(snes.norm, 0)
81        #
82        rh, ih = snes.history
83        self.assertTrue(len(rh) == 0)
84        self.assertTrue(len(ih) == 0)
85        #
86        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
87        snes.reason = reason
88        self.assertEqual(snes.reason, reason)
89        self.assertTrue(snes.is_converged)
90        self.assertFalse(snes.is_diverged)
91        self.assertFalse(snes.is_iterating)
92        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
93        snes.reason = reason
94        self.assertEqual(snes.reason, reason)
95        self.assertFalse(snes.is_converged)
96        self.assertTrue(snes.is_diverged)
97        self.assertFalse(snes.is_iterating)
98        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
99        snes.reason = reason
100        self.assertEqual(snes.reason, reason)
101        self.assertFalse(snes.is_converged)
102        self.assertFalse(snes.is_diverged)
103        self.assertTrue(snes.is_iterating)
104        #
105        self.assertFalse(snes.use_ew)
106        self.assertFalse(snes.use_mf)
107        self.assertFalse(snes.use_fd)
108
109    def testGetSetFunc(self):
110        r, func = self.snes.getFunction()
111        self.assertFalse(r)
112        self.assertTrue(func is None)
113        r = PETSc.Vec().createSeq(2)
114        func = Function()
115        refcnt = getrefcount(func)
116        self.snes.setFunction(func, r)
117        self.snes.setFunction(func, r)
118        self.assertEqual(getrefcount(func), refcnt + 1)
119        r2, func2 = self.snes.getFunction()
120        self.assertEqual(r, r2)
121        self.assertEqual(func, func2[0])
122        self.assertEqual(getrefcount(func), refcnt + 1)
123        r3, func3 = self.snes.getFunction()
124        self.assertEqual(r, r3)
125        self.assertEqual(func, func3[0])
126        self.assertEqual(getrefcount(func), refcnt + 1)
127
128    def testCompFunc(self):
129        r = PETSc.Vec().createSeq(2)
130        func = Function()
131        self.snes.setFunction(func, r)
132        x, y = r.duplicate(), r.duplicate()
133        x[0], x[1] = [1, 2]
134        self.snes.computeFunction(x, y)
135        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
136        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
137
138    def testGetSetJac(self):
139        A, P, jac = self.snes.getJacobian()
140        self.assertFalse(A)
141        self.assertFalse(P)
142        self.assertTrue(jac is None)
143        J = PETSc.Mat().create(PETSc.COMM_SELF)
144        J.setSizes([2, 2])
145        J.setType(PETSc.Mat.Type.SEQAIJ)
146        J.setUp()
147        jac = Jacobian()
148        refcnt = getrefcount(jac)
149        self.snes.setJacobian(jac, J)
150        self.snes.setJacobian(jac, J)
151        self.assertEqual(getrefcount(jac), refcnt + 1)
152        J2, P2, jac2 = self.snes.getJacobian()
153        self.assertEqual(J, J2)
154        self.assertEqual(J2, P2)
155        self.assertEqual(jac, jac2[0])
156        self.assertEqual(getrefcount(jac), refcnt + 1)
157        J3, P3, jac3 = self.snes.getJacobian()
158        self.assertEqual(J, J3)
159        self.assertEqual(J3, P3)
160        self.assertEqual(jac, jac3[0])
161        self.assertEqual(getrefcount(jac), refcnt + 1)
162
163    def testCompJac(self):
164        J = PETSc.Mat().create(PETSc.COMM_SELF)
165        J.setSizes([2, 2])
166        J.setType(PETSc.Mat.Type.SEQAIJ)
167        J.setUp()
168        jac = Jacobian()
169        self.snes.setJacobian(jac, J)
170        x = PETSc.Vec().createSeq(2)
171        x[0], x[1] = [1, 2]
172        self.snes.getKSP().getPC()
173        self.snes.computeJacobian(x, J)
174
175    def testGetSetUpd(self):
176        self.assertTrue(self.snes.getUpdate() is None)
177        upd = lambda snes, it: None
178        refcnt = getrefcount(upd)
179        self.snes.setUpdate(upd)
180        self.assertEqual(getrefcount(upd), refcnt + 1)
181        self.snes.setUpdate(upd)
182        self.assertEqual(getrefcount(upd), refcnt + 1)
183        self.snes.setUpdate(None)
184        self.assertTrue(self.snes.getUpdate() is None)
185        self.assertEqual(getrefcount(upd), refcnt)
186        self.snes.setUpdate(upd)
187        self.assertEqual(getrefcount(upd), refcnt + 1)
188        upd2 = lambda snes, it: None
189        refcnt2 = getrefcount(upd2)
190        self.snes.setUpdate(upd2)
191        self.assertEqual(getrefcount(upd), refcnt)
192        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
193        tmp = self.snes.getUpdate()[0]
194        self.assertTrue(tmp is upd2)
195        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
196        del tmp
197        self.snes.setUpdate(None)
198        self.assertTrue(self.snes.getUpdate() is None)
199        self.assertEqual(getrefcount(upd2), refcnt2)
200
201    def testGetKSP(self):
202        ksp = self.snes.getKSP()
203        self.assertEqual(ksp.getRefCount(), 2)
204
205    def testSolve(self):
206        J = PETSc.Mat().create(PETSc.COMM_SELF)
207        J.setSizes([2, 2])
208        J.setType(PETSc.Mat.Type.SEQAIJ)
209        J.setUp()
210        r = PETSc.Vec().createSeq(2)
211        x = PETSc.Vec().createSeq(2)
212        b = PETSc.Vec().createSeq(2)
213        self.snes.setFunction(Function(), r)
214        self.snes.setJacobian(Jacobian(), J)
215        x.setArray([2, 3])
216        b.set(0)
217        self.snes.setConvergenceHistory()
218        self.snes.setFromOptions()
219        self.snes.solve(b, x)
220        rh, ih = self.snes.getConvergenceHistory()
221        self.snes.setConvergenceHistory(0, reset=True)
222        rh, ih = self.snes.getConvergenceHistory()
223        self.assertEqual(len(rh), 0)
224        self.assertEqual(len(ih), 0)
225        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
226        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
227        # XXX this test should not be here !
228        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
229        self.assertTrue(reason > 0)
230
231        # test interface
232        x = self.snes.getSolution()
233        x.setArray([2, 3])
234        self.snes.solve()
235        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
236        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
237
238    def testResetAndSolve(self):
239        self.snes.reset()
240        self.testSolve()
241        self.snes.reset()
242        self.testSolve()
243        self.snes.reset()
244
245    def testSetMonitor(self):
246        reshist = {}
247
248        def monitor(snes, its, fgnorm):
249            reshist[its] = fgnorm
250
251        refcnt = getrefcount(monitor)
252        self.snes.setMonitor(monitor)
253        self.assertEqual(getrefcount(monitor), refcnt + 1)
254        self.testSolve()
255        self.assertTrue(len(reshist) > 0)
256        reshist = {}
257        self.snes.monitorCancel()
258        self.assertEqual(getrefcount(monitor), refcnt)
259        self.testSolve()
260        self.assertTrue(len(reshist) == 0)
261        self.snes.setMonitor(monitor)
262        self.snes.monitor(1, 7)
263        self.assertTrue(reshist[1] == 7)
264        ## Monitor = PETSc.SNES.Monitor
265        ## self.snes.setMonitor(Monitor())
266        ## self.snes.setMonitor(Monitor.DEFAULT)
267        ## self.snes.setMonitor(Monitor.SOLUTION)
268        ## self.snes.setMonitor(Monitor.RESIDUAL)
269        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
270
271    def testSetGetStepFails(self):
272        its = self.snes.getIterationNumber()
273        self.assertEqual(its, 0)
274        fails = self.snes.getNonlinearStepFailures()
275        self.assertEqual(fails, 0)
276        fails = self.snes.getMaxNonlinearStepFailures()
277        self.assertEqual(fails, 1)
278        self.snes.setMaxNonlinearStepFailures(5)
279        fails = self.snes.getMaxNonlinearStepFailures()
280        self.assertEqual(fails, 5)
281        self.snes.setMaxNonlinearStepFailures(1)
282        fails = self.snes.getMaxNonlinearStepFailures()
283        self.assertEqual(fails, 1)
284
285    def testSetGetLinFails(self):
286        its = self.snes.getLinearSolveIterations()
287        self.assertEqual(its, 0)
288        fails = self.snes.getLinearSolveFailures()
289        self.assertEqual(fails, 0)
290        fails = self.snes.getMaxLinearSolveFailures()
291        self.assertEqual(fails, 1)
292        self.snes.setMaxLinearSolveFailures(5)
293        fails = self.snes.getMaxLinearSolveFailures()
294        self.assertEqual(fails, 5)
295        self.snes.setMaxLinearSolveFailures(1)
296        fails = self.snes.getMaxLinearSolveFailures()
297        self.assertEqual(fails, 1)
298
299    def testEW(self):
300        self.snes.setUseEW(False)
301        self.assertFalse(self.snes.getUseEW())
302        self.snes.setUseEW(True)
303        self.assertTrue(self.snes.getUseEW())
304        params = self.snes.getParamsEW()
305        params['version'] = 1
306        self.snes.setParamsEW(**params)
307        params = self.snes.getParamsEW()
308        self.assertEqual(params['version'], 1)
309        params['version'] = PETSc.DEFAULT
310        self.snes.setParamsEW(**params)
311        params = self.snes.getParamsEW()
312        self.assertEqual(params['version'], 1)
313
314    def testMF(self):
315        # self.snes.setOptionsPrefix('MF-')
316        # opts = PETSc.Options(self.snes)
317        # opts['mat_mffd_type'] = 'ds'
318        # opts['snes_monitor']  = 'stdout'
319        # opts['ksp_monitor']   = 'stdout'
320        # opts['snes_view']     = 'stdout'
321        J = PETSc.Mat().create(PETSc.COMM_SELF)
322        J.setSizes([2, 2])
323        J.setType(PETSc.Mat.Type.SEQAIJ)
324        J.setUp()
325        r = PETSc.Vec().createSeq(2)
326        x = PETSc.Vec().createSeq(2)
327        b = PETSc.Vec().createSeq(2)
328        fun = Function()
329        jac = Jacobian()
330        self.snes.setFunction(fun, r)
331        self.snes.setJacobian(jac, J)
332        self.assertFalse(self.snes.getUseMF())
333        self.snes.setUseMF(False)
334        self.assertFalse(self.snes.getUseMF())
335        self.snes.setUseMF(True)
336        self.assertTrue(self.snes.getUseMF())
337        self.snes.setFromOptions()
338        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
339            x.setArray([2, 3])
340            b.set(0)
341            self.snes.solve(b, x)
342            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
343            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
344
345    def testFDColor(self):
346        J = PETSc.Mat().create(PETSc.COMM_SELF)
347        J.setSizes([2, 2])
348        J.setType(PETSc.Mat.Type.SEQAIJ)
349        J.setUp()
350        r = PETSc.Vec().createSeq(2)
351        x = PETSc.Vec().createSeq(2)
352        b = PETSc.Vec().createSeq(2)
353        fun = Function()
354        jac = Jacobian()
355        self.snes.setFunction(fun, r)
356        self.snes.setJacobian(jac, J)
357        self.assertFalse(self.snes.getUseFD())
358        jac(self.snes, x, J, J)
359        self.snes.setUseFD(False)
360        self.assertFalse(self.snes.getUseFD())
361        self.snes.setUseFD(True)
362        self.assertTrue(self.snes.getUseFD())
363        self.snes.setFromOptions()
364        x.setArray([2, 3])
365        b.set(0)
366        self.snes.solve(b, x)
367        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
368        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
369
370    def testNPC(self):
371        self.snes.appctx = (1, 2, 3)
372        npc = self.snes.getNPC()
373        self.assertEqual(npc.appctx, (1, 2, 3))
374
375
376# --------------------------------------------------------------------
377
378
379class TestSNESLS(BaseTestSNES, unittest.TestCase):
380    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
381
382
383class TestSNESTR(BaseTestSNES, unittest.TestCase):
384    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
385
386
387# --------------------------------------------------------------------
388
389if __name__ == '__main__':
390    unittest.main()
391