1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy 7 8# -------------------------------------------------------------------- 9 10class Function: 11 def __call__(self, snes, x, f): 12 f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item() 13 f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item() 14 f.assemble() 15 16class Jacobian: 17 def __call__(self, snes, x, J, P): 18 P[0,0] = (2.0*x[0] + x[1]).item() 19 P[0,1] = (x[0]).item() 20 P[1,0] = (x[1]).item() 21 P[1,1] = (x[0] + 2.0*x[1]).item() 22 P.assemble() 23 if J != P: J.assemble() 24 25# -------------------------------------------------------------------- 26 27class BaseTestSNES(object): 28 29 SNES_TYPE = None 30 31 def setUp(self): 32 snes = PETSc.SNES() 33 snes.create(PETSc.COMM_SELF) 34 if self.SNES_TYPE: 35 snes.setType(self.SNES_TYPE) 36 self.snes = snes 37 38 def tearDown(self): 39 self.snes = None 40 PETSc.garbage_cleanup() 41 42 def testGetSetType(self): 43 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 44 self.snes.setType(self.SNES_TYPE) 45 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 46 47 def testTols(self): 48 tols = self.snes.getTolerances() 49 self.snes.setTolerances(*tols) 50 tnames = ('rtol', 'atol','stol', 'max_it') 51 tolvals = [getattr(self.snes, t) for t in tnames] 52 self.assertEqual(tuple(tols), tuple(tolvals)) 53 54 def testProperties(self): 55 snes = self.snes 56 # 57 snes.appctx = (1,2,3) 58 self.assertEqual(snes.appctx, (1,2,3)) 59 snes.appctx = None 60 self.assertEqual(snes.appctx, None) 61 # 62 snes.its = 1 63 self.assertEqual(snes.its, 1) 64 snes.its = 0 65 self.assertEqual(snes.its, 0) 66 # 67 snes.norm = 1 68 self.assertEqual(snes.norm, 1) 69 snes.norm = 0 70 self.assertEqual(snes.norm, 0) 71 # 72 rh, ih = snes.history 73 self.assertTrue(len(rh)==0) 74 self.assertTrue(len(ih)==0) 75 # 76 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 77 snes.reason = reason 78 self.assertEqual(snes.reason, reason) 79 self.assertTrue(snes.converged) 80 self.assertFalse(snes.diverged) 81 self.assertFalse(snes.iterating) 82 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 83 snes.reason = reason 84 self.assertEqual(snes.reason, reason) 85 self.assertFalse(snes.converged) 86 self.assertTrue(snes.diverged) 87 self.assertFalse(snes.iterating) 88 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 89 snes.reason = reason 90 self.assertEqual(snes.reason, reason) 91 self.assertFalse(snes.converged) 92 self.assertFalse(snes.diverged) 93 self.assertTrue(snes.iterating) 94 # 95 self.assertFalse(snes.use_ew) 96 self.assertFalse(snes.use_mf) 97 self.assertFalse(snes.use_fd) 98 99 def testGetSetFunc(self): 100 r, func = self.snes.getFunction() 101 self.assertFalse(r) 102 self.assertTrue(func is None) 103 r = PETSc.Vec().createSeq(2) 104 func = Function() 105 refcnt = getrefcount(func) 106 self.snes.setFunction(func, r) 107 self.snes.setFunction(func, r) 108 self.assertEqual(getrefcount(func), refcnt + 1) 109 r2, func2 = self.snes.getFunction() 110 self.assertEqual(r, r2) 111 self.assertEqual(func, func2[0]) 112 self.assertEqual(getrefcount(func), refcnt + 1) 113 r3, func3 = self.snes.getFunction() 114 self.assertEqual(r, r3) 115 self.assertEqual(func, func3[0]) 116 self.assertEqual(getrefcount(func), refcnt + 1) 117 118 def testCompFunc(self): 119 r = PETSc.Vec().createSeq(2) 120 func = Function() 121 self.snes.setFunction(func, r) 122 x, y = r.duplicate(), r.duplicate() 123 x[0], x[1] = [1, 2] 124 self.snes.computeFunction(x, y) 125 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 126 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 127 128 def testGetSetJac(self): 129 A, P, jac = self.snes.getJacobian() 130 self.assertFalse(A) 131 self.assertFalse(P) 132 self.assertTrue(jac is None) 133 J = PETSc.Mat().create(PETSc.COMM_SELF) 134 J.setSizes([2,2]) 135 J.setType(PETSc.Mat.Type.SEQAIJ) 136 J.setUp() 137 jac = Jacobian() 138 refcnt = getrefcount(jac) 139 self.snes.setJacobian(jac, J) 140 self.snes.setJacobian(jac, J) 141 self.assertEqual(getrefcount(jac), refcnt + 1) 142 J2, P2, jac2 = self.snes.getJacobian() 143 self.assertEqual(J, J2) 144 self.assertEqual(J2, P2) 145 self.assertEqual(jac, jac2[0]) 146 self.assertEqual(getrefcount(jac), refcnt + 1) 147 J3, P3, jac3 = self.snes.getJacobian() 148 self.assertEqual(J, J3) 149 self.assertEqual(J3, P3) 150 self.assertEqual(jac, jac3[0]) 151 self.assertEqual(getrefcount(jac), refcnt + 1) 152 153 def testCompJac(self): 154 J = PETSc.Mat().create(PETSc.COMM_SELF) 155 J.setSizes([2,2]) 156 J.setType(PETSc.Mat.Type.SEQAIJ) 157 J.setUp() 158 jac = Jacobian() 159 self.snes.setJacobian(jac, J) 160 x = PETSc.Vec().createSeq(2) 161 x[0], x[1] = [1, 2] 162 self.snes.getKSP().getPC() 163 self.snes.computeJacobian(x, J) 164 165 def testGetSetUpd(self): 166 self.assertTrue(self.snes.getUpdate() is None) 167 upd = lambda snes, it: None 168 refcnt = getrefcount(upd) 169 self.snes.setUpdate(upd) 170 self.assertEqual(getrefcount(upd), refcnt + 1) 171 self.snes.setUpdate(upd) 172 self.assertEqual(getrefcount(upd), refcnt + 1) 173 self.snes.setUpdate(None) 174 self.assertTrue(self.snes.getUpdate() is None) 175 self.assertEqual(getrefcount(upd), refcnt) 176 self.snes.setUpdate(upd) 177 self.assertEqual(getrefcount(upd), refcnt + 1) 178 upd2 = lambda snes, it: None 179 refcnt2 = getrefcount(upd2) 180 self.snes.setUpdate(upd2) 181 self.assertEqual(getrefcount(upd), refcnt) 182 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 183 tmp = self.snes.getUpdate()[0] 184 self.assertTrue(tmp is upd2) 185 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 186 del tmp 187 self.snes.setUpdate(None) 188 self.assertTrue(self.snes.getUpdate() is None) 189 self.assertEqual(getrefcount(upd2), refcnt2) 190 191 def testGetKSP(self): 192 ksp = self.snes.getKSP() 193 self.assertEqual(ksp.getRefCount(), 2) 194 195 def testSolve(self): 196 J = PETSc.Mat().create(PETSc.COMM_SELF) 197 J.setSizes([2,2]) 198 J.setType(PETSc.Mat.Type.SEQAIJ) 199 J.setUp() 200 r = PETSc.Vec().createSeq(2) 201 x = PETSc.Vec().createSeq(2) 202 b = PETSc.Vec().createSeq(2) 203 self.snes.setFunction(Function(), r) 204 self.snes.setJacobian(Jacobian(), J) 205 x.setArray([2,3]) 206 b.set(0) 207 self.snes.setConvergenceHistory() 208 self.snes.setFromOptions() 209 self.snes.solve(b, x) 210 rh, ih = self.snes.getConvergenceHistory() 211 self.snes.setConvergenceHistory(0, reset=True) 212 rh, ih = self.snes.getConvergenceHistory() 213 self.assertEqual(len(rh), 0) 214 self.assertEqual(len(ih), 0) 215 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 216 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 217 # XXX this test should not be here ! 218 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 219 self.assertTrue(reason > 0) 220 221 # test interface 222 x = self.snes.getSolution() 223 x.setArray([2,3]) 224 self.snes.solve() 225 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 226 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 227 228 def testResetAndSolve(self): 229 self.snes.reset() 230 self.testSolve() 231 self.snes.reset() 232 self.testSolve() 233 self.snes.reset() 234 235 def testSetMonitor(self): 236 reshist = {} 237 def monitor(snes, its, fgnorm): 238 reshist[its] = fgnorm 239 refcnt = getrefcount(monitor) 240 self.snes.setMonitor(monitor) 241 self.assertEqual(getrefcount(monitor), refcnt + 1) 242 self.testSolve() 243 self.assertTrue(len(reshist) > 0) 244 reshist = {} 245 self.snes.monitorCancel() 246 self.assertEqual(getrefcount(monitor), refcnt) 247 self.testSolve() 248 self.assertTrue(len(reshist) == 0) 249 self.snes.setMonitor(monitor) 250 self.snes.monitor(1, 7) 251 self.assertTrue(reshist[1] == 7) 252 ## Monitor = PETSc.SNES.Monitor 253 ## self.snes.setMonitor(Monitor()) 254 ## self.snes.setMonitor(Monitor.DEFAULT) 255 ## self.snes.setMonitor(Monitor.SOLUTION) 256 ## self.snes.setMonitor(Monitor.RESIDUAL) 257 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 258 259 def testSetGetStepFails(self): 260 its = self.snes.getIterationNumber() 261 self.assertEqual(its, 0) 262 fails = self.snes.getNonlinearStepFailures() 263 self.assertEqual(fails, 0) 264 fails = self.snes.getMaxNonlinearStepFailures() 265 self.assertEqual(fails, 1) 266 self.snes.setMaxNonlinearStepFailures(5) 267 fails = self.snes.getMaxNonlinearStepFailures() 268 self.assertEqual(fails, 5) 269 self.snes.setMaxNonlinearStepFailures(1) 270 fails = self.snes.getMaxNonlinearStepFailures() 271 self.assertEqual(fails, 1) 272 273 def testSetGetLinFails(self): 274 its = self.snes.getLinearSolveIterations() 275 self.assertEqual(its, 0) 276 fails = self.snes.getLinearSolveFailures() 277 self.assertEqual(fails, 0) 278 fails = self.snes.getMaxLinearSolveFailures() 279 self.assertEqual(fails, 1) 280 self.snes.setMaxLinearSolveFailures(5) 281 fails = self.snes.getMaxLinearSolveFailures() 282 self.assertEqual(fails, 5) 283 self.snes.setMaxLinearSolveFailures(1) 284 fails = self.snes.getMaxLinearSolveFailures() 285 self.assertEqual(fails, 1) 286 287 def testEW(self): 288 self.snes.setUseEW(False) 289 self.assertFalse(self.snes.getUseEW()) 290 self.snes.setUseEW(True) 291 self.assertTrue(self.snes.getUseEW()) 292 params = self.snes.getParamsEW() 293 params['version'] = 1 294 self.snes.setParamsEW(**params) 295 params = self.snes.getParamsEW() 296 self.assertEqual(params['version'], 1) 297 params['version'] = PETSc.DEFAULT 298 self.snes.setParamsEW(**params) 299 params = self.snes.getParamsEW() 300 self.assertEqual(params['version'], 1) 301 302 def testMF(self): 303 #self.snes.setOptionsPrefix('MF-') 304 #opts = PETSc.Options(self.snes) 305 #opts['mat_mffd_type'] = 'ds' 306 #opts['snes_monitor'] = 'stdout' 307 #opts['ksp_monitor'] = 'stdout' 308 #opts['snes_view'] = 'stdout' 309 J = PETSc.Mat().create(PETSc.COMM_SELF) 310 J.setSizes([2,2]) 311 J.setType(PETSc.Mat.Type.SEQAIJ) 312 J.setUp() 313 r = PETSc.Vec().createSeq(2) 314 x = PETSc.Vec().createSeq(2) 315 b = PETSc.Vec().createSeq(2) 316 fun = Function() 317 jac = Jacobian() 318 self.snes.setFunction(fun, r) 319 self.snes.setJacobian(jac, J) 320 self.assertFalse(self.snes.getUseMF()) 321 self.snes.setUseMF(False) 322 self.assertFalse(self.snes.getUseMF()) 323 self.snes.setUseMF(True) 324 self.assertTrue(self.snes.getUseMF()) 325 self.snes.setFromOptions() 326 if self.snes.getType() != PETSc.SNES.Type.NEWTONTR: 327 x.setArray([2,3]) 328 b.set(0) 329 self.snes.solve(b, x) 330 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 331 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 332 333 def testFDColor(self): 334 J = PETSc.Mat().create(PETSc.COMM_SELF) 335 J.setSizes([2,2]) 336 J.setType(PETSc.Mat.Type.SEQAIJ) 337 J.setUp() 338 r = PETSc.Vec().createSeq(2) 339 x = PETSc.Vec().createSeq(2) 340 b = PETSc.Vec().createSeq(2) 341 fun = Function() 342 jac = Jacobian() 343 self.snes.setFunction(fun, r) 344 self.snes.setJacobian(jac, J) 345 self.assertFalse(self.snes.getUseFD()) 346 jac(self.snes, x, J, J) 347 self.snes.setUseFD(False) 348 self.assertFalse(self.snes.getUseFD()) 349 self.snes.setUseFD(True) 350 self.assertTrue(self.snes.getUseFD()) 351 self.snes.setFromOptions() 352 x.setArray([2,3]) 353 b.set(0) 354 self.snes.solve(b, x) 355 self.assertAlmostEqual(abs(x[0]), 1.0, places=4) 356 self.assertAlmostEqual(abs(x[1]), 2.0, places=4) 357 358 def testNPC(self): 359 self.snes.appctx = (1,2,3) 360 npc = self.snes.getNPC() 361 self.assertEqual(npc.appctx, (1,2,3)) 362 363# -------------------------------------------------------------------- 364 365class TestSNESLS(BaseTestSNES, unittest.TestCase): 366 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 367 368class TestSNESTR(BaseTestSNES, unittest.TestCase): 369 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 370 371# -------------------------------------------------------------------- 372 373if __name__ == '__main__': 374 unittest.main() 375