xref: /petsc/src/binding/petsc4py/demo/legacy/ode/vanderpol.py (revision d53c7dcfa6b151504f6ca0aeb7a446ad68fe4d68)
1# Testing TSAdjoint and matrix-free Jacobian
2# Basic usage:
3#     python vanderpol.py
4# Test implicit methods using implicit form:
5#     python -implicitform
6# Test explicit methods:
7#     python -implicitform 0
8# Test IMEX methods:
9#     python -imexform
10# Matrix-free implementations can be enabled with an additional option -mf
11
12import sys
13import petsc4py
14
15petsc4py.init(sys.argv)
16
17from petsc4py import PETSc
18
19
20class VDP:
21    n = 2
22    comm = PETSc.COMM_SELF
23
24    def __init__(self, mu_=1.0e3, mf_=False, imex_=False):
25        self.mu_ = mu_
26        self.mf_ = mf_
27        self.imex_ = imex_
28        if self.mf_:
29            self.Jim_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm)
30            self.Jim_.setUp()
31            self.JimP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm)
32            self.JimP_.setUp()
33            self.Jex_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm)
34            self.Jex_.setUp()
35            self.JexP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm)
36            self.JexP_.setUp()
37
38    def initialCondition(self, u):
39        mu = self.mu_
40        u[0] = 2.0
41        u[1] = -2.0 / 3.0 + 10.0 / (81.0 * mu) - 292.0 / (2187.0 * mu * mu)
42        u.assemble()
43
44    def evalFunction(self, ts, t, u, f):
45        mu = self.mu_
46        f[0] = u[1]
47        if self.imex_:
48            f[1] = 0.0
49        else:
50            f[1] = mu * ((1.0 - u[0] * u[0]) * u[1] - u[0])
51        f.assemble()
52
53    def evalJacobian(self, ts, t, u, A, B):
54        if not self.mf_:
55            J = A
56        else:
57            J = self.Jex_
58        mu = self.mu_
59        J[0, 0] = 0
60        J[0, 1] = 1.0
61        if self.imex_:
62            J[1, 0] = 0
63            J[1, 1] = 0
64        else:
65            J[1, 0] = -mu * (2.0 * u[1] * u[0] + 1.0)
66            J[1, 1] = mu * (1.0 - u[0] * u[0])
67        J.assemble()
68        if A != B:
69            B.assemble()
70
71    def evalJacobianP(self, ts, t, u, C):
72        if not self.mf_:
73            Jp = C
74        else:
75            Jp = self.JexP_
76        if not self.imex_:
77            Jp[0, 0] = 0
78            Jp[1, 0] = (1.0 - u[0] * u[0]) * u[1] - u[0]
79            Jp.assemble()
80
81    def evalIFunction(self, ts, t, u, udot, f):
82        mu = self.mu_
83        if self.imex_:
84            f[0] = udot[0]
85        else:
86            f[0] = udot[0] - u[1]
87        f[1] = udot[1] - mu * ((1.0 - u[0] * u[0]) * u[1] - u[0])
88        f.assemble()
89
90    def evalIJacobian(self, ts, t, u, udot, shift, A, B):
91        if not self.mf_:
92            J = A
93        else:
94            J = self.Jim_
95        mu = self.mu_
96        if self.imex_:
97            J[0, 0] = shift
98            J[0, 1] = 0.0
99        else:
100            J[0, 0] = shift
101            J[0, 1] = -1.0
102        J[1, 0] = mu * (2.0 * u[1] * u[0] + 1.0)
103        J[1, 1] = shift - mu * (1.0 - u[0] * u[0])
104        J.assemble()
105        if A != B:
106            B.assemble()
107
108    def evalIJacobianP(self, ts, t, u, udot, shift, C):
109        if not self.mf_:
110            Jp = C
111        else:
112            Jp = self.JimP_
113        Jp[0, 0] = 0
114        Jp[1, 0] = u[0] - (1.0 - u[0] * u[0]) * u[1]
115        Jp.assemble()
116
117
118class JacShell:
119    def __init__(self, ode):
120        self.ode_ = ode
121
122    def mult(self, A, x, y):
123        "y <- A * x"
124        self.ode_.Jex_.mult(x, y)
125
126    def multTranspose(self, A, x, y):
127        "y <- A' * x"
128        self.ode_.Jex_.multTranspose(x, y)
129
130
131class JacPShell:
132    def __init__(self, ode):
133        self.ode_ = ode
134
135    def multTranspose(self, A, x, y):
136        "y <- A' * x"
137        self.ode_.JexP_.multTranspose(x, y)
138
139
140class IJacShell:
141    def __init__(self, ode):
142        self.ode_ = ode
143
144    def mult(self, A, x, y):
145        "y <- A * x"
146        self.ode_.Jim_.mult(x, y)
147
148    def multTranspose(self, A, x, y):
149        "y <- A' * x"
150        self.ode_.Jim_.multTranspose(x, y)
151
152
153class IJacPShell:
154    def __init__(self, ode):
155        self.ode_ = ode
156
157    def multTranspose(self, A, x, y):
158        "y <- A' * x"
159        self.ode_.JimP_.multTranspose(x, y)
160
161
162OptDB = PETSc.Options()
163
164mu_ = OptDB.getScalar('mu', 1.0e3)
165mf_ = OptDB.getBool('mf', False)
166
167implicitform_ = OptDB.getBool('implicitform', False)
168imexform_ = OptDB.getBool('imexform', False)
169
170ode = VDP(mu_, mf_, imexform_)
171
172if not mf_:
173    Jim = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm)
174    Jim.setUp()
175    JimP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm)
176    JimP.setUp()
177    Jex = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm)
178    Jex.setUp()
179    JexP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm)
180    JexP.setUp()
181else:
182    Jim = PETSc.Mat().create()
183    Jim.setSizes([ode.n, ode.n])
184    Jim.setType('python')
185    shell = IJacShell(ode)
186    Jim.setPythonContext(shell)
187    Jim.setUp()
188    Jim.assemble()
189    JimP = PETSc.Mat().create()
190    JimP.setSizes([ode.n, 1])
191    JimP.setType('python')
192    shell = IJacPShell(ode)
193    JimP.setPythonContext(shell)
194    JimP.setUp()
195    JimP.assemble()
196    Jex = PETSc.Mat().create()
197    Jex.setSizes([ode.n, ode.n])
198    Jex.setType('python')
199    shell = JacShell(ode)
200    Jex.setPythonContext(shell)
201    Jex.setUp()
202    Jex.assemble()
203    JexP = PETSc.Mat().create()
204    JexP.setSizes([ode.n, 1])
205    JexP.setType('python')
206    shell = JacPShell(ode)
207    JexP.setPythonContext(shell)
208    JexP.setUp()
209    JexP.zeroEntries()
210    JexP.assemble()
211
212u = PETSc.Vec().createSeq(ode.n, comm=ode.comm)
213f = u.duplicate()
214adj_u = []
215adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm))
216adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm))
217adj_p = []
218adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm))
219adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm))
220
221ts = PETSc.TS().create(comm=ode.comm)
222ts.setProblemType(ts.ProblemType.NONLINEAR)
223
224if imexform_:
225    ts.setType(ts.Type.ARKIMEX)
226    ts.setIFunction(ode.evalIFunction, f)
227    ts.setIJacobian(ode.evalIJacobian, Jim)
228    ts.setIJacobianP(ode.evalIJacobianP, JimP)
229    ts.setRHSFunction(ode.evalFunction, f)
230    ts.setRHSJacobian(ode.evalJacobian, Jex)
231    ts.setRHSJacobianP(ode.evalJacobianP, JexP)
232else:
233    if implicitform_:
234        ts.setType(ts.Type.CN)
235        ts.setIFunction(ode.evalIFunction, f)
236        ts.setIJacobian(ode.evalIJacobian, Jim)
237        ts.setIJacobianP(ode.evalIJacobianP, JimP)
238    else:
239        ts.setType(ts.Type.RK)
240        ts.setRHSFunction(ode.evalFunction, f)
241        ts.setRHSJacobian(ode.evalJacobian, Jex)
242        ts.setRHSJacobianP(ode.evalJacobianP, JexP)
243
244ts.setSaveTrajectory()
245ts.setTime(0.0)
246ts.setTimeStep(0.001)
247ts.setMaxTime(0.5)
248ts.setMaxSteps(1000)
249ts.setExactFinalTime(PETSc.TS.ExactFinalTime.MATCHSTEP)
250
251ts.setFromOptions()
252ode.initialCondition(u)
253ts.solve(u)
254
255adj_u[0][0] = 1
256adj_u[0][1] = 0
257adj_u[0].assemble()
258adj_u[1][0] = 0
259adj_u[1][1] = 1
260adj_u[1].assemble()
261adj_p[0][0] = 0
262adj_p[0].assemble()
263adj_p[1][0] = 0
264adj_p[1].assemble()
265
266ts.setCostGradients(adj_u, adj_p)
267
268ts.adjointSolve()
269
270adj_u[0].view()
271adj_u[1].view()
272adj_p[0].view()
273adj_p[1].view()
274
275
276def compute_derp(du, dp):
277    print(
278        du[1] * (-10.0 / (81.0 * mu_ * mu_) + 2.0 * 292.0 / (2187.0 * mu_ * mu_ * mu_))
279        + dp[0]
280    )
281
282
283compute_derp(adj_u[0], adj_p[0])
284compute_derp(adj_u[1], adj_p[1])
285
286del ode, Jim, JimP, Jex, JexP, u, f, ts, adj_u, adj_p
287