xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 27f49a208b01d2e827ab9db411a2d16003fe9262)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy
7
8# --------------------------------------------------------------------
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item()
13        f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item()
14        f.assemble()
15
16class Jacobian:
17    def __call__(self, snes, x, J, P):
18        P[0,0] = (2.0*x[0] + x[1]).item()
19        P[0,1] = (x[0]).item()
20        P[1,0] = (x[1]).item()
21        P[1,1] = (x[0] + 2.0*x[1]).item()
22        P.assemble()
23        if J != P: J.assemble()
24
25# --------------------------------------------------------------------
26
27class BaseTestSNES(object):
28
29    SNES_TYPE = None
30
31    def setUp(self):
32        snes = PETSc.SNES()
33        snes.create(PETSc.COMM_SELF)
34        if self.SNES_TYPE:
35            snes.setType(self.SNES_TYPE)
36        self.snes = snes
37
38    def tearDown(self):
39        self.snes = None
40        PETSc.garbage_cleanup()
41
42    def testGetSetType(self):
43        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
44        self.snes.setType(self.SNES_TYPE)
45        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
46
47    def testTols(self):
48        tols = self.snes.getTolerances()
49        self.snes.setTolerances(*tols)
50        tnames = ('rtol', 'atol','stol', 'max_it')
51        tolvals = [getattr(self.snes, t) for t in  tnames]
52        self.assertEqual(tuple(tols), tuple(tolvals))
53
54    def testProperties(self):
55        snes = self.snes
56        #
57        snes.appctx = (1,2,3)
58        self.assertEqual(snes.appctx, (1,2,3))
59        snes.appctx = None
60        self.assertEqual(snes.appctx, None)
61        #
62        snes.its = 1
63        self.assertEqual(snes.its, 1)
64        snes.its = 0
65        self.assertEqual(snes.its, 0)
66        #
67        snes.norm = 1
68        self.assertEqual(snes.norm, 1)
69        snes.norm = 0
70        self.assertEqual(snes.norm, 0)
71        #
72        rh, ih = snes.history
73        self.assertTrue(len(rh)==0)
74        self.assertTrue(len(ih)==0)
75        #
76        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
77        snes.reason = reason
78        self.assertEqual(snes.reason, reason)
79        self.assertTrue(snes.converged)
80        self.assertFalse(snes.diverged)
81        self.assertFalse(snes.iterating)
82        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
83        snes.reason = reason
84        self.assertEqual(snes.reason, reason)
85        self.assertFalse(snes.converged)
86        self.assertTrue(snes.diverged)
87        self.assertFalse(snes.iterating)
88        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
89        snes.reason = reason
90        self.assertEqual(snes.reason, reason)
91        self.assertFalse(snes.converged)
92        self.assertFalse(snes.diverged)
93        self.assertTrue(snes.iterating)
94        #
95        self.assertFalse(snes.use_ew)
96        self.assertFalse(snes.use_mf)
97        self.assertFalse(snes.use_fd)
98
99    def testGetSetFunc(self):
100        r, func = self.snes.getFunction()
101        self.assertFalse(r)
102        self.assertTrue(func is None)
103        r = PETSc.Vec().createSeq(2)
104        func = Function()
105        refcnt = getrefcount(func)
106        self.snes.setFunction(func, r)
107        self.snes.setFunction(func, r)
108        self.assertEqual(getrefcount(func), refcnt + 1)
109        r2, func2 = self.snes.getFunction()
110        self.assertEqual(r, r2)
111        self.assertEqual(func, func2[0])
112        self.assertEqual(getrefcount(func), refcnt + 1)
113        r3, func3 = self.snes.getFunction()
114        self.assertEqual(r, r3)
115        self.assertEqual(func, func3[0])
116        self.assertEqual(getrefcount(func), refcnt + 1)
117
118    def testCompFunc(self):
119        r = PETSc.Vec().createSeq(2)
120        func = Function()
121        self.snes.setFunction(func, r)
122        x, y = r.duplicate(), r.duplicate()
123        x[0], x[1] = [1, 2]
124        self.snes.computeFunction(x, y)
125        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
126        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
127
128    def testGetSetJac(self):
129        A, P, jac = self.snes.getJacobian()
130        self.assertFalse(A)
131        self.assertFalse(P)
132        self.assertTrue(jac is None)
133        J = PETSc.Mat().create(PETSc.COMM_SELF)
134        J.setSizes([2,2])
135        J.setType(PETSc.Mat.Type.SEQAIJ)
136        J.setUp()
137        jac = Jacobian()
138        refcnt = getrefcount(jac)
139        self.snes.setJacobian(jac, J)
140        self.snes.setJacobian(jac, J)
141        self.assertEqual(getrefcount(jac), refcnt + 1)
142        J2, P2, jac2 = self.snes.getJacobian()
143        self.assertEqual(J, J2)
144        self.assertEqual(J2, P2)
145        self.assertEqual(jac, jac2[0])
146        self.assertEqual(getrefcount(jac), refcnt + 1)
147        J3, P3, jac3 = self.snes.getJacobian()
148        self.assertEqual(J, J3)
149        self.assertEqual(J3, P3)
150        self.assertEqual(jac, jac3[0])
151        self.assertEqual(getrefcount(jac), refcnt + 1)
152
153    def testCompJac(self):
154        J = PETSc.Mat().create(PETSc.COMM_SELF)
155        J.setSizes([2,2])
156        J.setType(PETSc.Mat.Type.SEQAIJ)
157        J.setUp()
158        jac = Jacobian()
159        self.snes.setJacobian(jac, J)
160        x = PETSc.Vec().createSeq(2)
161        x[0], x[1] = [1, 2]
162        self.snes.getKSP().getPC()
163        self.snes.computeJacobian(x, J)
164
165    def testGetSetUpd(self):
166        self.assertTrue(self.snes.getUpdate() is None)
167        upd = lambda snes, it: None
168        refcnt = getrefcount(upd)
169        self.snes.setUpdate(upd)
170        self.assertEqual(getrefcount(upd), refcnt + 1)
171        self.snes.setUpdate(upd)
172        self.assertEqual(getrefcount(upd), refcnt + 1)
173        self.snes.setUpdate(None)
174        self.assertTrue(self.snes.getUpdate() is None)
175        self.assertEqual(getrefcount(upd), refcnt)
176        self.snes.setUpdate(upd)
177        self.assertEqual(getrefcount(upd), refcnt + 1)
178        upd2 = lambda snes, it: None
179        refcnt2 = getrefcount(upd2)
180        self.snes.setUpdate(upd2)
181        self.assertEqual(getrefcount(upd),  refcnt)
182        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
183        tmp = self.snes.getUpdate()[0]
184        self.assertTrue(tmp is upd2)
185        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
186        del tmp
187        self.snes.setUpdate(None)
188        self.assertTrue(self.snes.getUpdate() is None)
189        self.assertEqual(getrefcount(upd2), refcnt2)
190
191    def testGetKSP(self):
192        ksp = self.snes.getKSP()
193        self.assertEqual(ksp.getRefCount(), 2)
194
195    def testSolve(self):
196        J = PETSc.Mat().create(PETSc.COMM_SELF)
197        J.setSizes([2,2])
198        J.setType(PETSc.Mat.Type.SEQAIJ)
199        J.setUp()
200        r = PETSc.Vec().createSeq(2)
201        x = PETSc.Vec().createSeq(2)
202        b = PETSc.Vec().createSeq(2)
203        self.snes.setFunction(Function(), r)
204        self.snes.setJacobian(Jacobian(), J)
205        x.setArray([2,3])
206        b.set(0)
207        self.snes.setConvergenceHistory()
208        self.snes.setFromOptions()
209        self.snes.solve(b, x)
210        rh, ih = self.snes.getConvergenceHistory()
211        self.snes.setConvergenceHistory(0, reset=True)
212        rh, ih = self.snes.getConvergenceHistory()
213        self.assertEqual(len(rh), 0)
214        self.assertEqual(len(ih), 0)
215        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
216        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
217        # XXX this test should not be here !
218        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
219        self.assertTrue(reason > 0)
220
221        # test interface
222        x = self.snes.getSolution()
223        x.setArray([2,3])
224        self.snes.solve()
225        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
226        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
227
228    def testResetAndSolve(self):
229        self.snes.reset()
230        self.testSolve()
231        self.snes.reset()
232        self.testSolve()
233        self.snes.reset()
234
235    def testSetMonitor(self):
236        reshist = {}
237        def monitor(snes, its, fgnorm):
238            reshist[its] = fgnorm
239        refcnt = getrefcount(monitor)
240        self.snes.setMonitor(monitor)
241        self.assertEqual(getrefcount(monitor), refcnt + 1)
242        self.testSolve()
243        self.assertTrue(len(reshist) > 0)
244        reshist = {}
245        self.snes.monitorCancel()
246        self.assertEqual(getrefcount(monitor), refcnt)
247        self.testSolve()
248        self.assertTrue(len(reshist) == 0)
249        self.snes.setMonitor(monitor)
250        self.snes.monitor(1, 7)
251        self.assertTrue(reshist[1] == 7)
252        ## Monitor = PETSc.SNES.Monitor
253        ## self.snes.setMonitor(Monitor())
254        ## self.snes.setMonitor(Monitor.DEFAULT)
255        ## self.snes.setMonitor(Monitor.SOLUTION)
256        ## self.snes.setMonitor(Monitor.RESIDUAL)
257        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
258
259    def testSetGetStepFails(self):
260        its = self.snes.getIterationNumber()
261        self.assertEqual(its, 0)
262        fails = self.snes.getNonlinearStepFailures()
263        self.assertEqual(fails, 0)
264        fails = self.snes.getMaxNonlinearStepFailures()
265        self.assertEqual(fails, 1)
266        self.snes.setMaxNonlinearStepFailures(5)
267        fails = self.snes.getMaxNonlinearStepFailures()
268        self.assertEqual(fails, 5)
269        self.snes.setMaxNonlinearStepFailures(1)
270        fails = self.snes.getMaxNonlinearStepFailures()
271        self.assertEqual(fails, 1)
272
273    def testSetGetLinFails(self):
274        its = self.snes.getLinearSolveIterations()
275        self.assertEqual(its, 0)
276        fails = self.snes.getLinearSolveFailures()
277        self.assertEqual(fails, 0)
278        fails = self.snes.getMaxLinearSolveFailures()
279        self.assertEqual(fails, 1)
280        self.snes.setMaxLinearSolveFailures(5)
281        fails = self.snes.getMaxLinearSolveFailures()
282        self.assertEqual(fails, 5)
283        self.snes.setMaxLinearSolveFailures(1)
284        fails = self.snes.getMaxLinearSolveFailures()
285        self.assertEqual(fails, 1)
286
287    def testEW(self):
288        self.snes.setUseEW(False)
289        self.assertFalse(self.snes.getUseEW())
290        self.snes.setUseEW(True)
291        self.assertTrue(self.snes.getUseEW())
292        params = self.snes.getParamsEW()
293        params['version'] = 1
294        self.snes.setParamsEW(**params)
295        params = self.snes.getParamsEW()
296        self.assertEqual(params['version'], 1)
297        params['version'] = PETSc.DEFAULT
298        self.snes.setParamsEW(**params)
299        params = self.snes.getParamsEW()
300        self.assertEqual(params['version'], 1)
301
302    def testMF(self):
303        #self.snes.setOptionsPrefix('MF-')
304        #opts = PETSc.Options(self.snes)
305        #opts['mat_mffd_type'] = 'ds'
306        #opts['snes_monitor']  = 'stdout'
307        #opts['ksp_monitor']   = 'stdout'
308        #opts['snes_view']     = 'stdout'
309        J = PETSc.Mat().create(PETSc.COMM_SELF)
310        J.setSizes([2,2])
311        J.setType(PETSc.Mat.Type.SEQAIJ)
312        J.setUp()
313        r = PETSc.Vec().createSeq(2)
314        x = PETSc.Vec().createSeq(2)
315        b = PETSc.Vec().createSeq(2)
316        fun = Function()
317        jac = Jacobian()
318        self.snes.setFunction(fun, r)
319        self.snes.setJacobian(jac, J)
320        self.assertFalse(self.snes.getUseMF())
321        self.snes.setUseMF(False)
322        self.assertFalse(self.snes.getUseMF())
323        self.snes.setUseMF(True)
324        self.assertTrue(self.snes.getUseMF())
325        self.snes.setFromOptions()
326        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
327            x.setArray([2,3])
328            b.set(0)
329            self.snes.solve(b, x)
330            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
331            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
332
333    def testFDColor(self):
334        J = PETSc.Mat().create(PETSc.COMM_SELF)
335        J.setSizes([2,2])
336        J.setType(PETSc.Mat.Type.SEQAIJ)
337        J.setUp()
338        r = PETSc.Vec().createSeq(2)
339        x = PETSc.Vec().createSeq(2)
340        b = PETSc.Vec().createSeq(2)
341        fun = Function()
342        jac = Jacobian()
343        self.snes.setFunction(fun, r)
344        self.snes.setJacobian(jac, J)
345        self.assertFalse(self.snes.getUseFD())
346        jac(self.snes, x, J, J)
347        self.snes.setUseFD(False)
348        self.assertFalse(self.snes.getUseFD())
349        self.snes.setUseFD(True)
350        self.assertTrue(self.snes.getUseFD())
351        self.snes.setFromOptions()
352        x.setArray([2,3])
353        b.set(0)
354        self.snes.solve(b, x)
355        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
356        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
357
358    def testNPC(self):
359        self.snes.appctx = (1,2,3)
360        npc = self.snes.getNPC()
361        self.assertEqual(npc.appctx, (1,2,3))
362
363# --------------------------------------------------------------------
364
365class TestSNESLS(BaseTestSNES, unittest.TestCase):
366    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
367
368class TestSNESTR(BaseTestSNES, unittest.TestCase):
369    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
370
371# --------------------------------------------------------------------
372
373if __name__ == '__main__':
374    unittest.main()
375