1# Testing TSAdjoint and matrix-free Jacobian 2# Basic usage: 3# python vanderpol.py 4# Test implicit methods using implicit form: 5# python -implicitform 6# Test explicit methods: 7# python -implicitform 0 8# Test IMEX methods: 9# python -imexform 10# Matrix-free implementations can be enabled with an additional option -mf 11 12import sys, petsc4py 13petsc4py.init(sys.argv) 14 15from petsc4py import PETSc 16 17class VDP(object): 18 n = 2 19 comm = PETSc.COMM_SELF 20 def __init__(self, mu_=1.0e3, mf_=False, imex_=False): 21 self.mu_ = mu_ 22 self.mf_ = mf_ 23 self.imex_ = imex_ 24 if self.mf_: 25 self.Jim_ = PETSc.Mat().createDense([self.n,self.n], comm=self.comm) 26 self.Jim_.setUp() 27 self.JimP_ = PETSc.Mat().createDense([self.n,1], comm=self.comm) 28 self.JimP_.setUp() 29 self.Jex_ = PETSc.Mat().createDense([self.n,self.n], comm=self.comm) 30 self.Jex_.setUp() 31 self.JexP_ = PETSc.Mat().createDense([self.n,1], comm=self.comm) 32 self.JexP_.setUp() 33 def initialCondition(self, u): 34 mu = self.mu_ 35 u[0] = 2.0 36 u[1] = -2.0/3.0 + 10.0/(81.0*mu) - 292.0/(2187.0*mu*mu) 37 u.assemble() 38 def evalFunction(self, ts, t, u, f): 39 mu = self.mu_ 40 f[0] = u[1] 41 if self.imex_: 42 f[1] = 0.0 43 else: 44 f[1] = mu*((1.-u[0]*u[0])*u[1]-u[0]) 45 f.assemble() 46 def evalJacobian(self, ts, t, u, A, B): 47 if not self.mf_: 48 J = A 49 else : 50 J = self.Jex_ 51 mu = self.mu_ 52 J[0,0] = 0 53 J[0,1] = 1.0 54 if self.imex_: 55 J[1,0] = 0 56 J[1,1] = 0 57 else: 58 J[1,0] = -mu*(2.0*u[1]*u[0]+1.) 59 J[1,1] = mu*(1.0-u[0]*u[0]) 60 J.assemble() 61 if A != B: B.assemble() 62 return True # same nonzero pattern 63 def evalJacobianP(self, ts, t, u, C): 64 if not self.mf_: 65 Jp = C 66 else: 67 Jp = self.JexP_ 68 if not self.imex_: 69 Jp[0,0] = 0 70 Jp[1,0] = (1.-u[0]*u[0])*u[1]-u[0] 71 Jp.assemble() 72 return True 73 def evalIFunction(self, ts, t, u, udot, f): 74 mu = self.mu_ 75 if self.imex_: 76 f[0] = udot[0] 77 else: 78 f[0] = udot[0]-u[1] 79 f[1] = udot[1]-mu*((1.-u[0]*u[0])*u[1]-u[0]) 80 f.assemble() 81 def evalIJacobian(self, ts, t, u, udot, shift, A, B): 82 if not self.mf_: 83 J = A 84 else : 85 J = self.Jim_ 86 mu = self.mu_ 87 if self.imex_: 88 J[0,0] = shift 89 J[0,1] = 0.0 90 else: 91 J[0,0] = shift 92 J[0,1] = -1.0 93 J[1,0] = mu*(2.0*u[1]*u[0]+1.) 94 J[1,1] = shift-mu*(1.0-u[0]*u[0]) 95 J.assemble() 96 if A != B: B.assemble() 97 return True # same nonzero pattern 98 def evalIJacobianP(self, ts, t, u, udot, shift, C): 99 if not self.mf_: 100 Jp = C 101 else: 102 Jp = self.JimP_ 103 Jp[0,0] = 0 104 Jp[1,0] = u[0]-(1.-u[0]*u[0])*u[1] 105 Jp.assemble() 106 return True 107 108class JacShell: 109 def __init__(self, ode): 110 self.ode_ = ode 111 def mult(self, A, x, y): 112 "y <- A * x" 113 self.ode_.Jex_.mult(x,y) 114 def multTranspose(self, A, x, y): 115 "y <- A' * x" 116 self.ode_.Jex_.multTranspose(x, y) 117 118class JacPShell: 119 def __init__(self, ode): 120 self.ode_ = ode 121 def multTranspose(self, A, x, y): 122 "y <- A' * x" 123 self.ode_.JexP_.multTranspose(x, y) 124 125class IJacShell: 126 def __init__(self, ode): 127 self.ode_ = ode 128 def mult(self, A, x, y): 129 "y <- A * x" 130 self.ode_.Jim_.mult(x,y) 131 def multTranspose(self, A, x, y): 132 "y <- A' * x" 133 self.ode_.Jim_.multTranspose(x, y) 134 135class IJacPShell: 136 def __init__(self, ode): 137 self.ode_ = ode 138 def multTranspose(self, A, x, y): 139 "y <- A' * x" 140 self.ode_.JimP_.multTranspose(x, y) 141 142OptDB = PETSc.Options() 143 144mu_ = OptDB.getScalar('mu', 1.0e3) 145mf_ = OptDB.getBool('mf', False) 146 147implicitform_ = OptDB.getBool('implicitform', False) 148imexform_ = OptDB.getBool('imexform', False) 149 150ode = VDP(mu_,mf_,imexform_) 151 152if not mf_: 153 Jim = PETSc.Mat().createDense([ode.n,ode.n], comm=ode.comm) 154 Jim.setUp() 155 JimP = PETSc.Mat().createDense([ode.n,1], comm=ode.comm) 156 JimP.setUp() 157 Jex = PETSc.Mat().createDense([ode.n,ode.n], comm=ode.comm) 158 Jex.setUp() 159 JexP = PETSc.Mat().createDense([ode.n,1], comm=ode.comm) 160 JexP.setUp() 161else: 162 Jim = PETSc.Mat().create() 163 Jim.setSizes([ode.n,ode.n]) 164 Jim.setType('python') 165 shell = IJacShell(ode) 166 Jim.setPythonContext(shell) 167 Jim.setUp() 168 Jim.assemble() 169 JimP = PETSc.Mat().create() 170 JimP.setSizes([ode.n,1]) 171 JimP.setType('python') 172 shell = IJacPShell(ode) 173 JimP.setPythonContext(shell) 174 JimP.setUp() 175 JimP.assemble() 176 Jex = PETSc.Mat().create() 177 Jex.setSizes([ode.n,ode.n]) 178 Jex.setType('python') 179 shell = JacShell(ode) 180 Jex.setPythonContext(shell) 181 Jex.setUp() 182 Jex.assemble() 183 JexP = PETSc.Mat().create() 184 JexP.setSizes([ode.n,1]) 185 JexP.setType('python') 186 shell = JacPShell(ode) 187 JexP.setPythonContext(shell) 188 JexP.setUp() 189 JexP.zeroEntries() 190 JexP.assemble() 191 192u = PETSc.Vec().createSeq(ode.n, comm=ode.comm) 193f = u.duplicate() 194adj_u = [] 195adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 196adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 197adj_p = [] 198adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 199adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 200 201ts = PETSc.TS().create(comm=ode.comm) 202ts.setProblemType(ts.ProblemType.NONLINEAR) 203 204if imexform_: 205 ts.setType(ts.Type.ARKIMEX) 206 ts.setIFunction(ode.evalIFunction, f) 207 ts.setIJacobian(ode.evalIJacobian, Jim) 208 ts.setIJacobianP(ode.evalIJacobianP, JimP) 209 ts.setRHSFunction(ode.evalFunction, f) 210 ts.setRHSJacobian(ode.evalJacobian, Jex) 211 ts.setRHSJacobianP(ode.evalJacobianP, JexP) 212else: 213 if implicitform_: 214 ts.setType(ts.Type.CN) 215 ts.setIFunction(ode.evalIFunction, f) 216 ts.setIJacobian(ode.evalIJacobian, Jim) 217 ts.setIJacobianP(ode.evalIJacobianP, JimP) 218 else: 219 ts.setType(ts.Type.RK) 220 ts.setRHSFunction(ode.evalFunction, f) 221 ts.setRHSJacobian(ode.evalJacobian, Jex) 222 ts.setRHSJacobianP(ode.evalJacobianP, JexP) 223 224ts.setSaveTrajectory() 225ts.setTime(0.0) 226ts.setTimeStep(0.001) 227ts.setMaxTime(0.5) 228ts.setMaxSteps(1000) 229ts.setExactFinalTime(PETSc.TS.ExactFinalTime.MATCHSTEP) 230 231ts.setFromOptions() 232ode.initialCondition(u) 233ts.solve(u) 234 235adj_u[0][0] = 1 236adj_u[0][1] = 0 237adj_u[0].assemble() 238adj_u[1][0] = 0 239adj_u[1][1] = 1 240adj_u[1].assemble() 241adj_p[0][0] = 0 242adj_p[0].assemble() 243adj_p[1][0] = 0 244adj_p[1].assemble() 245 246ts.setCostGradients(adj_u,adj_p) 247 248ts.adjointSolve() 249 250adj_u[0].view() 251adj_u[1].view() 252adj_p[0].view() 253adj_p[1].view() 254 255def compute_derp(du,dp): 256 print(du[1]*(-10.0/(81.0*mu_*mu_)+2.0*292.0/(2187.0*mu_*mu_*mu_))+dp[0]) 257 258compute_derp(adj_u[0],adj_p[0]) 259compute_derp(adj_u[1],adj_p[1]) 260 261del ode, Jim, JimP, Jex, JexP, u, f, ts, adj_u, adj_p 262