1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy as np 7 8# -------------------------------------------------------------------- 9 10 11class Function: 12 def __call__(self, snes, x, f): 13 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 14 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 15 f.assemble() 16 17 18class Jacobian: 19 def __call__(self, snes, x, J, P): 20 P[0, 0] = (2.0 * x[0] + x[1]).item() 21 P[0, 1] = (x[0]).item() 22 P[1, 0] = (x[1]).item() 23 P[1, 1] = (x[0] + 2.0 * x[1]).item() 24 P.assemble() 25 if J != P: 26 J.assemble() 27 28 29class FunctionAL: 30 def __call__(self, snes, x, f): 31 load = snes.getNewtonALLoadParameter() 32 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0 * load).item() 33 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0 * load).item() 34 f.assemble() 35 36 37class FunctionALLoad: 38 def __call__(self, snes, x, f): 39 f[0] = 3.0 40 f[1] = 6.0 41 f.assemble() 42 43 44# -------------------------------------------------------------------- 45 46 47class BaseTestSNES: 48 SNES_TYPE = None 49 50 def setUp(self): 51 snes = PETSc.SNES() 52 snes.create(PETSc.COMM_SELF) 53 if self.SNES_TYPE: 54 snes.setType(self.SNES_TYPE) 55 self.snes = snes 56 57 def tearDown(self): 58 self.snes = None 59 PETSc.garbage_cleanup() 60 61 def testGetSetType(self): 62 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 63 self.snes.setType(self.SNES_TYPE) 64 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 65 66 def testTols(self): 67 tols = self.snes.getTolerances() 68 self.snes.setTolerances(*tols) 69 tnames = ('rtol', 'atol', 'stol', 'max_it') 70 tolvals = [getattr(self.snes, t) for t in tnames] 71 self.assertEqual(tuple(tols), tuple(tolvals)) 72 dtol = self.snes.getDivergenceTolerance() 73 self.assertTrue(dtol > 0) 74 self.snes.setDivergenceTolerance(PETSc.UNLIMITED) 75 dtol = self.snes.getDivergenceTolerance() 76 self.assertEqual(dtol, PETSc.UNLIMITED) 77 self.snes.setDivergenceTolerance(PETSc.CURRENT) 78 self.assertEqual(dtol, PETSc.UNLIMITED) 79 80 def testProperties(self): 81 snes = self.snes 82 # 83 snes.appctx = (1, 2, 3) 84 self.assertEqual(snes.appctx, (1, 2, 3)) 85 snes.appctx = None 86 self.assertEqual(snes.appctx, None) 87 # 88 snes.its = 1 89 self.assertEqual(snes.its, 1) 90 snes.its = 0 91 self.assertEqual(snes.its, 0) 92 # 93 snes.norm = 1 94 self.assertEqual(snes.norm, 1) 95 snes.norm = 0 96 self.assertEqual(snes.norm, 0) 97 # 98 rh, ih = snes.history 99 self.assertTrue(len(rh) == 0) 100 self.assertTrue(len(ih) == 0) 101 # 102 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 103 snes.reason = reason 104 self.assertEqual(snes.reason, reason) 105 self.assertTrue(snes.is_converged) 106 self.assertFalse(snes.is_diverged) 107 self.assertFalse(snes.is_iterating) 108 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 109 snes.reason = reason 110 self.assertEqual(snes.reason, reason) 111 self.assertFalse(snes.is_converged) 112 self.assertTrue(snes.is_diverged) 113 self.assertFalse(snes.is_iterating) 114 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 115 snes.reason = reason 116 self.assertEqual(snes.reason, reason) 117 self.assertFalse(snes.is_converged) 118 self.assertFalse(snes.is_diverged) 119 self.assertTrue(snes.is_iterating) 120 # 121 self.assertFalse(snes.use_ew) 122 self.assertFalse(snes.use_mf) 123 self.assertFalse(snes.use_fd) 124 ouse = snes.use_ksp 125 self.assertEqual(ouse, snes.getUseKSP()) 126 snes.use_ksp = not ouse 127 self.assertEqual(not ouse, snes.getUseKSP()) 128 snes.setUseKSP(ouse) 129 self.assertEqual(ouse, snes.use_ksp) 130 131 def testGetSetFunc(self): 132 r, func = self.snes.getFunction() 133 self.assertFalse(r) 134 self.assertTrue(func is None) 135 r = PETSc.Vec().createSeq(2) 136 func = Function() 137 refcnt = getrefcount(func) 138 self.snes.setFunction(func, r) 139 self.snes.setFunction(func, r) 140 self.assertEqual(getrefcount(func), refcnt + 1) 141 r2, func2 = self.snes.getFunction() 142 self.assertEqual(r, r2) 143 self.assertEqual(func, func2[0]) 144 self.assertEqual(getrefcount(func), refcnt + 1) 145 r3, func3 = self.snes.getFunction() 146 self.assertEqual(r, r3) 147 self.assertEqual(func, func3[0]) 148 self.assertEqual(getrefcount(func), refcnt + 1) 149 150 def testCompFunc(self): 151 r = PETSc.Vec().createSeq(2) 152 func = Function() 153 self.snes.setFunction(func, r) 154 x, y = r.duplicate(), r.duplicate() 155 x[0], x[1] = [1, 2] 156 self.snes.computeFunction(x, y) 157 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 158 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 159 160 def testGetSetJac(self): 161 A, P, jac = self.snes.getJacobian() 162 self.assertFalse(A) 163 self.assertFalse(P) 164 self.assertTrue(jac is None) 165 J = PETSc.Mat().create(PETSc.COMM_SELF) 166 J.setSizes([2, 2]) 167 J.setType(PETSc.Mat.Type.SEQAIJ) 168 J.setUp() 169 jac = Jacobian() 170 refcnt = getrefcount(jac) 171 self.snes.setJacobian(jac, J) 172 self.snes.setJacobian(jac, J) 173 self.assertEqual(getrefcount(jac), refcnt + 1) 174 J2, P2, jac2 = self.snes.getJacobian() 175 self.assertEqual(J, J2) 176 self.assertEqual(J2, P2) 177 self.assertEqual(jac, jac2[0]) 178 self.assertEqual(getrefcount(jac), refcnt + 1) 179 J3, P3, jac3 = self.snes.getJacobian() 180 self.assertEqual(J, J3) 181 self.assertEqual(J3, P3) 182 self.assertEqual(jac, jac3[0]) 183 self.assertEqual(getrefcount(jac), refcnt + 1) 184 185 def testCompJac(self): 186 J = PETSc.Mat().create(PETSc.COMM_SELF) 187 J.setSizes([2, 2]) 188 J.setType(PETSc.Mat.Type.SEQAIJ) 189 J.setUp() 190 jac = Jacobian() 191 self.snes.setJacobian(jac, J) 192 x = PETSc.Vec().createSeq(2) 193 x[0], x[1] = [1, 2] 194 self.snes.getKSP().getPC() 195 self.snes.computeJacobian(x, J) 196 197 def testGetSetUpd(self): 198 self.assertTrue(self.snes.getUpdate() is None) 199 upd = lambda snes, it: None 200 refcnt = getrefcount(upd) 201 self.snes.setUpdate(upd) 202 self.assertEqual(getrefcount(upd), refcnt + 1) 203 self.snes.setUpdate(upd) 204 self.assertEqual(getrefcount(upd), refcnt + 1) 205 self.snes.setUpdate(None) 206 self.assertTrue(self.snes.getUpdate() is None) 207 self.assertEqual(getrefcount(upd), refcnt) 208 self.snes.setUpdate(upd) 209 self.assertEqual(getrefcount(upd), refcnt + 1) 210 upd2 = lambda snes, it: None 211 refcnt2 = getrefcount(upd2) 212 self.snes.setUpdate(upd2) 213 self.assertEqual(getrefcount(upd), refcnt) 214 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 215 tmp = self.snes.getUpdate()[0] 216 self.assertTrue(tmp is upd2) 217 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 218 del tmp 219 self.snes.setUpdate(None) 220 self.assertTrue(self.snes.getUpdate() is None) 221 self.assertEqual(getrefcount(upd2), refcnt2) 222 223 def testGetKSP(self): 224 ksp = self.snes.getKSP() 225 self.assertEqual(ksp.getRefCount(), 2) 226 227 def testSolve(self): 228 J = PETSc.Mat().create(PETSc.COMM_SELF) 229 J.setSizes([2, 2]) 230 J.setType(PETSc.Mat.Type.SEQAIJ) 231 J.setUp() 232 r = PETSc.Vec().createSeq(2) 233 x = PETSc.Vec().createSeq(2) 234 b = PETSc.Vec().createSeq(2) 235 if self.snes.getType() == PETSc.SNES.Type.NEWTONAL: 236 self.snes.setFunction(FunctionAL(), r) 237 self.snes.setNewtonALCorrectionType(PETSc.SNES.NewtonALCorrectionType.EXACT) 238 self.snes.setNewtonALFunction(FunctionALLoad()) 239 else: 240 self.snes.setFunction(Function(), r) 241 self.snes.setJacobian(Jacobian(), J) 242 243 def _update(snes, it, cnt): 244 cnt += 1 245 cnt_up = np.array(0) 246 self.snes.setUpdate(_update, (cnt_up,) ) 247 248 x.setArray([2, 3]) 249 b.set(0) 250 self.snes.setConvergenceHistory() 251 self.snes.setFromOptions() 252 self.snes.solve(b, x) 253 self.snes.setUpdate(None) 254 rh, ih = self.snes.getConvergenceHistory() 255 self.snes.setConvergenceHistory(0, reset=True) 256 rh, ih = self.snes.getConvergenceHistory() 257 self.assertEqual(len(rh), 0) 258 self.assertEqual(len(ih), 0) 259 if self.snes.getType() != PETSc.SNES.Type.NEWTONAL: 260 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 261 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 262 self.assertEqual(self.snes.getIterationNumber(), cnt_up) 263 # XXX this test should not be here ! 264 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 265 self.assertTrue(reason > 0) 266 267 # test interface 268 x = self.snes.getSolution() 269 x.setArray([2, 3]) 270 self.snes.solve() 271 if self.snes.getType() != PETSc.SNES.Type.NEWTONAL: 272 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 273 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 274 275 def testResetAndSolve(self): 276 self.snes.reset() 277 self.testSolve() 278 self.snes.reset() 279 self.testSolve() 280 self.snes.reset() 281 282 def testSetMonitor(self): 283 reshist = {} 284 285 def monitor(snes, its, fgnorm): 286 reshist[its] = fgnorm 287 288 refcnt = getrefcount(monitor) 289 self.snes.setMonitor(monitor) 290 self.assertEqual(getrefcount(monitor), refcnt + 1) 291 self.testSolve() 292 self.assertTrue(len(reshist) > 0) 293 reshist = {} 294 self.snes.monitorCancel() 295 self.assertEqual(getrefcount(monitor), refcnt) 296 self.testSolve() 297 self.assertTrue(len(reshist) == 0) 298 self.snes.setMonitor(monitor) 299 self.snes.monitor(1, 7) 300 self.assertTrue(reshist[1] == 7) 301 ## Monitor = PETSc.SNES.Monitor 302 ## self.snes.setMonitor(Monitor()) 303 ## self.snes.setMonitor(Monitor.DEFAULT) 304 ## self.snes.setMonitor(Monitor.SOLUTION) 305 ## self.snes.setMonitor(Monitor.RESIDUAL) 306 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 307 308 def testSetGetStepFails(self): 309 its = self.snes.getIterationNumber() 310 self.assertEqual(its, 0) 311 fails = self.snes.getNonlinearStepFailures() 312 self.assertEqual(fails, 0) 313 fails = self.snes.getMaxNonlinearStepFailures() 314 self.assertEqual(fails, 1) 315 self.snes.setMaxNonlinearStepFailures(5) 316 fails = self.snes.getMaxNonlinearStepFailures() 317 self.assertEqual(fails, 5) 318 self.snes.setMaxNonlinearStepFailures(1) 319 fails = self.snes.getMaxNonlinearStepFailures() 320 self.assertEqual(fails, 1) 321 322 def testSetGetLinFails(self): 323 its = self.snes.getLinearSolveIterations() 324 self.assertEqual(its, 0) 325 fails = self.snes.getLinearSolveFailures() 326 self.assertEqual(fails, 0) 327 fails = self.snes.getMaxLinearSolveFailures() 328 self.assertEqual(fails, 1) 329 self.snes.setMaxLinearSolveFailures(5) 330 fails = self.snes.getMaxLinearSolveFailures() 331 self.assertEqual(fails, 5) 332 self.snes.setMaxLinearSolveFailures(1) 333 fails = self.snes.getMaxLinearSolveFailures() 334 self.assertEqual(fails, 1) 335 336 def testEW(self): 337 self.snes.setUseEW(False) 338 self.assertFalse(self.snes.getUseEW()) 339 self.snes.setUseEW(True) 340 self.assertTrue(self.snes.getUseEW()) 341 params = self.snes.getParamsEW() 342 params['version'] = 1 343 self.snes.setParamsEW(**params) 344 params = self.snes.getParamsEW() 345 self.assertEqual(params['version'], 1) 346 params['version'] = PETSc.CURRENT 347 self.snes.setParamsEW(**params) 348 params = self.snes.getParamsEW() 349 self.assertEqual(params['version'], 1) 350 351 def testMF(self): 352 # self.snes.setOptionsPrefix('MF-') 353 # opts = PETSc.Options(self.snes) 354 # opts['mat_mffd_type'] = 'ds' 355 # opts['snes_monitor'] = 'stdout' 356 # opts['ksp_monitor'] = 'stdout' 357 # opts['snes_view'] = 'stdout' 358 J = PETSc.Mat().create(PETSc.COMM_SELF) 359 J.setSizes([2, 2]) 360 J.setType(PETSc.Mat.Type.SEQAIJ) 361 J.setUp() 362 r = PETSc.Vec().createSeq(2) 363 x = PETSc.Vec().createSeq(2) 364 b = PETSc.Vec().createSeq(2) 365 fun = Function() 366 jac = Jacobian() 367 self.snes.setFunction(fun, r) 368 self.snes.setJacobian(jac, J) 369 self.assertFalse(self.snes.getUseMF()) 370 self.snes.setUseMF(False) 371 self.assertFalse(self.snes.getUseMF()) 372 self.snes.setUseMF(True) 373 self.assertTrue(self.snes.getUseMF()) 374 self.snes.setFromOptions() 375 if self.snes.getType() == PETSc.SNES.Type.NEWTONLS: 376 x.setArray([2, 3]) 377 b.set(0) 378 self.snes.solve(b, x) 379 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 380 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 381 382 def testFDColor(self): 383 J = PETSc.Mat().create(PETSc.COMM_SELF) 384 J.setSizes([2, 2]) 385 J.setType(PETSc.Mat.Type.SEQAIJ) 386 J.setUp() 387 r = PETSc.Vec().createSeq(2) 388 x = PETSc.Vec().createSeq(2) 389 b = PETSc.Vec().createSeq(2) 390 fun = Function() 391 jac = Jacobian() 392 self.snes.setFunction(fun, r) 393 self.snes.setJacobian(jac, J) 394 self.assertFalse(self.snes.getUseFD()) 395 jac(self.snes, x, J, J) 396 self.snes.setUseFD(False) 397 self.assertFalse(self.snes.getUseFD()) 398 self.snes.setUseFD(True) 399 self.assertTrue(self.snes.getUseFD()) 400 self.snes.setFromOptions() 401 x.setArray([2, 3]) 402 b.set(0) 403 self.snes.solve(b, x) 404 if self.snes.getType() != PETSc.SNES.Type.NEWTONAL: 405 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 406 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 407 408 def testNPC(self): 409 self.snes.appctx = (1, 2, 3) 410 npc = self.snes.getNPC() 411 self.assertEqual(npc.appctx, (1, 2, 3)) 412 413 def testTRAPI(self): 414 newreg = (1,2,3) 415 newup = (1,2,3,4,5) 416 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 417 defreg = self.snes.getTRTolerances() 418 defup = self.snes.getTRUpdateParameters() 419 self.snes.setTRTolerances(*newreg) 420 self.snes.setTRUpdateParameters(*newup) 421 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 422 self.assertEqual(newreg, self.snes.getTRTolerances()) 423 self.assertEqual(newup, self.snes.getTRUpdateParameters()) 424 self.snes.setTRTolerances() 425 self.snes.setTRUpdateParameters() 426 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 427 self.assertEqual(newreg, self.snes.getTRTolerances()) 428 self.assertEqual(newup, self.snes.getTRUpdateParameters()) 429 self.snes.setTRTolerances(*(PETSc.DETERMINE,)*3) 430 self.snes.setTRUpdateParameters(*(PETSc.DETERMINE,)*5) 431 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 432 self.assertEqual(defreg, self.snes.getTRTolerances()) 433 self.assertEqual(defup, self.snes.getTRUpdateParameters()) 434 435# -------------------------------------------------------------------- 436 437 438class TestSNESLS(BaseTestSNES, unittest.TestCase): 439 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 440 441 442class TestSNESTR(BaseTestSNES, unittest.TestCase): 443 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 444 445 446class TestSNESAL(BaseTestSNES, unittest.TestCase): 447 SNES_TYPE = PETSc.SNES.Type.NEWTONAL 448 449 450# -------------------------------------------------------------------- 451 452 453class TestSNESLineSearchAPI(unittest.TestCase): 454 def test_create_destroy(self): 455 ls = PETSc.SNESLineSearch() 456 ls.create() 457 ls.destroy() 458 459 def test_type_set_get(self): 460 ls = PETSc.SNESLineSearch().create() 461 ls.setType(PETSc.SNESLineSearch.Type.BASIC) 462 typ = ls.getType() 463 self.assertEqual(typ, 'basic') 464 ls.destroy() 465 466 def test_tolerances_set_get(self): 467 ls = PETSc.SNESLineSearch().create() 468 ls.setTolerances(rtol=0.125, atol=3, minstep=4, ltol=5, maxstep=6, max_its=7) 469 minstep, maxstep, rtol, atol, ltol, max_its = ls.getTolerances() 470 self.assertEqual(rtol, 0.125) 471 self.assertEqual(atol, 3) 472 self.assertEqual(minstep, 4) 473 self.assertEqual(ltol, 5) 474 self.assertEqual(maxstep, 6) 475 self.assertEqual(max_its, 7) 476 ls.destroy() 477 478 def test_order_set_get(self): 479 ls = PETSc.SNESLineSearch().create() 480 ls.setOrder(2) 481 order = ls.getOrder() 482 self.assertEqual(order, 2) 483 ls.destroy() 484 485 def test_set_from_options(self): 486 ls = PETSc.SNESLineSearch().create() 487 ls.setFromOptions() 488 # ls.view() 489 ls.destroy() 490 491 def test_snes_linesearch_property(self): 492 snes = PETSc.SNES().create() 493 ls = snes.getLineSearch() 494 self.assertTrue(isinstance(ls, PETSc.SNESLineSearch)) 495 # Set/get via property 496 self.assertEqual(snes.linesearch, ls) 497 snes.linesearch = ls 498 self.assertEqual(snes.linesearch, ls) 499 snes.destroy() 500 501 502# -------------------------------------------------------------------- 503 504if __name__ == '__main__': 505 unittest.main() 506