1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9 10class Function: 11 def __call__(self, snes, x, f): 12 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 13 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 14 f.assemble() 15 16 17class Jacobian: 18 def __call__(self, snes, x, J, P): 19 P[0, 0] = (2.0 * x[0] + x[1]).item() 20 P[0, 1] = (x[0]).item() 21 P[1, 0] = (x[1]).item() 22 P[1, 1] = (x[0] + 2.0 * x[1]).item() 23 P.assemble() 24 if J != P: 25 J.assemble() 26 27 28# -------------------------------------------------------------------- 29 30 31class BaseTestSNES: 32 SNES_TYPE = None 33 34 def setUp(self): 35 snes = PETSc.SNES() 36 snes.create(PETSc.COMM_SELF) 37 if self.SNES_TYPE: 38 snes.setType(self.SNES_TYPE) 39 self.snes = snes 40 41 def tearDown(self): 42 self.snes = None 43 PETSc.garbage_cleanup() 44 45 def testGetSetType(self): 46 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 47 self.snes.setType(self.SNES_TYPE) 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 50 def testTols(self): 51 tols = self.snes.getTolerances() 52 self.snes.setTolerances(*tols) 53 tnames = ('rtol', 'atol', 'stol', 'max_it') 54 tolvals = [getattr(self.snes, t) for t in tnames] 55 self.assertEqual(tuple(tols), tuple(tolvals)) 56 dtol = self.snes.getDivergenceTolerance() 57 self.assertTrue(dtol > 0) 58 self.snes.setDivergenceTolerance(-1) 59 dtol = self.snes.getDivergenceTolerance() 60 self.assertEqual(dtol, -1) 61 self.snes.setDivergenceTolerance(PETSc.DEFAULT) 62 self.assertEqual(dtol, -1) 63 64 def testProperties(self): 65 snes = self.snes 66 # 67 snes.appctx = (1, 2, 3) 68 self.assertEqual(snes.appctx, (1, 2, 3)) 69 snes.appctx = None 70 self.assertEqual(snes.appctx, None) 71 # 72 snes.its = 1 73 self.assertEqual(snes.its, 1) 74 snes.its = 0 75 self.assertEqual(snes.its, 0) 76 # 77 snes.norm = 1 78 self.assertEqual(snes.norm, 1) 79 snes.norm = 0 80 self.assertEqual(snes.norm, 0) 81 # 82 rh, ih = snes.history 83 self.assertTrue(len(rh) == 0) 84 self.assertTrue(len(ih) == 0) 85 # 86 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 87 snes.reason = reason 88 self.assertEqual(snes.reason, reason) 89 self.assertTrue(snes.is_converged) 90 self.assertFalse(snes.is_diverged) 91 self.assertFalse(snes.is_iterating) 92 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 93 snes.reason = reason 94 self.assertEqual(snes.reason, reason) 95 self.assertFalse(snes.is_converged) 96 self.assertTrue(snes.is_diverged) 97 self.assertFalse(snes.is_iterating) 98 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 99 snes.reason = reason 100 self.assertEqual(snes.reason, reason) 101 self.assertFalse(snes.is_converged) 102 self.assertFalse(snes.is_diverged) 103 self.assertTrue(snes.is_iterating) 104 # 105 self.assertFalse(snes.use_ew) 106 self.assertFalse(snes.use_mf) 107 self.assertFalse(snes.use_fd) 108 109 def testGetSetFunc(self): 110 r, func = self.snes.getFunction() 111 self.assertFalse(r) 112 self.assertTrue(func is None) 113 r = PETSc.Vec().createSeq(2) 114 func = Function() 115 refcnt = getrefcount(func) 116 self.snes.setFunction(func, r) 117 self.snes.setFunction(func, r) 118 self.assertEqual(getrefcount(func), refcnt + 1) 119 r2, func2 = self.snes.getFunction() 120 self.assertEqual(r, r2) 121 self.assertEqual(func, func2[0]) 122 self.assertEqual(getrefcount(func), refcnt + 1) 123 r3, func3 = self.snes.getFunction() 124 self.assertEqual(r, r3) 125 self.assertEqual(func, func3[0]) 126 self.assertEqual(getrefcount(func), refcnt + 1) 127 128 def testCompFunc(self): 129 r = PETSc.Vec().createSeq(2) 130 func = Function() 131 self.snes.setFunction(func, r) 132 x, y = r.duplicate(), r.duplicate() 133 x[0], x[1] = [1, 2] 134 self.snes.computeFunction(x, y) 135 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 136 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 137 138 def testGetSetJac(self): 139 A, P, jac = self.snes.getJacobian() 140 self.assertFalse(A) 141 self.assertFalse(P) 142 self.assertTrue(jac is None) 143 J = PETSc.Mat().create(PETSc.COMM_SELF) 144 J.setSizes([2, 2]) 145 J.setType(PETSc.Mat.Type.SEQAIJ) 146 J.setUp() 147 jac = Jacobian() 148 refcnt = getrefcount(jac) 149 self.snes.setJacobian(jac, J) 150 self.snes.setJacobian(jac, J) 151 self.assertEqual(getrefcount(jac), refcnt + 1) 152 J2, P2, jac2 = self.snes.getJacobian() 153 self.assertEqual(J, J2) 154 self.assertEqual(J2, P2) 155 self.assertEqual(jac, jac2[0]) 156 self.assertEqual(getrefcount(jac), refcnt + 1) 157 J3, P3, jac3 = self.snes.getJacobian() 158 self.assertEqual(J, J3) 159 self.assertEqual(J3, P3) 160 self.assertEqual(jac, jac3[0]) 161 self.assertEqual(getrefcount(jac), refcnt + 1) 162 163 def testCompJac(self): 164 J = PETSc.Mat().create(PETSc.COMM_SELF) 165 J.setSizes([2, 2]) 166 J.setType(PETSc.Mat.Type.SEQAIJ) 167 J.setUp() 168 jac = Jacobian() 169 self.snes.setJacobian(jac, J) 170 x = PETSc.Vec().createSeq(2) 171 x[0], x[1] = [1, 2] 172 self.snes.getKSP().getPC() 173 self.snes.computeJacobian(x, J) 174 175 def testGetSetUpd(self): 176 self.assertTrue(self.snes.getUpdate() is None) 177 upd = lambda snes, it: None 178 refcnt = getrefcount(upd) 179 self.snes.setUpdate(upd) 180 self.assertEqual(getrefcount(upd), refcnt + 1) 181 self.snes.setUpdate(upd) 182 self.assertEqual(getrefcount(upd), refcnt + 1) 183 self.snes.setUpdate(None) 184 self.assertTrue(self.snes.getUpdate() is None) 185 self.assertEqual(getrefcount(upd), refcnt) 186 self.snes.setUpdate(upd) 187 self.assertEqual(getrefcount(upd), refcnt + 1) 188 upd2 = lambda snes, it: None 189 refcnt2 = getrefcount(upd2) 190 self.snes.setUpdate(upd2) 191 self.assertEqual(getrefcount(upd), refcnt) 192 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 193 tmp = self.snes.getUpdate()[0] 194 self.assertTrue(tmp is upd2) 195 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 196 del tmp 197 self.snes.setUpdate(None) 198 self.assertTrue(self.snes.getUpdate() is None) 199 self.assertEqual(getrefcount(upd2), refcnt2) 200 201 def testGetKSP(self): 202 ksp = self.snes.getKSP() 203 self.assertEqual(ksp.getRefCount(), 2) 204 205 def testSolve(self): 206 J = PETSc.Mat().create(PETSc.COMM_SELF) 207 J.setSizes([2, 2]) 208 J.setType(PETSc.Mat.Type.SEQAIJ) 209 J.setUp() 210 r = PETSc.Vec().createSeq(2) 211 x = PETSc.Vec().createSeq(2) 212 b = PETSc.Vec().createSeq(2) 213 self.snes.setFunction(Function(), r) 214 self.snes.setJacobian(Jacobian(), J) 215 x.setArray([2, 3]) 216 b.set(0) 217 self.snes.setConvergenceHistory() 218 self.snes.setFromOptions() 219 self.snes.solve(b, x) 220 rh, ih = self.snes.getConvergenceHistory() 221 self.snes.setConvergenceHistory(0, reset=True) 222 rh, ih = self.snes.getConvergenceHistory() 223 self.assertEqual(len(rh), 0) 224 self.assertEqual(len(ih), 0) 225 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 226 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 227 # XXX this test should not be here ! 228 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 229 self.assertTrue(reason > 0) 230 231 # test interface 232 x = self.snes.getSolution() 233 x.setArray([2, 3]) 234 self.snes.solve() 235 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 236 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 237 238 def testResetAndSolve(self): 239 self.snes.reset() 240 self.testSolve() 241 self.snes.reset() 242 self.testSolve() 243 self.snes.reset() 244 245 def testSetMonitor(self): 246 reshist = {} 247 248 def monitor(snes, its, fgnorm): 249 reshist[its] = fgnorm 250 251 refcnt = getrefcount(monitor) 252 self.snes.setMonitor(monitor) 253 self.assertEqual(getrefcount(monitor), refcnt + 1) 254 self.testSolve() 255 self.assertTrue(len(reshist) > 0) 256 reshist = {} 257 self.snes.monitorCancel() 258 self.assertEqual(getrefcount(monitor), refcnt) 259 self.testSolve() 260 self.assertTrue(len(reshist) == 0) 261 self.snes.setMonitor(monitor) 262 self.snes.monitor(1, 7) 263 self.assertTrue(reshist[1] == 7) 264 ## Monitor = PETSc.SNES.Monitor 265 ## self.snes.setMonitor(Monitor()) 266 ## self.snes.setMonitor(Monitor.DEFAULT) 267 ## self.snes.setMonitor(Monitor.SOLUTION) 268 ## self.snes.setMonitor(Monitor.RESIDUAL) 269 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 270 271 def testSetGetStepFails(self): 272 its = self.snes.getIterationNumber() 273 self.assertEqual(its, 0) 274 fails = self.snes.getNonlinearStepFailures() 275 self.assertEqual(fails, 0) 276 fails = self.snes.getMaxNonlinearStepFailures() 277 self.assertEqual(fails, 1) 278 self.snes.setMaxNonlinearStepFailures(5) 279 fails = self.snes.getMaxNonlinearStepFailures() 280 self.assertEqual(fails, 5) 281 self.snes.setMaxNonlinearStepFailures(1) 282 fails = self.snes.getMaxNonlinearStepFailures() 283 self.assertEqual(fails, 1) 284 285 def testSetGetLinFails(self): 286 its = self.snes.getLinearSolveIterations() 287 self.assertEqual(its, 0) 288 fails = self.snes.getLinearSolveFailures() 289 self.assertEqual(fails, 0) 290 fails = self.snes.getMaxLinearSolveFailures() 291 self.assertEqual(fails, 1) 292 self.snes.setMaxLinearSolveFailures(5) 293 fails = self.snes.getMaxLinearSolveFailures() 294 self.assertEqual(fails, 5) 295 self.snes.setMaxLinearSolveFailures(1) 296 fails = self.snes.getMaxLinearSolveFailures() 297 self.assertEqual(fails, 1) 298 299 def testEW(self): 300 self.snes.setUseEW(False) 301 self.assertFalse(self.snes.getUseEW()) 302 self.snes.setUseEW(True) 303 self.assertTrue(self.snes.getUseEW()) 304 params = self.snes.getParamsEW() 305 params['version'] = 1 306 self.snes.setParamsEW(**params) 307 params = self.snes.getParamsEW() 308 self.assertEqual(params['version'], 1) 309 params['version'] = PETSc.DEFAULT 310 self.snes.setParamsEW(**params) 311 params = self.snes.getParamsEW() 312 self.assertEqual(params['version'], 1) 313 314 def testMF(self): 315 # self.snes.setOptionsPrefix('MF-') 316 # opts = PETSc.Options(self.snes) 317 # opts['mat_mffd_type'] = 'ds' 318 # opts['snes_monitor'] = 'stdout' 319 # opts['ksp_monitor'] = 'stdout' 320 # opts['snes_view'] = 'stdout' 321 J = PETSc.Mat().create(PETSc.COMM_SELF) 322 J.setSizes([2, 2]) 323 J.setType(PETSc.Mat.Type.SEQAIJ) 324 J.setUp() 325 r = PETSc.Vec().createSeq(2) 326 x = PETSc.Vec().createSeq(2) 327 b = PETSc.Vec().createSeq(2) 328 fun = Function() 329 jac = Jacobian() 330 self.snes.setFunction(fun, r) 331 self.snes.setJacobian(jac, J) 332 self.assertFalse(self.snes.getUseMF()) 333 self.snes.setUseMF(False) 334 self.assertFalse(self.snes.getUseMF()) 335 self.snes.setUseMF(True) 336 self.assertTrue(self.snes.getUseMF()) 337 self.snes.setFromOptions() 338 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 339 x.setArray([2, 3]) 340 b.set(0) 341 self.snes.solve(b, x) 342 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 343 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 344 345 def testFDColor(self): 346 J = PETSc.Mat().create(PETSc.COMM_SELF) 347 J.setSizes([2, 2]) 348 J.setType(PETSc.Mat.Type.SEQAIJ) 349 J.setUp() 350 r = PETSc.Vec().createSeq(2) 351 x = PETSc.Vec().createSeq(2) 352 b = PETSc.Vec().createSeq(2) 353 fun = Function() 354 jac = Jacobian() 355 self.snes.setFunction(fun, r) 356 self.snes.setJacobian(jac, J) 357 self.assertFalse(self.snes.getUseFD()) 358 jac(self.snes, x, J, J) 359 self.snes.setUseFD(False) 360 self.assertFalse(self.snes.getUseFD()) 361 self.snes.setUseFD(True) 362 self.assertTrue(self.snes.getUseFD()) 363 self.snes.setFromOptions() 364 x.setArray([2, 3]) 365 b.set(0) 366 self.snes.solve(b, x) 367 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 368 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 369 370 def testNPC(self): 371 self.snes.appctx = (1, 2, 3) 372 npc = self.snes.getNPC() 373 self.assertEqual(npc.appctx, (1, 2, 3)) 374 375 376# -------------------------------------------------------------------- 377 378 379class TestSNESLS(BaseTestSNES, unittest.TestCase): 380 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 381 382 383class TestSNESTR(BaseTestSNES, unittest.TestCase): 384 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 385 386 387# -------------------------------------------------------------------- 388 389if __name__ == '__main__': 390 unittest.main() 391