xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 4cc2b5b5fe4ffd09e5956b56d7cdc4f43e324103)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy as np
7
8# --------------------------------------------------------------------
9
10
11class Function:
12    def __call__(self, snes, x, f):
13        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
14        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
15        f.assemble()
16
17
18class Jacobian:
19    def __call__(self, snes, x, J, P):
20        P[0, 0] = (2.0 * x[0] + x[1]).item()
21        P[0, 1] = (x[0]).item()
22        P[1, 0] = (x[1]).item()
23        P[1, 1] = (x[0] + 2.0 * x[1]).item()
24        P.assemble()
25        if J != P:
26            J.assemble()
27
28
29
30# --------------------------------------------------------------------
31
32
33class BaseTestSNES:
34    SNES_TYPE = None
35
36    def setUp(self):
37        snes = PETSc.SNES()
38        snes.create(PETSc.COMM_SELF)
39        if self.SNES_TYPE:
40            snes.setType(self.SNES_TYPE)
41        self.snes = snes
42
43    def tearDown(self):
44        self.snes = None
45        PETSc.garbage_cleanup()
46
47    def testGetSetType(self):
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49        self.snes.setType(self.SNES_TYPE)
50        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
51
52    def testTols(self):
53        tols = self.snes.getTolerances()
54        self.snes.setTolerances(*tols)
55        tnames = ('rtol', 'atol', 'stol', 'max_it')
56        tolvals = [getattr(self.snes, t) for t in tnames]
57        self.assertEqual(tuple(tols), tuple(tolvals))
58
59    def testProperties(self):
60        snes = self.snes
61        #
62        snes.appctx = (1, 2, 3)
63        self.assertEqual(snes.appctx, (1, 2, 3))
64        snes.appctx = None
65        self.assertEqual(snes.appctx, None)
66        #
67        snes.its = 1
68        self.assertEqual(snes.its, 1)
69        snes.its = 0
70        self.assertEqual(snes.its, 0)
71        #
72        snes.norm = 1
73        self.assertEqual(snes.norm, 1)
74        snes.norm = 0
75        self.assertEqual(snes.norm, 0)
76        #
77        rh, ih = snes.history
78        self.assertTrue(len(rh) == 0)
79        self.assertTrue(len(ih) == 0)
80        #
81        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
82        snes.reason = reason
83        self.assertEqual(snes.reason, reason)
84        self.assertTrue(snes.is_converged)
85        self.assertFalse(snes.is_diverged)
86        self.assertFalse(snes.is_iterating)
87        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
88        snes.reason = reason
89        self.assertEqual(snes.reason, reason)
90        self.assertFalse(snes.is_converged)
91        self.assertTrue(snes.is_diverged)
92        self.assertFalse(snes.is_iterating)
93        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
94        snes.reason = reason
95        self.assertEqual(snes.reason, reason)
96        self.assertFalse(snes.is_converged)
97        self.assertFalse(snes.is_diverged)
98        self.assertTrue(snes.is_iterating)
99        #
100        self.assertFalse(snes.use_ew)
101        self.assertFalse(snes.use_mf)
102        self.assertFalse(snes.use_fd)
103
104    def testGetSetFunc(self):
105        r, func = self.snes.getFunction()
106        self.assertFalse(r)
107        self.assertTrue(func is None)
108        r = PETSc.Vec().createSeq(2)
109        func = Function()
110        refcnt = getrefcount(func)
111        self.snes.setFunction(func, r)
112        self.snes.setFunction(func, r)
113        self.assertEqual(getrefcount(func), refcnt + 1)
114        r2, func2 = self.snes.getFunction()
115        self.assertEqual(r, r2)
116        self.assertEqual(func, func2[0])
117        self.assertEqual(getrefcount(func), refcnt + 1)
118        r3, func3 = self.snes.getFunction()
119        self.assertEqual(r, r3)
120        self.assertEqual(func, func3[0])
121        self.assertEqual(getrefcount(func), refcnt + 1)
122
123    def testCompFunc(self):
124        r = PETSc.Vec().createSeq(2)
125        func = Function()
126        self.snes.setFunction(func, r)
127        x, y = r.duplicate(), r.duplicate()
128        x[0], x[1] = [1, 2]
129        self.snes.computeFunction(x, y)
130        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
131        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
132
133    def testGetSetJac(self):
134        A, P, jac = self.snes.getJacobian()
135        self.assertFalse(A)
136        self.assertFalse(P)
137        self.assertTrue(jac is None)
138        J = PETSc.Mat().create(PETSc.COMM_SELF)
139        J.setSizes([2, 2])
140        J.setType(PETSc.Mat.Type.SEQAIJ)
141        J.setUp()
142        jac = Jacobian()
143        refcnt = getrefcount(jac)
144        self.snes.setJacobian(jac, J)
145        self.snes.setJacobian(jac, J)
146        self.assertEqual(getrefcount(jac), refcnt + 1)
147        J2, P2, jac2 = self.snes.getJacobian()
148        self.assertEqual(J, J2)
149        self.assertEqual(J2, P2)
150        self.assertEqual(jac, jac2[0])
151        self.assertEqual(getrefcount(jac), refcnt + 1)
152        J3, P3, jac3 = self.snes.getJacobian()
153        self.assertEqual(J, J3)
154        self.assertEqual(J3, P3)
155        self.assertEqual(jac, jac3[0])
156        self.assertEqual(getrefcount(jac), refcnt + 1)
157
158    def testCompJac(self):
159        J = PETSc.Mat().create(PETSc.COMM_SELF)
160        J.setSizes([2, 2])
161        J.setType(PETSc.Mat.Type.SEQAIJ)
162        J.setUp()
163        jac = Jacobian()
164        self.snes.setJacobian(jac, J)
165        x = PETSc.Vec().createSeq(2)
166        x[0], x[1] = [1, 2]
167        self.snes.getKSP().getPC()
168        self.snes.computeJacobian(x, J)
169
170    def testGetSetUpd(self):
171        self.assertTrue(self.snes.getUpdate() is None)
172        upd = lambda snes, it: None
173        refcnt = getrefcount(upd)
174        self.snes.setUpdate(upd)
175        self.assertEqual(getrefcount(upd), refcnt + 1)
176        self.snes.setUpdate(upd)
177        self.assertEqual(getrefcount(upd), refcnt + 1)
178        self.snes.setUpdate(None)
179        self.assertTrue(self.snes.getUpdate() is None)
180        self.assertEqual(getrefcount(upd), refcnt)
181        self.snes.setUpdate(upd)
182        self.assertEqual(getrefcount(upd), refcnt + 1)
183        upd2 = lambda snes, it: None
184        refcnt2 = getrefcount(upd2)
185        self.snes.setUpdate(upd2)
186        self.assertEqual(getrefcount(upd), refcnt)
187        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
188        tmp = self.snes.getUpdate()[0]
189        self.assertTrue(tmp is upd2)
190        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
191        del tmp
192        self.snes.setUpdate(None)
193        self.assertTrue(self.snes.getUpdate() is None)
194        self.assertEqual(getrefcount(upd2), refcnt2)
195
196    def testGetKSP(self):
197        ksp = self.snes.getKSP()
198        self.assertEqual(ksp.getRefCount(), 2)
199
200    def testSolve(self):
201        J = PETSc.Mat().create(PETSc.COMM_SELF)
202        J.setSizes([2, 2])
203        J.setType(PETSc.Mat.Type.SEQAIJ)
204        J.setUp()
205        r = PETSc.Vec().createSeq(2)
206        x = PETSc.Vec().createSeq(2)
207        b = PETSc.Vec().createSeq(2)
208        self.snes.setFunction(Function(), r)
209        self.snes.setJacobian(Jacobian(), J)
210
211        def _update(snes, it, cnt):
212             cnt += 1
213        cnt_up = np.array(0)
214        self.snes.setUpdate(_update, (cnt_up,) )
215
216        x.setArray([2, 3])
217        b.set(0)
218        self.snes.setConvergenceHistory()
219        self.snes.setFromOptions()
220        self.snes.solve(b, x)
221        self.snes.setUpdate(None)
222        rh, ih = self.snes.getConvergenceHistory()
223        self.snes.setConvergenceHistory(0, reset=True)
224        rh, ih = self.snes.getConvergenceHistory()
225        self.assertEqual(len(rh), 0)
226        self.assertEqual(len(ih), 0)
227        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
228        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
229        self.assertEqual(self.snes.getIterationNumber(), cnt_up)
230        # XXX this test should not be here !
231        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
232        self.assertTrue(reason > 0)
233
234        # test interface
235        x = self.snes.getSolution()
236        x.setArray([2, 3])
237        self.snes.solve()
238        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
239        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
240
241    def testResetAndSolve(self):
242        self.snes.reset()
243        self.testSolve()
244        self.snes.reset()
245        self.testSolve()
246        self.snes.reset()
247
248    def testSetMonitor(self):
249        reshist = {}
250
251        def monitor(snes, its, fgnorm):
252            reshist[its] = fgnorm
253
254        refcnt = getrefcount(monitor)
255        self.snes.setMonitor(monitor)
256        self.assertEqual(getrefcount(monitor), refcnt + 1)
257        self.testSolve()
258        self.assertTrue(len(reshist) > 0)
259        reshist = {}
260        self.snes.monitorCancel()
261        self.assertEqual(getrefcount(monitor), refcnt)
262        self.testSolve()
263        self.assertTrue(len(reshist) == 0)
264        self.snes.setMonitor(monitor)
265        self.snes.monitor(1, 7)
266        self.assertTrue(reshist[1] == 7)
267        ## Monitor = PETSc.SNES.Monitor
268        ## self.snes.setMonitor(Monitor())
269        ## self.snes.setMonitor(Monitor.DEFAULT)
270        ## self.snes.setMonitor(Monitor.SOLUTION)
271        ## self.snes.setMonitor(Monitor.RESIDUAL)
272        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
273
274    def testSetGetStepFails(self):
275        its = self.snes.getIterationNumber()
276        self.assertEqual(its, 0)
277        fails = self.snes.getNonlinearStepFailures()
278        self.assertEqual(fails, 0)
279        fails = self.snes.getMaxNonlinearStepFailures()
280        self.assertEqual(fails, 1)
281        self.snes.setMaxNonlinearStepFailures(5)
282        fails = self.snes.getMaxNonlinearStepFailures()
283        self.assertEqual(fails, 5)
284        self.snes.setMaxNonlinearStepFailures(1)
285        fails = self.snes.getMaxNonlinearStepFailures()
286        self.assertEqual(fails, 1)
287
288    def testSetGetLinFails(self):
289        its = self.snes.getLinearSolveIterations()
290        self.assertEqual(its, 0)
291        fails = self.snes.getLinearSolveFailures()
292        self.assertEqual(fails, 0)
293        fails = self.snes.getMaxLinearSolveFailures()
294        self.assertEqual(fails, 1)
295        self.snes.setMaxLinearSolveFailures(5)
296        fails = self.snes.getMaxLinearSolveFailures()
297        self.assertEqual(fails, 5)
298        self.snes.setMaxLinearSolveFailures(1)
299        fails = self.snes.getMaxLinearSolveFailures()
300        self.assertEqual(fails, 1)
301
302    def testEW(self):
303        self.snes.setUseEW(False)
304        self.assertFalse(self.snes.getUseEW())
305        self.snes.setUseEW(True)
306        self.assertTrue(self.snes.getUseEW())
307        params = self.snes.getParamsEW()
308        params['version'] = 1
309        self.snes.setParamsEW(**params)
310        params = self.snes.getParamsEW()
311        self.assertEqual(params['version'], 1)
312        params['version'] = PETSc.DEFAULT
313        self.snes.setParamsEW(**params)
314        params = self.snes.getParamsEW()
315        self.assertEqual(params['version'], 1)
316
317    def testMF(self):
318        # self.snes.setOptionsPrefix('MF-')
319        # opts = PETSc.Options(self.snes)
320        # opts['mat_mffd_type'] = 'ds'
321        # opts['snes_monitor']  = 'stdout'
322        # opts['ksp_monitor']   = 'stdout'
323        # opts['snes_view']     = 'stdout'
324        J = PETSc.Mat().create(PETSc.COMM_SELF)
325        J.setSizes([2, 2])
326        J.setType(PETSc.Mat.Type.SEQAIJ)
327        J.setUp()
328        r = PETSc.Vec().createSeq(2)
329        x = PETSc.Vec().createSeq(2)
330        b = PETSc.Vec().createSeq(2)
331        fun = Function()
332        jac = Jacobian()
333        self.snes.setFunction(fun, r)
334        self.snes.setJacobian(jac, J)
335        self.assertFalse(self.snes.getUseMF())
336        self.snes.setUseMF(False)
337        self.assertFalse(self.snes.getUseMF())
338        self.snes.setUseMF(True)
339        self.assertTrue(self.snes.getUseMF())
340        self.snes.setFromOptions()
341        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
342            x.setArray([2, 3])
343            b.set(0)
344            self.snes.solve(b, x)
345            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
346            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
347
348    def testFDColor(self):
349        J = PETSc.Mat().create(PETSc.COMM_SELF)
350        J.setSizes([2, 2])
351        J.setType(PETSc.Mat.Type.SEQAIJ)
352        J.setUp()
353        r = PETSc.Vec().createSeq(2)
354        x = PETSc.Vec().createSeq(2)
355        b = PETSc.Vec().createSeq(2)
356        fun = Function()
357        jac = Jacobian()
358        self.snes.setFunction(fun, r)
359        self.snes.setJacobian(jac, J)
360        self.assertFalse(self.snes.getUseFD())
361        jac(self.snes, x, J, J)
362        self.snes.setUseFD(False)
363        self.assertFalse(self.snes.getUseFD())
364        self.snes.setUseFD(True)
365        self.assertTrue(self.snes.getUseFD())
366        self.snes.setFromOptions()
367        x.setArray([2, 3])
368        b.set(0)
369        self.snes.solve(b, x)
370        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
371        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
372
373    def testNPC(self):
374        self.snes.appctx = (1, 2, 3)
375        npc = self.snes.getNPC()
376        self.assertEqual(npc.appctx, (1, 2, 3))
377
378
379# --------------------------------------------------------------------
380
381
382class TestSNESLS(BaseTestSNES, unittest.TestCase):
383    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
384
385
386class TestSNESTR(BaseTestSNES, unittest.TestCase):
387    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
388
389
390# --------------------------------------------------------------------
391
392if __name__ == '__main__':
393    unittest.main()
394