xref: /petsc/src/binding/petsc4py/demo/legacy/ode/vanderpol.py (revision bcee047adeeb73090d7e36cc71e39fc287cdbb97)
1# Testing TSAdjoint and matrix-free Jacobian
2# Basic usage:
3#     python vanderpol.py
4# Test implicit methods using implicit form:
5#     python -implicitform
6# Test explicit methods:
7#     python -implicitform 0
8# Test IMEX methods:
9#     python -imexform
10# Matrix-free implementations can be enabled with an additional option -mf
11
12import sys, petsc4py
13petsc4py.init(sys.argv)
14
15from petsc4py import PETSc
16
17class VDP(object):
18    n = 2
19    comm = PETSc.COMM_SELF
20    def __init__(self, mu_=1.0e3, mf_=False, imex_=False):
21        self.mu_ = mu_
22        self.mf_ = mf_
23        self.imex_ = imex_
24        if self.mf_:
25            self.Jim_ = PETSc.Mat().createDense([self.n,self.n], comm=self.comm)
26            self.Jim_.setUp()
27            self.JimP_ = PETSc.Mat().createDense([self.n,1], comm=self.comm)
28            self.JimP_.setUp()
29            self.Jex_ = PETSc.Mat().createDense([self.n,self.n], comm=self.comm)
30            self.Jex_.setUp()
31            self.JexP_ = PETSc.Mat().createDense([self.n,1], comm=self.comm)
32            self.JexP_.setUp()
33    def initialCondition(self, u):
34        mu = self.mu_
35        u[0] = 2.0
36        u[1] = -2.0/3.0 + 10.0/(81.0*mu) - 292.0/(2187.0*mu*mu)
37        u.assemble()
38    def evalFunction(self, ts, t, u, f):
39        mu = self.mu_
40        f[0] = u[1]
41        if self.imex_:
42            f[1] = 0.0
43        else:
44            f[1] = mu*((1.-u[0]*u[0])*u[1]-u[0])
45        f.assemble()
46    def evalJacobian(self, ts, t, u, A, B):
47        if not self.mf_:
48            J = A
49        else :
50            J = self.Jex_
51        mu = self.mu_
52        J[0,0] = 0
53        J[0,1] = 1.0
54        if self.imex_:
55            J[1,0] = 0
56            J[1,1] = 0
57        else:
58            J[1,0] = -mu*(2.0*u[1]*u[0]+1.)
59            J[1,1] = mu*(1.0-u[0]*u[0])
60        J.assemble()
61        if A != B: B.assemble()
62        return True # same nonzero pattern
63    def evalJacobianP(self, ts, t, u, C):
64        if not self.mf_:
65            Jp = C
66        else:
67            Jp = self.JexP_
68        if not self.imex_:
69            Jp[0,0] = 0
70            Jp[1,0] = (1.-u[0]*u[0])*u[1]-u[0]
71            Jp.assemble()
72        return True
73    def evalIFunction(self, ts, t, u, udot, f):
74        mu = self.mu_
75        if self.imex_:
76            f[0] = udot[0]
77        else:
78            f[0] = udot[0]-u[1]
79        f[1] = udot[1]-mu*((1.-u[0]*u[0])*u[1]-u[0])
80        f.assemble()
81    def evalIJacobian(self, ts, t, u, udot, shift, A, B):
82        if not self.mf_:
83            J = A
84        else :
85            J = self.Jim_
86        mu = self.mu_
87        if self.imex_:
88            J[0,0] = shift
89            J[0,1] = 0.0
90        else:
91            J[0,0] = shift
92            J[0,1] = -1.0
93        J[1,0] = mu*(2.0*u[1]*u[0]+1.)
94        J[1,1] = shift-mu*(1.0-u[0]*u[0])
95        J.assemble()
96        if A != B: B.assemble()
97        return True # same nonzero pattern
98    def evalIJacobianP(self, ts, t, u, udot, shift, C):
99        if not self.mf_:
100            Jp = C
101        else:
102            Jp = self.JimP_
103        Jp[0,0] = 0
104        Jp[1,0] = u[0]-(1.-u[0]*u[0])*u[1]
105        Jp.assemble()
106        return True
107
108class JacShell:
109    def __init__(self, ode):
110        self.ode_ = ode
111    def mult(self, A, x, y):
112        "y <- A * x"
113        self.ode_.Jex_.mult(x,y)
114    def multTranspose(self, A, x, y):
115        "y <- A' * x"
116        self.ode_.Jex_.multTranspose(x, y)
117
118class JacPShell:
119    def __init__(self, ode):
120        self.ode_ = ode
121    def multTranspose(self, A, x, y):
122        "y <- A' * x"
123        self.ode_.JexP_.multTranspose(x, y)
124
125class IJacShell:
126    def __init__(self, ode):
127        self.ode_ = ode
128    def mult(self, A, x, y):
129        "y <- A * x"
130        self.ode_.Jim_.mult(x,y)
131    def multTranspose(self, A, x, y):
132        "y <- A' * x"
133        self.ode_.Jim_.multTranspose(x, y)
134
135class IJacPShell:
136    def __init__(self, ode):
137        self.ode_ = ode
138    def multTranspose(self, A, x, y):
139        "y <- A' * x"
140        self.ode_.JimP_.multTranspose(x, y)
141
142OptDB = PETSc.Options()
143
144mu_ = OptDB.getScalar('mu', 1.0e3)
145mf_ = OptDB.getBool('mf', False)
146
147implicitform_ = OptDB.getBool('implicitform', False)
148imexform_ = OptDB.getBool('imexform', False)
149
150ode = VDP(mu_,mf_,imexform_)
151
152if not mf_:
153    Jim = PETSc.Mat().createDense([ode.n,ode.n], comm=ode.comm)
154    Jim.setUp()
155    JimP = PETSc.Mat().createDense([ode.n,1], comm=ode.comm)
156    JimP.setUp()
157    Jex = PETSc.Mat().createDense([ode.n,ode.n], comm=ode.comm)
158    Jex.setUp()
159    JexP = PETSc.Mat().createDense([ode.n,1], comm=ode.comm)
160    JexP.setUp()
161else:
162    Jim = PETSc.Mat().create()
163    Jim.setSizes([ode.n,ode.n])
164    Jim.setType('python')
165    shell = IJacShell(ode)
166    Jim.setPythonContext(shell)
167    Jim.setUp()
168    Jim.assemble()
169    JimP = PETSc.Mat().create()
170    JimP.setSizes([ode.n,1])
171    JimP.setType('python')
172    shell = IJacPShell(ode)
173    JimP.setPythonContext(shell)
174    JimP.setUp()
175    JimP.assemble()
176    Jex = PETSc.Mat().create()
177    Jex.setSizes([ode.n,ode.n])
178    Jex.setType('python')
179    shell = JacShell(ode)
180    Jex.setPythonContext(shell)
181    Jex.setUp()
182    Jex.assemble()
183    JexP = PETSc.Mat().create()
184    JexP.setSizes([ode.n,1])
185    JexP.setType('python')
186    shell = JacPShell(ode)
187    JexP.setPythonContext(shell)
188    JexP.setUp()
189    JexP.zeroEntries()
190    JexP.assemble()
191
192u = PETSc.Vec().createSeq(ode.n, comm=ode.comm)
193f = u.duplicate()
194adj_u = []
195adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm))
196adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm))
197adj_p = []
198adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm))
199adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm))
200
201ts = PETSc.TS().create(comm=ode.comm)
202ts.setProblemType(ts.ProblemType.NONLINEAR)
203
204if imexform_:
205    ts.setType(ts.Type.ARKIMEX)
206    ts.setIFunction(ode.evalIFunction, f)
207    ts.setIJacobian(ode.evalIJacobian, Jim)
208    ts.setIJacobianP(ode.evalIJacobianP, JimP)
209    ts.setRHSFunction(ode.evalFunction, f)
210    ts.setRHSJacobian(ode.evalJacobian, Jex)
211    ts.setRHSJacobianP(ode.evalJacobianP, JexP)
212else:
213    if implicitform_:
214        ts.setType(ts.Type.CN)
215        ts.setIFunction(ode.evalIFunction, f)
216        ts.setIJacobian(ode.evalIJacobian, Jim)
217        ts.setIJacobianP(ode.evalIJacobianP, JimP)
218    else:
219        ts.setType(ts.Type.RK)
220        ts.setRHSFunction(ode.evalFunction, f)
221        ts.setRHSJacobian(ode.evalJacobian, Jex)
222        ts.setRHSJacobianP(ode.evalJacobianP, JexP)
223
224ts.setSaveTrajectory()
225ts.setTime(0.0)
226ts.setTimeStep(0.001)
227ts.setMaxTime(0.5)
228ts.setMaxSteps(1000)
229ts.setExactFinalTime(PETSc.TS.ExactFinalTime.MATCHSTEP)
230
231ts.setFromOptions()
232ode.initialCondition(u)
233ts.solve(u)
234
235adj_u[0][0] = 1
236adj_u[0][1] = 0
237adj_u[0].assemble()
238adj_u[1][0] = 0
239adj_u[1][1] = 1
240adj_u[1].assemble()
241adj_p[0][0] = 0
242adj_p[0].assemble()
243adj_p[1][0] = 0
244adj_p[1].assemble()
245
246ts.setCostGradients(adj_u,adj_p)
247
248ts.adjointSolve()
249
250adj_u[0].view()
251adj_u[1].view()
252adj_p[0].view()
253adj_p[1].view()
254
255def compute_derp(du,dp):
256    print(du[1]*(-10.0/(81.0*mu_*mu_)+2.0*292.0/(2187.0*mu_*mu_*mu_))+dp[0])
257
258compute_derp(adj_u[0],adj_p[0])
259compute_derp(adj_u[1],adj_p[1])
260
261del ode, Jim, JimP, Jex, JexP, u, f, ts, adj_u, adj_p
262