xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 2d240a76ffdefda4afcdc14c5a261325b30d32dc)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
13        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
14        f.assemble()
15
16
17class Jacobian:
18    def __call__(self, snes, x, J, P):
19        P[0, 0] = (2.0 * x[0] + x[1]).item()
20        P[0, 1] = (x[0]).item()
21        P[1, 0] = (x[1]).item()
22        P[1, 1] = (x[0] + 2.0 * x[1]).item()
23        P.assemble()
24        if J != P:
25            J.assemble()
26
27
28# --------------------------------------------------------------------
29
30
31class BaseTestSNES:
32    SNES_TYPE = None
33
34    def setUp(self):
35        snes = PETSc.SNES()
36        snes.create(PETSc.COMM_SELF)
37        if self.SNES_TYPE:
38            snes.setType(self.SNES_TYPE)
39        self.snes = snes
40
41    def tearDown(self):
42        self.snes = None
43        PETSc.garbage_cleanup()
44
45    def testGetSetType(self):
46        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
47        self.snes.setType(self.SNES_TYPE)
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49
50    def testTols(self):
51        tols = self.snes.getTolerances()
52        self.snes.setTolerances(*tols)
53        tnames = ('rtol', 'atol', 'stol', 'max_it')
54        tolvals = [getattr(self.snes, t) for t in tnames]
55        self.assertEqual(tuple(tols), tuple(tolvals))
56
57    def testProperties(self):
58        snes = self.snes
59        #
60        snes.appctx = (1, 2, 3)
61        self.assertEqual(snes.appctx, (1, 2, 3))
62        snes.appctx = None
63        self.assertEqual(snes.appctx, None)
64        #
65        snes.its = 1
66        self.assertEqual(snes.its, 1)
67        snes.its = 0
68        self.assertEqual(snes.its, 0)
69        #
70        snes.norm = 1
71        self.assertEqual(snes.norm, 1)
72        snes.norm = 0
73        self.assertEqual(snes.norm, 0)
74        #
75        rh, ih = snes.history
76        self.assertTrue(len(rh) == 0)
77        self.assertTrue(len(ih) == 0)
78        #
79        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
80        snes.reason = reason
81        self.assertEqual(snes.reason, reason)
82        self.assertTrue(snes.is_converged)
83        self.assertFalse(snes.is_diverged)
84        self.assertFalse(snes.is_iterating)
85        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
86        snes.reason = reason
87        self.assertEqual(snes.reason, reason)
88        self.assertFalse(snes.is_converged)
89        self.assertTrue(snes.is_diverged)
90        self.assertFalse(snes.is_iterating)
91        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
92        snes.reason = reason
93        self.assertEqual(snes.reason, reason)
94        self.assertFalse(snes.is_converged)
95        self.assertFalse(snes.is_diverged)
96        self.assertTrue(snes.is_iterating)
97        #
98        self.assertFalse(snes.use_ew)
99        self.assertFalse(snes.use_mf)
100        self.assertFalse(snes.use_fd)
101
102    def testGetSetFunc(self):
103        r, func = self.snes.getFunction()
104        self.assertFalse(r)
105        self.assertTrue(func is None)
106        r = PETSc.Vec().createSeq(2)
107        func = Function()
108        refcnt = getrefcount(func)
109        self.snes.setFunction(func, r)
110        self.snes.setFunction(func, r)
111        self.assertEqual(getrefcount(func), refcnt + 1)
112        r2, func2 = self.snes.getFunction()
113        self.assertEqual(r, r2)
114        self.assertEqual(func, func2[0])
115        self.assertEqual(getrefcount(func), refcnt + 1)
116        r3, func3 = self.snes.getFunction()
117        self.assertEqual(r, r3)
118        self.assertEqual(func, func3[0])
119        self.assertEqual(getrefcount(func), refcnt + 1)
120
121    def testCompFunc(self):
122        r = PETSc.Vec().createSeq(2)
123        func = Function()
124        self.snes.setFunction(func, r)
125        x, y = r.duplicate(), r.duplicate()
126        x[0], x[1] = [1, 2]
127        self.snes.computeFunction(x, y)
128        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
129        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
130
131    def testGetSetJac(self):
132        A, P, jac = self.snes.getJacobian()
133        self.assertFalse(A)
134        self.assertFalse(P)
135        self.assertTrue(jac is None)
136        J = PETSc.Mat().create(PETSc.COMM_SELF)
137        J.setSizes([2, 2])
138        J.setType(PETSc.Mat.Type.SEQAIJ)
139        J.setUp()
140        jac = Jacobian()
141        refcnt = getrefcount(jac)
142        self.snes.setJacobian(jac, J)
143        self.snes.setJacobian(jac, J)
144        self.assertEqual(getrefcount(jac), refcnt + 1)
145        J2, P2, jac2 = self.snes.getJacobian()
146        self.assertEqual(J, J2)
147        self.assertEqual(J2, P2)
148        self.assertEqual(jac, jac2[0])
149        self.assertEqual(getrefcount(jac), refcnt + 1)
150        J3, P3, jac3 = self.snes.getJacobian()
151        self.assertEqual(J, J3)
152        self.assertEqual(J3, P3)
153        self.assertEqual(jac, jac3[0])
154        self.assertEqual(getrefcount(jac), refcnt + 1)
155
156    def testCompJac(self):
157        J = PETSc.Mat().create(PETSc.COMM_SELF)
158        J.setSizes([2, 2])
159        J.setType(PETSc.Mat.Type.SEQAIJ)
160        J.setUp()
161        jac = Jacobian()
162        self.snes.setJacobian(jac, J)
163        x = PETSc.Vec().createSeq(2)
164        x[0], x[1] = [1, 2]
165        self.snes.getKSP().getPC()
166        self.snes.computeJacobian(x, J)
167
168    def testGetSetUpd(self):
169        self.assertTrue(self.snes.getUpdate() is None)
170        upd = lambda snes, it: None
171        refcnt = getrefcount(upd)
172        self.snes.setUpdate(upd)
173        self.assertEqual(getrefcount(upd), refcnt + 1)
174        self.snes.setUpdate(upd)
175        self.assertEqual(getrefcount(upd), refcnt + 1)
176        self.snes.setUpdate(None)
177        self.assertTrue(self.snes.getUpdate() is None)
178        self.assertEqual(getrefcount(upd), refcnt)
179        self.snes.setUpdate(upd)
180        self.assertEqual(getrefcount(upd), refcnt + 1)
181        upd2 = lambda snes, it: None
182        refcnt2 = getrefcount(upd2)
183        self.snes.setUpdate(upd2)
184        self.assertEqual(getrefcount(upd), refcnt)
185        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
186        tmp = self.snes.getUpdate()[0]
187        self.assertTrue(tmp is upd2)
188        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
189        del tmp
190        self.snes.setUpdate(None)
191        self.assertTrue(self.snes.getUpdate() is None)
192        self.assertEqual(getrefcount(upd2), refcnt2)
193
194    def testGetKSP(self):
195        ksp = self.snes.getKSP()
196        self.assertEqual(ksp.getRefCount(), 2)
197
198    def testSolve(self):
199        J = PETSc.Mat().create(PETSc.COMM_SELF)
200        J.setSizes([2, 2])
201        J.setType(PETSc.Mat.Type.SEQAIJ)
202        J.setUp()
203        r = PETSc.Vec().createSeq(2)
204        x = PETSc.Vec().createSeq(2)
205        b = PETSc.Vec().createSeq(2)
206        self.snes.setFunction(Function(), r)
207        self.snes.setJacobian(Jacobian(), J)
208        x.setArray([2, 3])
209        b.set(0)
210        self.snes.setConvergenceHistory()
211        self.snes.setFromOptions()
212        self.snes.solve(b, x)
213        rh, ih = self.snes.getConvergenceHistory()
214        self.snes.setConvergenceHistory(0, reset=True)
215        rh, ih = self.snes.getConvergenceHistory()
216        self.assertEqual(len(rh), 0)
217        self.assertEqual(len(ih), 0)
218        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
219        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
220        # XXX this test should not be here !
221        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
222        self.assertTrue(reason > 0)
223
224        # test interface
225        x = self.snes.getSolution()
226        x.setArray([2, 3])
227        self.snes.solve()
228        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
229        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
230
231    def testResetAndSolve(self):
232        self.snes.reset()
233        self.testSolve()
234        self.snes.reset()
235        self.testSolve()
236        self.snes.reset()
237
238    def testSetMonitor(self):
239        reshist = {}
240
241        def monitor(snes, its, fgnorm):
242            reshist[its] = fgnorm
243
244        refcnt = getrefcount(monitor)
245        self.snes.setMonitor(monitor)
246        self.assertEqual(getrefcount(monitor), refcnt + 1)
247        self.testSolve()
248        self.assertTrue(len(reshist) > 0)
249        reshist = {}
250        self.snes.monitorCancel()
251        self.assertEqual(getrefcount(monitor), refcnt)
252        self.testSolve()
253        self.assertTrue(len(reshist) == 0)
254        self.snes.setMonitor(monitor)
255        self.snes.monitor(1, 7)
256        self.assertTrue(reshist[1] == 7)
257        ## Monitor = PETSc.SNES.Monitor
258        ## self.snes.setMonitor(Monitor())
259        ## self.snes.setMonitor(Monitor.DEFAULT)
260        ## self.snes.setMonitor(Monitor.SOLUTION)
261        ## self.snes.setMonitor(Monitor.RESIDUAL)
262        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
263
264    def testSetGetStepFails(self):
265        its = self.snes.getIterationNumber()
266        self.assertEqual(its, 0)
267        fails = self.snes.getNonlinearStepFailures()
268        self.assertEqual(fails, 0)
269        fails = self.snes.getMaxNonlinearStepFailures()
270        self.assertEqual(fails, 1)
271        self.snes.setMaxNonlinearStepFailures(5)
272        fails = self.snes.getMaxNonlinearStepFailures()
273        self.assertEqual(fails, 5)
274        self.snes.setMaxNonlinearStepFailures(1)
275        fails = self.snes.getMaxNonlinearStepFailures()
276        self.assertEqual(fails, 1)
277
278    def testSetGetLinFails(self):
279        its = self.snes.getLinearSolveIterations()
280        self.assertEqual(its, 0)
281        fails = self.snes.getLinearSolveFailures()
282        self.assertEqual(fails, 0)
283        fails = self.snes.getMaxLinearSolveFailures()
284        self.assertEqual(fails, 1)
285        self.snes.setMaxLinearSolveFailures(5)
286        fails = self.snes.getMaxLinearSolveFailures()
287        self.assertEqual(fails, 5)
288        self.snes.setMaxLinearSolveFailures(1)
289        fails = self.snes.getMaxLinearSolveFailures()
290        self.assertEqual(fails, 1)
291
292    def testEW(self):
293        self.snes.setUseEW(False)
294        self.assertFalse(self.snes.getUseEW())
295        self.snes.setUseEW(True)
296        self.assertTrue(self.snes.getUseEW())
297        params = self.snes.getParamsEW()
298        params['version'] = 1
299        self.snes.setParamsEW(**params)
300        params = self.snes.getParamsEW()
301        self.assertEqual(params['version'], 1)
302        params['version'] = PETSc.DEFAULT
303        self.snes.setParamsEW(**params)
304        params = self.snes.getParamsEW()
305        self.assertEqual(params['version'], 1)
306
307    def testMF(self):
308        # self.snes.setOptionsPrefix('MF-')
309        # opts = PETSc.Options(self.snes)
310        # opts['mat_mffd_type'] = 'ds'
311        # opts['snes_monitor']  = 'stdout'
312        # opts['ksp_monitor']   = 'stdout'
313        # opts['snes_view']     = 'stdout'
314        J = PETSc.Mat().create(PETSc.COMM_SELF)
315        J.setSizes([2, 2])
316        J.setType(PETSc.Mat.Type.SEQAIJ)
317        J.setUp()
318        r = PETSc.Vec().createSeq(2)
319        x = PETSc.Vec().createSeq(2)
320        b = PETSc.Vec().createSeq(2)
321        fun = Function()
322        jac = Jacobian()
323        self.snes.setFunction(fun, r)
324        self.snes.setJacobian(jac, J)
325        self.assertFalse(self.snes.getUseMF())
326        self.snes.setUseMF(False)
327        self.assertFalse(self.snes.getUseMF())
328        self.snes.setUseMF(True)
329        self.assertTrue(self.snes.getUseMF())
330        self.snes.setFromOptions()
331        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
332            x.setArray([2, 3])
333            b.set(0)
334            self.snes.solve(b, x)
335            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
336            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
337
338    def testFDColor(self):
339        J = PETSc.Mat().create(PETSc.COMM_SELF)
340        J.setSizes([2, 2])
341        J.setType(PETSc.Mat.Type.SEQAIJ)
342        J.setUp()
343        r = PETSc.Vec().createSeq(2)
344        x = PETSc.Vec().createSeq(2)
345        b = PETSc.Vec().createSeq(2)
346        fun = Function()
347        jac = Jacobian()
348        self.snes.setFunction(fun, r)
349        self.snes.setJacobian(jac, J)
350        self.assertFalse(self.snes.getUseFD())
351        jac(self.snes, x, J, J)
352        self.snes.setUseFD(False)
353        self.assertFalse(self.snes.getUseFD())
354        self.snes.setUseFD(True)
355        self.assertTrue(self.snes.getUseFD())
356        self.snes.setFromOptions()
357        x.setArray([2, 3])
358        b.set(0)
359        self.snes.solve(b, x)
360        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
361        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
362
363    def testNPC(self):
364        self.snes.appctx = (1, 2, 3)
365        npc = self.snes.getNPC()
366        self.assertEqual(npc.appctx, (1, 2, 3))
367
368
369# --------------------------------------------------------------------
370
371
372class TestSNESLS(BaseTestSNES, unittest.TestCase):
373    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
374
375
376class TestSNESTR(BaseTestSNES, unittest.TestCase):
377    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
378
379
380# --------------------------------------------------------------------
381
382if __name__ == '__main__':
383    unittest.main()
384