xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 36d43d94b6d42e888c89e2d3ed68780aaa9faca1)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy as np
7
8# --------------------------------------------------------------------
9
10
11class Function:
12    def __call__(self, snes, x, f):
13        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
14        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
15        f.assemble()
16
17
18class Jacobian:
19    def __call__(self, snes, x, J, P):
20        P[0, 0] = (2.0 * x[0] + x[1]).item()
21        P[0, 1] = (x[0]).item()
22        P[1, 0] = (x[1]).item()
23        P[1, 1] = (x[0] + 2.0 * x[1]).item()
24        P.assemble()
25        if J != P:
26            J.assemble()
27
28
29
30# --------------------------------------------------------------------
31
32
33class BaseTestSNES:
34    SNES_TYPE = None
35
36    def setUp(self):
37        snes = PETSc.SNES()
38        snes.create(PETSc.COMM_SELF)
39        if self.SNES_TYPE:
40            snes.setType(self.SNES_TYPE)
41        self.snes = snes
42
43    def tearDown(self):
44        self.snes = None
45        PETSc.garbage_cleanup()
46
47    def testGetSetType(self):
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49        self.snes.setType(self.SNES_TYPE)
50        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
51
52    def testTols(self):
53        tols = self.snes.getTolerances()
54        self.snes.setTolerances(*tols)
55        tnames = ('rtol', 'atol', 'stol', 'max_it')
56        tolvals = [getattr(self.snes, t) for t in tnames]
57        self.assertEqual(tuple(tols), tuple(tolvals))
58        dtol = self.snes.getDivergenceTolerance()
59        self.assertTrue(dtol > 0)
60        self.snes.setDivergenceTolerance(-1)
61        dtol = self.snes.getDivergenceTolerance()
62        self.assertEqual(dtol, -1)
63        self.snes.setDivergenceTolerance(PETSc.DEFAULT)
64        self.assertEqual(dtol, -1)
65
66    def testProperties(self):
67        snes = self.snes
68        #
69        snes.appctx = (1, 2, 3)
70        self.assertEqual(snes.appctx, (1, 2, 3))
71        snes.appctx = None
72        self.assertEqual(snes.appctx, None)
73        #
74        snes.its = 1
75        self.assertEqual(snes.its, 1)
76        snes.its = 0
77        self.assertEqual(snes.its, 0)
78        #
79        snes.norm = 1
80        self.assertEqual(snes.norm, 1)
81        snes.norm = 0
82        self.assertEqual(snes.norm, 0)
83        #
84        rh, ih = snes.history
85        self.assertTrue(len(rh) == 0)
86        self.assertTrue(len(ih) == 0)
87        #
88        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
89        snes.reason = reason
90        self.assertEqual(snes.reason, reason)
91        self.assertTrue(snes.is_converged)
92        self.assertFalse(snes.is_diverged)
93        self.assertFalse(snes.is_iterating)
94        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
95        snes.reason = reason
96        self.assertEqual(snes.reason, reason)
97        self.assertFalse(snes.is_converged)
98        self.assertTrue(snes.is_diverged)
99        self.assertFalse(snes.is_iterating)
100        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
101        snes.reason = reason
102        self.assertEqual(snes.reason, reason)
103        self.assertFalse(snes.is_converged)
104        self.assertFalse(snes.is_diverged)
105        self.assertTrue(snes.is_iterating)
106        #
107        self.assertFalse(snes.use_ew)
108        self.assertFalse(snes.use_mf)
109        self.assertFalse(snes.use_fd)
110
111    def testGetSetFunc(self):
112        r, func = self.snes.getFunction()
113        self.assertFalse(r)
114        self.assertTrue(func is None)
115        r = PETSc.Vec().createSeq(2)
116        func = Function()
117        refcnt = getrefcount(func)
118        self.snes.setFunction(func, r)
119        self.snes.setFunction(func, r)
120        self.assertEqual(getrefcount(func), refcnt + 1)
121        r2, func2 = self.snes.getFunction()
122        self.assertEqual(r, r2)
123        self.assertEqual(func, func2[0])
124        self.assertEqual(getrefcount(func), refcnt + 1)
125        r3, func3 = self.snes.getFunction()
126        self.assertEqual(r, r3)
127        self.assertEqual(func, func3[0])
128        self.assertEqual(getrefcount(func), refcnt + 1)
129
130    def testCompFunc(self):
131        r = PETSc.Vec().createSeq(2)
132        func = Function()
133        self.snes.setFunction(func, r)
134        x, y = r.duplicate(), r.duplicate()
135        x[0], x[1] = [1, 2]
136        self.snes.computeFunction(x, y)
137        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
138        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
139
140    def testGetSetJac(self):
141        A, P, jac = self.snes.getJacobian()
142        self.assertFalse(A)
143        self.assertFalse(P)
144        self.assertTrue(jac is None)
145        J = PETSc.Mat().create(PETSc.COMM_SELF)
146        J.setSizes([2, 2])
147        J.setType(PETSc.Mat.Type.SEQAIJ)
148        J.setUp()
149        jac = Jacobian()
150        refcnt = getrefcount(jac)
151        self.snes.setJacobian(jac, J)
152        self.snes.setJacobian(jac, J)
153        self.assertEqual(getrefcount(jac), refcnt + 1)
154        J2, P2, jac2 = self.snes.getJacobian()
155        self.assertEqual(J, J2)
156        self.assertEqual(J2, P2)
157        self.assertEqual(jac, jac2[0])
158        self.assertEqual(getrefcount(jac), refcnt + 1)
159        J3, P3, jac3 = self.snes.getJacobian()
160        self.assertEqual(J, J3)
161        self.assertEqual(J3, P3)
162        self.assertEqual(jac, jac3[0])
163        self.assertEqual(getrefcount(jac), refcnt + 1)
164
165    def testCompJac(self):
166        J = PETSc.Mat().create(PETSc.COMM_SELF)
167        J.setSizes([2, 2])
168        J.setType(PETSc.Mat.Type.SEQAIJ)
169        J.setUp()
170        jac = Jacobian()
171        self.snes.setJacobian(jac, J)
172        x = PETSc.Vec().createSeq(2)
173        x[0], x[1] = [1, 2]
174        self.snes.getKSP().getPC()
175        self.snes.computeJacobian(x, J)
176
177    def testGetSetUpd(self):
178        self.assertTrue(self.snes.getUpdate() is None)
179        upd = lambda snes, it: None
180        refcnt = getrefcount(upd)
181        self.snes.setUpdate(upd)
182        self.assertEqual(getrefcount(upd), refcnt + 1)
183        self.snes.setUpdate(upd)
184        self.assertEqual(getrefcount(upd), refcnt + 1)
185        self.snes.setUpdate(None)
186        self.assertTrue(self.snes.getUpdate() is None)
187        self.assertEqual(getrefcount(upd), refcnt)
188        self.snes.setUpdate(upd)
189        self.assertEqual(getrefcount(upd), refcnt + 1)
190        upd2 = lambda snes, it: None
191        refcnt2 = getrefcount(upd2)
192        self.snes.setUpdate(upd2)
193        self.assertEqual(getrefcount(upd), refcnt)
194        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
195        tmp = self.snes.getUpdate()[0]
196        self.assertTrue(tmp is upd2)
197        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
198        del tmp
199        self.snes.setUpdate(None)
200        self.assertTrue(self.snes.getUpdate() is None)
201        self.assertEqual(getrefcount(upd2), refcnt2)
202
203    def testGetKSP(self):
204        ksp = self.snes.getKSP()
205        self.assertEqual(ksp.getRefCount(), 2)
206
207    def testSolve(self):
208        J = PETSc.Mat().create(PETSc.COMM_SELF)
209        J.setSizes([2, 2])
210        J.setType(PETSc.Mat.Type.SEQAIJ)
211        J.setUp()
212        r = PETSc.Vec().createSeq(2)
213        x = PETSc.Vec().createSeq(2)
214        b = PETSc.Vec().createSeq(2)
215        self.snes.setFunction(Function(), r)
216        self.snes.setJacobian(Jacobian(), J)
217
218        def _update(snes, it, cnt):
219             cnt += 1
220        cnt_up = np.array(0)
221        self.snes.setUpdate(_update, (cnt_up,) )
222
223        x.setArray([2, 3])
224        b.set(0)
225        self.snes.setConvergenceHistory()
226        self.snes.setFromOptions()
227        self.snes.solve(b, x)
228        self.snes.setUpdate(None)
229        rh, ih = self.snes.getConvergenceHistory()
230        self.snes.setConvergenceHistory(0, reset=True)
231        rh, ih = self.snes.getConvergenceHistory()
232        self.assertEqual(len(rh), 0)
233        self.assertEqual(len(ih), 0)
234        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
235        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
236        self.assertEqual(self.snes.getIterationNumber(), cnt_up)
237        # XXX this test should not be here !
238        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
239        self.assertTrue(reason > 0)
240
241        # test interface
242        x = self.snes.getSolution()
243        x.setArray([2, 3])
244        self.snes.solve()
245        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
246        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
247
248    def testResetAndSolve(self):
249        self.snes.reset()
250        self.testSolve()
251        self.snes.reset()
252        self.testSolve()
253        self.snes.reset()
254
255    def testSetMonitor(self):
256        reshist = {}
257
258        def monitor(snes, its, fgnorm):
259            reshist[its] = fgnorm
260
261        refcnt = getrefcount(monitor)
262        self.snes.setMonitor(monitor)
263        self.assertEqual(getrefcount(monitor), refcnt + 1)
264        self.testSolve()
265        self.assertTrue(len(reshist) > 0)
266        reshist = {}
267        self.snes.monitorCancel()
268        self.assertEqual(getrefcount(monitor), refcnt)
269        self.testSolve()
270        self.assertTrue(len(reshist) == 0)
271        self.snes.setMonitor(monitor)
272        self.snes.monitor(1, 7)
273        self.assertTrue(reshist[1] == 7)
274        ## Monitor = PETSc.SNES.Monitor
275        ## self.snes.setMonitor(Monitor())
276        ## self.snes.setMonitor(Monitor.DEFAULT)
277        ## self.snes.setMonitor(Monitor.SOLUTION)
278        ## self.snes.setMonitor(Monitor.RESIDUAL)
279        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
280
281    def testSetGetStepFails(self):
282        its = self.snes.getIterationNumber()
283        self.assertEqual(its, 0)
284        fails = self.snes.getNonlinearStepFailures()
285        self.assertEqual(fails, 0)
286        fails = self.snes.getMaxNonlinearStepFailures()
287        self.assertEqual(fails, 1)
288        self.snes.setMaxNonlinearStepFailures(5)
289        fails = self.snes.getMaxNonlinearStepFailures()
290        self.assertEqual(fails, 5)
291        self.snes.setMaxNonlinearStepFailures(1)
292        fails = self.snes.getMaxNonlinearStepFailures()
293        self.assertEqual(fails, 1)
294
295    def testSetGetLinFails(self):
296        its = self.snes.getLinearSolveIterations()
297        self.assertEqual(its, 0)
298        fails = self.snes.getLinearSolveFailures()
299        self.assertEqual(fails, 0)
300        fails = self.snes.getMaxLinearSolveFailures()
301        self.assertEqual(fails, 1)
302        self.snes.setMaxLinearSolveFailures(5)
303        fails = self.snes.getMaxLinearSolveFailures()
304        self.assertEqual(fails, 5)
305        self.snes.setMaxLinearSolveFailures(1)
306        fails = self.snes.getMaxLinearSolveFailures()
307        self.assertEqual(fails, 1)
308
309    def testEW(self):
310        self.snes.setUseEW(False)
311        self.assertFalse(self.snes.getUseEW())
312        self.snes.setUseEW(True)
313        self.assertTrue(self.snes.getUseEW())
314        params = self.snes.getParamsEW()
315        params['version'] = 1
316        self.snes.setParamsEW(**params)
317        params = self.snes.getParamsEW()
318        self.assertEqual(params['version'], 1)
319        params['version'] = PETSc.DEFAULT
320        self.snes.setParamsEW(**params)
321        params = self.snes.getParamsEW()
322        self.assertEqual(params['version'], 1)
323
324    def testMF(self):
325        # self.snes.setOptionsPrefix('MF-')
326        # opts = PETSc.Options(self.snes)
327        # opts['mat_mffd_type'] = 'ds'
328        # opts['snes_monitor']  = 'stdout'
329        # opts['ksp_monitor']   = 'stdout'
330        # opts['snes_view']     = 'stdout'
331        J = PETSc.Mat().create(PETSc.COMM_SELF)
332        J.setSizes([2, 2])
333        J.setType(PETSc.Mat.Type.SEQAIJ)
334        J.setUp()
335        r = PETSc.Vec().createSeq(2)
336        x = PETSc.Vec().createSeq(2)
337        b = PETSc.Vec().createSeq(2)
338        fun = Function()
339        jac = Jacobian()
340        self.snes.setFunction(fun, r)
341        self.snes.setJacobian(jac, J)
342        self.assertFalse(self.snes.getUseMF())
343        self.snes.setUseMF(False)
344        self.assertFalse(self.snes.getUseMF())
345        self.snes.setUseMF(True)
346        self.assertTrue(self.snes.getUseMF())
347        self.snes.setFromOptions()
348        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
349            x.setArray([2, 3])
350            b.set(0)
351            self.snes.solve(b, x)
352            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
353            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
354
355    def testFDColor(self):
356        J = PETSc.Mat().create(PETSc.COMM_SELF)
357        J.setSizes([2, 2])
358        J.setType(PETSc.Mat.Type.SEQAIJ)
359        J.setUp()
360        r = PETSc.Vec().createSeq(2)
361        x = PETSc.Vec().createSeq(2)
362        b = PETSc.Vec().createSeq(2)
363        fun = Function()
364        jac = Jacobian()
365        self.snes.setFunction(fun, r)
366        self.snes.setJacobian(jac, J)
367        self.assertFalse(self.snes.getUseFD())
368        jac(self.snes, x, J, J)
369        self.snes.setUseFD(False)
370        self.assertFalse(self.snes.getUseFD())
371        self.snes.setUseFD(True)
372        self.assertTrue(self.snes.getUseFD())
373        self.snes.setFromOptions()
374        x.setArray([2, 3])
375        b.set(0)
376        self.snes.solve(b, x)
377        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
378        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
379
380    def testNPC(self):
381        self.snes.appctx = (1, 2, 3)
382        npc = self.snes.getNPC()
383        self.assertEqual(npc.appctx, (1, 2, 3))
384
385
386# --------------------------------------------------------------------
387
388
389class TestSNESLS(BaseTestSNES, unittest.TestCase):
390    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
391
392
393class TestSNESTR(BaseTestSNES, unittest.TestCase):
394    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
395
396
397# --------------------------------------------------------------------
398
399if __name__ == '__main__':
400    unittest.main()
401