1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9 10class Function: 11 def __call__(self, snes, x, f): 12 f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item() 13 f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item() 14 f.assemble() 15 16 17class Jacobian: 18 def __call__(self, snes, x, J, P): 19 P[0, 0] = (2.0 * x[0] + x[1]).item() 20 P[0, 1] = (x[0]).item() 21 P[1, 0] = (x[1]).item() 22 P[1, 1] = (x[0] + 2.0 * x[1]).item() 23 P.assemble() 24 if J != P: 25 J.assemble() 26 27 28# -------------------------------------------------------------------- 29 30 31class BaseTestSNES: 32 SNES_TYPE = None 33 34 def setUp(self): 35 snes = PETSc.SNES() 36 snes.create(PETSc.COMM_SELF) 37 if self.SNES_TYPE: 38 snes.setType(self.SNES_TYPE) 39 self.snes = snes 40 41 def tearDown(self): 42 self.snes = None 43 PETSc.garbage_cleanup() 44 45 def testGetSetType(self): 46 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 47 self.snes.setType(self.SNES_TYPE) 48 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 49 50 def testTols(self): 51 tols = self.snes.getTolerances() 52 self.snes.setTolerances(*tols) 53 tnames = ('rtol', 'atol', 'stol', 'max_it') 54 tolvals = [getattr(self.snes, t) for t in tnames] 55 self.assertEqual(tuple(tols), tuple(tolvals)) 56 57 def testProperties(self): 58 snes = self.snes 59 # 60 snes.appctx = (1, 2, 3) 61 self.assertEqual(snes.appctx, (1, 2, 3)) 62 snes.appctx = None 63 self.assertEqual(snes.appctx, None) 64 # 65 snes.its = 1 66 self.assertEqual(snes.its, 1) 67 snes.its = 0 68 self.assertEqual(snes.its, 0) 69 # 70 snes.norm = 1 71 self.assertEqual(snes.norm, 1) 72 snes.norm = 0 73 self.assertEqual(snes.norm, 0) 74 # 75 rh, ih = snes.history 76 self.assertTrue(len(rh) == 0) 77 self.assertTrue(len(ih) == 0) 78 # 79 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 80 snes.reason = reason 81 self.assertEqual(snes.reason, reason) 82 self.assertTrue(snes.is_converged) 83 self.assertFalse(snes.is_diverged) 84 self.assertFalse(snes.is_iterating) 85 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 86 snes.reason = reason 87 self.assertEqual(snes.reason, reason) 88 self.assertFalse(snes.is_converged) 89 self.assertTrue(snes.is_diverged) 90 self.assertFalse(snes.is_iterating) 91 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 92 snes.reason = reason 93 self.assertEqual(snes.reason, reason) 94 self.assertFalse(snes.is_converged) 95 self.assertFalse(snes.is_diverged) 96 self.assertTrue(snes.is_iterating) 97 # 98 self.assertFalse(snes.use_ew) 99 self.assertFalse(snes.use_mf) 100 self.assertFalse(snes.use_fd) 101 102 def testGetSetFunc(self): 103 r, func = self.snes.getFunction() 104 self.assertFalse(r) 105 self.assertTrue(func is None) 106 r = PETSc.Vec().createSeq(2) 107 func = Function() 108 refcnt = getrefcount(func) 109 self.snes.setFunction(func, r) 110 self.snes.setFunction(func, r) 111 self.assertEqual(getrefcount(func), refcnt + 1) 112 r2, func2 = self.snes.getFunction() 113 self.assertEqual(r, r2) 114 self.assertEqual(func, func2[0]) 115 self.assertEqual(getrefcount(func), refcnt + 1) 116 r3, func3 = self.snes.getFunction() 117 self.assertEqual(r, r3) 118 self.assertEqual(func, func3[0]) 119 self.assertEqual(getrefcount(func), refcnt + 1) 120 121 def testCompFunc(self): 122 r = PETSc.Vec().createSeq(2) 123 func = Function() 124 self.snes.setFunction(func, r) 125 x, y = r.duplicate(), r.duplicate() 126 x[0], x[1] = [1, 2] 127 self.snes.computeFunction(x, y) 128 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 129 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 130 131 def testGetSetJac(self): 132 A, P, jac = self.snes.getJacobian() 133 self.assertFalse(A) 134 self.assertFalse(P) 135 self.assertTrue(jac is None) 136 J = PETSc.Mat().create(PETSc.COMM_SELF) 137 J.setSizes([2, 2]) 138 J.setType(PETSc.Mat.Type.SEQAIJ) 139 J.setUp() 140 jac = Jacobian() 141 refcnt = getrefcount(jac) 142 self.snes.setJacobian(jac, J) 143 self.snes.setJacobian(jac, J) 144 self.assertEqual(getrefcount(jac), refcnt + 1) 145 J2, P2, jac2 = self.snes.getJacobian() 146 self.assertEqual(J, J2) 147 self.assertEqual(J2, P2) 148 self.assertEqual(jac, jac2[0]) 149 self.assertEqual(getrefcount(jac), refcnt + 1) 150 J3, P3, jac3 = self.snes.getJacobian() 151 self.assertEqual(J, J3) 152 self.assertEqual(J3, P3) 153 self.assertEqual(jac, jac3[0]) 154 self.assertEqual(getrefcount(jac), refcnt + 1) 155 156 def testCompJac(self): 157 J = PETSc.Mat().create(PETSc.COMM_SELF) 158 J.setSizes([2, 2]) 159 J.setType(PETSc.Mat.Type.SEQAIJ) 160 J.setUp() 161 jac = Jacobian() 162 self.snes.setJacobian(jac, J) 163 x = PETSc.Vec().createSeq(2) 164 x[0], x[1] = [1, 2] 165 self.snes.getKSP().getPC() 166 self.snes.computeJacobian(x, J) 167 168 def testGetSetUpd(self): 169 self.assertTrue(self.snes.getUpdate() is None) 170 upd = lambda snes, it: None 171 refcnt = getrefcount(upd) 172 self.snes.setUpdate(upd) 173 self.assertEqual(getrefcount(upd), refcnt + 1) 174 self.snes.setUpdate(upd) 175 self.assertEqual(getrefcount(upd), refcnt + 1) 176 self.snes.setUpdate(None) 177 self.assertTrue(self.snes.getUpdate() is None) 178 self.assertEqual(getrefcount(upd), refcnt) 179 self.snes.setUpdate(upd) 180 self.assertEqual(getrefcount(upd), refcnt + 1) 181 upd2 = lambda snes, it: None 182 refcnt2 = getrefcount(upd2) 183 self.snes.setUpdate(upd2) 184 self.assertEqual(getrefcount(upd), refcnt) 185 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 186 tmp = self.snes.getUpdate()[0] 187 self.assertTrue(tmp is upd2) 188 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 189 del tmp 190 self.snes.setUpdate(None) 191 self.assertTrue(self.snes.getUpdate() is None) 192 self.assertEqual(getrefcount(upd2), refcnt2) 193 194 def testGetKSP(self): 195 ksp = self.snes.getKSP() 196 self.assertEqual(ksp.getRefCount(), 2) 197 198 def testSolve(self): 199 J = PETSc.Mat().create(PETSc.COMM_SELF) 200 J.setSizes([2, 2]) 201 J.setType(PETSc.Mat.Type.SEQAIJ) 202 J.setUp() 203 r = PETSc.Vec().createSeq(2) 204 x = PETSc.Vec().createSeq(2) 205 b = PETSc.Vec().createSeq(2) 206 self.snes.setFunction(Function(), r) 207 self.snes.setJacobian(Jacobian(), J) 208 x.setArray([2, 3]) 209 b.set(0) 210 self.snes.setConvergenceHistory() 211 self.snes.setFromOptions() 212 self.snes.solve(b, x) 213 rh, ih = self.snes.getConvergenceHistory() 214 self.snes.setConvergenceHistory(0, reset=True) 215 rh, ih = self.snes.getConvergenceHistory() 216 self.assertEqual(len(rh), 0) 217 self.assertEqual(len(ih), 0) 218 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 219 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 220 # XXX this test should not be here ! 221 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 222 self.assertTrue(reason > 0) 223 224 # test interface 225 x = self.snes.getSolution() 226 x.setArray([2, 3]) 227 self.snes.solve() 228 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 229 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 230 231 def testResetAndSolve(self): 232 self.snes.reset() 233 self.testSolve() 234 self.snes.reset() 235 self.testSolve() 236 self.snes.reset() 237 238 def testSetMonitor(self): 239 reshist = {} 240 241 def monitor(snes, its, fgnorm): 242 reshist[its] = fgnorm 243 244 refcnt = getrefcount(monitor) 245 self.snes.setMonitor(monitor) 246 self.assertEqual(getrefcount(monitor), refcnt + 1) 247 self.testSolve() 248 self.assertTrue(len(reshist) > 0) 249 reshist = {} 250 self.snes.monitorCancel() 251 self.assertEqual(getrefcount(monitor), refcnt) 252 self.testSolve() 253 self.assertTrue(len(reshist) == 0) 254 self.snes.setMonitor(monitor) 255 self.snes.monitor(1, 7) 256 self.assertTrue(reshist[1] == 7) 257 ## Monitor = PETSc.SNES.Monitor 258 ## self.snes.setMonitor(Monitor()) 259 ## self.snes.setMonitor(Monitor.DEFAULT) 260 ## self.snes.setMonitor(Monitor.SOLUTION) 261 ## self.snes.setMonitor(Monitor.RESIDUAL) 262 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 263 264 def testSetGetStepFails(self): 265 its = self.snes.getIterationNumber() 266 self.assertEqual(its, 0) 267 fails = self.snes.getNonlinearStepFailures() 268 self.assertEqual(fails, 0) 269 fails = self.snes.getMaxNonlinearStepFailures() 270 self.assertEqual(fails, 1) 271 self.snes.setMaxNonlinearStepFailures(5) 272 fails = self.snes.getMaxNonlinearStepFailures() 273 self.assertEqual(fails, 5) 274 self.snes.setMaxNonlinearStepFailures(1) 275 fails = self.snes.getMaxNonlinearStepFailures() 276 self.assertEqual(fails, 1) 277 278 def testSetGetLinFails(self): 279 its = self.snes.getLinearSolveIterations() 280 self.assertEqual(its, 0) 281 fails = self.snes.getLinearSolveFailures() 282 self.assertEqual(fails, 0) 283 fails = self.snes.getMaxLinearSolveFailures() 284 self.assertEqual(fails, 1) 285 self.snes.setMaxLinearSolveFailures(5) 286 fails = self.snes.getMaxLinearSolveFailures() 287 self.assertEqual(fails, 5) 288 self.snes.setMaxLinearSolveFailures(1) 289 fails = self.snes.getMaxLinearSolveFailures() 290 self.assertEqual(fails, 1) 291 292 def testEW(self): 293 self.snes.setUseEW(False) 294 self.assertFalse(self.snes.getUseEW()) 295 self.snes.setUseEW(True) 296 self.assertTrue(self.snes.getUseEW()) 297 params = self.snes.getParamsEW() 298 params['version'] = 1 299 self.snes.setParamsEW(**params) 300 params = self.snes.getParamsEW() 301 self.assertEqual(params['version'], 1) 302 params['version'] = PETSc.DEFAULT 303 self.snes.setParamsEW(**params) 304 params = self.snes.getParamsEW() 305 self.assertEqual(params['version'], 1) 306 307 def testMF(self): 308 # self.snes.setOptionsPrefix('MF-') 309 # opts = PETSc.Options(self.snes) 310 # opts['mat_mffd_type'] = 'ds' 311 # opts['snes_monitor'] = 'stdout' 312 # opts['ksp_monitor'] = 'stdout' 313 # opts['snes_view'] = 'stdout' 314 J = PETSc.Mat().create(PETSc.COMM_SELF) 315 J.setSizes([2, 2]) 316 J.setType(PETSc.Mat.Type.SEQAIJ) 317 J.setUp() 318 r = PETSc.Vec().createSeq(2) 319 x = PETSc.Vec().createSeq(2) 320 b = PETSc.Vec().createSeq(2) 321 fun = Function() 322 jac = Jacobian() 323 self.snes.setFunction(fun, r) 324 self.snes.setJacobian(jac, J) 325 self.assertFalse(self.snes.getUseMF()) 326 self.snes.setUseMF(False) 327 self.assertFalse(self.snes.getUseMF()) 328 self.snes.setUseMF(True) 329 self.assertTrue(self.snes.getUseMF()) 330 self.snes.setFromOptions() 331 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 332 x.setArray([2, 3]) 333 b.set(0) 334 self.snes.solve(b, x) 335 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 336 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 337 338 def testFDColor(self): 339 J = PETSc.Mat().create(PETSc.COMM_SELF) 340 J.setSizes([2, 2]) 341 J.setType(PETSc.Mat.Type.SEQAIJ) 342 J.setUp() 343 r = PETSc.Vec().createSeq(2) 344 x = PETSc.Vec().createSeq(2) 345 b = PETSc.Vec().createSeq(2) 346 fun = Function() 347 jac = Jacobian() 348 self.snes.setFunction(fun, r) 349 self.snes.setJacobian(jac, J) 350 self.assertFalse(self.snes.getUseFD()) 351 jac(self.snes, x, J, J) 352 self.snes.setUseFD(False) 353 self.assertFalse(self.snes.getUseFD()) 354 self.snes.setUseFD(True) 355 self.assertTrue(self.snes.getUseFD()) 356 self.snes.setFromOptions() 357 x.setArray([2, 3]) 358 b.set(0) 359 self.snes.solve(b, x) 360 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 361 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 362 363 def testNPC(self): 364 self.snes.appctx = (1, 2, 3) 365 npc = self.snes.getNPC() 366 self.assertEqual(npc.appctx, (1, 2, 3)) 367 368 369# -------------------------------------------------------------------- 370 371 372class TestSNESLS(BaseTestSNES, unittest.TestCase): 373 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 374 375 376class TestSNESTR(BaseTestSNES, unittest.TestCase): 377 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 378 379 380# -------------------------------------------------------------------- 381 382if __name__ == '__main__': 383 unittest.main() 384