xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision a4af0ceea8a251db97ee0dc5c0d52d4adf50264a)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy
7
8# --------------------------------------------------------------------
9
10class Function:
11    def __call__(self, snes, x, f):
12        f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item()
13        f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item()
14        f.assemble()
15
16class Jacobian:
17    def __call__(self, snes, x, J, P):
18        P[0,0] = (2.0*x[0] + x[1]).item()
19        P[0,1] = (x[0]).item()
20        P[1,0] = (x[1]).item()
21        P[1,1] = (x[0] + 2.0*x[1]).item()
22        P.assemble()
23        if J != P: J.assemble()
24
25# --------------------------------------------------------------------
26
27class BaseTestSNES(object):
28
29    SNES_TYPE = None
30
31    def setUp(self):
32        snes = PETSc.SNES()
33        snes.create(PETSc.COMM_SELF)
34        if self.SNES_TYPE:
35            snes.setType(self.SNES_TYPE)
36        self.snes = snes
37
38    def tearDown(self):
39        self.snes = None
40
41    def testGetSetType(self):
42        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
43        self.snes.setType(self.SNES_TYPE)
44        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
45
46    def testTols(self):
47        tols = self.snes.getTolerances()
48        self.snes.setTolerances(*tols)
49        tnames = ('rtol', 'atol','stol', 'max_it')
50        tolvals = [getattr(self.snes, t) for t in  tnames]
51        self.assertEqual(tuple(tols), tuple(tolvals))
52
53    def testProperties(self):
54        snes = self.snes
55        #
56        snes.appctx = (1,2,3)
57        self.assertEqual(snes.appctx, (1,2,3))
58        snes.appctx = None
59        self.assertEqual(snes.appctx, None)
60        #
61        snes.its = 1
62        self.assertEqual(snes.its, 1)
63        snes.its = 0
64        self.assertEqual(snes.its, 0)
65        #
66        snes.norm = 1
67        self.assertEqual(snes.norm, 1)
68        snes.norm = 0
69        self.assertEqual(snes.norm, 0)
70        #
71        rh, ih = snes.history
72        self.assertTrue(len(rh)==0)
73        self.assertTrue(len(ih)==0)
74        #
75        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
76        snes.reason = reason
77        self.assertEqual(snes.reason, reason)
78        self.assertTrue(snes.converged)
79        self.assertFalse(snes.diverged)
80        self.assertFalse(snes.iterating)
81        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
82        snes.reason = reason
83        self.assertEqual(snes.reason, reason)
84        self.assertFalse(snes.converged)
85        self.assertTrue(snes.diverged)
86        self.assertFalse(snes.iterating)
87        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
88        snes.reason = reason
89        self.assertEqual(snes.reason, reason)
90        self.assertFalse(snes.converged)
91        self.assertFalse(snes.diverged)
92        self.assertTrue(snes.iterating)
93        #
94        self.assertFalse(snes.use_ew)
95        self.assertFalse(snes.use_mf)
96        self.assertFalse(snes.use_fd)
97
98    def testGetSetFunc(self):
99        r, func = self.snes.getFunction()
100        self.assertFalse(r)
101        self.assertTrue(func is None)
102        r = PETSc.Vec().createSeq(2)
103        func = Function()
104        refcnt = getrefcount(func)
105        self.snes.setFunction(func, r)
106        self.snes.setFunction(func, r)
107        self.assertEqual(getrefcount(func), refcnt + 1)
108        r2, func2 = self.snes.getFunction()
109        self.assertEqual(r, r2)
110        self.assertEqual(func, func2[0])
111        self.assertEqual(getrefcount(func), refcnt + 1)
112        r3, func3 = self.snes.getFunction()
113        self.assertEqual(r, r3)
114        self.assertEqual(func, func3[0])
115        self.assertEqual(getrefcount(func), refcnt + 1)
116
117    def testCompFunc(self):
118        r = PETSc.Vec().createSeq(2)
119        func = Function()
120        self.snes.setFunction(func, r)
121        x, y = r.duplicate(), r.duplicate()
122        x[0], x[1] = [1, 2]
123        self.snes.computeFunction(x, y)
124        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
125        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
126
127    def testGetSetJac(self):
128        A, P, jac = self.snes.getJacobian()
129        self.assertFalse(A)
130        self.assertFalse(P)
131        self.assertTrue(jac is None)
132        J = PETSc.Mat().create(PETSc.COMM_SELF)
133        J.setSizes([2,2])
134        J.setType(PETSc.Mat.Type.SEQAIJ)
135        J.setUp()
136        jac = Jacobian()
137        refcnt = getrefcount(jac)
138        self.snes.setJacobian(jac, J)
139        self.snes.setJacobian(jac, J)
140        self.assertEqual(getrefcount(jac), refcnt + 1)
141        J2, P2, jac2 = self.snes.getJacobian()
142        self.assertEqual(J, J2)
143        self.assertEqual(J2, P2)
144        self.assertEqual(jac, jac2[0])
145        self.assertEqual(getrefcount(jac), refcnt + 1)
146        J3, P3, jac3 = self.snes.getJacobian()
147        self.assertEqual(J, J3)
148        self.assertEqual(J3, P3)
149        self.assertEqual(jac, jac3[0])
150        self.assertEqual(getrefcount(jac), refcnt + 1)
151
152    def testCompJac(self):
153        J = PETSc.Mat().create(PETSc.COMM_SELF)
154        J.setSizes([2,2])
155        J.setType(PETSc.Mat.Type.SEQAIJ)
156        J.setUp()
157        jac = Jacobian()
158        self.snes.setJacobian(jac, J)
159        x = PETSc.Vec().createSeq(2)
160        x[0], x[1] = [1, 2]
161        self.snes.getKSP().getPC()
162        self.snes.computeJacobian(x, J)
163
164    def testGetSetUpd(self):
165        self.assertTrue(self.snes.getUpdate() is None)
166        upd = lambda snes, it: None
167        refcnt = getrefcount(upd)
168        self.snes.setUpdate(upd)
169        self.assertEqual(getrefcount(upd), refcnt + 1)
170        self.snes.setUpdate(upd)
171        self.assertEqual(getrefcount(upd), refcnt + 1)
172        self.snes.setUpdate(None)
173        self.assertTrue(self.snes.getUpdate() is None)
174        self.assertEqual(getrefcount(upd), refcnt)
175        self.snes.setUpdate(upd)
176        self.assertEqual(getrefcount(upd), refcnt + 1)
177        upd2 = lambda snes, it: None
178        refcnt2 = getrefcount(upd2)
179        self.snes.setUpdate(upd2)
180        self.assertEqual(getrefcount(upd),  refcnt)
181        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
182        tmp = self.snes.getUpdate()[0]
183        self.assertTrue(tmp is upd2)
184        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
185        del tmp
186        self.snes.setUpdate(None)
187        self.assertTrue(self.snes.getUpdate() is None)
188        self.assertEqual(getrefcount(upd2), refcnt2)
189
190    def testGetKSP(self):
191        ksp = self.snes.getKSP()
192        self.assertEqual(ksp.getRefCount(), 2)
193
194    def testSolve(self):
195        J = PETSc.Mat().create(PETSc.COMM_SELF)
196        J.setSizes([2,2])
197        J.setType(PETSc.Mat.Type.SEQAIJ)
198        J.setUp()
199        r = PETSc.Vec().createSeq(2)
200        x = PETSc.Vec().createSeq(2)
201        b = PETSc.Vec().createSeq(2)
202        self.snes.setFunction(Function(), r)
203        self.snes.setJacobian(Jacobian(), J)
204        x.setArray([2,3])
205        b.set(0)
206        self.snes.setConvergenceHistory()
207        self.snes.setFromOptions()
208        self.snes.solve(b, x)
209        rh, ih = self.snes.getConvergenceHistory()
210        self.snes.setConvergenceHistory(0, reset=True)
211        rh, ih = self.snes.getConvergenceHistory()
212        self.assertEqual(len(rh), 0)
213        self.assertEqual(len(ih), 0)
214        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
215        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
216        # XXX this test should not be here !
217        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
218        self.assertTrue(reason > 0)
219
220    def testResetAndSolve(self):
221        self.snes.reset()
222        self.testSolve()
223        self.snes.reset()
224        self.testSolve()
225        self.snes.reset()
226
227    def testSetMonitor(self):
228        reshist = {}
229        def monitor(snes, its, fgnorm):
230            reshist[its] = fgnorm
231        refcnt = getrefcount(monitor)
232        self.snes.setMonitor(monitor)
233        self.assertEqual(getrefcount(monitor), refcnt + 1)
234        self.testSolve()
235        self.assertTrue(len(reshist) > 0)
236        reshist = {}
237        self.snes.monitorCancel()
238        self.assertEqual(getrefcount(monitor), refcnt)
239        self.testSolve()
240        self.assertTrue(len(reshist) == 0)
241        self.snes.setMonitor(monitor)
242        self.snes.monitor(1, 7)
243        self.assertTrue(reshist[1] == 7)
244        ## Monitor = PETSc.SNES.Monitor
245        ## self.snes.setMonitor(Monitor())
246        ## self.snes.setMonitor(Monitor.DEFAULT)
247        ## self.snes.setMonitor(Monitor.SOLUTION)
248        ## self.snes.setMonitor(Monitor.RESIDUAL)
249        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
250
251    def testSetGetStepFails(self):
252        its = self.snes.getIterationNumber()
253        self.assertEqual(its, 0)
254        fails = self.snes.getNonlinearStepFailures()
255        self.assertEqual(fails, 0)
256        fails = self.snes.getMaxNonlinearStepFailures()
257        self.assertEqual(fails, 1)
258        self.snes.setMaxNonlinearStepFailures(5)
259        fails = self.snes.getMaxNonlinearStepFailures()
260        self.assertEqual(fails, 5)
261        self.snes.setMaxNonlinearStepFailures(1)
262        fails = self.snes.getMaxNonlinearStepFailures()
263        self.assertEqual(fails, 1)
264
265    def testSetGetLinFails(self):
266        its = self.snes.getLinearSolveIterations()
267        self.assertEqual(its, 0)
268        fails = self.snes.getLinearSolveFailures()
269        self.assertEqual(fails, 0)
270        fails = self.snes.getMaxLinearSolveFailures()
271        self.assertEqual(fails, 1)
272        self.snes.setMaxLinearSolveFailures(5)
273        fails = self.snes.getMaxLinearSolveFailures()
274        self.assertEqual(fails, 5)
275        self.snes.setMaxLinearSolveFailures(1)
276        fails = self.snes.getMaxLinearSolveFailures()
277        self.assertEqual(fails, 1)
278
279    def testEW(self):
280        self.snes.setUseEW(False)
281        self.assertFalse(self.snes.getUseEW())
282        self.snes.setUseEW(True)
283        self.assertTrue(self.snes.getUseEW())
284        params = self.snes.getParamsEW()
285        params['version'] = 1
286        self.snes.setParamsEW(**params)
287        params = self.snes.getParamsEW()
288        self.assertEqual(params['version'], 1)
289        params['version'] = PETSc.DEFAULT
290        self.snes.setParamsEW(**params)
291        params = self.snes.getParamsEW()
292        self.assertEqual(params['version'], 1)
293
294    def testMF(self):
295        #self.snes.setOptionsPrefix('MF-')
296        #opts = PETSc.Options(self.snes)
297        #opts['mat_mffd_type'] = 'ds'
298        #opts['snes_monitor']  = 'stdout'
299        #opts['ksp_monitor']   = 'stdout'
300        #opts['snes_view']     = 'stdout'
301        J = PETSc.Mat().create(PETSc.COMM_SELF)
302        J.setSizes([2,2])
303        J.setType(PETSc.Mat.Type.SEQAIJ)
304        J.setUp()
305        r = PETSc.Vec().createSeq(2)
306        x = PETSc.Vec().createSeq(2)
307        b = PETSc.Vec().createSeq(2)
308        fun = Function()
309        jac = Jacobian()
310        self.snes.setFunction(fun, r)
311        self.snes.setJacobian(jac, J)
312        self.assertFalse(self.snes.getUseMF())
313        self.snes.setUseMF(False)
314        self.assertFalse(self.snes.getUseMF())
315        self.snes.setUseMF(True)
316        self.assertTrue(self.snes.getUseMF())
317        self.snes.setFromOptions()
318        x.setArray([2,3])
319        b.set(0)
320        self.snes.solve(b, x)
321        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
322        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
323
324    def testFDColor(self):
325        J = PETSc.Mat().create(PETSc.COMM_SELF)
326        J.setSizes([2,2])
327        J.setType(PETSc.Mat.Type.SEQAIJ)
328        J.setUp()
329        r = PETSc.Vec().createSeq(2)
330        x = PETSc.Vec().createSeq(2)
331        b = PETSc.Vec().createSeq(2)
332        fun = Function()
333        jac = Jacobian()
334        self.snes.setFunction(fun, r)
335        self.snes.setJacobian(jac, J)
336        self.assertFalse(self.snes.getUseFD())
337        jac(self.snes, x, J, J)
338        self.snes.setUseFD(False)
339        self.assertFalse(self.snes.getUseFD())
340        self.snes.setUseFD(True)
341        self.assertTrue(self.snes.getUseFD())
342        self.snes.setFromOptions()
343        x.setArray([2,3])
344        b.set(0)
345        self.snes.solve(b, x)
346        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
347        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
348
349# --------------------------------------------------------------------
350
351class TestSNESLS(BaseTestSNES, unittest.TestCase):
352    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
353
354class TestSNESTR(BaseTestSNES, unittest.TestCase):
355    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
356
357# --------------------------------------------------------------------
358
359if __name__ == '__main__':
360    unittest.main()
361