xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision f97672e55eacc8688507b9471cd7ec2664d7f203)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy
7
8# --------------------------------------------------------------------
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item()
13        f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item()
14        f.assemble()
15
16class Jacobian:
17    def __call__(self, snes, x, J, P):
18        P[0,0] = (2.0*x[0] + x[1]).item()
19        P[0,1] = (x[0]).item()
20        P[1,0] = (x[1]).item()
21        P[1,1] = (x[0] + 2.0*x[1]).item()
22        P.assemble()
23        if J != P: J.assemble()
24
25# --------------------------------------------------------------------
26
27class BaseTestSNES(object):
28
29    SNES_TYPE = None
30
31    def setUp(self):
32        snes = PETSc.SNES()
33        snes.create(PETSc.COMM_SELF)
34        if self.SNES_TYPE:
35            snes.setType(self.SNES_TYPE)
36        self.snes = snes
37
38    def tearDown(self):
39        self.snes = None
40
41    def testGetSetType(self):
42        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
43        self.snes.setType(self.SNES_TYPE)
44        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
45
46    def testTols(self):
47        tols = self.snes.getTolerances()
48        self.snes.setTolerances(*tols)
49        tnames = ('rtol', 'atol','stol', 'max_it')
50        tolvals = [getattr(self.snes, t) for t in  tnames]
51        self.assertEqual(tuple(tols), tuple(tolvals))
52
53    def testProperties(self):
54        snes = self.snes
55        #
56        snes.appctx = (1,2,3)
57        self.assertEqual(snes.appctx, (1,2,3))
58        snes.appctx = None
59        self.assertEqual(snes.appctx, None)
60        #
61        snes.its = 1
62        self.assertEqual(snes.its, 1)
63        snes.its = 0
64        self.assertEqual(snes.its, 0)
65        #
66        snes.norm = 1
67        self.assertEqual(snes.norm, 1)
68        snes.norm = 0
69        self.assertEqual(snes.norm, 0)
70        #
71        rh, ih = snes.history
72        self.assertTrue(len(rh)==0)
73        self.assertTrue(len(ih)==0)
74        #
75        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
76        snes.reason = reason
77        self.assertEqual(snes.reason, reason)
78        self.assertTrue(snes.converged)
79        self.assertFalse(snes.diverged)
80        self.assertFalse(snes.iterating)
81        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
82        snes.reason = reason
83        self.assertEqual(snes.reason, reason)
84        self.assertFalse(snes.converged)
85        self.assertTrue(snes.diverged)
86        self.assertFalse(snes.iterating)
87        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
88        snes.reason = reason
89        self.assertEqual(snes.reason, reason)
90        self.assertFalse(snes.converged)
91        self.assertFalse(snes.diverged)
92        self.assertTrue(snes.iterating)
93        #
94        self.assertFalse(snes.use_ew)
95        self.assertFalse(snes.use_mf)
96        self.assertFalse(snes.use_fd)
97
98    def testGetSetFunc(self):
99        r, func = self.snes.getFunction()
100        self.assertFalse(r)
101        self.assertTrue(func is None)
102        r = PETSc.Vec().createSeq(2)
103        func = Function()
104        refcnt = getrefcount(func)
105        self.snes.setFunction(func, r)
106        self.snes.setFunction(func, r)
107        self.assertEqual(getrefcount(func), refcnt + 1)
108        r2, func2 = self.snes.getFunction()
109        self.assertEqual(r, r2)
110        self.assertEqual(func, func2[0])
111        self.assertEqual(getrefcount(func), refcnt + 1)
112        r3, func3 = self.snes.getFunction()
113        self.assertEqual(r, r3)
114        self.assertEqual(func, func3[0])
115        self.assertEqual(getrefcount(func), refcnt + 1)
116
117    def testCompFunc(self):
118        r = PETSc.Vec().createSeq(2)
119        func = Function()
120        self.snes.setFunction(func, r)
121        x, y = r.duplicate(), r.duplicate()
122        x[0], x[1] = [1, 2]
123        self.snes.computeFunction(x, y)
124        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
125        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
126
127    def testGetSetJac(self):
128        A, P, jac = self.snes.getJacobian()
129        self.assertFalse(A)
130        self.assertFalse(P)
131        self.assertTrue(jac is None)
132        J = PETSc.Mat().create(PETSc.COMM_SELF)
133        J.setSizes([2,2])
134        J.setType(PETSc.Mat.Type.SEQAIJ)
135        J.setUp()
136        jac = Jacobian()
137        refcnt = getrefcount(jac)
138        self.snes.setJacobian(jac, J)
139        self.snes.setJacobian(jac, J)
140        self.assertEqual(getrefcount(jac), refcnt + 1)
141        J2, P2, jac2 = self.snes.getJacobian()
142        self.assertEqual(J, J2)
143        self.assertEqual(J2, P2)
144        self.assertEqual(jac, jac2[0])
145        self.assertEqual(getrefcount(jac), refcnt + 1)
146        J3, P3, jac3 = self.snes.getJacobian()
147        self.assertEqual(J, J3)
148        self.assertEqual(J3, P3)
149        self.assertEqual(jac, jac3[0])
150        self.assertEqual(getrefcount(jac), refcnt + 1)
151
152    def testCompJac(self):
153        J = PETSc.Mat().create(PETSc.COMM_SELF)
154        J.setSizes([2,2])
155        J.setType(PETSc.Mat.Type.SEQAIJ)
156        J.setUp()
157        jac = Jacobian()
158        self.snes.setJacobian(jac, J)
159        x = PETSc.Vec().createSeq(2)
160        x[0], x[1] = [1, 2]
161        self.snes.getKSP().getPC()
162        self.snes.computeJacobian(x, J)
163
164    def testGetSetUpd(self):
165        self.assertTrue(self.snes.getUpdate() is None)
166        upd = lambda snes, it: None
167        refcnt = getrefcount(upd)
168        self.snes.setUpdate(upd)
169        self.assertEqual(getrefcount(upd), refcnt + 1)
170        self.snes.setUpdate(upd)
171        self.assertEqual(getrefcount(upd), refcnt + 1)
172        self.snes.setUpdate(None)
173        self.assertTrue(self.snes.getUpdate() is None)
174        self.assertEqual(getrefcount(upd), refcnt)
175        self.snes.setUpdate(upd)
176        self.assertEqual(getrefcount(upd), refcnt + 1)
177        upd2 = lambda snes, it: None
178        refcnt2 = getrefcount(upd2)
179        self.snes.setUpdate(upd2)
180        self.assertEqual(getrefcount(upd),  refcnt)
181        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
182        tmp = self.snes.getUpdate()[0]
183        self.assertTrue(tmp is upd2)
184        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
185        del tmp
186        self.snes.setUpdate(None)
187        self.assertTrue(self.snes.getUpdate() is None)
188        self.assertEqual(getrefcount(upd2), refcnt2)
189
190    def testGetKSP(self):
191        ksp = self.snes.getKSP()
192        self.assertEqual(ksp.getRefCount(), 2)
193
194    def testSolve(self):
195        J = PETSc.Mat().create(PETSc.COMM_SELF)
196        J.setSizes([2,2])
197        J.setType(PETSc.Mat.Type.SEQAIJ)
198        J.setUp()
199        r = PETSc.Vec().createSeq(2)
200        x = PETSc.Vec().createSeq(2)
201        b = PETSc.Vec().createSeq(2)
202        self.snes.setFunction(Function(), r)
203        self.snes.setJacobian(Jacobian(), J)
204        x.setArray([2,3])
205        b.set(0)
206        self.snes.setConvergenceHistory()
207        self.snes.setFromOptions()
208        self.snes.solve(b, x)
209        rh, ih = self.snes.getConvergenceHistory()
210        self.snes.setConvergenceHistory(0, reset=True)
211        rh, ih = self.snes.getConvergenceHistory()
212        self.assertEqual(len(rh), 0)
213        self.assertEqual(len(ih), 0)
214        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
215        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
216        # XXX this test should not be here !
217        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
218        self.assertTrue(reason > 0)
219
220        # test interface
221        x = self.snes.getSolution()
222        x.setArray([2,3])
223        self.snes.solve()
224        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
225        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
226
227    def testResetAndSolve(self):
228        self.snes.reset()
229        self.testSolve()
230        self.snes.reset()
231        self.testSolve()
232        self.snes.reset()
233
234    def testSetMonitor(self):
235        reshist = {}
236        def monitor(snes, its, fgnorm):
237            reshist[its] = fgnorm
238        refcnt = getrefcount(monitor)
239        self.snes.setMonitor(monitor)
240        self.assertEqual(getrefcount(monitor), refcnt + 1)
241        self.testSolve()
242        self.assertTrue(len(reshist) > 0)
243        reshist = {}
244        self.snes.monitorCancel()
245        self.assertEqual(getrefcount(monitor), refcnt)
246        self.testSolve()
247        self.assertTrue(len(reshist) == 0)
248        self.snes.setMonitor(monitor)
249        self.snes.monitor(1, 7)
250        self.assertTrue(reshist[1] == 7)
251        ## Monitor = PETSc.SNES.Monitor
252        ## self.snes.setMonitor(Monitor())
253        ## self.snes.setMonitor(Monitor.DEFAULT)
254        ## self.snes.setMonitor(Monitor.SOLUTION)
255        ## self.snes.setMonitor(Monitor.RESIDUAL)
256        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
257
258    def testSetGetStepFails(self):
259        its = self.snes.getIterationNumber()
260        self.assertEqual(its, 0)
261        fails = self.snes.getNonlinearStepFailures()
262        self.assertEqual(fails, 0)
263        fails = self.snes.getMaxNonlinearStepFailures()
264        self.assertEqual(fails, 1)
265        self.snes.setMaxNonlinearStepFailures(5)
266        fails = self.snes.getMaxNonlinearStepFailures()
267        self.assertEqual(fails, 5)
268        self.snes.setMaxNonlinearStepFailures(1)
269        fails = self.snes.getMaxNonlinearStepFailures()
270        self.assertEqual(fails, 1)
271
272    def testSetGetLinFails(self):
273        its = self.snes.getLinearSolveIterations()
274        self.assertEqual(its, 0)
275        fails = self.snes.getLinearSolveFailures()
276        self.assertEqual(fails, 0)
277        fails = self.snes.getMaxLinearSolveFailures()
278        self.assertEqual(fails, 1)
279        self.snes.setMaxLinearSolveFailures(5)
280        fails = self.snes.getMaxLinearSolveFailures()
281        self.assertEqual(fails, 5)
282        self.snes.setMaxLinearSolveFailures(1)
283        fails = self.snes.getMaxLinearSolveFailures()
284        self.assertEqual(fails, 1)
285
286    def testEW(self):
287        self.snes.setUseEW(False)
288        self.assertFalse(self.snes.getUseEW())
289        self.snes.setUseEW(True)
290        self.assertTrue(self.snes.getUseEW())
291        params = self.snes.getParamsEW()
292        params['version'] = 1
293        self.snes.setParamsEW(**params)
294        params = self.snes.getParamsEW()
295        self.assertEqual(params['version'], 1)
296        params['version'] = PETSc.DEFAULT
297        self.snes.setParamsEW(**params)
298        params = self.snes.getParamsEW()
299        self.assertEqual(params['version'], 1)
300
301    def testMF(self):
302        #self.snes.setOptionsPrefix('MF-')
303        #opts = PETSc.Options(self.snes)
304        #opts['mat_mffd_type'] = 'ds'
305        #opts['snes_monitor']  = 'stdout'
306        #opts['ksp_monitor']   = 'stdout'
307        #opts['snes_view']     = 'stdout'
308        J = PETSc.Mat().create(PETSc.COMM_SELF)
309        J.setSizes([2,2])
310        J.setType(PETSc.Mat.Type.SEQAIJ)
311        J.setUp()
312        r = PETSc.Vec().createSeq(2)
313        x = PETSc.Vec().createSeq(2)
314        b = PETSc.Vec().createSeq(2)
315        fun = Function()
316        jac = Jacobian()
317        self.snes.setFunction(fun, r)
318        self.snes.setJacobian(jac, J)
319        self.assertFalse(self.snes.getUseMF())
320        self.snes.setUseMF(False)
321        self.assertFalse(self.snes.getUseMF())
322        self.snes.setUseMF(True)
323        self.assertTrue(self.snes.getUseMF())
324        self.snes.setFromOptions()
325        x.setArray([2,3])
326        b.set(0)
327        self.snes.solve(b, x)
328        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
329        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
330
331    def testFDColor(self):
332        J = PETSc.Mat().create(PETSc.COMM_SELF)
333        J.setSizes([2,2])
334        J.setType(PETSc.Mat.Type.SEQAIJ)
335        J.setUp()
336        r = PETSc.Vec().createSeq(2)
337        x = PETSc.Vec().createSeq(2)
338        b = PETSc.Vec().createSeq(2)
339        fun = Function()
340        jac = Jacobian()
341        self.snes.setFunction(fun, r)
342        self.snes.setJacobian(jac, J)
343        self.assertFalse(self.snes.getUseFD())
344        jac(self.snes, x, J, J)
345        self.snes.setUseFD(False)
346        self.assertFalse(self.snes.getUseFD())
347        self.snes.setUseFD(True)
348        self.assertTrue(self.snes.getUseFD())
349        self.snes.setFromOptions()
350        x.setArray([2,3])
351        b.set(0)
352        self.snes.solve(b, x)
353        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
354        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
355
356# --------------------------------------------------------------------
357
358class TestSNESLS(BaseTestSNES, unittest.TestCase):
359    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
360
361class TestSNESTR(BaseTestSNES, unittest.TestCase):
362    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
363
364# --------------------------------------------------------------------
365
366if __name__ == '__main__':
367    unittest.main()
368