1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9 10class BaseTestKSP: 11 KSP_TYPE = None 12 PC_TYPE = None 13 14 def setUp(self): 15 ksp = PETSc.KSP() 16 ksp.create(PETSc.COMM_SELF) 17 if self.KSP_TYPE: 18 ksp.setType(self.KSP_TYPE) 19 if self.PC_TYPE: 20 pc = ksp.getPC() 21 pc.setType(self.PC_TYPE) 22 self.ksp = ksp 23 24 def tearDown(self): 25 self.ksp = None 26 PETSc.garbage_cleanup() 27 28 def testGetSetType(self): 29 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 30 self.ksp.setType(self.KSP_TYPE) 31 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 32 33 def testTols(self): 34 tols = self.ksp.getTolerances() 35 self.ksp.setTolerances(*tols) 36 tnames = ('rtol', 'atol', 'divtol', 'max_it') 37 tolvals = [getattr(self.ksp, t) for t in tnames] 38 self.assertEqual(tuple(tols), tuple(tolvals)) 39 40 def testProperties(self): 41 ksp = self.ksp 42 # 43 ksp.appctx = (1, 2, 3) 44 self.assertEqual(ksp.appctx, (1, 2, 3)) 45 ksp.appctx = None 46 self.assertEqual(ksp.appctx, None) 47 # 48 side = ksp.pc_side 49 ksp.pc_side = side 50 self.assertEqual(ksp.pc_side, side) 51 # 52 nt = ksp.norm_type 53 ksp.norm_type = nt 54 self.assertEqual(ksp.norm_type, nt) 55 # 56 ksp.its = 1 57 self.assertEqual(ksp.its, 1) 58 ksp.its = 0 59 self.assertEqual(ksp.its, 0) 60 # 61 ksp.norm = 1 62 self.assertEqual(ksp.norm, 1) 63 ksp.norm = 0 64 self.assertEqual(ksp.norm, 0) 65 # 66 rh = ksp.history 67 self.assertTrue(len(rh) == 0) 68 # 69 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS 70 ksp.reason = reason 71 self.assertEqual(ksp.reason, reason) 72 self.assertTrue(ksp.is_converged) 73 self.assertFalse(ksp.is_diverged) 74 self.assertFalse(ksp.is_iterating) 75 reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT 76 ksp.reason = reason 77 self.assertEqual(ksp.reason, reason) 78 self.assertFalse(ksp.is_converged) 79 self.assertTrue(ksp.is_diverged) 80 self.assertFalse(ksp.is_iterating) 81 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING 82 ksp.reason = reason 83 self.assertEqual(ksp.reason, reason) 84 self.assertFalse(ksp.is_converged) 85 self.assertFalse(ksp.is_diverged) 86 self.assertTrue(ksp.is_iterating) 87 88 def testGetSetPC(self): 89 oldpc = self.ksp.getPC() 90 self.assertEqual(oldpc.getRefCount(), 2) 91 newpc = PETSc.PC() 92 newpc.create(self.ksp.getComm()) 93 self.assertEqual(newpc.getRefCount(), 1) 94 self.ksp.setPC(newpc) 95 self.assertEqual(newpc.getRefCount(), 2) 96 self.assertEqual(oldpc.getRefCount(), 1) 97 oldpc.destroy() 98 self.assertFalse(bool(oldpc)) 99 pc = self.ksp.getPC() 100 self.assertTrue(bool(pc)) 101 self.assertEqual(pc, newpc) 102 self.assertEqual(pc.getRefCount(), 3) 103 newpc.destroy() 104 self.assertFalse(bool(newpc)) 105 self.assertEqual(pc.getRefCount(), 2) 106 107 def testSolve(self): 108 A = PETSc.Mat().create(PETSc.COMM_SELF) 109 A.setSizes([3, 3]) 110 A.setType(PETSc.Mat.Type.SEQAIJ) 111 A.setPreallocationNNZ(1) 112 for i in range(3): 113 A.setValue(i, i, 0.9 / (i + 1)) 114 A.assemble() 115 A.shift(1) 116 x, b = A.createVecs() 117 b.set(10) 118 x.setRandom() 119 self.ksp.setOperators(A) 120 self.ksp.setConvergenceHistory() 121 self.ksp.solve(b, x) 122 u = x.duplicate() 123 self.ksp.buildSolution(u) 124 self.ksp.buildResidual(u) 125 rh = self.ksp.getConvergenceHistory() 126 self.ksp.setConvergenceHistory(0) 127 rh = self.ksp.getConvergenceHistory() 128 self.assertEqual(len(rh), 0) 129 del A, x, b 130 131 def testResetAndSolve(self): 132 self.ksp.reset() 133 self.testSolve() 134 self.ksp.reset() 135 self.testSolve() 136 self.ksp.reset() 137 138 def testSetMonitor(self): 139 reshist = {} 140 141 def monitor(ksp, its, rnorm): 142 reshist[its] = rnorm 143 144 refcnt = getrefcount(monitor) 145 self.ksp.setMonitor(monitor) 146 self.assertEqual(getrefcount(monitor), refcnt + 1) 147 ## self.testSolve() 148 reshist = {} 149 self.ksp.monitorCancel() 150 self.assertEqual(getrefcount(monitor), refcnt) 151 self.testSolve() 152 self.assertEqual(len(reshist), 0) 153 ## Monitor = PETSc.KSP.Monitor 154 ## self.ksp.setMonitor(Monitor()) 155 ## self.ksp.setMonitor(Monitor.DEFAULT) 156 ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM) 157 ## self.ksp.setMonitor(Monitor.SOLUTION) 158 159 def testSetConvergenceTest(self): 160 def converged(ksp, its, rnorm): 161 if its > 10: 162 return True 163 return False 164 165 refcnt = getrefcount(converged) 166 self.ksp.setConvergenceTest(converged) 167 self.assertEqual(getrefcount(converged), refcnt + 1) 168 self.ksp.setConvergenceTest(None) 169 self.assertEqual(getrefcount(converged), refcnt) 170 171 def testAddConvergenceTest(self): 172 def converged(ksp, its, rnorm): 173 return True 174 175 refcnt = getrefcount(converged) 176 self.ksp.addConvergenceTest(converged, prepend=True) 177 self.assertEqual(getrefcount(converged), refcnt + 1) 178 self.testSolve() 179 self.ksp.setConvergenceTest(None) 180 self.assertEqual(getrefcount(converged), refcnt) 181 self.testSolve() 182 self.ksp.addConvergenceTest(converged, prepend=False) 183 self.assertEqual(getrefcount(converged), refcnt + 1) 184 self.testSolve() 185 self.ksp.setConvergenceTest(None) 186 self.assertEqual(getrefcount(converged), refcnt) 187 188 def testSetPreSolveTest(self): 189 check = {'val': 0} 190 191 def presolve(ksp, rhs, x): 192 check['val'] = 1 193 194 refcnt = getrefcount(presolve) 195 self.ksp.setPreSolve(presolve) 196 self.assertEqual(getrefcount(presolve), refcnt + 1) 197 self.testSolve() 198 self.assertEqual(check['val'], 1) 199 self.ksp.setPreSolve(None) 200 self.assertEqual(getrefcount(presolve), refcnt) 201 202 def testSetPostSolveTest(self): 203 check = {'val': 0} 204 205 def postsolve(ksp, rhs, x): 206 check['val'] = 1 207 208 refcnt = getrefcount(postsolve) 209 self.ksp.setPostSolve(postsolve) 210 self.assertEqual(getrefcount(postsolve), refcnt + 1) 211 self.testSolve() 212 self.assertEqual(check['val'], 1) 213 self.ksp.setPostSolve(None) 214 self.assertEqual(getrefcount(postsolve), refcnt) 215 216 217# -------------------------------------------------------------------- 218 219 220class TestKSPPREONLY(BaseTestKSP, unittest.TestCase): 221 KSP_TYPE = PETSc.KSP.Type.PREONLY 222 PC_TYPE = PETSc.PC.Type.LU 223 224 225class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase): 226 KSP_TYPE = PETSc.KSP.Type.RICHARDSON 227 228 229class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase): 230 try: 231 KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV 232 except AttributeError: 233 KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV 234 235 236class TestKSPCG(BaseTestKSP, unittest.TestCase): 237 KSP_TYPE = PETSc.KSP.Type.CG 238 239 240class TestKSPCGNE(BaseTestKSP, unittest.TestCase): 241 KSP_TYPE = PETSc.KSP.Type.CGNE 242 243 244class TestKSPSTCG(BaseTestKSP, unittest.TestCase): 245 KSP_TYPE = PETSc.KSP.Type.STCG 246 247 248class TestKSPBCGS(BaseTestKSP, unittest.TestCase): 249 KSP_TYPE = PETSc.KSP.Type.BCGS 250 251 252class TestKSPBCGSL(BaseTestKSP, unittest.TestCase): 253 KSP_TYPE = PETSc.KSP.Type.BCGSL 254 255 256class TestKSPCGS(BaseTestKSP, unittest.TestCase): 257 KSP_TYPE = PETSc.KSP.Type.CGS 258 259 260class TestKSPQCG(BaseTestKSP, unittest.TestCase): 261 KSP_TYPE = PETSc.KSP.Type.QCG 262 PC_TYPE = PETSc.PC.Type.JACOBI 263 264 265class TestKSPBICG(BaseTestKSP, unittest.TestCase): 266 KSP_TYPE = PETSc.KSP.Type.BICG 267 268 269class TestKSPGMRES(BaseTestKSP, unittest.TestCase): 270 KSP_TYPE = PETSc.KSP.Type.GMRES 271 272 273class TestKSPFGMRES(BaseTestKSP, unittest.TestCase): 274 KSP_TYPE = PETSc.KSP.Type.FGMRES 275 276 277class TestKSPLSQR(BaseTestKSP, unittest.TestCase): 278 KSP_TYPE = PETSc.KSP.Type.LSQR 279 280 281# -------------------------------------------------------------------- 282 283if PETSc.ScalarType().dtype.char in 'FDG': 284 del TestKSPSTCG 285 286# -------------------------------------------------------------------- 287 288if __name__ == '__main__': 289 unittest.main() 290