155a74a43SLisandro Dalcin# Testing TSAdjoint and matrix-free Jacobian 255a74a43SLisandro Dalcin# Basic usage: 355a74a43SLisandro Dalcin# python vanderpol.py 455a74a43SLisandro Dalcin# Test implicit methods using implicit form: 555a74a43SLisandro Dalcin# python -implicitform 655a74a43SLisandro Dalcin# Test explicit methods: 755a74a43SLisandro Dalcin# python -implicitform 0 855a74a43SLisandro Dalcin# Test IMEX methods: 955a74a43SLisandro Dalcin# python -imexform 1055a74a43SLisandro Dalcin# Matrix-free implementations can be enabled with an additional option -mf 1155a74a43SLisandro Dalcin 12*69777137SStefano Zampiniimport sys 13*69777137SStefano Zampiniimport petsc4py 14*69777137SStefano Zampini 1555a74a43SLisandro Dalcinpetsc4py.init(sys.argv) 1655a74a43SLisandro Dalcin 1755a74a43SLisandro Dalcinfrom petsc4py import PETSc 1855a74a43SLisandro Dalcin 19*69777137SStefano Zampini 20*69777137SStefano Zampiniclass VDP: 2155a74a43SLisandro Dalcin n = 2 2255a74a43SLisandro Dalcin comm = PETSc.COMM_SELF 23*69777137SStefano Zampini 2455a74a43SLisandro Dalcin def __init__(self, mu_=1.0e3, mf_=False, imex_=False): 2555a74a43SLisandro Dalcin self.mu_ = mu_ 2655a74a43SLisandro Dalcin self.mf_ = mf_ 2755a74a43SLisandro Dalcin self.imex_ = imex_ 2855a74a43SLisandro Dalcin if self.mf_: 2955a74a43SLisandro Dalcin self.Jim_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm) 3055a74a43SLisandro Dalcin self.Jim_.setUp() 3155a74a43SLisandro Dalcin self.JimP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm) 3255a74a43SLisandro Dalcin self.JimP_.setUp() 3355a74a43SLisandro Dalcin self.Jex_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm) 3455a74a43SLisandro Dalcin self.Jex_.setUp() 3555a74a43SLisandro Dalcin self.JexP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm) 3655a74a43SLisandro Dalcin self.JexP_.setUp() 37*69777137SStefano Zampini 3855a74a43SLisandro Dalcin def initialCondition(self, u): 3955a74a43SLisandro Dalcin mu = self.mu_ 4055a74a43SLisandro Dalcin u[0] = 2.0 4155a74a43SLisandro Dalcin u[1] = -2.0 / 3.0 + 10.0 / (81.0 * mu) - 292.0 / (2187.0 * mu * mu) 4255a74a43SLisandro Dalcin u.assemble() 43*69777137SStefano Zampini 4455a74a43SLisandro Dalcin def evalFunction(self, ts, t, u, f): 4555a74a43SLisandro Dalcin mu = self.mu_ 4655a74a43SLisandro Dalcin f[0] = u[1] 4755a74a43SLisandro Dalcin if self.imex_: 4855a74a43SLisandro Dalcin f[1] = 0.0 4955a74a43SLisandro Dalcin else: 50*69777137SStefano Zampini f[1] = mu * ((1.0 - u[0] * u[0]) * u[1] - u[0]) 5155a74a43SLisandro Dalcin f.assemble() 52*69777137SStefano Zampini 5355a74a43SLisandro Dalcin def evalJacobian(self, ts, t, u, A, B): 5455a74a43SLisandro Dalcin if not self.mf_: 5555a74a43SLisandro Dalcin J = A 5655a74a43SLisandro Dalcin else: 5755a74a43SLisandro Dalcin J = self.Jex_ 5855a74a43SLisandro Dalcin mu = self.mu_ 5955a74a43SLisandro Dalcin J[0, 0] = 0 6055a74a43SLisandro Dalcin J[0, 1] = 1.0 6155a74a43SLisandro Dalcin if self.imex_: 6255a74a43SLisandro Dalcin J[1, 0] = 0 6355a74a43SLisandro Dalcin J[1, 1] = 0 6455a74a43SLisandro Dalcin else: 65*69777137SStefano Zampini J[1, 0] = -mu * (2.0 * u[1] * u[0] + 1.0) 6655a74a43SLisandro Dalcin J[1, 1] = mu * (1.0 - u[0] * u[0]) 6755a74a43SLisandro Dalcin J.assemble() 68*69777137SStefano Zampini if A != B: 69*69777137SStefano Zampini B.assemble() 70*69777137SStefano Zampini 7155a74a43SLisandro Dalcin def evalJacobianP(self, ts, t, u, C): 7255a74a43SLisandro Dalcin if not self.mf_: 7355a74a43SLisandro Dalcin Jp = C 7455a74a43SLisandro Dalcin else: 7555a74a43SLisandro Dalcin Jp = self.JexP_ 7655a74a43SLisandro Dalcin if not self.imex_: 7755a74a43SLisandro Dalcin Jp[0, 0] = 0 78*69777137SStefano Zampini Jp[1, 0] = (1.0 - u[0] * u[0]) * u[1] - u[0] 7955a74a43SLisandro Dalcin Jp.assemble() 80*69777137SStefano Zampini 8155a74a43SLisandro Dalcin def evalIFunction(self, ts, t, u, udot, f): 8255a74a43SLisandro Dalcin mu = self.mu_ 8355a74a43SLisandro Dalcin if self.imex_: 8455a74a43SLisandro Dalcin f[0] = udot[0] 8555a74a43SLisandro Dalcin else: 8655a74a43SLisandro Dalcin f[0] = udot[0] - u[1] 87*69777137SStefano Zampini f[1] = udot[1] - mu * ((1.0 - u[0] * u[0]) * u[1] - u[0]) 8855a74a43SLisandro Dalcin f.assemble() 89*69777137SStefano Zampini 9055a74a43SLisandro Dalcin def evalIJacobian(self, ts, t, u, udot, shift, A, B): 9155a74a43SLisandro Dalcin if not self.mf_: 9255a74a43SLisandro Dalcin J = A 9355a74a43SLisandro Dalcin else: 9455a74a43SLisandro Dalcin J = self.Jim_ 9555a74a43SLisandro Dalcin mu = self.mu_ 9655a74a43SLisandro Dalcin if self.imex_: 9755a74a43SLisandro Dalcin J[0, 0] = shift 9855a74a43SLisandro Dalcin J[0, 1] = 0.0 9955a74a43SLisandro Dalcin else: 10055a74a43SLisandro Dalcin J[0, 0] = shift 10155a74a43SLisandro Dalcin J[0, 1] = -1.0 102*69777137SStefano Zampini J[1, 0] = mu * (2.0 * u[1] * u[0] + 1.0) 10355a74a43SLisandro Dalcin J[1, 1] = shift - mu * (1.0 - u[0] * u[0]) 10455a74a43SLisandro Dalcin J.assemble() 105*69777137SStefano Zampini if A != B: 106*69777137SStefano Zampini B.assemble() 107*69777137SStefano Zampini 10855a74a43SLisandro Dalcin def evalIJacobianP(self, ts, t, u, udot, shift, C): 10955a74a43SLisandro Dalcin if not self.mf_: 11055a74a43SLisandro Dalcin Jp = C 11155a74a43SLisandro Dalcin else: 11255a74a43SLisandro Dalcin Jp = self.JimP_ 11355a74a43SLisandro Dalcin Jp[0, 0] = 0 114*69777137SStefano Zampini Jp[1, 0] = u[0] - (1.0 - u[0] * u[0]) * u[1] 11555a74a43SLisandro Dalcin Jp.assemble() 116*69777137SStefano Zampini 11755a74a43SLisandro Dalcin 11855a74a43SLisandro Dalcinclass JacShell: 11955a74a43SLisandro Dalcin def __init__(self, ode): 12055a74a43SLisandro Dalcin self.ode_ = ode 121*69777137SStefano Zampini 12255a74a43SLisandro Dalcin def mult(self, A, x, y): 12355a74a43SLisandro Dalcin "y <- A * x" 12455a74a43SLisandro Dalcin self.ode_.Jex_.mult(x, y) 125*69777137SStefano Zampini 12655a74a43SLisandro Dalcin def multTranspose(self, A, x, y): 12755a74a43SLisandro Dalcin "y <- A' * x" 12855a74a43SLisandro Dalcin self.ode_.Jex_.multTranspose(x, y) 12955a74a43SLisandro Dalcin 130*69777137SStefano Zampini 13155a74a43SLisandro Dalcinclass JacPShell: 13255a74a43SLisandro Dalcin def __init__(self, ode): 13355a74a43SLisandro Dalcin self.ode_ = ode 134*69777137SStefano Zampini 13555a74a43SLisandro Dalcin def multTranspose(self, A, x, y): 13655a74a43SLisandro Dalcin "y <- A' * x" 13755a74a43SLisandro Dalcin self.ode_.JexP_.multTranspose(x, y) 13855a74a43SLisandro Dalcin 139*69777137SStefano Zampini 14055a74a43SLisandro Dalcinclass IJacShell: 14155a74a43SLisandro Dalcin def __init__(self, ode): 14255a74a43SLisandro Dalcin self.ode_ = ode 143*69777137SStefano Zampini 14455a74a43SLisandro Dalcin def mult(self, A, x, y): 14555a74a43SLisandro Dalcin "y <- A * x" 14655a74a43SLisandro Dalcin self.ode_.Jim_.mult(x, y) 147*69777137SStefano Zampini 14855a74a43SLisandro Dalcin def multTranspose(self, A, x, y): 14955a74a43SLisandro Dalcin "y <- A' * x" 15055a74a43SLisandro Dalcin self.ode_.Jim_.multTranspose(x, y) 15155a74a43SLisandro Dalcin 152*69777137SStefano Zampini 15355a74a43SLisandro Dalcinclass IJacPShell: 15455a74a43SLisandro Dalcin def __init__(self, ode): 15555a74a43SLisandro Dalcin self.ode_ = ode 156*69777137SStefano Zampini 15755a74a43SLisandro Dalcin def multTranspose(self, A, x, y): 15855a74a43SLisandro Dalcin "y <- A' * x" 15955a74a43SLisandro Dalcin self.ode_.JimP_.multTranspose(x, y) 16055a74a43SLisandro Dalcin 161*69777137SStefano Zampini 16255a74a43SLisandro DalcinOptDB = PETSc.Options() 16355a74a43SLisandro Dalcin 16455a74a43SLisandro Dalcinmu_ = OptDB.getScalar('mu', 1.0e3) 16555a74a43SLisandro Dalcinmf_ = OptDB.getBool('mf', False) 16655a74a43SLisandro Dalcin 16755a74a43SLisandro Dalcinimplicitform_ = OptDB.getBool('implicitform', False) 16855a74a43SLisandro Dalcinimexform_ = OptDB.getBool('imexform', False) 16955a74a43SLisandro Dalcin 17055a74a43SLisandro Dalcinode = VDP(mu_, mf_, imexform_) 17155a74a43SLisandro Dalcin 17255a74a43SLisandro Dalcinif not mf_: 17355a74a43SLisandro Dalcin Jim = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm) 17455a74a43SLisandro Dalcin Jim.setUp() 17555a74a43SLisandro Dalcin JimP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm) 17655a74a43SLisandro Dalcin JimP.setUp() 17755a74a43SLisandro Dalcin Jex = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm) 17855a74a43SLisandro Dalcin Jex.setUp() 17955a74a43SLisandro Dalcin JexP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm) 18055a74a43SLisandro Dalcin JexP.setUp() 18155a74a43SLisandro Dalcinelse: 18255a74a43SLisandro Dalcin Jim = PETSc.Mat().create() 18355a74a43SLisandro Dalcin Jim.setSizes([ode.n, ode.n]) 18455a74a43SLisandro Dalcin Jim.setType('python') 18555a74a43SLisandro Dalcin shell = IJacShell(ode) 18655a74a43SLisandro Dalcin Jim.setPythonContext(shell) 18755a74a43SLisandro Dalcin Jim.setUp() 18855a74a43SLisandro Dalcin Jim.assemble() 18955a74a43SLisandro Dalcin JimP = PETSc.Mat().create() 19055a74a43SLisandro Dalcin JimP.setSizes([ode.n, 1]) 19155a74a43SLisandro Dalcin JimP.setType('python') 19255a74a43SLisandro Dalcin shell = IJacPShell(ode) 19355a74a43SLisandro Dalcin JimP.setPythonContext(shell) 19455a74a43SLisandro Dalcin JimP.setUp() 19555a74a43SLisandro Dalcin JimP.assemble() 19655a74a43SLisandro Dalcin Jex = PETSc.Mat().create() 19755a74a43SLisandro Dalcin Jex.setSizes([ode.n, ode.n]) 19855a74a43SLisandro Dalcin Jex.setType('python') 19955a74a43SLisandro Dalcin shell = JacShell(ode) 20055a74a43SLisandro Dalcin Jex.setPythonContext(shell) 20155a74a43SLisandro Dalcin Jex.setUp() 20255a74a43SLisandro Dalcin Jex.assemble() 20355a74a43SLisandro Dalcin JexP = PETSc.Mat().create() 20455a74a43SLisandro Dalcin JexP.setSizes([ode.n, 1]) 20555a74a43SLisandro Dalcin JexP.setType('python') 20655a74a43SLisandro Dalcin shell = JacPShell(ode) 20755a74a43SLisandro Dalcin JexP.setPythonContext(shell) 20855a74a43SLisandro Dalcin JexP.setUp() 20955a74a43SLisandro Dalcin JexP.zeroEntries() 21055a74a43SLisandro Dalcin JexP.assemble() 21155a74a43SLisandro Dalcin 21255a74a43SLisandro Dalcinu = PETSc.Vec().createSeq(ode.n, comm=ode.comm) 21355a74a43SLisandro Dalcinf = u.duplicate() 21455a74a43SLisandro Dalcinadj_u = [] 21555a74a43SLisandro Dalcinadj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 21655a74a43SLisandro Dalcinadj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 21755a74a43SLisandro Dalcinadj_p = [] 21855a74a43SLisandro Dalcinadj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 21955a74a43SLisandro Dalcinadj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 22055a74a43SLisandro Dalcin 22155a74a43SLisandro Dalcints = PETSc.TS().create(comm=ode.comm) 22255a74a43SLisandro Dalcints.setProblemType(ts.ProblemType.NONLINEAR) 22355a74a43SLisandro Dalcin 22455a74a43SLisandro Dalcinif imexform_: 22555a74a43SLisandro Dalcin ts.setType(ts.Type.ARKIMEX) 22655a74a43SLisandro Dalcin ts.setIFunction(ode.evalIFunction, f) 22755a74a43SLisandro Dalcin ts.setIJacobian(ode.evalIJacobian, Jim) 22855a74a43SLisandro Dalcin ts.setIJacobianP(ode.evalIJacobianP, JimP) 22955a74a43SLisandro Dalcin ts.setRHSFunction(ode.evalFunction, f) 23055a74a43SLisandro Dalcin ts.setRHSJacobian(ode.evalJacobian, Jex) 23155a74a43SLisandro Dalcin ts.setRHSJacobianP(ode.evalJacobianP, JexP) 23255a74a43SLisandro Dalcinelse: 23355a74a43SLisandro Dalcin if implicitform_: 23455a74a43SLisandro Dalcin ts.setType(ts.Type.CN) 23555a74a43SLisandro Dalcin ts.setIFunction(ode.evalIFunction, f) 23655a74a43SLisandro Dalcin ts.setIJacobian(ode.evalIJacobian, Jim) 23755a74a43SLisandro Dalcin ts.setIJacobianP(ode.evalIJacobianP, JimP) 23855a74a43SLisandro Dalcin else: 23955a74a43SLisandro Dalcin ts.setType(ts.Type.RK) 24055a74a43SLisandro Dalcin ts.setRHSFunction(ode.evalFunction, f) 24155a74a43SLisandro Dalcin ts.setRHSJacobian(ode.evalJacobian, Jex) 24255a74a43SLisandro Dalcin ts.setRHSJacobianP(ode.evalJacobianP, JexP) 24355a74a43SLisandro Dalcin 24455a74a43SLisandro Dalcints.setSaveTrajectory() 24555a74a43SLisandro Dalcints.setTime(0.0) 24655a74a43SLisandro Dalcints.setTimeStep(0.001) 24755a74a43SLisandro Dalcints.setMaxTime(0.5) 24855a74a43SLisandro Dalcints.setMaxSteps(1000) 24955a74a43SLisandro Dalcints.setExactFinalTime(PETSc.TS.ExactFinalTime.MATCHSTEP) 25055a74a43SLisandro Dalcin 25155a74a43SLisandro Dalcints.setFromOptions() 25255a74a43SLisandro Dalcinode.initialCondition(u) 25355a74a43SLisandro Dalcints.solve(u) 25455a74a43SLisandro Dalcin 25555a74a43SLisandro Dalcinadj_u[0][0] = 1 25655a74a43SLisandro Dalcinadj_u[0][1] = 0 25755a74a43SLisandro Dalcinadj_u[0].assemble() 25855a74a43SLisandro Dalcinadj_u[1][0] = 0 25955a74a43SLisandro Dalcinadj_u[1][1] = 1 26055a74a43SLisandro Dalcinadj_u[1].assemble() 26155a74a43SLisandro Dalcinadj_p[0][0] = 0 26255a74a43SLisandro Dalcinadj_p[0].assemble() 26355a74a43SLisandro Dalcinadj_p[1][0] = 0 26455a74a43SLisandro Dalcinadj_p[1].assemble() 26555a74a43SLisandro Dalcin 26655a74a43SLisandro Dalcints.setCostGradients(adj_u, adj_p) 26755a74a43SLisandro Dalcin 26855a74a43SLisandro Dalcints.adjointSolve() 26955a74a43SLisandro Dalcin 27055a74a43SLisandro Dalcinadj_u[0].view() 27155a74a43SLisandro Dalcinadj_u[1].view() 27255a74a43SLisandro Dalcinadj_p[0].view() 27355a74a43SLisandro Dalcinadj_p[1].view() 27455a74a43SLisandro Dalcin 275*69777137SStefano Zampini 27655a74a43SLisandro Dalcindef compute_derp(du, dp): 277*69777137SStefano Zampini print( 278*69777137SStefano Zampini du[1] * (-10.0 / (81.0 * mu_ * mu_) + 2.0 * 292.0 / (2187.0 * mu_ * mu_ * mu_)) 279*69777137SStefano Zampini + dp[0] 280*69777137SStefano Zampini ) 281*69777137SStefano Zampini 28255a74a43SLisandro Dalcin 28355a74a43SLisandro Dalcincompute_derp(adj_u[0], adj_p[0]) 28455a74a43SLisandro Dalcincompute_derp(adj_u[1], adj_p[1]) 28555a74a43SLisandro Dalcin 28655a74a43SLisandro Dalcindel ode, Jim, JimP, Jex, JexP, u, f, ts, adj_u, adj_p 287