1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from test_ksp import BaseTestKSP 6 7# -------------------------------------------------------------------- 8 9 10class MyKSP: 11 def __init__(self): 12 pass 13 14 def create(self, ksp): 15 self.work = [] 16 17 def destroy(self, ksp): 18 for v in self.work: 19 v.destroy() 20 21 def setUp(self, ksp): 22 self.work[:] = ksp.getWorkVecs(right=2, left=None) 23 24 def reset(self, ksp): 25 for v in self.work: 26 v.destroy() 27 del self.work[:] 28 29 def loop(self, ksp, r): 30 its = ksp.getIterationNumber() 31 rnorm = r.norm() 32 ksp.setResidualNorm(rnorm) 33 ksp.logConvergenceHistory(rnorm) 34 ksp.monitor(its, rnorm) 35 reason = ksp.callConvergenceTest(its, rnorm) 36 if not reason: 37 ksp.setIterationNumber(its + 1) 38 else: 39 ksp.setConvergedReason(reason) 40 return reason 41 42 43class MyRichardson(MyKSP): 44 def solve(self, ksp, b, x): 45 A, B = ksp.getOperators() 46 P = ksp.getPC() 47 r, z = self.work 48 # 49 A.mult(x, r) 50 r.aypx(-1, b) 51 P.apply(r, z) 52 x.axpy(1, z) 53 while not self.loop(ksp, z): 54 A.mult(x, r) 55 r.aypx(-1, b) 56 P.apply(r, z) 57 x.axpy(1, z) 58 59 60class MyCG(MyKSP): 61 def setUp(self, ksp): 62 super().setUp(ksp) 63 d = self.work[0].duplicate() 64 q = d.duplicate() 65 self.work += [d, q] 66 67 def solve(self, ksp, b, x): 68 A, B = ksp.getOperators() 69 # P = ksp.getPC() 70 r, z, d, q = self.work 71 # 72 A.mult(x, r) 73 r.aypx(-1, b) 74 r.copy(d) 75 delta_0 = r.dot(r) 76 delta = delta_0 77 while not self.loop(ksp, r): 78 A.mult(d, q) 79 alpha = delta / d.dot(q) 80 x.axpy(+alpha, d) 81 r.axpy(-alpha, q) 82 delta_old = delta 83 delta = r.dot(r) 84 beta = delta / delta_old 85 d.aypx(beta, r) 86 87 88# -------------------------------------------------------------------- 89 90 91class BaseTestKSPPYTHON(BaseTestKSP): 92 KSP_TYPE = PETSc.KSP.Type.PYTHON 93 ContextClass = None 94 95 def setUp(self): 96 super().setUp() 97 ctx = self.ContextClass() 98 self.ksp.setPythonContext(ctx) 99 100 def testGetType(self): 101 ctx = self.ksp.getPythonContext() 102 pytype = f'{ctx.__module__}.{type(ctx).__name__}' 103 self.assertTrue(self.ksp.getPythonType() == pytype) 104 105 def tearDown(self): 106 self.ksp.destroy() 107 PETSc.garbage_cleanup() 108 109 110class TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase): 111 PC_TYPE = PETSc.PC.Type.JACOBI 112 ContextClass = MyRichardson 113 114 115class TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase): 116 PC_TYPE = PETSc.PC.Type.NONE 117 ContextClass = MyCG 118 119 120# -------------------------------------------------------------------- 121 122if __name__ == '__main__': 123 unittest.main() 124