1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy as np 7 8# -------------------------------------------------------------------- 9 10 11class Function: 12 def __call__(self, snes, x, f): 13 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 14 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 15 f.assemble() 16 17 18class Jacobian: 19 def __call__(self, snes, x, J, P): 20 P[0, 0] = (2.0 * x[0] + x[1]).item() 21 P[0, 1] = (x[0]).item() 22 P[1, 0] = (x[1]).item() 23 P[1, 1] = (x[0] + 2.0 * x[1]).item() 24 P.assemble() 25 if J != P: 26 J.assemble() 27 28 29 30# -------------------------------------------------------------------- 31 32 33class BaseTestSNES: 34 SNES_TYPE = None 35 36 def setUp(self): 37 snes = PETSc.SNES() 38 snes.create(PETSc.COMM_SELF) 39 if self.SNES_TYPE: 40 snes.setType(self.SNES_TYPE) 41 self.snes = snes 42 43 def tearDown(self): 44 self.snes = None 45 PETSc.garbage_cleanup() 46 47 def testGetSetType(self): 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 self.snes.setType(self.SNES_TYPE) 50 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 51 52 def testTols(self): 53 tols = self.snes.getTolerances() 54 self.snes.setTolerances(*tols) 55 tnames = ('rtol', 'atol', 'stol', 'max_it') 56 tolvals = [getattr(self.snes, t) for t in tnames] 57 self.assertEqual(tuple(tols), tuple(tolvals)) 58 dtol = self.snes.getDivergenceTolerance() 59 self.assertTrue(dtol > 0) 60 self.snes.setDivergenceTolerance(-1) 61 dtol = self.snes.getDivergenceTolerance() 62 self.assertEqual(dtol, -1) 63 self.snes.setDivergenceTolerance(PETSc.DEFAULT) 64 self.assertEqual(dtol, -1) 65 66 def testProperties(self): 67 snes = self.snes 68 # 69 snes.appctx = (1, 2, 3) 70 self.assertEqual(snes.appctx, (1, 2, 3)) 71 snes.appctx = None 72 self.assertEqual(snes.appctx, None) 73 # 74 snes.its = 1 75 self.assertEqual(snes.its, 1) 76 snes.its = 0 77 self.assertEqual(snes.its, 0) 78 # 79 snes.norm = 1 80 self.assertEqual(snes.norm, 1) 81 snes.norm = 0 82 self.assertEqual(snes.norm, 0) 83 # 84 rh, ih = snes.history 85 self.assertTrue(len(rh) == 0) 86 self.assertTrue(len(ih) == 0) 87 # 88 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 89 snes.reason = reason 90 self.assertEqual(snes.reason, reason) 91 self.assertTrue(snes.is_converged) 92 self.assertFalse(snes.is_diverged) 93 self.assertFalse(snes.is_iterating) 94 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 95 snes.reason = reason 96 self.assertEqual(snes.reason, reason) 97 self.assertFalse(snes.is_converged) 98 self.assertTrue(snes.is_diverged) 99 self.assertFalse(snes.is_iterating) 100 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 101 snes.reason = reason 102 self.assertEqual(snes.reason, reason) 103 self.assertFalse(snes.is_converged) 104 self.assertFalse(snes.is_diverged) 105 self.assertTrue(snes.is_iterating) 106 # 107 self.assertFalse(snes.use_ew) 108 self.assertFalse(snes.use_mf) 109 self.assertFalse(snes.use_fd) 110 111 def testGetSetFunc(self): 112 r, func = self.snes.getFunction() 113 self.assertFalse(r) 114 self.assertTrue(func is None) 115 r = PETSc.Vec().createSeq(2) 116 func = Function() 117 refcnt = getrefcount(func) 118 self.snes.setFunction(func, r) 119 self.snes.setFunction(func, r) 120 self.assertEqual(getrefcount(func), refcnt + 1) 121 r2, func2 = self.snes.getFunction() 122 self.assertEqual(r, r2) 123 self.assertEqual(func, func2[0]) 124 self.assertEqual(getrefcount(func), refcnt + 1) 125 r3, func3 = self.snes.getFunction() 126 self.assertEqual(r, r3) 127 self.assertEqual(func, func3[0]) 128 self.assertEqual(getrefcount(func), refcnt + 1) 129 130 def testCompFunc(self): 131 r = PETSc.Vec().createSeq(2) 132 func = Function() 133 self.snes.setFunction(func, r) 134 x, y = r.duplicate(), r.duplicate() 135 x[0], x[1] = [1, 2] 136 self.snes.computeFunction(x, y) 137 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 138 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 139 140 def testGetSetJac(self): 141 A, P, jac = self.snes.getJacobian() 142 self.assertFalse(A) 143 self.assertFalse(P) 144 self.assertTrue(jac is None) 145 J = PETSc.Mat().create(PETSc.COMM_SELF) 146 J.setSizes([2, 2]) 147 J.setType(PETSc.Mat.Type.SEQAIJ) 148 J.setUp() 149 jac = Jacobian() 150 refcnt = getrefcount(jac) 151 self.snes.setJacobian(jac, J) 152 self.snes.setJacobian(jac, J) 153 self.assertEqual(getrefcount(jac), refcnt + 1) 154 J2, P2, jac2 = self.snes.getJacobian() 155 self.assertEqual(J, J2) 156 self.assertEqual(J2, P2) 157 self.assertEqual(jac, jac2[0]) 158 self.assertEqual(getrefcount(jac), refcnt + 1) 159 J3, P3, jac3 = self.snes.getJacobian() 160 self.assertEqual(J, J3) 161 self.assertEqual(J3, P3) 162 self.assertEqual(jac, jac3[0]) 163 self.assertEqual(getrefcount(jac), refcnt + 1) 164 165 def testCompJac(self): 166 J = PETSc.Mat().create(PETSc.COMM_SELF) 167 J.setSizes([2, 2]) 168 J.setType(PETSc.Mat.Type.SEQAIJ) 169 J.setUp() 170 jac = Jacobian() 171 self.snes.setJacobian(jac, J) 172 x = PETSc.Vec().createSeq(2) 173 x[0], x[1] = [1, 2] 174 self.snes.getKSP().getPC() 175 self.snes.computeJacobian(x, J) 176 177 def testGetSetUpd(self): 178 self.assertTrue(self.snes.getUpdate() is None) 179 upd = lambda snes, it: None 180 refcnt = getrefcount(upd) 181 self.snes.setUpdate(upd) 182 self.assertEqual(getrefcount(upd), refcnt + 1) 183 self.snes.setUpdate(upd) 184 self.assertEqual(getrefcount(upd), refcnt + 1) 185 self.snes.setUpdate(None) 186 self.assertTrue(self.snes.getUpdate() is None) 187 self.assertEqual(getrefcount(upd), refcnt) 188 self.snes.setUpdate(upd) 189 self.assertEqual(getrefcount(upd), refcnt + 1) 190 upd2 = lambda snes, it: None 191 refcnt2 = getrefcount(upd2) 192 self.snes.setUpdate(upd2) 193 self.assertEqual(getrefcount(upd), refcnt) 194 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 195 tmp = self.snes.getUpdate()[0] 196 self.assertTrue(tmp is upd2) 197 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 198 del tmp 199 self.snes.setUpdate(None) 200 self.assertTrue(self.snes.getUpdate() is None) 201 self.assertEqual(getrefcount(upd2), refcnt2) 202 203 def testGetKSP(self): 204 ksp = self.snes.getKSP() 205 self.assertEqual(ksp.getRefCount(), 2) 206 207 def testSolve(self): 208 J = PETSc.Mat().create(PETSc.COMM_SELF) 209 J.setSizes([2, 2]) 210 J.setType(PETSc.Mat.Type.SEQAIJ) 211 J.setUp() 212 r = PETSc.Vec().createSeq(2) 213 x = PETSc.Vec().createSeq(2) 214 b = PETSc.Vec().createSeq(2) 215 self.snes.setFunction(Function(), r) 216 self.snes.setJacobian(Jacobian(), J) 217 218 def _update(snes, it, cnt): 219 cnt += 1 220 cnt_up = np.array(0) 221 self.snes.setUpdate(_update, (cnt_up,) ) 222 223 x.setArray([2, 3]) 224 b.set(0) 225 self.snes.setConvergenceHistory() 226 self.snes.setFromOptions() 227 self.snes.solve(b, x) 228 self.snes.setUpdate(None) 229 rh, ih = self.snes.getConvergenceHistory() 230 self.snes.setConvergenceHistory(0, reset=True) 231 rh, ih = self.snes.getConvergenceHistory() 232 self.assertEqual(len(rh), 0) 233 self.assertEqual(len(ih), 0) 234 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 235 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 236 self.assertEqual(self.snes.getIterationNumber(), cnt_up) 237 # XXX this test should not be here ! 238 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 239 self.assertTrue(reason > 0) 240 241 # test interface 242 x = self.snes.getSolution() 243 x.setArray([2, 3]) 244 self.snes.solve() 245 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 246 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 247 248 def testResetAndSolve(self): 249 self.snes.reset() 250 self.testSolve() 251 self.snes.reset() 252 self.testSolve() 253 self.snes.reset() 254 255 def testSetMonitor(self): 256 reshist = {} 257 258 def monitor(snes, its, fgnorm): 259 reshist[its] = fgnorm 260 261 refcnt = getrefcount(monitor) 262 self.snes.setMonitor(monitor) 263 self.assertEqual(getrefcount(monitor), refcnt + 1) 264 self.testSolve() 265 self.assertTrue(len(reshist) > 0) 266 reshist = {} 267 self.snes.monitorCancel() 268 self.assertEqual(getrefcount(monitor), refcnt) 269 self.testSolve() 270 self.assertTrue(len(reshist) == 0) 271 self.snes.setMonitor(monitor) 272 self.snes.monitor(1, 7) 273 self.assertTrue(reshist[1] == 7) 274 ## Monitor = PETSc.SNES.Monitor 275 ## self.snes.setMonitor(Monitor()) 276 ## self.snes.setMonitor(Monitor.DEFAULT) 277 ## self.snes.setMonitor(Monitor.SOLUTION) 278 ## self.snes.setMonitor(Monitor.RESIDUAL) 279 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 280 281 def testSetGetStepFails(self): 282 its = self.snes.getIterationNumber() 283 self.assertEqual(its, 0) 284 fails = self.snes.getNonlinearStepFailures() 285 self.assertEqual(fails, 0) 286 fails = self.snes.getMaxNonlinearStepFailures() 287 self.assertEqual(fails, 1) 288 self.snes.setMaxNonlinearStepFailures(5) 289 fails = self.snes.getMaxNonlinearStepFailures() 290 self.assertEqual(fails, 5) 291 self.snes.setMaxNonlinearStepFailures(1) 292 fails = self.snes.getMaxNonlinearStepFailures() 293 self.assertEqual(fails, 1) 294 295 def testSetGetLinFails(self): 296 its = self.snes.getLinearSolveIterations() 297 self.assertEqual(its, 0) 298 fails = self.snes.getLinearSolveFailures() 299 self.assertEqual(fails, 0) 300 fails = self.snes.getMaxLinearSolveFailures() 301 self.assertEqual(fails, 1) 302 self.snes.setMaxLinearSolveFailures(5) 303 fails = self.snes.getMaxLinearSolveFailures() 304 self.assertEqual(fails, 5) 305 self.snes.setMaxLinearSolveFailures(1) 306 fails = self.snes.getMaxLinearSolveFailures() 307 self.assertEqual(fails, 1) 308 309 def testEW(self): 310 self.snes.setUseEW(False) 311 self.assertFalse(self.snes.getUseEW()) 312 self.snes.setUseEW(True) 313 self.assertTrue(self.snes.getUseEW()) 314 params = self.snes.getParamsEW() 315 params['version'] = 1 316 self.snes.setParamsEW(**params) 317 params = self.snes.getParamsEW() 318 self.assertEqual(params['version'], 1) 319 params['version'] = PETSc.DEFAULT 320 self.snes.setParamsEW(**params) 321 params = self.snes.getParamsEW() 322 self.assertEqual(params['version'], 1) 323 324 def testMF(self): 325 # self.snes.setOptionsPrefix('MF-') 326 # opts = PETSc.Options(self.snes) 327 # opts['mat_mffd_type'] = 'ds' 328 # opts['snes_monitor'] = 'stdout' 329 # opts['ksp_monitor'] = 'stdout' 330 # opts['snes_view'] = 'stdout' 331 J = PETSc.Mat().create(PETSc.COMM_SELF) 332 J.setSizes([2, 2]) 333 J.setType(PETSc.Mat.Type.SEQAIJ) 334 J.setUp() 335 r = PETSc.Vec().createSeq(2) 336 x = PETSc.Vec().createSeq(2) 337 b = PETSc.Vec().createSeq(2) 338 fun = Function() 339 jac = Jacobian() 340 self.snes.setFunction(fun, r) 341 self.snes.setJacobian(jac, J) 342 self.assertFalse(self.snes.getUseMF()) 343 self.snes.setUseMF(False) 344 self.assertFalse(self.snes.getUseMF()) 345 self.snes.setUseMF(True) 346 self.assertTrue(self.snes.getUseMF()) 347 self.snes.setFromOptions() 348 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 349 x.setArray([2, 3]) 350 b.set(0) 351 self.snes.solve(b, x) 352 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 353 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 354 355 def testFDColor(self): 356 J = PETSc.Mat().create(PETSc.COMM_SELF) 357 J.setSizes([2, 2]) 358 J.setType(PETSc.Mat.Type.SEQAIJ) 359 J.setUp() 360 r = PETSc.Vec().createSeq(2) 361 x = PETSc.Vec().createSeq(2) 362 b = PETSc.Vec().createSeq(2) 363 fun = Function() 364 jac = Jacobian() 365 self.snes.setFunction(fun, r) 366 self.snes.setJacobian(jac, J) 367 self.assertFalse(self.snes.getUseFD()) 368 jac(self.snes, x, J, J) 369 self.snes.setUseFD(False) 370 self.assertFalse(self.snes.getUseFD()) 371 self.snes.setUseFD(True) 372 self.assertTrue(self.snes.getUseFD()) 373 self.snes.setFromOptions() 374 x.setArray([2, 3]) 375 b.set(0) 376 self.snes.solve(b, x) 377 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 378 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 379 380 def testNPC(self): 381 self.snes.appctx = (1, 2, 3) 382 npc = self.snes.getNPC() 383 self.assertEqual(npc.appctx, (1, 2, 3)) 384 385 386# -------------------------------------------------------------------- 387 388 389class TestSNESLS(BaseTestSNES, unittest.TestCase): 390 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 391 392 393class TestSNESTR(BaseTestSNES, unittest.TestCase): 394 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 395 396 397# -------------------------------------------------------------------- 398 399if __name__ == '__main__': 400 unittest.main() 401