xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 51adc54c70eadd4bc8bbbc2f72a2dc5ea185959b)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
13        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
14        f.assemble()
15
16
17class Jacobian:
18    def __call__(self, snes, x, J, P):
19        P[0, 0] = (2.0 * x[0] + x[1]).item()
20        P[0, 1] = (x[0]).item()
21        P[1, 0] = (x[1]).item()
22        P[1, 1] = (x[0] + 2.0 * x[1]).item()
23        P.assemble()
24        if J != P:
25            J.assemble()
26
27
28# --------------------------------------------------------------------
29
30
31class BaseTestSNES:
32    SNES_TYPE = None
33
34    def setUp(self):
35        snes = PETSc.SNES()
36        snes.create(PETSc.COMM_SELF)
37        if self.SNES_TYPE:
38            snes.setType(self.SNES_TYPE)
39        self.snes = snes
40
41    def tearDown(self):
42        self.snes = None
43        PETSc.garbage_cleanup()
44
45    def testGetSetType(self):
46        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
47        self.snes.setType(self.SNES_TYPE)
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49
50    def testTols(self):
51        tols = self.snes.getTolerances()
52        self.snes.setTolerances(*tols)
53        tnames = ('rtol', 'atol', 'stol', 'max_it')
54        tolvals = [getattr(self.snes, t) for t in tnames]
55        self.assertEqual(tuple(tols), tuple(tolvals))
56        dtol = self.snes.getDivergenceTolerance()
57        self.assertTrue(dtol > 0)
58        self.snes.setDivergenceTolerance(-1)
59        dtol = self.snes.getDivergenceTolerance()
60        self.assertEqual(dtol, -1)
61        self.snes.setDivergenceTolerance(PETSc.DEFAULT)
62        self.assertEqual(dtol, -1)
63
64    def testProperties(self):
65        snes = self.snes
66        #
67        snes.appctx = (1, 2, 3)
68        self.assertEqual(snes.appctx, (1, 2, 3))
69        snes.appctx = None
70        self.assertEqual(snes.appctx, None)
71        #
72        snes.its = 1
73        self.assertEqual(snes.its, 1)
74        snes.its = 0
75        self.assertEqual(snes.its, 0)
76        #
77        snes.norm = 1
78        self.assertEqual(snes.norm, 1)
79        snes.norm = 0
80        self.assertEqual(snes.norm, 0)
81        #
82        rh, ih = snes.history
83        self.assertTrue(len(rh) == 0)
84        self.assertTrue(len(ih) == 0)
85        #
86        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
87        snes.reason = reason
88        self.assertEqual(snes.reason, reason)
89        self.assertTrue(snes.is_converged)
90        self.assertFalse(snes.is_diverged)
91        self.assertFalse(snes.is_iterating)
92        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
93        snes.reason = reason
94        self.assertEqual(snes.reason, reason)
95        self.assertFalse(snes.is_converged)
96        self.assertTrue(snes.is_diverged)
97        self.assertFalse(snes.is_iterating)
98        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
99        snes.reason = reason
100        self.assertEqual(snes.reason, reason)
101        self.assertFalse(snes.is_converged)
102        self.assertFalse(snes.is_diverged)
103        self.assertTrue(snes.is_iterating)
104        #
105        self.assertFalse(snes.use_ew)
106        self.assertFalse(snes.use_mf)
107        self.assertFalse(snes.use_fd)
108        ouse = snes.use_ksp
109        self.assertEqual(ouse, snes.getUseKSP())
110        snes.use_ksp = not ouse
111        self.assertEqual(not ouse, snes.getUseKSP())
112        snes.setUseKSP(ouse)
113        self.assertEqual(ouse, snes.use_ksp)
114
115    def testGetSetFunc(self):
116        r, func = self.snes.getFunction()
117        self.assertFalse(r)
118        self.assertTrue(func is None)
119        r = PETSc.Vec().createSeq(2)
120        func = Function()
121        refcnt = getrefcount(func)
122        self.snes.setFunction(func, r)
123        self.snes.setFunction(func, r)
124        self.assertEqual(getrefcount(func), refcnt + 1)
125        r2, func2 = self.snes.getFunction()
126        self.assertEqual(r, r2)
127        self.assertEqual(func, func2[0])
128        self.assertEqual(getrefcount(func), refcnt + 1)
129        r3, func3 = self.snes.getFunction()
130        self.assertEqual(r, r3)
131        self.assertEqual(func, func3[0])
132        self.assertEqual(getrefcount(func), refcnt + 1)
133
134    def testCompFunc(self):
135        r = PETSc.Vec().createSeq(2)
136        func = Function()
137        self.snes.setFunction(func, r)
138        x, y = r.duplicate(), r.duplicate()
139        x[0], x[1] = [1, 2]
140        self.snes.computeFunction(x, y)
141        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
142        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
143
144    def testGetSetJac(self):
145        A, P, jac = self.snes.getJacobian()
146        self.assertFalse(A)
147        self.assertFalse(P)
148        self.assertTrue(jac is None)
149        J = PETSc.Mat().create(PETSc.COMM_SELF)
150        J.setSizes([2, 2])
151        J.setType(PETSc.Mat.Type.SEQAIJ)
152        J.setUp()
153        jac = Jacobian()
154        refcnt = getrefcount(jac)
155        self.snes.setJacobian(jac, J)
156        self.snes.setJacobian(jac, J)
157        self.assertEqual(getrefcount(jac), refcnt + 1)
158        J2, P2, jac2 = self.snes.getJacobian()
159        self.assertEqual(J, J2)
160        self.assertEqual(J2, P2)
161        self.assertEqual(jac, jac2[0])
162        self.assertEqual(getrefcount(jac), refcnt + 1)
163        J3, P3, jac3 = self.snes.getJacobian()
164        self.assertEqual(J, J3)
165        self.assertEqual(J3, P3)
166        self.assertEqual(jac, jac3[0])
167        self.assertEqual(getrefcount(jac), refcnt + 1)
168
169    def testCompJac(self):
170        J = PETSc.Mat().create(PETSc.COMM_SELF)
171        J.setSizes([2, 2])
172        J.setType(PETSc.Mat.Type.SEQAIJ)
173        J.setUp()
174        jac = Jacobian()
175        self.snes.setJacobian(jac, J)
176        x = PETSc.Vec().createSeq(2)
177        x[0], x[1] = [1, 2]
178        self.snes.getKSP().getPC()
179        self.snes.computeJacobian(x, J)
180
181    def testGetSetUpd(self):
182        self.assertTrue(self.snes.getUpdate() is None)
183        upd = lambda snes, it: None
184        refcnt = getrefcount(upd)
185        self.snes.setUpdate(upd)
186        self.assertEqual(getrefcount(upd), refcnt + 1)
187        self.snes.setUpdate(upd)
188        self.assertEqual(getrefcount(upd), refcnt + 1)
189        self.snes.setUpdate(None)
190        self.assertTrue(self.snes.getUpdate() is None)
191        self.assertEqual(getrefcount(upd), refcnt)
192        self.snes.setUpdate(upd)
193        self.assertEqual(getrefcount(upd), refcnt + 1)
194        upd2 = lambda snes, it: None
195        refcnt2 = getrefcount(upd2)
196        self.snes.setUpdate(upd2)
197        self.assertEqual(getrefcount(upd), refcnt)
198        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
199        tmp = self.snes.getUpdate()[0]
200        self.assertTrue(tmp is upd2)
201        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
202        del tmp
203        self.snes.setUpdate(None)
204        self.assertTrue(self.snes.getUpdate() is None)
205        self.assertEqual(getrefcount(upd2), refcnt2)
206
207    def testGetKSP(self):
208        ksp = self.snes.getKSP()
209        self.assertEqual(ksp.getRefCount(), 2)
210
211    def testSolve(self):
212        J = PETSc.Mat().create(PETSc.COMM_SELF)
213        J.setSizes([2, 2])
214        J.setType(PETSc.Mat.Type.SEQAIJ)
215        J.setUp()
216        r = PETSc.Vec().createSeq(2)
217        x = PETSc.Vec().createSeq(2)
218        b = PETSc.Vec().createSeq(2)
219        self.snes.setFunction(Function(), r)
220        self.snes.setJacobian(Jacobian(), J)
221        x.setArray([2, 3])
222        b.set(0)
223        self.snes.setConvergenceHistory()
224        self.snes.setFromOptions()
225        self.snes.solve(b, x)
226        rh, ih = self.snes.getConvergenceHistory()
227        self.snes.setConvergenceHistory(0, reset=True)
228        rh, ih = self.snes.getConvergenceHistory()
229        self.assertEqual(len(rh), 0)
230        self.assertEqual(len(ih), 0)
231        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
232        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
233        # XXX this test should not be here !
234        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
235        self.assertTrue(reason > 0)
236
237        # test interface
238        x = self.snes.getSolution()
239        x.setArray([2, 3])
240        self.snes.solve()
241        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
242        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
243
244    def testResetAndSolve(self):
245        self.snes.reset()
246        self.testSolve()
247        self.snes.reset()
248        self.testSolve()
249        self.snes.reset()
250
251    def testSetMonitor(self):
252        reshist = {}
253
254        def monitor(snes, its, fgnorm):
255            reshist[its] = fgnorm
256
257        refcnt = getrefcount(monitor)
258        self.snes.setMonitor(monitor)
259        self.assertEqual(getrefcount(monitor), refcnt + 1)
260        self.testSolve()
261        self.assertTrue(len(reshist) > 0)
262        reshist = {}
263        self.snes.monitorCancel()
264        self.assertEqual(getrefcount(monitor), refcnt)
265        self.testSolve()
266        self.assertTrue(len(reshist) == 0)
267        self.snes.setMonitor(monitor)
268        self.snes.monitor(1, 7)
269        self.assertTrue(reshist[1] == 7)
270        ## Monitor = PETSc.SNES.Monitor
271        ## self.snes.setMonitor(Monitor())
272        ## self.snes.setMonitor(Monitor.DEFAULT)
273        ## self.snes.setMonitor(Monitor.SOLUTION)
274        ## self.snes.setMonitor(Monitor.RESIDUAL)
275        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
276
277    def testSetGetStepFails(self):
278        its = self.snes.getIterationNumber()
279        self.assertEqual(its, 0)
280        fails = self.snes.getNonlinearStepFailures()
281        self.assertEqual(fails, 0)
282        fails = self.snes.getMaxNonlinearStepFailures()
283        self.assertEqual(fails, 1)
284        self.snes.setMaxNonlinearStepFailures(5)
285        fails = self.snes.getMaxNonlinearStepFailures()
286        self.assertEqual(fails, 5)
287        self.snes.setMaxNonlinearStepFailures(1)
288        fails = self.snes.getMaxNonlinearStepFailures()
289        self.assertEqual(fails, 1)
290
291    def testSetGetLinFails(self):
292        its = self.snes.getLinearSolveIterations()
293        self.assertEqual(its, 0)
294        fails = self.snes.getLinearSolveFailures()
295        self.assertEqual(fails, 0)
296        fails = self.snes.getMaxLinearSolveFailures()
297        self.assertEqual(fails, 1)
298        self.snes.setMaxLinearSolveFailures(5)
299        fails = self.snes.getMaxLinearSolveFailures()
300        self.assertEqual(fails, 5)
301        self.snes.setMaxLinearSolveFailures(1)
302        fails = self.snes.getMaxLinearSolveFailures()
303        self.assertEqual(fails, 1)
304
305    def testEW(self):
306        self.snes.setUseEW(False)
307        self.assertFalse(self.snes.getUseEW())
308        self.snes.setUseEW(True)
309        self.assertTrue(self.snes.getUseEW())
310        params = self.snes.getParamsEW()
311        params['version'] = 1
312        self.snes.setParamsEW(**params)
313        params = self.snes.getParamsEW()
314        self.assertEqual(params['version'], 1)
315        params['version'] = PETSc.DEFAULT
316        self.snes.setParamsEW(**params)
317        params = self.snes.getParamsEW()
318        self.assertEqual(params['version'], 1)
319
320    def testMF(self):
321        # self.snes.setOptionsPrefix('MF-')
322        # opts = PETSc.Options(self.snes)
323        # opts['mat_mffd_type'] = 'ds'
324        # opts['snes_monitor']  = 'stdout'
325        # opts['ksp_monitor']   = 'stdout'
326        # opts['snes_view']     = 'stdout'
327        J = PETSc.Mat().create(PETSc.COMM_SELF)
328        J.setSizes([2, 2])
329        J.setType(PETSc.Mat.Type.SEQAIJ)
330        J.setUp()
331        r = PETSc.Vec().createSeq(2)
332        x = PETSc.Vec().createSeq(2)
333        b = PETSc.Vec().createSeq(2)
334        fun = Function()
335        jac = Jacobian()
336        self.snes.setFunction(fun, r)
337        self.snes.setJacobian(jac, J)
338        self.assertFalse(self.snes.getUseMF())
339        self.snes.setUseMF(False)
340        self.assertFalse(self.snes.getUseMF())
341        self.snes.setUseMF(True)
342        self.assertTrue(self.snes.getUseMF())
343        self.snes.setFromOptions()
344        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
345            x.setArray([2, 3])
346            b.set(0)
347            self.snes.solve(b, x)
348            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
349            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
350
351    def testFDColor(self):
352        J = PETSc.Mat().create(PETSc.COMM_SELF)
353        J.setSizes([2, 2])
354        J.setType(PETSc.Mat.Type.SEQAIJ)
355        J.setUp()
356        r = PETSc.Vec().createSeq(2)
357        x = PETSc.Vec().createSeq(2)
358        b = PETSc.Vec().createSeq(2)
359        fun = Function()
360        jac = Jacobian()
361        self.snes.setFunction(fun, r)
362        self.snes.setJacobian(jac, J)
363        self.assertFalse(self.snes.getUseFD())
364        jac(self.snes, x, J, J)
365        self.snes.setUseFD(False)
366        self.assertFalse(self.snes.getUseFD())
367        self.snes.setUseFD(True)
368        self.assertTrue(self.snes.getUseFD())
369        self.snes.setFromOptions()
370        x.setArray([2, 3])
371        b.set(0)
372        self.snes.solve(b, x)
373        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
374        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
375
376    def testNPC(self):
377        self.snes.appctx = (1, 2, 3)
378        npc = self.snes.getNPC()
379        self.assertEqual(npc.appctx, (1, 2, 3))
380
381
382# --------------------------------------------------------------------
383
384
385class TestSNESLS(BaseTestSNES, unittest.TestCase):
386    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
387
388
389class TestSNESTR(BaseTestSNES, unittest.TestCase):
390    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
391
392
393# --------------------------------------------------------------------
394
395if __name__ == '__main__':
396    unittest.main()
397