1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9class BaseTestKSP(object): 10 11 KSP_TYPE = None 12 PC_TYPE = None 13 14 def setUp(self): 15 ksp = PETSc.KSP() 16 ksp.create(PETSc.COMM_SELF) 17 if self.KSP_TYPE: 18 ksp.setType(self.KSP_TYPE) 19 if self.PC_TYPE: 20 pc = ksp.getPC() 21 pc.setType(self.PC_TYPE) 22 self.ksp = ksp 23 24 def tearDown(self): 25 self.ksp = None 26 PETSc.garbage_cleanup() 27 28 def testGetSetType(self): 29 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 30 self.ksp.setType(self.KSP_TYPE) 31 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 32 33 def testTols(self): 34 tols = self.ksp.getTolerances() 35 self.ksp.setTolerances(*tols) 36 tnames = ('rtol', 'atol', 'divtol', 'max_it') 37 tolvals = [getattr(self.ksp, t) for t in tnames] 38 self.assertEqual(tuple(tols), tuple(tolvals)) 39 40 def testProperties(self): 41 ksp = self.ksp 42 # 43 ksp.appctx = (1,2,3) 44 self.assertEqual(ksp.appctx, (1,2,3)) 45 ksp.appctx = None 46 self.assertEqual(ksp.appctx, None) 47 # 48 side = ksp.pc_side 49 ksp.pc_side = side 50 self.assertEqual(ksp.pc_side, side) 51 # 52 nt = ksp.norm_type 53 ksp.norm_type = nt 54 self.assertEqual(ksp.norm_type, nt) 55 # 56 ksp.its = 1 57 self.assertEqual(ksp.its, 1) 58 ksp.its = 0 59 self.assertEqual(ksp.its, 0) 60 # 61 ksp.norm = 1 62 self.assertEqual(ksp.norm, 1) 63 ksp.norm = 0 64 self.assertEqual(ksp.norm, 0) 65 # 66 rh = ksp.history 67 self.assertTrue(len(rh)==0) 68 # 69 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS 70 ksp.reason = reason 71 self.assertEqual(ksp.reason, reason) 72 self.assertTrue(ksp.is_converged) 73 self.assertFalse(ksp.is_diverged) 74 self.assertFalse(ksp.is_iterating) 75 reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT 76 ksp.reason = reason 77 self.assertEqual(ksp.reason, reason) 78 self.assertFalse(ksp.is_converged) 79 self.assertTrue(ksp.is_diverged) 80 self.assertFalse(ksp.is_iterating) 81 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING 82 ksp.reason = reason 83 self.assertEqual(ksp.reason, reason) 84 self.assertFalse(ksp.is_converged) 85 self.assertFalse(ksp.is_diverged) 86 self.assertTrue(ksp.is_iterating) 87 88 def testGetSetPC(self): 89 oldpc = self.ksp.getPC() 90 self.assertEqual(oldpc.getRefCount(), 2) 91 newpc = PETSc.PC() 92 newpc.create(self.ksp.getComm()) 93 self.assertEqual(newpc.getRefCount(), 1) 94 self.ksp.setPC(newpc) 95 self.assertEqual(newpc.getRefCount(), 2) 96 self.assertEqual(oldpc.getRefCount(), 1) 97 oldpc.destroy() 98 self.assertFalse(bool(oldpc)) 99 pc = self.ksp.getPC() 100 self.assertTrue(bool(pc)) 101 self.assertEqual(pc, newpc) 102 self.assertEqual(pc.getRefCount(), 3) 103 newpc.destroy() 104 self.assertFalse(bool(newpc)) 105 self.assertEqual(pc.getRefCount(), 2) 106 107 def testSolve(self): 108 A = PETSc.Mat().create(PETSc.COMM_SELF) 109 A.setSizes([3,3]) 110 A.setType(PETSc.Mat.Type.SEQAIJ) 111 A.setPreallocationNNZ(1) 112 for i in range(3): 113 A.setValue(i, i, 0.9/(i+1)) 114 A.assemble() 115 A.shift(1) 116 x, b = A.createVecs() 117 b.set(10) 118 x.setRandom() 119 self.ksp.setOperators(A) 120 self.ksp.setConvergenceHistory() 121 self.ksp.solve(b, x) 122 r = b.duplicate() 123 u = x.duplicate() 124 self.ksp.buildSolution(u) 125 self.ksp.buildResidual(u) 126 rh = self.ksp.getConvergenceHistory() 127 self.ksp.setConvergenceHistory(0) 128 rh = self.ksp.getConvergenceHistory() 129 self.assertEqual(len(rh), 0) 130 del A, x, b 131 132 def testResetAndSolve(self): 133 self.ksp.reset() 134 self.testSolve() 135 self.ksp.reset() 136 self.testSolve() 137 self.ksp.reset() 138 139 def testSetMonitor(self): 140 reshist = {} 141 def monitor(ksp, its, rnorm): 142 reshist[its] = rnorm 143 refcnt = getrefcount(monitor) 144 self.ksp.setMonitor(monitor) 145 self.assertEqual(getrefcount(monitor), refcnt + 1) 146 ## self.testSolve() 147 reshist = {} 148 self.ksp.monitorCancel() 149 self.assertEqual(getrefcount(monitor), refcnt) 150 self.testSolve() 151 self.assertEqual(len(reshist), 0) 152 ## Monitor = PETSc.KSP.Monitor 153 ## self.ksp.setMonitor(Monitor()) 154 ## self.ksp.setMonitor(Monitor.DEFAULT) 155 ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM) 156 ## self.ksp.setMonitor(Monitor.SOLUTION) 157 158 def testSetConvergenceTest(self): 159 def converged(ksp, its, rnorm): 160 if its > 10: return True 161 return False 162 refcnt = getrefcount(converged) 163 self.ksp.setConvergenceTest(converged) 164 self.assertEqual(getrefcount(converged), refcnt + 1) 165 self.ksp.setConvergenceTest(None) 166 self.assertEqual(getrefcount(converged), refcnt) 167 168# -------------------------------------------------------------------- 169 170class TestKSPPREONLY(BaseTestKSP, unittest.TestCase): 171 KSP_TYPE = PETSc.KSP.Type.PREONLY 172 PC_TYPE = PETSc.PC.Type.LU 173 174class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase): 175 KSP_TYPE = PETSc.KSP.Type.RICHARDSON 176 177class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase): 178 try: 179 KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV 180 except AttributeError: 181 KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV 182 183class TestKSPCG(BaseTestKSP, unittest.TestCase): 184 KSP_TYPE = PETSc.KSP.Type.CG 185 186class TestKSPCGNE(BaseTestKSP, unittest.TestCase): 187 KSP_TYPE = PETSc.KSP.Type.CGNE 188 189class TestKSPSTCG(BaseTestKSP, unittest.TestCase): 190 KSP_TYPE = PETSc.KSP.Type.STCG 191 192class TestKSPBCGS(BaseTestKSP, unittest.TestCase): 193 KSP_TYPE = PETSc.KSP.Type.BCGS 194 195class TestKSPBCGSL(BaseTestKSP, unittest.TestCase): 196 KSP_TYPE = PETSc.KSP.Type.BCGSL 197 198class TestKSPCGS(BaseTestKSP, unittest.TestCase): 199 KSP_TYPE = PETSc.KSP.Type.CGS 200 201class TestKSPQCG(BaseTestKSP, unittest.TestCase): 202 KSP_TYPE = PETSc.KSP.Type.QCG 203 PC_TYPE = PETSc.PC.Type.JACOBI 204 205class TestKSPBICG(BaseTestKSP, unittest.TestCase): 206 KSP_TYPE = PETSc.KSP.Type.BICG 207 208class TestKSPGMRES(BaseTestKSP, unittest.TestCase): 209 KSP_TYPE = PETSc.KSP.Type.GMRES 210 211class TestKSPFGMRES(BaseTestKSP, unittest.TestCase): 212 KSP_TYPE = PETSc.KSP.Type.FGMRES 213 214# -------------------------------------------------------------------- 215 216if PETSc.ScalarType().dtype.char in 'FDG': 217 del TestKSPSTCG 218 219# -------------------------------------------------------------------- 220 221if __name__ == '__main__': 222 unittest.main() 223