1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9class BaseTestKSP(object): 10 11 KSP_TYPE = None 12 PC_TYPE = None 13 14 def setUp(self): 15 ksp = PETSc.KSP() 16 ksp.create(PETSc.COMM_SELF) 17 if self.KSP_TYPE: 18 ksp.setType(self.KSP_TYPE) 19 if self.PC_TYPE: 20 pc = ksp.getPC() 21 pc.setType(self.PC_TYPE) 22 self.ksp = ksp 23 24 def tearDown(self): 25 self.ksp = None 26 27 def testGetSetType(self): 28 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 29 self.ksp.setType(self.KSP_TYPE) 30 self.assertEqual(self.ksp.getType(), self.KSP_TYPE) 31 32 def testTols(self): 33 tols = self.ksp.getTolerances() 34 self.ksp.setTolerances(*tols) 35 tnames = ('rtol', 'atol', 'divtol', 'max_it') 36 tolvals = [getattr(self.ksp, t) for t in tnames] 37 self.assertEqual(tuple(tols), tuple(tolvals)) 38 39 def testProperties(self): 40 ksp = self.ksp 41 # 42 ksp.appctx = (1,2,3) 43 self.assertEqual(ksp.appctx, (1,2,3)) 44 ksp.appctx = None 45 self.assertEqual(ksp.appctx, None) 46 # 47 side = ksp.pc_side 48 ksp.pc_side = side 49 self.assertEqual(ksp.pc_side, side) 50 # 51 nt = ksp.norm_type 52 ksp.norm_type = nt 53 self.assertEqual(ksp.norm_type, nt) 54 # 55 ksp.its = 1 56 self.assertEqual(ksp.its, 1) 57 ksp.its = 0 58 self.assertEqual(ksp.its, 0) 59 # 60 ksp.norm = 1 61 self.assertEqual(ksp.norm, 1) 62 ksp.norm = 0 63 self.assertEqual(ksp.norm, 0) 64 # 65 rh = ksp.history 66 self.assertTrue(len(rh)==0) 67 # 68 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITS 69 ksp.reason = reason 70 self.assertEqual(ksp.reason, reason) 71 self.assertTrue(ksp.converged) 72 self.assertFalse(ksp.diverged) 73 self.assertFalse(ksp.iterating) 74 reason = PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT 75 ksp.reason = reason 76 self.assertEqual(ksp.reason, reason) 77 self.assertFalse(ksp.converged) 78 self.assertTrue(ksp.diverged) 79 self.assertFalse(ksp.iterating) 80 reason = PETSc.KSP.ConvergedReason.CONVERGED_ITERATING 81 ksp.reason = reason 82 self.assertEqual(ksp.reason, reason) 83 self.assertFalse(ksp.converged) 84 self.assertFalse(ksp.diverged) 85 self.assertTrue(ksp.iterating) 86 87 def testGetSetPC(self): 88 oldpc = self.ksp.getPC() 89 self.assertEqual(oldpc.getRefCount(), 2) 90 newpc = PETSc.PC() 91 newpc.create(self.ksp.getComm()) 92 self.assertEqual(newpc.getRefCount(), 1) 93 self.ksp.setPC(newpc) 94 self.assertEqual(newpc.getRefCount(), 2) 95 self.assertEqual(oldpc.getRefCount(), 1) 96 oldpc.destroy() 97 self.assertFalse(bool(oldpc)) 98 pc = self.ksp.getPC() 99 self.assertTrue(bool(pc)) 100 self.assertEqual(pc, newpc) 101 self.assertEqual(pc.getRefCount(), 3) 102 newpc.destroy() 103 self.assertFalse(bool(newpc)) 104 self.assertEqual(pc.getRefCount(), 2) 105 106 def testSolve(self): 107 A = PETSc.Mat().create(PETSc.COMM_SELF) 108 A.setSizes([3,3]) 109 A.setType(PETSc.Mat.Type.SEQAIJ) 110 A.setPreallocationNNZ(1) 111 for i in range(3): 112 A.setValue(i, i, 0.9/(i+1)) 113 A.assemble() 114 A.shift(1) 115 x, b = A.createVecs() 116 b.set(10) 117 x.setRandom() 118 self.ksp.setOperators(A) 119 self.ksp.setConvergenceHistory() 120 self.ksp.solve(b, x) 121 r = b.duplicate() 122 u = x.duplicate() 123 self.ksp.buildSolution(u) 124 self.ksp.buildResidual(u) 125 rh = self.ksp.getConvergenceHistory() 126 self.ksp.setConvergenceHistory(0) 127 rh = self.ksp.getConvergenceHistory() 128 self.assertEqual(len(rh), 0) 129 del A, x, b 130 131 def testResetAndSolve(self): 132 self.ksp.reset() 133 self.testSolve() 134 self.ksp.reset() 135 self.testSolve() 136 self.ksp.reset() 137 138 def testSetMonitor(self): 139 reshist = {} 140 def monitor(ksp, its, rnorm): 141 reshist[its] = rnorm 142 refcnt = getrefcount(monitor) 143 self.ksp.setMonitor(monitor) 144 self.assertEqual(getrefcount(monitor), refcnt + 1) 145 ## self.testSolve() 146 reshist = {} 147 self.ksp.monitorCancel() 148 self.assertEqual(getrefcount(monitor), refcnt) 149 self.testSolve() 150 self.assertEqual(len(reshist), 0) 151 ## Monitor = PETSc.KSP.Monitor 152 ## self.ksp.setMonitor(Monitor()) 153 ## self.ksp.setMonitor(Monitor.DEFAULT) 154 ## self.ksp.setMonitor(Monitor.TRUE_RESIDUAL_NORM) 155 ## self.ksp.setMonitor(Monitor.SOLUTION) 156 157 def testSetConvergenceTest(self): 158 def converged(ksp, its, rnorm): 159 if its > 10: return True 160 return False 161 refcnt = getrefcount(converged) 162 self.ksp.setConvergenceTest(converged) 163 self.assertEqual(getrefcount(converged), refcnt + 1) 164 self.ksp.setConvergenceTest(None) 165 self.assertEqual(getrefcount(converged), refcnt) 166 167# -------------------------------------------------------------------- 168 169class TestKSPPREONLY(BaseTestKSP, unittest.TestCase): 170 KSP_TYPE = PETSc.KSP.Type.PREONLY 171 PC_TYPE = PETSc.PC.Type.LU 172 173class TestKSPRICHARDSON(BaseTestKSP, unittest.TestCase): 174 KSP_TYPE = PETSc.KSP.Type.RICHARDSON 175 176class TestKSPCHEBYCHEV(BaseTestKSP, unittest.TestCase): 177 try: 178 KSP_TYPE = PETSc.KSP.Type.CHEBYSHEV 179 except AttributeError: 180 KSP_TYPE = PETSc.KSP.Type.CHEBYCHEV 181 182class TestKSPCG(BaseTestKSP, unittest.TestCase): 183 KSP_TYPE = PETSc.KSP.Type.CG 184 185class TestKSPCGNE(BaseTestKSP, unittest.TestCase): 186 KSP_TYPE = PETSc.KSP.Type.CGNE 187 188class TestKSPSTCG(BaseTestKSP, unittest.TestCase): 189 KSP_TYPE = PETSc.KSP.Type.STCG 190 191class TestKSPBCGS(BaseTestKSP, unittest.TestCase): 192 KSP_TYPE = PETSc.KSP.Type.BCGS 193 194class TestKSPBCGSL(BaseTestKSP, unittest.TestCase): 195 KSP_TYPE = PETSc.KSP.Type.BCGSL 196 197class TestKSPCGS(BaseTestKSP, unittest.TestCase): 198 KSP_TYPE = PETSc.KSP.Type.CGS 199 200class TestKSPQCG(BaseTestKSP, unittest.TestCase): 201 KSP_TYPE = PETSc.KSP.Type.QCG 202 PC_TYPE = PETSc.PC.Type.JACOBI 203 204class TestKSPBICG(BaseTestKSP, unittest.TestCase): 205 KSP_TYPE = PETSc.KSP.Type.BICG 206 207class TestKSPGMRES(BaseTestKSP, unittest.TestCase): 208 KSP_TYPE = PETSc.KSP.Type.GMRES 209 210class TestKSPFGMRES(BaseTestKSP, unittest.TestCase): 211 KSP_TYPE = PETSc.KSP.Type.FGMRES 212 213# -------------------------------------------------------------------- 214 215if PETSc.ScalarType().dtype.char in 'FDG': 216 del TestKSPSTCG 217 218# -------------------------------------------------------------------- 219 220if __name__ == '__main__': 221 unittest.main() 222