1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy as np 7 8# -------------------------------------------------------------------- 9 10 11class Function: 12 def __call__(self, snes, x, f): 13 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 14 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 15 f.assemble() 16 17 18class Jacobian: 19 def __call__(self, snes, x, J, P): 20 P[0, 0] = (2.0 * x[0] + x[1]).item() 21 P[0, 1] = (x[0]).item() 22 P[1, 0] = (x[1]).item() 23 P[1, 1] = (x[0] + 2.0 * x[1]).item() 24 P.assemble() 25 if J != P: 26 J.assemble() 27 28 29 30# -------------------------------------------------------------------- 31 32 33class BaseTestSNES: 34 SNES_TYPE = None 35 36 def setUp(self): 37 snes = PETSc.SNES() 38 snes.create(PETSc.COMM_SELF) 39 if self.SNES_TYPE: 40 snes.setType(self.SNES_TYPE) 41 self.snes = snes 42 43 def tearDown(self): 44 self.snes = None 45 PETSc.garbage_cleanup() 46 47 def testGetSetType(self): 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 self.snes.setType(self.SNES_TYPE) 50 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 51 52 def testTols(self): 53 tols = self.snes.getTolerances() 54 self.snes.setTolerances(*tols) 55 tnames = ('rtol', 'atol', 'stol', 'max_it') 56 tolvals = [getattr(self.snes, t) for t in tnames] 57 self.assertEqual(tuple(tols), tuple(tolvals)) 58 dtol = self.snes.getDivergenceTolerance() 59 self.assertTrue(dtol > 0) 60 self.snes.setDivergenceTolerance(PETSc.UNLIMITED) 61 dtol = self.snes.getDivergenceTolerance() 62 self.assertEqual(dtol, PETSc.UNLIMITED) 63 self.snes.setDivergenceTolerance(PETSc.CURRENT) 64 self.assertEqual(dtol, PETSc.UNLIMITED) 65 66 def testProperties(self): 67 snes = self.snes 68 # 69 snes.appctx = (1, 2, 3) 70 self.assertEqual(snes.appctx, (1, 2, 3)) 71 snes.appctx = None 72 self.assertEqual(snes.appctx, None) 73 # 74 snes.its = 1 75 self.assertEqual(snes.its, 1) 76 snes.its = 0 77 self.assertEqual(snes.its, 0) 78 # 79 snes.norm = 1 80 self.assertEqual(snes.norm, 1) 81 snes.norm = 0 82 self.assertEqual(snes.norm, 0) 83 # 84 rh, ih = snes.history 85 self.assertTrue(len(rh) == 0) 86 self.assertTrue(len(ih) == 0) 87 # 88 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 89 snes.reason = reason 90 self.assertEqual(snes.reason, reason) 91 self.assertTrue(snes.is_converged) 92 self.assertFalse(snes.is_diverged) 93 self.assertFalse(snes.is_iterating) 94 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 95 snes.reason = reason 96 self.assertEqual(snes.reason, reason) 97 self.assertFalse(snes.is_converged) 98 self.assertTrue(snes.is_diverged) 99 self.assertFalse(snes.is_iterating) 100 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 101 snes.reason = reason 102 self.assertEqual(snes.reason, reason) 103 self.assertFalse(snes.is_converged) 104 self.assertFalse(snes.is_diverged) 105 self.assertTrue(snes.is_iterating) 106 # 107 self.assertFalse(snes.use_ew) 108 self.assertFalse(snes.use_mf) 109 self.assertFalse(snes.use_fd) 110 ouse = snes.use_ksp 111 self.assertEqual(ouse, snes.getUseKSP()) 112 snes.use_ksp = not ouse 113 self.assertEqual(not ouse, snes.getUseKSP()) 114 snes.setUseKSP(ouse) 115 self.assertEqual(ouse, snes.use_ksp) 116 117 def testGetSetFunc(self): 118 r, func = self.snes.getFunction() 119 self.assertFalse(r) 120 self.assertTrue(func is None) 121 r = PETSc.Vec().createSeq(2) 122 func = Function() 123 refcnt = getrefcount(func) 124 self.snes.setFunction(func, r) 125 self.snes.setFunction(func, r) 126 self.assertEqual(getrefcount(func), refcnt + 1) 127 r2, func2 = self.snes.getFunction() 128 self.assertEqual(r, r2) 129 self.assertEqual(func, func2[0]) 130 self.assertEqual(getrefcount(func), refcnt + 1) 131 r3, func3 = self.snes.getFunction() 132 self.assertEqual(r, r3) 133 self.assertEqual(func, func3[0]) 134 self.assertEqual(getrefcount(func), refcnt + 1) 135 136 def testCompFunc(self): 137 r = PETSc.Vec().createSeq(2) 138 func = Function() 139 self.snes.setFunction(func, r) 140 x, y = r.duplicate(), r.duplicate() 141 x[0], x[1] = [1, 2] 142 self.snes.computeFunction(x, y) 143 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 144 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 145 146 def testGetSetJac(self): 147 A, P, jac = self.snes.getJacobian() 148 self.assertFalse(A) 149 self.assertFalse(P) 150 self.assertTrue(jac is None) 151 J = PETSc.Mat().create(PETSc.COMM_SELF) 152 J.setSizes([2, 2]) 153 J.setType(PETSc.Mat.Type.SEQAIJ) 154 J.setUp() 155 jac = Jacobian() 156 refcnt = getrefcount(jac) 157 self.snes.setJacobian(jac, J) 158 self.snes.setJacobian(jac, J) 159 self.assertEqual(getrefcount(jac), refcnt + 1) 160 J2, P2, jac2 = self.snes.getJacobian() 161 self.assertEqual(J, J2) 162 self.assertEqual(J2, P2) 163 self.assertEqual(jac, jac2[0]) 164 self.assertEqual(getrefcount(jac), refcnt + 1) 165 J3, P3, jac3 = self.snes.getJacobian() 166 self.assertEqual(J, J3) 167 self.assertEqual(J3, P3) 168 self.assertEqual(jac, jac3[0]) 169 self.assertEqual(getrefcount(jac), refcnt + 1) 170 171 def testCompJac(self): 172 J = PETSc.Mat().create(PETSc.COMM_SELF) 173 J.setSizes([2, 2]) 174 J.setType(PETSc.Mat.Type.SEQAIJ) 175 J.setUp() 176 jac = Jacobian() 177 self.snes.setJacobian(jac, J) 178 x = PETSc.Vec().createSeq(2) 179 x[0], x[1] = [1, 2] 180 self.snes.getKSP().getPC() 181 self.snes.computeJacobian(x, J) 182 183 def testGetSetUpd(self): 184 self.assertTrue(self.snes.getUpdate() is None) 185 upd = lambda snes, it: None 186 refcnt = getrefcount(upd) 187 self.snes.setUpdate(upd) 188 self.assertEqual(getrefcount(upd), refcnt + 1) 189 self.snes.setUpdate(upd) 190 self.assertEqual(getrefcount(upd), refcnt + 1) 191 self.snes.setUpdate(None) 192 self.assertTrue(self.snes.getUpdate() is None) 193 self.assertEqual(getrefcount(upd), refcnt) 194 self.snes.setUpdate(upd) 195 self.assertEqual(getrefcount(upd), refcnt + 1) 196 upd2 = lambda snes, it: None 197 refcnt2 = getrefcount(upd2) 198 self.snes.setUpdate(upd2) 199 self.assertEqual(getrefcount(upd), refcnt) 200 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 201 tmp = self.snes.getUpdate()[0] 202 self.assertTrue(tmp is upd2) 203 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 204 del tmp 205 self.snes.setUpdate(None) 206 self.assertTrue(self.snes.getUpdate() is None) 207 self.assertEqual(getrefcount(upd2), refcnt2) 208 209 def testGetKSP(self): 210 ksp = self.snes.getKSP() 211 self.assertEqual(ksp.getRefCount(), 2) 212 213 def testSolve(self): 214 J = PETSc.Mat().create(PETSc.COMM_SELF) 215 J.setSizes([2, 2]) 216 J.setType(PETSc.Mat.Type.SEQAIJ) 217 J.setUp() 218 r = PETSc.Vec().createSeq(2) 219 x = PETSc.Vec().createSeq(2) 220 b = PETSc.Vec().createSeq(2) 221 self.snes.setFunction(Function(), r) 222 self.snes.setJacobian(Jacobian(), J) 223 224 def _update(snes, it, cnt): 225 cnt += 1 226 cnt_up = np.array(0) 227 self.snes.setUpdate(_update, (cnt_up,) ) 228 229 x.setArray([2, 3]) 230 b.set(0) 231 self.snes.setConvergenceHistory() 232 self.snes.setFromOptions() 233 self.snes.solve(b, x) 234 self.snes.setUpdate(None) 235 rh, ih = self.snes.getConvergenceHistory() 236 self.snes.setConvergenceHistory(0, reset=True) 237 rh, ih = self.snes.getConvergenceHistory() 238 self.assertEqual(len(rh), 0) 239 self.assertEqual(len(ih), 0) 240 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 241 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 242 self.assertEqual(self.snes.getIterationNumber(), cnt_up) 243 # XXX this test should not be here ! 244 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 245 self.assertTrue(reason > 0) 246 247 # test interface 248 x = self.snes.getSolution() 249 x.setArray([2, 3]) 250 self.snes.solve() 251 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 252 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 253 254 def testResetAndSolve(self): 255 self.snes.reset() 256 self.testSolve() 257 self.snes.reset() 258 self.testSolve() 259 self.snes.reset() 260 261 def testSetMonitor(self): 262 reshist = {} 263 264 def monitor(snes, its, fgnorm): 265 reshist[its] = fgnorm 266 267 refcnt = getrefcount(monitor) 268 self.snes.setMonitor(monitor) 269 self.assertEqual(getrefcount(monitor), refcnt + 1) 270 self.testSolve() 271 self.assertTrue(len(reshist) > 0) 272 reshist = {} 273 self.snes.monitorCancel() 274 self.assertEqual(getrefcount(monitor), refcnt) 275 self.testSolve() 276 self.assertTrue(len(reshist) == 0) 277 self.snes.setMonitor(monitor) 278 self.snes.monitor(1, 7) 279 self.assertTrue(reshist[1] == 7) 280 ## Monitor = PETSc.SNES.Monitor 281 ## self.snes.setMonitor(Monitor()) 282 ## self.snes.setMonitor(Monitor.DEFAULT) 283 ## self.snes.setMonitor(Monitor.SOLUTION) 284 ## self.snes.setMonitor(Monitor.RESIDUAL) 285 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 286 287 def testSetGetStepFails(self): 288 its = self.snes.getIterationNumber() 289 self.assertEqual(its, 0) 290 fails = self.snes.getNonlinearStepFailures() 291 self.assertEqual(fails, 0) 292 fails = self.snes.getMaxNonlinearStepFailures() 293 self.assertEqual(fails, 1) 294 self.snes.setMaxNonlinearStepFailures(5) 295 fails = self.snes.getMaxNonlinearStepFailures() 296 self.assertEqual(fails, 5) 297 self.snes.setMaxNonlinearStepFailures(1) 298 fails = self.snes.getMaxNonlinearStepFailures() 299 self.assertEqual(fails, 1) 300 301 def testSetGetLinFails(self): 302 its = self.snes.getLinearSolveIterations() 303 self.assertEqual(its, 0) 304 fails = self.snes.getLinearSolveFailures() 305 self.assertEqual(fails, 0) 306 fails = self.snes.getMaxLinearSolveFailures() 307 self.assertEqual(fails, 1) 308 self.snes.setMaxLinearSolveFailures(5) 309 fails = self.snes.getMaxLinearSolveFailures() 310 self.assertEqual(fails, 5) 311 self.snes.setMaxLinearSolveFailures(1) 312 fails = self.snes.getMaxLinearSolveFailures() 313 self.assertEqual(fails, 1) 314 315 def testEW(self): 316 self.snes.setUseEW(False) 317 self.assertFalse(self.snes.getUseEW()) 318 self.snes.setUseEW(True) 319 self.assertTrue(self.snes.getUseEW()) 320 params = self.snes.getParamsEW() 321 params['version'] = 1 322 self.snes.setParamsEW(**params) 323 params = self.snes.getParamsEW() 324 self.assertEqual(params['version'], 1) 325 params['version'] = PETSc.CURRENT 326 self.snes.setParamsEW(**params) 327 params = self.snes.getParamsEW() 328 self.assertEqual(params['version'], 1) 329 330 def testMF(self): 331 # self.snes.setOptionsPrefix('MF-') 332 # opts = PETSc.Options(self.snes) 333 # opts['mat_mffd_type'] = 'ds' 334 # opts['snes_monitor'] = 'stdout' 335 # opts['ksp_monitor'] = 'stdout' 336 # opts['snes_view'] = 'stdout' 337 J = PETSc.Mat().create(PETSc.COMM_SELF) 338 J.setSizes([2, 2]) 339 J.setType(PETSc.Mat.Type.SEQAIJ) 340 J.setUp() 341 r = PETSc.Vec().createSeq(2) 342 x = PETSc.Vec().createSeq(2) 343 b = PETSc.Vec().createSeq(2) 344 fun = Function() 345 jac = Jacobian() 346 self.snes.setFunction(fun, r) 347 self.snes.setJacobian(jac, J) 348 self.assertFalse(self.snes.getUseMF()) 349 self.snes.setUseMF(False) 350 self.assertFalse(self.snes.getUseMF()) 351 self.snes.setUseMF(True) 352 self.assertTrue(self.snes.getUseMF()) 353 self.snes.setFromOptions() 354 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 355 x.setArray([2, 3]) 356 b.set(0) 357 self.snes.solve(b, x) 358 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 359 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 360 361 def testFDColor(self): 362 J = PETSc.Mat().create(PETSc.COMM_SELF) 363 J.setSizes([2, 2]) 364 J.setType(PETSc.Mat.Type.SEQAIJ) 365 J.setUp() 366 r = PETSc.Vec().createSeq(2) 367 x = PETSc.Vec().createSeq(2) 368 b = PETSc.Vec().createSeq(2) 369 fun = Function() 370 jac = Jacobian() 371 self.snes.setFunction(fun, r) 372 self.snes.setJacobian(jac, J) 373 self.assertFalse(self.snes.getUseFD()) 374 jac(self.snes, x, J, J) 375 self.snes.setUseFD(False) 376 self.assertFalse(self.snes.getUseFD()) 377 self.snes.setUseFD(True) 378 self.assertTrue(self.snes.getUseFD()) 379 self.snes.setFromOptions() 380 x.setArray([2, 3]) 381 b.set(0) 382 self.snes.solve(b, x) 383 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 384 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 385 386 def testNPC(self): 387 self.snes.appctx = (1, 2, 3) 388 npc = self.snes.getNPC() 389 self.assertEqual(npc.appctx, (1, 2, 3)) 390 391 def testTRAPI(self): 392 newreg = (1,2,3) 393 newup = (1,2,3,4,5) 394 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 395 defreg = self.snes.getTRTolerances() 396 defup = self.snes.getTRUpdateParameters() 397 self.snes.setTRTolerances(*newreg) 398 self.snes.setTRUpdateParameters(*newup) 399 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 400 self.assertEqual(newreg, self.snes.getTRTolerances()) 401 self.assertEqual(newup, self.snes.getTRUpdateParameters()) 402 self.snes.setTRTolerances() 403 self.snes.setTRUpdateParameters() 404 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 405 self.assertEqual(newreg, self.snes.getTRTolerances()) 406 self.assertEqual(newup, self.snes.getTRUpdateParameters()) 407 self.snes.setTRTolerances(*(PETSc.DETERMINE,)*3) 408 self.snes.setTRUpdateParameters(*(PETSc.DETERMINE,)*5) 409 if self.snes.getType() == PETSc.SNES.Type.NEWTONTR: 410 self.assertEqual(defreg, self.snes.getTRTolerances()) 411 self.assertEqual(defup, self.snes.getTRUpdateParameters()) 412 413# -------------------------------------------------------------------- 414 415 416class TestSNESLS(BaseTestSNES, unittest.TestCase): 417 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 418 419 420class TestSNESTR(BaseTestSNES, unittest.TestCase): 421 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 422 423 424# -------------------------------------------------------------------- 425 426if __name__ == '__main__': 427 unittest.main() 428