1# Testing TSAdjoint and matrix-free Jacobian 2# Basic usage: 3# python vanderpol.py 4# Test implicit methods using implicit form: 5# python -implicitform 6# Test explicit methods: 7# python -implicitform 0 8# Test IMEX methods: 9# python -imexform 10# Matrix-free implementations can be enabled with an additional option -mf 11 12import sys 13import petsc4py 14 15petsc4py.init(sys.argv) 16 17from petsc4py import PETSc 18 19 20class VDP: 21 n = 2 22 comm = PETSc.COMM_SELF 23 24 def __init__(self, mu_=1.0e3, mf_=False, imex_=False): 25 self.mu_ = mu_ 26 self.mf_ = mf_ 27 self.imex_ = imex_ 28 if self.mf_: 29 self.Jim_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm) 30 self.Jim_.setUp() 31 self.JimP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm) 32 self.JimP_.setUp() 33 self.Jex_ = PETSc.Mat().createDense([self.n, self.n], comm=self.comm) 34 self.Jex_.setUp() 35 self.JexP_ = PETSc.Mat().createDense([self.n, 1], comm=self.comm) 36 self.JexP_.setUp() 37 38 def initialCondition(self, u): 39 mu = self.mu_ 40 u[0] = 2.0 41 u[1] = -2.0 / 3.0 + 10.0 / (81.0 * mu) - 292.0 / (2187.0 * mu * mu) 42 u.assemble() 43 44 def evalFunction(self, ts, t, u, f): 45 mu = self.mu_ 46 f[0] = u[1] 47 if self.imex_: 48 f[1] = 0.0 49 else: 50 f[1] = mu * ((1.0 - u[0] * u[0]) * u[1] - u[0]) 51 f.assemble() 52 53 def evalJacobian(self, ts, t, u, A, B): 54 if not self.mf_: 55 J = A 56 else: 57 J = self.Jex_ 58 mu = self.mu_ 59 J[0, 0] = 0 60 J[0, 1] = 1.0 61 if self.imex_: 62 J[1, 0] = 0 63 J[1, 1] = 0 64 else: 65 J[1, 0] = -mu * (2.0 * u[1] * u[0] + 1.0) 66 J[1, 1] = mu * (1.0 - u[0] * u[0]) 67 J.assemble() 68 if A != B: 69 B.assemble() 70 71 def evalJacobianP(self, ts, t, u, C): 72 if not self.mf_: 73 Jp = C 74 else: 75 Jp = self.JexP_ 76 if not self.imex_: 77 Jp[0, 0] = 0 78 Jp[1, 0] = (1.0 - u[0] * u[0]) * u[1] - u[0] 79 Jp.assemble() 80 81 def evalIFunction(self, ts, t, u, udot, f): 82 mu = self.mu_ 83 if self.imex_: 84 f[0] = udot[0] 85 else: 86 f[0] = udot[0] - u[1] 87 f[1] = udot[1] - mu * ((1.0 - u[0] * u[0]) * u[1] - u[0]) 88 f.assemble() 89 90 def evalIJacobian(self, ts, t, u, udot, shift, A, B): 91 if not self.mf_: 92 J = A 93 else: 94 J = self.Jim_ 95 mu = self.mu_ 96 if self.imex_: 97 J[0, 0] = shift 98 J[0, 1] = 0.0 99 else: 100 J[0, 0] = shift 101 J[0, 1] = -1.0 102 J[1, 0] = mu * (2.0 * u[1] * u[0] + 1.0) 103 J[1, 1] = shift - mu * (1.0 - u[0] * u[0]) 104 J.assemble() 105 if A != B: 106 B.assemble() 107 108 def evalIJacobianP(self, ts, t, u, udot, shift, C): 109 if not self.mf_: 110 Jp = C 111 else: 112 Jp = self.JimP_ 113 Jp[0, 0] = 0 114 Jp[1, 0] = u[0] - (1.0 - u[0] * u[0]) * u[1] 115 Jp.assemble() 116 117 118class JacShell: 119 def __init__(self, ode): 120 self.ode_ = ode 121 122 def mult(self, A, x, y): 123 "y <- A * x" 124 self.ode_.Jex_.mult(x, y) 125 126 def multTranspose(self, A, x, y): 127 "y <- A' * x" 128 self.ode_.Jex_.multTranspose(x, y) 129 130 131class JacPShell: 132 def __init__(self, ode): 133 self.ode_ = ode 134 135 def multTranspose(self, A, x, y): 136 "y <- A' * x" 137 self.ode_.JexP_.multTranspose(x, y) 138 139 140class IJacShell: 141 def __init__(self, ode): 142 self.ode_ = ode 143 144 def mult(self, A, x, y): 145 "y <- A * x" 146 self.ode_.Jim_.mult(x, y) 147 148 def multTranspose(self, A, x, y): 149 "y <- A' * x" 150 self.ode_.Jim_.multTranspose(x, y) 151 152 153class IJacPShell: 154 def __init__(self, ode): 155 self.ode_ = ode 156 157 def multTranspose(self, A, x, y): 158 "y <- A' * x" 159 self.ode_.JimP_.multTranspose(x, y) 160 161 162OptDB = PETSc.Options() 163 164mu_ = OptDB.getScalar('mu', 1.0e3) 165mf_ = OptDB.getBool('mf', False) 166 167implicitform_ = OptDB.getBool('implicitform', False) 168imexform_ = OptDB.getBool('imexform', False) 169 170ode = VDP(mu_, mf_, imexform_) 171 172if not mf_: 173 Jim = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm) 174 Jim.setUp() 175 JimP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm) 176 JimP.setUp() 177 Jex = PETSc.Mat().createDense([ode.n, ode.n], comm=ode.comm) 178 Jex.setUp() 179 JexP = PETSc.Mat().createDense([ode.n, 1], comm=ode.comm) 180 JexP.setUp() 181else: 182 Jim = PETSc.Mat().create() 183 Jim.setSizes([ode.n, ode.n]) 184 Jim.setType('python') 185 shell = IJacShell(ode) 186 Jim.setPythonContext(shell) 187 Jim.setUp() 188 Jim.assemble() 189 JimP = PETSc.Mat().create() 190 JimP.setSizes([ode.n, 1]) 191 JimP.setType('python') 192 shell = IJacPShell(ode) 193 JimP.setPythonContext(shell) 194 JimP.setUp() 195 JimP.assemble() 196 Jex = PETSc.Mat().create() 197 Jex.setSizes([ode.n, ode.n]) 198 Jex.setType('python') 199 shell = JacShell(ode) 200 Jex.setPythonContext(shell) 201 Jex.setUp() 202 Jex.assemble() 203 JexP = PETSc.Mat().create() 204 JexP.setSizes([ode.n, 1]) 205 JexP.setType('python') 206 shell = JacPShell(ode) 207 JexP.setPythonContext(shell) 208 JexP.setUp() 209 JexP.zeroEntries() 210 JexP.assemble() 211 212u = PETSc.Vec().createSeq(ode.n, comm=ode.comm) 213f = u.duplicate() 214adj_u = [] 215adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 216adj_u.append(PETSc.Vec().createSeq(ode.n, comm=ode.comm)) 217adj_p = [] 218adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 219adj_p.append(PETSc.Vec().createSeq(1, comm=ode.comm)) 220 221ts = PETSc.TS().create(comm=ode.comm) 222ts.setProblemType(ts.ProblemType.NONLINEAR) 223 224if imexform_: 225 ts.setType(ts.Type.ARKIMEX) 226 ts.setIFunction(ode.evalIFunction, f) 227 ts.setIJacobian(ode.evalIJacobian, Jim) 228 ts.setIJacobianP(ode.evalIJacobianP, JimP) 229 ts.setRHSFunction(ode.evalFunction, f) 230 ts.setRHSJacobian(ode.evalJacobian, Jex) 231 ts.setRHSJacobianP(ode.evalJacobianP, JexP) 232else: 233 if implicitform_: 234 ts.setType(ts.Type.CN) 235 ts.setIFunction(ode.evalIFunction, f) 236 ts.setIJacobian(ode.evalIJacobian, Jim) 237 ts.setIJacobianP(ode.evalIJacobianP, JimP) 238 else: 239 ts.setType(ts.Type.RK) 240 ts.setRHSFunction(ode.evalFunction, f) 241 ts.setRHSJacobian(ode.evalJacobian, Jex) 242 ts.setRHSJacobianP(ode.evalJacobianP, JexP) 243 244ts.setSaveTrajectory() 245ts.setTime(0.0) 246ts.setTimeStep(0.001) 247ts.setMaxTime(0.5) 248ts.setMaxSteps(1000) 249ts.setExactFinalTime(PETSc.TS.ExactFinalTime.MATCHSTEP) 250 251ts.setFromOptions() 252ode.initialCondition(u) 253ts.solve(u) 254 255adj_u[0][0] = 1 256adj_u[0][1] = 0 257adj_u[0].assemble() 258adj_u[1][0] = 0 259adj_u[1][1] = 1 260adj_u[1].assemble() 261adj_p[0][0] = 0 262adj_p[0].assemble() 263adj_p[1][0] = 0 264adj_p[1].assemble() 265 266ts.setCostGradients(adj_u, adj_p) 267 268ts.adjointSolve() 269 270adj_u[0].view() 271adj_u[1].view() 272adj_p[0].view() 273adj_p[1].view() 274 275 276def compute_derp(du, dp): 277 print( 278 du[1] * (-10.0 / (81.0 * mu_ * mu_) + 2.0 * 292.0 / (2187.0 * mu_ * mu_ * mu_)) 279 + dp[0] 280 ) 281 282 283compute_derp(adj_u[0], adj_p[0]) 284compute_derp(adj_u[1], adj_p[1]) 285 286del ode, Jim, JimP, Jex, JexP, u, f, ts, adj_u, adj_p 287