1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9 10class Function: 11 def __call__(self, snes, x, f): 12 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 13 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 14 f.assemble() 15 16 17class Jacobian: 18 def __call__(self, snes, x, J, P): 19 P[0, 0] = (2.0 * x[0] + x[1]).item() 20 P[0, 1] = (x[0]).item() 21 P[1, 0] = (x[1]).item() 22 P[1, 1] = (x[0] + 2.0 * x[1]).item() 23 P.assemble() 24 if J != P: 25 J.assemble() 26 27 28# -------------------------------------------------------------------- 29 30 31class BaseTestSNES: 32 SNES_TYPE = None 33 34 def setUp(self): 35 snes = PETSc.SNES() 36 snes.create(PETSc.COMM_SELF) 37 if self.SNES_TYPE: 38 snes.setType(self.SNES_TYPE) 39 self.snes = snes 40 41 def tearDown(self): 42 self.snes = None 43 PETSc.garbage_cleanup() 44 45 def testGetSetType(self): 46 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 47 self.snes.setType(self.SNES_TYPE) 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 50 def testTols(self): 51 tols = self.snes.getTolerances() 52 self.snes.setTolerances(*tols) 53 tnames = ('rtol', 'atol', 'stol', 'max_it') 54 tolvals = [getattr(self.snes, t) for t in tnames] 55 self.assertEqual(tuple(tols), tuple(tolvals)) 56 dtol = self.snes.getDivergenceTolerance() 57 self.assertTrue(dtol > 0) 58 self.snes.setDivergenceTolerance(-1) 59 dtol = self.snes.getDivergenceTolerance() 60 self.assertEqual(dtol, -1) 61 self.snes.setDivergenceTolerance(PETSc.DEFAULT) 62 self.assertEqual(dtol, -1) 63 64 def testProperties(self): 65 snes = self.snes 66 # 67 snes.appctx = (1, 2, 3) 68 self.assertEqual(snes.appctx, (1, 2, 3)) 69 snes.appctx = None 70 self.assertEqual(snes.appctx, None) 71 # 72 snes.its = 1 73 self.assertEqual(snes.its, 1) 74 snes.its = 0 75 self.assertEqual(snes.its, 0) 76 # 77 snes.norm = 1 78 self.assertEqual(snes.norm, 1) 79 snes.norm = 0 80 self.assertEqual(snes.norm, 0) 81 # 82 rh, ih = snes.history 83 self.assertTrue(len(rh) == 0) 84 self.assertTrue(len(ih) == 0) 85 # 86 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 87 snes.reason = reason 88 self.assertEqual(snes.reason, reason) 89 self.assertTrue(snes.is_converged) 90 self.assertFalse(snes.is_diverged) 91 self.assertFalse(snes.is_iterating) 92 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 93 snes.reason = reason 94 self.assertEqual(snes.reason, reason) 95 self.assertFalse(snes.is_converged) 96 self.assertTrue(snes.is_diverged) 97 self.assertFalse(snes.is_iterating) 98 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 99 snes.reason = reason 100 self.assertEqual(snes.reason, reason) 101 self.assertFalse(snes.is_converged) 102 self.assertFalse(snes.is_diverged) 103 self.assertTrue(snes.is_iterating) 104 # 105 self.assertFalse(snes.use_ew) 106 self.assertFalse(snes.use_mf) 107 self.assertFalse(snes.use_fd) 108 ouse = snes.use_ksp 109 self.assertEqual(ouse, snes.getUseKSP()) 110 snes.use_ksp = not ouse 111 self.assertEqual(not ouse, snes.getUseKSP()) 112 snes.setUseKSP(ouse) 113 self.assertEqual(ouse, snes.use_ksp) 114 115 def testGetSetFunc(self): 116 r, func = self.snes.getFunction() 117 self.assertFalse(r) 118 self.assertTrue(func is None) 119 r = PETSc.Vec().createSeq(2) 120 func = Function() 121 refcnt = getrefcount(func) 122 self.snes.setFunction(func, r) 123 self.snes.setFunction(func, r) 124 self.assertEqual(getrefcount(func), refcnt + 1) 125 r2, func2 = self.snes.getFunction() 126 self.assertEqual(r, r2) 127 self.assertEqual(func, func2[0]) 128 self.assertEqual(getrefcount(func), refcnt + 1) 129 r3, func3 = self.snes.getFunction() 130 self.assertEqual(r, r3) 131 self.assertEqual(func, func3[0]) 132 self.assertEqual(getrefcount(func), refcnt + 1) 133 134 def testCompFunc(self): 135 r = PETSc.Vec().createSeq(2) 136 func = Function() 137 self.snes.setFunction(func, r) 138 x, y = r.duplicate(), r.duplicate() 139 x[0], x[1] = [1, 2] 140 self.snes.computeFunction(x, y) 141 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 142 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 143 144 def testGetSetJac(self): 145 A, P, jac = self.snes.getJacobian() 146 self.assertFalse(A) 147 self.assertFalse(P) 148 self.assertTrue(jac is None) 149 J = PETSc.Mat().create(PETSc.COMM_SELF) 150 J.setSizes([2, 2]) 151 J.setType(PETSc.Mat.Type.SEQAIJ) 152 J.setUp() 153 jac = Jacobian() 154 refcnt = getrefcount(jac) 155 self.snes.setJacobian(jac, J) 156 self.snes.setJacobian(jac, J) 157 self.assertEqual(getrefcount(jac), refcnt + 1) 158 J2, P2, jac2 = self.snes.getJacobian() 159 self.assertEqual(J, J2) 160 self.assertEqual(J2, P2) 161 self.assertEqual(jac, jac2[0]) 162 self.assertEqual(getrefcount(jac), refcnt + 1) 163 J3, P3, jac3 = self.snes.getJacobian() 164 self.assertEqual(J, J3) 165 self.assertEqual(J3, P3) 166 self.assertEqual(jac, jac3[0]) 167 self.assertEqual(getrefcount(jac), refcnt + 1) 168 169 def testCompJac(self): 170 J = PETSc.Mat().create(PETSc.COMM_SELF) 171 J.setSizes([2, 2]) 172 J.setType(PETSc.Mat.Type.SEQAIJ) 173 J.setUp() 174 jac = Jacobian() 175 self.snes.setJacobian(jac, J) 176 x = PETSc.Vec().createSeq(2) 177 x[0], x[1] = [1, 2] 178 self.snes.getKSP().getPC() 179 self.snes.computeJacobian(x, J) 180 181 def testGetSetUpd(self): 182 self.assertTrue(self.snes.getUpdate() is None) 183 upd = lambda snes, it: None 184 refcnt = getrefcount(upd) 185 self.snes.setUpdate(upd) 186 self.assertEqual(getrefcount(upd), refcnt + 1) 187 self.snes.setUpdate(upd) 188 self.assertEqual(getrefcount(upd), refcnt + 1) 189 self.snes.setUpdate(None) 190 self.assertTrue(self.snes.getUpdate() is None) 191 self.assertEqual(getrefcount(upd), refcnt) 192 self.snes.setUpdate(upd) 193 self.assertEqual(getrefcount(upd), refcnt + 1) 194 upd2 = lambda snes, it: None 195 refcnt2 = getrefcount(upd2) 196 self.snes.setUpdate(upd2) 197 self.assertEqual(getrefcount(upd), refcnt) 198 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 199 tmp = self.snes.getUpdate()[0] 200 self.assertTrue(tmp is upd2) 201 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 202 del tmp 203 self.snes.setUpdate(None) 204 self.assertTrue(self.snes.getUpdate() is None) 205 self.assertEqual(getrefcount(upd2), refcnt2) 206 207 def testGetKSP(self): 208 ksp = self.snes.getKSP() 209 self.assertEqual(ksp.getRefCount(), 2) 210 211 def testSolve(self): 212 J = PETSc.Mat().create(PETSc.COMM_SELF) 213 J.setSizes([2, 2]) 214 J.setType(PETSc.Mat.Type.SEQAIJ) 215 J.setUp() 216 r = PETSc.Vec().createSeq(2) 217 x = PETSc.Vec().createSeq(2) 218 b = PETSc.Vec().createSeq(2) 219 self.snes.setFunction(Function(), r) 220 self.snes.setJacobian(Jacobian(), J) 221 x.setArray([2, 3]) 222 b.set(0) 223 self.snes.setConvergenceHistory() 224 self.snes.setFromOptions() 225 self.snes.solve(b, x) 226 rh, ih = self.snes.getConvergenceHistory() 227 self.snes.setConvergenceHistory(0, reset=True) 228 rh, ih = self.snes.getConvergenceHistory() 229 self.assertEqual(len(rh), 0) 230 self.assertEqual(len(ih), 0) 231 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 232 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 233 # XXX this test should not be here ! 234 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 235 self.assertTrue(reason > 0) 236 237 # test interface 238 x = self.snes.getSolution() 239 x.setArray([2, 3]) 240 self.snes.solve() 241 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 242 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 243 244 def testResetAndSolve(self): 245 self.snes.reset() 246 self.testSolve() 247 self.snes.reset() 248 self.testSolve() 249 self.snes.reset() 250 251 def testSetMonitor(self): 252 reshist = {} 253 254 def monitor(snes, its, fgnorm): 255 reshist[its] = fgnorm 256 257 refcnt = getrefcount(monitor) 258 self.snes.setMonitor(monitor) 259 self.assertEqual(getrefcount(monitor), refcnt + 1) 260 self.testSolve() 261 self.assertTrue(len(reshist) > 0) 262 reshist = {} 263 self.snes.monitorCancel() 264 self.assertEqual(getrefcount(monitor), refcnt) 265 self.testSolve() 266 self.assertTrue(len(reshist) == 0) 267 self.snes.setMonitor(monitor) 268 self.snes.monitor(1, 7) 269 self.assertTrue(reshist[1] == 7) 270 ## Monitor = PETSc.SNES.Monitor 271 ## self.snes.setMonitor(Monitor()) 272 ## self.snes.setMonitor(Monitor.DEFAULT) 273 ## self.snes.setMonitor(Monitor.SOLUTION) 274 ## self.snes.setMonitor(Monitor.RESIDUAL) 275 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 276 277 def testSetGetStepFails(self): 278 its = self.snes.getIterationNumber() 279 self.assertEqual(its, 0) 280 fails = self.snes.getNonlinearStepFailures() 281 self.assertEqual(fails, 0) 282 fails = self.snes.getMaxNonlinearStepFailures() 283 self.assertEqual(fails, 1) 284 self.snes.setMaxNonlinearStepFailures(5) 285 fails = self.snes.getMaxNonlinearStepFailures() 286 self.assertEqual(fails, 5) 287 self.snes.setMaxNonlinearStepFailures(1) 288 fails = self.snes.getMaxNonlinearStepFailures() 289 self.assertEqual(fails, 1) 290 291 def testSetGetLinFails(self): 292 its = self.snes.getLinearSolveIterations() 293 self.assertEqual(its, 0) 294 fails = self.snes.getLinearSolveFailures() 295 self.assertEqual(fails, 0) 296 fails = self.snes.getMaxLinearSolveFailures() 297 self.assertEqual(fails, 1) 298 self.snes.setMaxLinearSolveFailures(5) 299 fails = self.snes.getMaxLinearSolveFailures() 300 self.assertEqual(fails, 5) 301 self.snes.setMaxLinearSolveFailures(1) 302 fails = self.snes.getMaxLinearSolveFailures() 303 self.assertEqual(fails, 1) 304 305 def testEW(self): 306 self.snes.setUseEW(False) 307 self.assertFalse(self.snes.getUseEW()) 308 self.snes.setUseEW(True) 309 self.assertTrue(self.snes.getUseEW()) 310 params = self.snes.getParamsEW() 311 params['version'] = 1 312 self.snes.setParamsEW(**params) 313 params = self.snes.getParamsEW() 314 self.assertEqual(params['version'], 1) 315 params['version'] = PETSc.DEFAULT 316 self.snes.setParamsEW(**params) 317 params = self.snes.getParamsEW() 318 self.assertEqual(params['version'], 1) 319 320 def testMF(self): 321 # self.snes.setOptionsPrefix('MF-') 322 # opts = PETSc.Options(self.snes) 323 # opts['mat_mffd_type'] = 'ds' 324 # opts['snes_monitor'] = 'stdout' 325 # opts['ksp_monitor'] = 'stdout' 326 # opts['snes_view'] = 'stdout' 327 J = PETSc.Mat().create(PETSc.COMM_SELF) 328 J.setSizes([2, 2]) 329 J.setType(PETSc.Mat.Type.SEQAIJ) 330 J.setUp() 331 r = PETSc.Vec().createSeq(2) 332 x = PETSc.Vec().createSeq(2) 333 b = PETSc.Vec().createSeq(2) 334 fun = Function() 335 jac = Jacobian() 336 self.snes.setFunction(fun, r) 337 self.snes.setJacobian(jac, J) 338 self.assertFalse(self.snes.getUseMF()) 339 self.snes.setUseMF(False) 340 self.assertFalse(self.snes.getUseMF()) 341 self.snes.setUseMF(True) 342 self.assertTrue(self.snes.getUseMF()) 343 self.snes.setFromOptions() 344 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 345 x.setArray([2, 3]) 346 b.set(0) 347 self.snes.solve(b, x) 348 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 349 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 350 351 def testFDColor(self): 352 J = PETSc.Mat().create(PETSc.COMM_SELF) 353 J.setSizes([2, 2]) 354 J.setType(PETSc.Mat.Type.SEQAIJ) 355 J.setUp() 356 r = PETSc.Vec().createSeq(2) 357 x = PETSc.Vec().createSeq(2) 358 b = PETSc.Vec().createSeq(2) 359 fun = Function() 360 jac = Jacobian() 361 self.snes.setFunction(fun, r) 362 self.snes.setJacobian(jac, J) 363 self.assertFalse(self.snes.getUseFD()) 364 jac(self.snes, x, J, J) 365 self.snes.setUseFD(False) 366 self.assertFalse(self.snes.getUseFD()) 367 self.snes.setUseFD(True) 368 self.assertTrue(self.snes.getUseFD()) 369 self.snes.setFromOptions() 370 x.setArray([2, 3]) 371 b.set(0) 372 self.snes.solve(b, x) 373 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 374 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 375 376 def testNPC(self): 377 self.snes.appctx = (1, 2, 3) 378 npc = self.snes.getNPC() 379 self.assertEqual(npc.appctx, (1, 2, 3)) 380 381 382# -------------------------------------------------------------------- 383 384 385class TestSNESLS(BaseTestSNES, unittest.TestCase): 386 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 387 388 389class TestSNESTR(BaseTestSNES, unittest.TestCase): 390 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 391 392 393# -------------------------------------------------------------------- 394 395if __name__ == '__main__': 396 unittest.main() 397