1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6import numpy 7 8# -------------------------------------------------------------------- 9 10class Function: 11 def __call__(self, snes, x, f): 12 f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item() 13 f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item() 14 f.assemble() 15 16class Jacobian: 17 def __call__(self, snes, x, J, P): 18 P[0,0] = (2.0*x[0] + x[1]).item() 19 P[0,1] = (x[0]).item() 20 P[1,0] = (x[1]).item() 21 P[1,1] = (x[0] + 2.0*x[1]).item() 22 P.assemble() 23 if J != P: J.assemble() 24 25# -------------------------------------------------------------------- 26 27class BaseTestSNES(object): 28 29 SNES_TYPE = None 30 31 def setUp(self): 32 snes = PETSc.SNES() 33 snes.create(PETSc.COMM_SELF) 34 if self.SNES_TYPE: 35 snes.setType(self.SNES_TYPE) 36 self.snes = snes 37 38 def tearDown(self): 39 self.snes = None 40 41 def testGetSetType(self): 42 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 43 self.snes.setType(self.SNES_TYPE) 44 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 45 46 def testTols(self): 47 tols = self.snes.getTolerances() 48 self.snes.setTolerances(*tols) 49 tnames = ('rtol', 'atol','stol', 'max_it') 50 tolvals = [getattr(self.snes, t) for t in tnames] 51 self.assertEqual(tuple(tols), tuple(tolvals)) 52 53 def testProperties(self): 54 snes = self.snes 55 # 56 snes.appctx = (1,2,3) 57 self.assertEqual(snes.appctx, (1,2,3)) 58 snes.appctx = None 59 self.assertEqual(snes.appctx, None) 60 # 61 snes.its = 1 62 self.assertEqual(snes.its, 1) 63 snes.its = 0 64 self.assertEqual(snes.its, 0) 65 # 66 snes.norm = 1 67 self.assertEqual(snes.norm, 1) 68 snes.norm = 0 69 self.assertEqual(snes.norm, 0) 70 # 71 rh, ih = snes.history 72 self.assertTrue(len(rh)==0) 73 self.assertTrue(len(ih)==0) 74 # 75 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 76 snes.reason = reason 77 self.assertEqual(snes.reason, reason) 78 self.assertTrue(snes.converged) 79 self.assertFalse(snes.diverged) 80 self.assertFalse(snes.iterating) 81 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 82 snes.reason = reason 83 self.assertEqual(snes.reason, reason) 84 self.assertFalse(snes.converged) 85 self.assertTrue(snes.diverged) 86 self.assertFalse(snes.iterating) 87 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 88 snes.reason = reason 89 self.assertEqual(snes.reason, reason) 90 self.assertFalse(snes.converged) 91 self.assertFalse(snes.diverged) 92 self.assertTrue(snes.iterating) 93 # 94 self.assertFalse(snes.use_ew) 95 self.assertFalse(snes.use_mf) 96 self.assertFalse(snes.use_fd) 97 98 def testGetSetFunc(self): 99 r, func = self.snes.getFunction() 100 self.assertFalse(r) 101 self.assertTrue(func is None) 102 r = PETSc.Vec().createSeq(2) 103 func = Function() 104 refcnt = getrefcount(func) 105 self.snes.setFunction(func, r) 106 self.snes.setFunction(func, r) 107 self.assertEqual(getrefcount(func), refcnt + 1) 108 r2, func2 = self.snes.getFunction() 109 self.assertEqual(r, r2) 110 self.assertEqual(func, func2[0]) 111 self.assertEqual(getrefcount(func), refcnt + 1) 112 r3, func3 = self.snes.getFunction() 113 self.assertEqual(r, r3) 114 self.assertEqual(func, func3[0]) 115 self.assertEqual(getrefcount(func), refcnt + 1) 116 117 def testCompFunc(self): 118 r = PETSc.Vec().createSeq(2) 119 func = Function() 120 self.snes.setFunction(func, r) 121 x, y = r.duplicate(), r.duplicate() 122 x[0], x[1] = [1, 2] 123 self.snes.computeFunction(x, y) 124 self.assertAlmostEqual(abs(y[0]), 0.0, places=5) 125 self.assertAlmostEqual(abs(y[1]), 0.0, places=5) 126 127 def testGetSetJac(self): 128 A, P, jac = self.snes.getJacobian() 129 self.assertFalse(A) 130 self.assertFalse(P) 131 self.assertTrue(jac is None) 132 J = PETSc.Mat().create(PETSc.COMM_SELF) 133 J.setSizes([2,2]) 134 J.setType(PETSc.Mat.Type.SEQAIJ) 135 J.setUp() 136 jac = Jacobian() 137 refcnt = getrefcount(jac) 138 self.snes.setJacobian(jac, J) 139 self.snes.setJacobian(jac, J) 140 self.assertEqual(getrefcount(jac), refcnt + 1) 141 J2, P2, jac2 = self.snes.getJacobian() 142 self.assertEqual(J, J2) 143 self.assertEqual(J2, P2) 144 self.assertEqual(jac, jac2[0]) 145 self.assertEqual(getrefcount(jac), refcnt + 1) 146 J3, P3, jac3 = self.snes.getJacobian() 147 self.assertEqual(J, J3) 148 self.assertEqual(J3, P3) 149 self.assertEqual(jac, jac3[0]) 150 self.assertEqual(getrefcount(jac), refcnt + 1) 151 152 def testCompJac(self): 153 J = PETSc.Mat().create(PETSc.COMM_SELF) 154 J.setSizes([2,2]) 155 J.setType(PETSc.Mat.Type.SEQAIJ) 156 J.setUp() 157 jac = Jacobian() 158 self.snes.setJacobian(jac, J) 159 x = PETSc.Vec().createSeq(2) 160 x[0], x[1] = [1, 2] 161 self.snes.getKSP().getPC() 162 self.snes.computeJacobian(x, J) 163 164 def testGetSetUpd(self): 165 self.assertTrue(self.snes.getUpdate() is None) 166 upd = lambda snes, it: None 167 refcnt = getrefcount(upd) 168 self.snes.setUpdate(upd) 169 self.assertEqual(getrefcount(upd), refcnt + 1) 170 self.snes.setUpdate(upd) 171 self.assertEqual(getrefcount(upd), refcnt + 1) 172 self.snes.setUpdate(None) 173 self.assertTrue(self.snes.getUpdate() is None) 174 self.assertEqual(getrefcount(upd), refcnt) 175 self.snes.setUpdate(upd) 176 self.assertEqual(getrefcount(upd), refcnt + 1) 177 upd2 = lambda snes, it: None 178 refcnt2 = getrefcount(upd2) 179 self.snes.setUpdate(upd2) 180 self.assertEqual(getrefcount(upd), refcnt) 181 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 182 tmp = self.snes.getUpdate()[0] 183 self.assertTrue(tmp is upd2) 184 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 185 del tmp 186 self.snes.setUpdate(None) 187 self.assertTrue(self.snes.getUpdate() is None) 188 self.assertEqual(getrefcount(upd2), refcnt2) 189 190 def testGetKSP(self): 191 ksp = self.snes.getKSP() 192 self.assertEqual(ksp.getRefCount(), 2) 193 194 def testSolve(self): 195 J = PETSc.Mat().create(PETSc.COMM_SELF) 196 J.setSizes([2,2]) 197 J.setType(PETSc.Mat.Type.SEQAIJ) 198 J.setUp() 199 r = PETSc.Vec().createSeq(2) 200 x = PETSc.Vec().createSeq(2) 201 b = PETSc.Vec().createSeq(2) 202 self.snes.setFunction(Function(), r) 203 self.snes.setJacobian(Jacobian(), J) 204 x.setArray([2,3]) 205 b.set(0) 206 self.snes.setConvergenceHistory() 207 self.snes.setFromOptions() 208 self.snes.solve(b, x) 209 rh, ih = self.snes.getConvergenceHistory() 210 self.snes.setConvergenceHistory(0, reset=True) 211 rh, ih = self.snes.getConvergenceHistory() 212 self.assertEqual(len(rh), 0) 213 self.assertEqual(len(ih), 0) 214 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 215 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 216 # XXX this test should not be here ! 217 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 218 self.assertTrue(reason > 0) 219 220 # test interface 221 x = self.snes.getSolution() 222 x.setArray([2,3]) 223 self.snes.solve() 224 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 225 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 226 227 def testResetAndSolve(self): 228 self.snes.reset() 229 self.testSolve() 230 self.snes.reset() 231 self.testSolve() 232 self.snes.reset() 233 234 def testSetMonitor(self): 235 reshist = {} 236 def monitor(snes, its, fgnorm): 237 reshist[its] = fgnorm 238 refcnt = getrefcount(monitor) 239 self.snes.setMonitor(monitor) 240 self.assertEqual(getrefcount(monitor), refcnt + 1) 241 self.testSolve() 242 self.assertTrue(len(reshist) > 0) 243 reshist = {} 244 self.snes.monitorCancel() 245 self.assertEqual(getrefcount(monitor), refcnt) 246 self.testSolve() 247 self.assertTrue(len(reshist) == 0) 248 self.snes.setMonitor(monitor) 249 self.snes.monitor(1, 7) 250 self.assertTrue(reshist[1] == 7) 251 ## Monitor = PETSc.SNES.Monitor 252 ## self.snes.setMonitor(Monitor()) 253 ## self.snes.setMonitor(Monitor.DEFAULT) 254 ## self.snes.setMonitor(Monitor.SOLUTION) 255 ## self.snes.setMonitor(Monitor.RESIDUAL) 256 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 257 258 def testSetGetStepFails(self): 259 its = self.snes.getIterationNumber() 260 self.assertEqual(its, 0) 261 fails = self.snes.getNonlinearStepFailures() 262 self.assertEqual(fails, 0) 263 fails = self.snes.getMaxNonlinearStepFailures() 264 self.assertEqual(fails, 1) 265 self.snes.setMaxNonlinearStepFailures(5) 266 fails = self.snes.getMaxNonlinearStepFailures() 267 self.assertEqual(fails, 5) 268 self.snes.setMaxNonlinearStepFailures(1) 269 fails = self.snes.getMaxNonlinearStepFailures() 270 self.assertEqual(fails, 1) 271 272 def testSetGetLinFails(self): 273 its = self.snes.getLinearSolveIterations() 274 self.assertEqual(its, 0) 275 fails = self.snes.getLinearSolveFailures() 276 self.assertEqual(fails, 0) 277 fails = self.snes.getMaxLinearSolveFailures() 278 self.assertEqual(fails, 1) 279 self.snes.setMaxLinearSolveFailures(5) 280 fails = self.snes.getMaxLinearSolveFailures() 281 self.assertEqual(fails, 5) 282 self.snes.setMaxLinearSolveFailures(1) 283 fails = self.snes.getMaxLinearSolveFailures() 284 self.assertEqual(fails, 1) 285 286 def testEW(self): 287 self.snes.setUseEW(False) 288 self.assertFalse(self.snes.getUseEW()) 289 self.snes.setUseEW(True) 290 self.assertTrue(self.snes.getUseEW()) 291 params = self.snes.getParamsEW() 292 params['version'] = 1 293 self.snes.setParamsEW(**params) 294 params = self.snes.getParamsEW() 295 self.assertEqual(params['version'], 1) 296 params['version'] = PETSc.DEFAULT 297 self.snes.setParamsEW(**params) 298 params = self.snes.getParamsEW() 299 self.assertEqual(params['version'], 1) 300 301 def testMF(self): 302 #self.snes.setOptionsPrefix('MF-') 303 #opts = PETSc.Options(self.snes) 304 #opts['mat_mffd_type'] = 'ds' 305 #opts['snes_monitor'] = 'stdout' 306 #opts['ksp_monitor'] = 'stdout' 307 #opts['snes_view'] = 'stdout' 308 J = PETSc.Mat().create(PETSc.COMM_SELF) 309 J.setSizes([2,2]) 310 J.setType(PETSc.Mat.Type.SEQAIJ) 311 J.setUp() 312 r = PETSc.Vec().createSeq(2) 313 x = PETSc.Vec().createSeq(2) 314 b = PETSc.Vec().createSeq(2) 315 fun = Function() 316 jac = Jacobian() 317 self.snes.setFunction(fun, r) 318 self.snes.setJacobian(jac, J) 319 self.assertFalse(self.snes.getUseMF()) 320 self.snes.setUseMF(False) 321 self.assertFalse(self.snes.getUseMF()) 322 self.snes.setUseMF(True) 323 self.assertTrue(self.snes.getUseMF()) 324 self.snes.setFromOptions() 325 x.setArray([2,3]) 326 b.set(0) 327 self.snes.solve(b, x) 328 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 329 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 330 331 def testFDColor(self): 332 J = PETSc.Mat().create(PETSc.COMM_SELF) 333 J.setSizes([2,2]) 334 J.setType(PETSc.Mat.Type.SEQAIJ) 335 J.setUp() 336 r = PETSc.Vec().createSeq(2) 337 x = PETSc.Vec().createSeq(2) 338 b = PETSc.Vec().createSeq(2) 339 fun = Function() 340 jac = Jacobian() 341 self.snes.setFunction(fun, r) 342 self.snes.setJacobian(jac, J) 343 self.assertFalse(self.snes.getUseFD()) 344 jac(self.snes, x, J, J) 345 self.snes.setUseFD(False) 346 self.assertFalse(self.snes.getUseFD()) 347 self.snes.setUseFD(True) 348 self.assertTrue(self.snes.getUseFD()) 349 self.snes.setFromOptions() 350 x.setArray([2,3]) 351 b.set(0) 352 self.snes.solve(b, x) 353 self.assertAlmostEqual(abs(x[0]), 1.0, places=5) 354 self.assertAlmostEqual(abs(x[1]), 2.0, places=5) 355 356# -------------------------------------------------------------------- 357 358class TestSNESLS(BaseTestSNES, unittest.TestCase): 359 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 360 361class TestSNESTR(BaseTestSNES, unittest.TestCase): 362 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 363 364# -------------------------------------------------------------------- 365 366if __name__ == '__main__': 367 unittest.main() 368