1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9class MyKSP(object): 10 11 def __init__(self): 12 pass 13 14 def create(self, ksp): 15 self.work = [] 16 17 def destroy(self, ksp): 18 for v in self.work: 19 v.destroy() 20 21 def setUp(self, ksp): 22 self.work[:] = ksp.getWorkVecs(right=2, left=None) 23 24 def reset(self, ksp): 25 for v in self.work: 26 v.destroy() 27 del self.work[:] 28 29 def loop(self, ksp, r): 30 its = ksp.getIterationNumber() 31 rnorm = r.norm() 32 ksp.setResidualNorm(rnorm) 33 ksp.logConvergenceHistory(rnorm) 34 ksp.monitor(its, rnorm) 35 reason = ksp.callConvergenceTest(its, rnorm) 36 if not reason: 37 ksp.setIterationNumber(its+1) 38 else: 39 ksp.setConvergedReason(reason) 40 return reason 41 42class MyRichardson(MyKSP): 43 44 def solve(self, ksp, b, x): 45 A, B = ksp.getOperators() 46 P = ksp.getPC() 47 r, z = self.work 48 # 49 A.mult(x, r) 50 r.aypx(-1, b) 51 P.apply(r, z) 52 x.axpy(1, z) 53 while not self.loop(ksp, z): 54 A.mult(x, r) 55 r.aypx(-1, b) 56 P.apply(r, z) 57 x.axpy(1, z) 58 59class MyCG(MyKSP): 60 61 def setUp(self, ksp): 62 super(MyCG, self).setUp(ksp) 63 d = self.work[0].duplicate() 64 q = d.duplicate() 65 self.work += [d, q] 66 67 def solve(self, ksp, b, x): 68 A, B = ksp.getOperators() 69 P = ksp.getPC() 70 r, z, d, q = self.work 71 # 72 A.mult(x, r) 73 r.aypx(-1, b) 74 r.copy(d) 75 delta_0 = r.dot(r) 76 delta = delta_0 77 while not self.loop(ksp, r): 78 A.mult(d, q) 79 alpha = delta / d.dot(q) 80 x.axpy(+alpha, d) 81 r.axpy(-alpha, q) 82 delta_old = delta 83 delta = r.dot(r) 84 beta = delta / delta_old 85 d.aypx(beta, r) 86 87# -------------------------------------------------------------------- 88 89from test_ksp import BaseTestKSP 90 91class BaseTestKSPPYTHON(BaseTestKSP): 92 93 KSP_TYPE = PETSc.KSP.Type.PYTHON 94 ContextClass = None 95 96 def setUp(self): 97 super(BaseTestKSPPYTHON, self).setUp() 98 ctx = self.ContextClass() 99 self.ksp.setPythonContext(ctx) 100 101 def testGetType(self): 102 ctx = self.ksp.getPythonContext() 103 pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__) 104 self.assertTrue(self.ksp.getPythonType() == pytype) 105 106 def tearDown(self): 107 self.ksp.destroy() 108 PETSc.garbage_cleanup() 109 110class TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase): 111 PC_TYPE = PETSc.PC.Type.JACOBI 112 ContextClass = MyRichardson 113 114class TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase): 115 PC_TYPE = PETSc.PC.Type.NONE 116 ContextClass = MyCG 117 118# -------------------------------------------------------------------- 119 120if __name__ == '__main__': 121 unittest.main() 122