xref: /petsc/src/binding/petsc4py/test/test_ts.py (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4
5# --------------------------------------------------------------------
6
7class MyODE:
8    """
9    du/dt + u**2 = 0;
10    u0,u1,u2 = 1,2,3
11    """
12    def __init__(self):
13        self.rhsfunction_calls = 0
14        self.rhsjacobian_calls = 0
15        self.ifunction_calls = 0
16        self.ijacobian_calls = 0
17        self.presolve_calls = 0
18        self.update_calls = 0
19        self.postsolve_calls = 0
20        self.monitor_calls = 0
21
22    def rhsfunction(self,ts,t,u,F):
23        # print ('MyODE.rhsfunction()')
24        self.rhsfunction_calls += 1
25        f = -(u * u)
26        f.copy(F)
27
28    def rhsjacobian(self,ts,t,u,J,P):
29        # print ('MyODE.rhsjacobian()')
30        self.rhsjacobian_calls += 1
31        P.zeroEntries()
32        diag = -2 * u
33        P.setDiagonal(diag)
34        P.assemble()
35        if J != P: J.assemble()
36        return True # same_nz
37
38    def ifunction(self,ts,t,u,du,F):
39        # print ('MyODE.ifunction()')
40        self.ifunction_calls += 1
41        f = du + u * u
42        f.copy(F)
43
44    def ijacobian(self,ts,t,u,du,a,J,P):
45        # print ('MyODE.ijacobian()')
46        self.ijacobian_calls += 1
47        P.zeroEntries()
48        diag = a + 2 * u
49        P.setDiagonal(diag)
50        P.assemble()
51        if J != P: J.assemble()
52        return True # same_nz
53
54    def monitor(self, ts, s, t, u):
55        self.monitor_calls += 1
56        dt = ts.time_step
57        ut  = ts.vec_sol.norm()
58        #prn = PETSc.Sys.Print
59        #prn('TS: step %2d, T:%f, dT:%f, u:%f' % (s,t,dt,ut))
60
61
62class BaseTestTSNonlinear(object):
63
64    TYPE = None
65
66    def setUp(self):
67        self.ts = PETSc.TS().create(PETSc.COMM_SELF)
68        eft = PETSc.TS.ExactFinalTime.STEPOVER
69        self.ts.setExactFinalTime(eft)
70        ptype = PETSc.TS.ProblemType.NONLINEAR
71        self.ts.setProblemType(ptype)
72        self.ts.setType(self.TYPE)
73        if PETSc.ScalarType().dtype.char in 'fF':
74            snes = self.ts.getSNES()
75            snes.setTolerances(rtol=1e-6)
76
77    def tearDown(self):
78        self.ts = None
79
80
81class BaseTestTSNonlinearRHS(BaseTestTSNonlinear):
82
83    def testSolveRHS(self):
84        ts = self.ts
85        dct = self.ts.getDict()
86        self.assertTrue(dct is not None)
87        self.assertTrue(type(dct) is dict)
88
89        ode = MyODE()
90        J = PETSc.Mat().create(ts.comm)
91        J.setSizes(3);
92        J.setFromOptions()
93        J.setUp()
94        u, f = J.createVecs()
95
96        ts.setAppCtx(ode)
97        ts.setRHSFunction(ode.rhsfunction, f)
98        ts.setRHSJacobian(ode.rhsjacobian, J, J)
99        ts.setMonitor(ode.monitor)
100
101        ts.snes.ksp.pc.setType('none')
102
103        T0, dT, nT = 0.00, 0.1, 10
104        T = T0 + nT*dT
105        ts.setTime(T0)
106        ts.setTimeStep(dT)
107        ts.setMaxTime(T)
108        ts.setMaxSteps(nT)
109        ts.setFromOptions()
110        u[0], u[1], u[2] = 1, 2, 3
111        ts.solve(u)
112
113        self.assertTrue(ode.rhsfunction_calls > 0)
114        self.assertTrue(ode.rhsjacobian_calls > 0)
115
116        dct = self.ts.getDict()
117        self.assertTrue('__appctx__'      in dct)
118        self.assertTrue('__rhsfunction__' in dct)
119        self.assertTrue('__rhsjacobian__' in dct)
120        self.assertTrue('__monitor__'     in dct)
121
122        n = ode.monitor_calls
123        ts.monitor(ts.step_number, ts.time)
124        self.assertEqual(ode.monitor_calls, n+1)
125        n = ode.monitor_calls
126        ts.monitorCancel()
127        ts.monitor(ts.step_number, ts.time)
128        self.assertEqual(ode.monitor_calls, n)
129
130    def testFDColorRHS(self):
131        ts = self.ts
132        ode = MyODE()
133        J = PETSc.Mat().create(ts.comm)
134        J.setSizes(5); J.setType('aij')
135        J.setPreallocationNNZ(nnz=1)
136        u, f = J.createVecs()
137
138        ts.setAppCtx(ode)
139        ts.setRHSFunction(ode.rhsfunction, f)
140        ts.setRHSJacobian(ode.rhsjacobian, J, J)
141        ts.setMonitor(ode.monitor)
142
143        T0, dT, nT = 0.00, 0.1, 10
144        T = T0 + nT*dT
145        ts.setTime(T0)
146        ts.setTimeStep(dT)
147        ts.setMaxTime(T)
148        ts.setMaxSteps(nT)
149        ts.setFromOptions()
150        u[0], u[1], u[2] = 1, 2, 3
151
152        ts.setSolution(u)
153        ode.rhsjacobian(ts,0,u,J,J)
154        ts.setUp()
155        ts.snes.setUseFD(True)
156        ts.solve(u)
157
158    def testResetAndSolveRHS(self):
159        self.ts.reset()
160        self.ts.setStepNumber(0)
161        self.testSolveRHS()
162        self.ts.reset()
163        self.ts.setStepNumber(0)
164        self.testSolveRHS()
165        self.ts.reset()
166
167class BaseTestTSNonlinearI(BaseTestTSNonlinear):
168
169    def testSolveI(self):
170        ts = self.ts
171        dct = self.ts.getDict()
172        self.assertTrue(dct is not None)
173        self.assertTrue(type(dct) is dict)
174
175        ode = MyODE()
176        J = PETSc.Mat().create(ts.comm)
177        J.setSizes(3);
178        J.setFromOptions()
179        J.setUp()
180        u, f = J.createVecs()
181
182        ts.setAppCtx(ode)
183        ts.setIFunction(ode.ifunction, f)
184        ts.setIJacobian(ode.ijacobian, J, J)
185        ts.setMonitor(ode.monitor)
186
187        ts.snes.ksp.pc.setType('none')
188
189        T0, dT, nT = 0.00, 0.1, 10
190        T = T0 + nT*dT
191        ts.setTime(T0)
192        ts.setTimeStep(dT)
193        ts.setMaxTime(T)
194        ts.setMaxSteps(nT)
195        ts.setFromOptions()
196        u[0], u[1], u[2] = 1, 2, 3
197        ts.solve(u)
198
199        self.assertTrue(ode.ifunction_calls > 0)
200        self.assertTrue(ode.ijacobian_calls > 0)
201
202        dct = self.ts.getDict()
203        self.assertTrue('__appctx__'      in dct)
204        self.assertTrue('__ifunction__' in dct)
205        self.assertTrue('__ijacobian__' in dct)
206        self.assertTrue('__monitor__'     in dct)
207
208        n = ode.monitor_calls
209        ts.monitor(ts.step_number, ts.time)
210        self.assertEqual(ode.monitor_calls, n+1)
211        n = ode.monitor_calls
212        ts.monitorCancel()
213        ts.monitor(ts.step_number, ts.time)
214        self.assertEqual(ode.monitor_calls, n)
215
216    def testFDColorI(self):
217        ts = self.ts
218        ode = MyODE()
219        J = PETSc.Mat().create(ts.comm)
220        J.setSizes(5); J.setType('aij')
221        J.setPreallocationNNZ(nnz=1)
222        J.setFromOptions()
223        u, f = J.createVecs()
224
225        ts.setAppCtx(ode)
226        ts.setIFunction(ode.ifunction, f)
227        ts.setIJacobian(ode.ijacobian, J, J)
228        ts.setMonitor(ode.monitor)
229
230        T0, dT, nT = 0.00, 0.1, 10
231        T = T0 + nT*dT
232        ts.setTime(T0)
233        ts.setTimeStep(dT)
234        ts.setMaxTime(T)
235        ts.setMaxSteps(nT)
236        ts.setFromOptions()
237        u[0], u[1], u[2] = 1, 2, 3
238
239        ts.setSolution(u)
240        ode.ijacobian(ts,0,u,0*u,1,J,J)
241        ts.setUp()
242        ts.snes.setUseFD(True)
243        ts.solve(u)
244
245    def testResetAndSolveI(self):
246        self.ts.reset()
247        self.ts.setStepNumber(0)
248        self.testSolveI()
249        self.ts.reset()
250        self.ts.setStepNumber(0)
251        self.testSolveI()
252        self.ts.reset()
253
254class TestTSBeuler(BaseTestTSNonlinearRHS,BaseTestTSNonlinearI,
255                   unittest.TestCase):
256    TYPE = PETSc.TS.Type.BEULER
257
258class TestTSCN(BaseTestTSNonlinearRHS,BaseTestTSNonlinearI,
259               unittest.TestCase):
260    TYPE = PETSc.TS.Type.CN
261
262class TestTSTheta(BaseTestTSNonlinearRHS, BaseTestTSNonlinearI,
263                  unittest.TestCase):
264    TYPE = PETSc.TS.Type.THETA
265
266class TestTSAlpha(BaseTestTSNonlinearRHS, BaseTestTSNonlinearI,
267                  unittest.TestCase):
268    TYPE = PETSc.TS.Type.ALPHA
269
270# --------------------------------------------------------------------
271
272if __name__ == '__main__':
273    unittest.main()
274
275# --------------------------------------------------------------------
276