1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy as np 7 8# -------------------------------------------------------------------- 9 10 11class Function: 12 def __call__(self, snes, x, f): 13 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 14 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 15 f.assemble() 16 17 18class Jacobian: 19 def __call__(self, snes, x, J, P): 20 P[0, 0] = (2.0 * x[0] + x[1]).item() 21 P[0, 1] = (x[0]).item() 22 P[1, 0] = (x[1]).item() 23 P[1, 1] = (x[0] + 2.0 * x[1]).item() 24 P.assemble() 25 if J != P: 26 J.assemble() 27 28 29 30# -------------------------------------------------------------------- 31 32 33class BaseTestSNES: 34 SNES_TYPE = None 35 36 def setUp(self): 37 snes = PETSc.SNES() 38 snes.create(PETSc.COMM_SELF) 39 if self.SNES_TYPE: 40 snes.setType(self.SNES_TYPE) 41 self.snes = snes 42 43 def tearDown(self): 44 self.snes = None 45 PETSc.garbage_cleanup() 46 47 def testGetSetType(self): 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 self.snes.setType(self.SNES_TYPE) 50 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 51 52 def testTols(self): 53 tols = self.snes.getTolerances() 54 self.snes.setTolerances(*tols) 55 tnames = ('rtol', 'atol', 'stol', 'max_it') 56 tolvals = [getattr(self.snes, t) for t in tnames] 57 self.assertEqual(tuple(tols), tuple(tolvals)) 58 59 def testProperties(self): 60 snes = self.snes 61 # 62 snes.appctx = (1, 2, 3) 63 self.assertEqual(snes.appctx, (1, 2, 3)) 64 snes.appctx = None 65 self.assertEqual(snes.appctx, None) 66 # 67 snes.its = 1 68 self.assertEqual(snes.its, 1) 69 snes.its = 0 70 self.assertEqual(snes.its, 0) 71 # 72 snes.norm = 1 73 self.assertEqual(snes.norm, 1) 74 snes.norm = 0 75 self.assertEqual(snes.norm, 0) 76 # 77 rh, ih = snes.history 78 self.assertTrue(len(rh) == 0) 79 self.assertTrue(len(ih) == 0) 80 # 81 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 82 snes.reason = reason 83 self.assertEqual(snes.reason, reason) 84 self.assertTrue(snes.is_converged) 85 self.assertFalse(snes.is_diverged) 86 self.assertFalse(snes.is_iterating) 87 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 88 snes.reason = reason 89 self.assertEqual(snes.reason, reason) 90 self.assertFalse(snes.is_converged) 91 self.assertTrue(snes.is_diverged) 92 self.assertFalse(snes.is_iterating) 93 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 94 snes.reason = reason 95 self.assertEqual(snes.reason, reason) 96 self.assertFalse(snes.is_converged) 97 self.assertFalse(snes.is_diverged) 98 self.assertTrue(snes.is_iterating) 99 # 100 self.assertFalse(snes.use_ew) 101 self.assertFalse(snes.use_mf) 102 self.assertFalse(snes.use_fd) 103 104 def testGetSetFunc(self): 105 r, func = self.snes.getFunction() 106 self.assertFalse(r) 107 self.assertTrue(func is None) 108 r = PETSc.Vec().createSeq(2) 109 func = Function() 110 refcnt = getrefcount(func) 111 self.snes.setFunction(func, r) 112 self.snes.setFunction(func, r) 113 self.assertEqual(getrefcount(func), refcnt + 1) 114 r2, func2 = self.snes.getFunction() 115 self.assertEqual(r, r2) 116 self.assertEqual(func, func2[0]) 117 self.assertEqual(getrefcount(func), refcnt + 1) 118 r3, func3 = self.snes.getFunction() 119 self.assertEqual(r, r3) 120 self.assertEqual(func, func3[0]) 121 self.assertEqual(getrefcount(func), refcnt + 1) 122 123 def testCompFunc(self): 124 r = PETSc.Vec().createSeq(2) 125 func = Function() 126 self.snes.setFunction(func, r) 127 x, y = r.duplicate(), r.duplicate() 128 x[0], x[1] = [1, 2] 129 self.snes.computeFunction(x, y) 130 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 131 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 132 133 def testGetSetJac(self): 134 A, P, jac = self.snes.getJacobian() 135 self.assertFalse(A) 136 self.assertFalse(P) 137 self.assertTrue(jac is None) 138 J = PETSc.Mat().create(PETSc.COMM_SELF) 139 J.setSizes([2, 2]) 140 J.setType(PETSc.Mat.Type.SEQAIJ) 141 J.setUp() 142 jac = Jacobian() 143 refcnt = getrefcount(jac) 144 self.snes.setJacobian(jac, J) 145 self.snes.setJacobian(jac, J) 146 self.assertEqual(getrefcount(jac), refcnt + 1) 147 J2, P2, jac2 = self.snes.getJacobian() 148 self.assertEqual(J, J2) 149 self.assertEqual(J2, P2) 150 self.assertEqual(jac, jac2[0]) 151 self.assertEqual(getrefcount(jac), refcnt + 1) 152 J3, P3, jac3 = self.snes.getJacobian() 153 self.assertEqual(J, J3) 154 self.assertEqual(J3, P3) 155 self.assertEqual(jac, jac3[0]) 156 self.assertEqual(getrefcount(jac), refcnt + 1) 157 158 def testCompJac(self): 159 J = PETSc.Mat().create(PETSc.COMM_SELF) 160 J.setSizes([2, 2]) 161 J.setType(PETSc.Mat.Type.SEQAIJ) 162 J.setUp() 163 jac = Jacobian() 164 self.snes.setJacobian(jac, J) 165 x = PETSc.Vec().createSeq(2) 166 x[0], x[1] = [1, 2] 167 self.snes.getKSP().getPC() 168 self.snes.computeJacobian(x, J) 169 170 def testGetSetUpd(self): 171 self.assertTrue(self.snes.getUpdate() is None) 172 upd = lambda snes, it: None 173 refcnt = getrefcount(upd) 174 self.snes.setUpdate(upd) 175 self.assertEqual(getrefcount(upd), refcnt + 1) 176 self.snes.setUpdate(upd) 177 self.assertEqual(getrefcount(upd), refcnt + 1) 178 self.snes.setUpdate(None) 179 self.assertTrue(self.snes.getUpdate() is None) 180 self.assertEqual(getrefcount(upd), refcnt) 181 self.snes.setUpdate(upd) 182 self.assertEqual(getrefcount(upd), refcnt + 1) 183 upd2 = lambda snes, it: None 184 refcnt2 = getrefcount(upd2) 185 self.snes.setUpdate(upd2) 186 self.assertEqual(getrefcount(upd), refcnt) 187 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 188 tmp = self.snes.getUpdate()[0] 189 self.assertTrue(tmp is upd2) 190 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 191 del tmp 192 self.snes.setUpdate(None) 193 self.assertTrue(self.snes.getUpdate() is None) 194 self.assertEqual(getrefcount(upd2), refcnt2) 195 196 def testGetKSP(self): 197 ksp = self.snes.getKSP() 198 self.assertEqual(ksp.getRefCount(), 2) 199 200 def testSolve(self): 201 J = PETSc.Mat().create(PETSc.COMM_SELF) 202 J.setSizes([2, 2]) 203 J.setType(PETSc.Mat.Type.SEQAIJ) 204 J.setUp() 205 r = PETSc.Vec().createSeq(2) 206 x = PETSc.Vec().createSeq(2) 207 b = PETSc.Vec().createSeq(2) 208 self.snes.setFunction(Function(), r) 209 self.snes.setJacobian(Jacobian(), J) 210 211 def _update(snes, it, cnt): 212 cnt += 1 213 cnt_up = np.array(0) 214 self.snes.setUpdate(_update, (cnt_up,) ) 215 216 x.setArray([2, 3]) 217 b.set(0) 218 self.snes.setConvergenceHistory() 219 self.snes.setFromOptions() 220 self.snes.solve(b, x) 221 self.snes.setUpdate(None) 222 rh, ih = self.snes.getConvergenceHistory() 223 self.snes.setConvergenceHistory(0, reset=True) 224 rh, ih = self.snes.getConvergenceHistory() 225 self.assertEqual(len(rh), 0) 226 self.assertEqual(len(ih), 0) 227 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 228 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 229 self.assertEqual(self.snes.getIterationNumber(), cnt_up) 230 # XXX this test should not be here ! 231 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 232 self.assertTrue(reason > 0) 233 234 # test interface 235 x = self.snes.getSolution() 236 x.setArray([2, 3]) 237 self.snes.solve() 238 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 239 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 240 241 def testResetAndSolve(self): 242 self.snes.reset() 243 self.testSolve() 244 self.snes.reset() 245 self.testSolve() 246 self.snes.reset() 247 248 def testSetMonitor(self): 249 reshist = {} 250 251 def monitor(snes, its, fgnorm): 252 reshist[its] = fgnorm 253 254 refcnt = getrefcount(monitor) 255 self.snes.setMonitor(monitor) 256 self.assertEqual(getrefcount(monitor), refcnt + 1) 257 self.testSolve() 258 self.assertTrue(len(reshist) > 0) 259 reshist = {} 260 self.snes.monitorCancel() 261 self.assertEqual(getrefcount(monitor), refcnt) 262 self.testSolve() 263 self.assertTrue(len(reshist) == 0) 264 self.snes.setMonitor(monitor) 265 self.snes.monitor(1, 7) 266 self.assertTrue(reshist[1] == 7) 267 ## Monitor = PETSc.SNES.Monitor 268 ## self.snes.setMonitor(Monitor()) 269 ## self.snes.setMonitor(Monitor.DEFAULT) 270 ## self.snes.setMonitor(Monitor.SOLUTION) 271 ## self.snes.setMonitor(Monitor.RESIDUAL) 272 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 273 274 def testSetGetStepFails(self): 275 its = self.snes.getIterationNumber() 276 self.assertEqual(its, 0) 277 fails = self.snes.getNonlinearStepFailures() 278 self.assertEqual(fails, 0) 279 fails = self.snes.getMaxNonlinearStepFailures() 280 self.assertEqual(fails, 1) 281 self.snes.setMaxNonlinearStepFailures(5) 282 fails = self.snes.getMaxNonlinearStepFailures() 283 self.assertEqual(fails, 5) 284 self.snes.setMaxNonlinearStepFailures(1) 285 fails = self.snes.getMaxNonlinearStepFailures() 286 self.assertEqual(fails, 1) 287 288 def testSetGetLinFails(self): 289 its = self.snes.getLinearSolveIterations() 290 self.assertEqual(its, 0) 291 fails = self.snes.getLinearSolveFailures() 292 self.assertEqual(fails, 0) 293 fails = self.snes.getMaxLinearSolveFailures() 294 self.assertEqual(fails, 1) 295 self.snes.setMaxLinearSolveFailures(5) 296 fails = self.snes.getMaxLinearSolveFailures() 297 self.assertEqual(fails, 5) 298 self.snes.setMaxLinearSolveFailures(1) 299 fails = self.snes.getMaxLinearSolveFailures() 300 self.assertEqual(fails, 1) 301 302 def testEW(self): 303 self.snes.setUseEW(False) 304 self.assertFalse(self.snes.getUseEW()) 305 self.snes.setUseEW(True) 306 self.assertTrue(self.snes.getUseEW()) 307 params = self.snes.getParamsEW() 308 params['version'] = 1 309 self.snes.setParamsEW(**params) 310 params = self.snes.getParamsEW() 311 self.assertEqual(params['version'], 1) 312 params['version'] = PETSc.DEFAULT 313 self.snes.setParamsEW(**params) 314 params = self.snes.getParamsEW() 315 self.assertEqual(params['version'], 1) 316 317 def testMF(self): 318 # self.snes.setOptionsPrefix('MF-') 319 # opts = PETSc.Options(self.snes) 320 # opts['mat_mffd_type'] = 'ds' 321 # opts['snes_monitor'] = 'stdout' 322 # opts['ksp_monitor'] = 'stdout' 323 # opts['snes_view'] = 'stdout' 324 J = PETSc.Mat().create(PETSc.COMM_SELF) 325 J.setSizes([2, 2]) 326 J.setType(PETSc.Mat.Type.SEQAIJ) 327 J.setUp() 328 r = PETSc.Vec().createSeq(2) 329 x = PETSc.Vec().createSeq(2) 330 b = PETSc.Vec().createSeq(2) 331 fun = Function() 332 jac = Jacobian() 333 self.snes.setFunction(fun, r) 334 self.snes.setJacobian(jac, J) 335 self.assertFalse(self.snes.getUseMF()) 336 self.snes.setUseMF(False) 337 self.assertFalse(self.snes.getUseMF()) 338 self.snes.setUseMF(True) 339 self.assertTrue(self.snes.getUseMF()) 340 self.snes.setFromOptions() 341 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 342 x.setArray([2, 3]) 343 b.set(0) 344 self.snes.solve(b, x) 345 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 346 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 347 348 def testFDColor(self): 349 J = PETSc.Mat().create(PETSc.COMM_SELF) 350 J.setSizes([2, 2]) 351 J.setType(PETSc.Mat.Type.SEQAIJ) 352 J.setUp() 353 r = PETSc.Vec().createSeq(2) 354 x = PETSc.Vec().createSeq(2) 355 b = PETSc.Vec().createSeq(2) 356 fun = Function() 357 jac = Jacobian() 358 self.snes.setFunction(fun, r) 359 self.snes.setJacobian(jac, J) 360 self.assertFalse(self.snes.getUseFD()) 361 jac(self.snes, x, J, J) 362 self.snes.setUseFD(False) 363 self.assertFalse(self.snes.getUseFD()) 364 self.snes.setUseFD(True) 365 self.assertTrue(self.snes.getUseFD()) 366 self.snes.setFromOptions() 367 x.setArray([2, 3]) 368 b.set(0) 369 self.snes.solve(b, x) 370 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 371 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 372 373 def testNPC(self): 374 self.snes.appctx = (1, 2, 3) 375 npc = self.snes.getNPC() 376 self.assertEqual(npc.appctx, (1, 2, 3)) 377 378 379# -------------------------------------------------------------------- 380 381 382class TestSNESLS(BaseTestSNES, unittest.TestCase): 383 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 384 385 386class TestSNESTR(BaseTestSNES, unittest.TestCase): 387 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 388 389 390# -------------------------------------------------------------------- 391 392if __name__ == '__main__': 393 unittest.main() 394