xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision bef158480efac06de457f7a665168877ab3c2fd7)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseMyPC(object):
10    def setup(self, pc):
11        pass
12    def reset(self, pc):
13        pass
14    def apply(self, pc, x, y):
15        raise NotImplementedError
16    def applyT(self, pc, x, y):
17        self.apply(pc, x, y)
18    def applyS(self, pc, x, y):
19        self.apply(pc, x, y)
20    def applySL(self, pc, x, y):
21        self.applyS(pc, x, y)
22    def applySR(self, pc, x, y):
23        self.applyS(pc, x, y)
24    def applyRich(self, pc, x, y, w, tols):
25        self.apply(pc, x, y)
26
27class MyPCNone(BaseMyPC):
28    def apply(self, pc, x, y):
29        x.copy(y)
30
31class MyPCJacobi(BaseMyPC):
32    def setup(self, pc):
33        A, P = pc.getOperators()
34        self.diag = P.getDiagonal()
35        self.diag.reciprocal()
36    def reset(self, pc):
37        self.diag.destroy()
38        del self.diag
39    def apply(self, pc, x, y):
40        y.pointwiseMult(self.diag, x)
41    def applyS(self, pc, x, y):
42        self.diag.copy(y)
43        y.sqrtabs()
44        y.pointwiseMult(y, x)
45
46class PC_PYTHON_CLASS(object):
47
48    def __init__(self):
49        self.impl = None
50        self.log = {}
51    def _log(self, method, *args):
52        self.log.setdefault(method, 0)
53        self.log[method] += 1
54    def create(self, pc):
55        self._log('create', pc)
56    def destroy(self, pc):
57        self._log('destroy')
58        self.impl = None
59    def reset(self, pc):
60        self._log('reset', pc)
61    def view(self, pc, vw):
62        self._log('view', pc, vw)
63        assert isinstance(pc, PETSc.PC)
64        assert isinstance(vw, PETSc.Viewer)
65        pass
66    def setFromOptions(self, pc):
67        self._log('setFromOptions', pc)
68        assert isinstance(pc, PETSc.PC)
69        OptDB = PETSc.Options(pc)
70        impl =  OptDB.getString('impl','MyPCNone')
71        klass = globals()[impl]
72        self.impl = klass()
73    def setUp(self, pc):
74        self._log('setUp', pc)
75        assert isinstance(pc, PETSc.PC)
76        self.impl.setup(pc)
77    def preSolve(self, pc, ksp, b, x):
78        self._log('preSolve', pc, ksp, b, x)
79    def postSolve(self, pc, ksp, b, x):
80        self._log('postSolve', pc, ksp, b, x)
81    def apply(self, pc, x, y):
82        self._log('apply', pc, x, y)
83        assert isinstance(pc, PETSc.PC)
84        assert isinstance(x,  PETSc.Vec)
85        assert isinstance(y,  PETSc.Vec)
86        self.impl.apply(pc, x, y)
87    def applySymmetricLeft(self, pc, x, y):
88        self._log('applySymmetricLeft', pc, x, y)
89        assert isinstance(pc, PETSc.PC)
90        assert isinstance(x,  PETSc.Vec)
91        assert isinstance(y,  PETSc.Vec)
92        self.impl.applySL(pc, x, y)
93    def applySymmetricRight(self, pc, x, y):
94        self._log('applySymmetricRight', pc, x, y)
95        assert isinstance(pc, PETSc.PC)
96        assert isinstance(x,  PETSc.Vec)
97        assert isinstance(y,  PETSc.Vec)
98        self.impl.applySR(pc, x, y)
99    def applyTranspose(self, pc, x, y):
100        self._log('applyTranspose', pc, x, y)
101        assert isinstance(pc, PETSc.PC)
102        assert isinstance(x,  PETSc.Vec)
103        assert isinstance(y,  PETSc.Vec)
104        self.impl.applyT(pc, x, y)
105
106    def applyRichardson(self, pc, x, y, w, tols):
107        self._log('applyRichardson', pc, x, y, w, tols)
108        assert isinstance(pc, PETSc.PC)
109        assert isinstance(x,  PETSc.Vec)
110        assert isinstance(y,  PETSc.Vec)
111        assert isinstance(w,  PETSc.Vec)
112        assert isinstance(tols,  tuple)
113        assert len(tols) == 4
114        self.impl.applyRich(pc, x, y, w, tols)
115
116
117class TestPCPYTHON(unittest.TestCase):
118
119    PC_TYPE = PETSc.PC.Type.PYTHON
120
121    PC_PREFIX = 'test-'
122
123    def setUp(self):
124        pc = self.pc = PETSc.PC()
125        pc.create(PETSc.COMM_SELF)
126        pc.setType(self.PC_TYPE)
127        module = __name__
128        factory = 'PC_PYTHON_CLASS'
129        self.pc.prefix = self.PC_PREFIX
130        OptDB = PETSc.Options(self.pc)
131        assert OptDB.prefix == self.pc.prefix
132        OptDB['pc_python_type'] = '%s.%s' % (module, factory)
133        self.pc.setFromOptions()
134        del OptDB['pc_python_type']
135        assert self._getCtx().log['create'] == 1
136        assert self._getCtx().log['setFromOptions'] == 1
137        ctx = self._getCtx()
138        self.assertEqual(getrefcount(ctx), 3)
139
140    def tearDown(self):
141        ctx = self._getCtx()
142        self.pc.destroy() # XXX
143        self.pc = None
144        assert ctx.log['destroy'] == 1
145        self.assertEqual(getrefcount(ctx), 2)
146
147    def _prepare(self):
148        A = PETSc.Mat().createAIJ([3,3], comm=PETSc.COMM_SELF)
149        A.setUp()
150        A.assemble()
151        A.shift(10)
152        x, y = A.createVecs()
153        x.setRandom()
154        self.pc.setOperators(A, A)
155        assert (A,A) == self.pc.getOperators()
156        return A, x, y
157
158    def _getCtx(self):
159        return self.pc.getPythonContext()
160
161    def _applyMeth(self, meth):
162        A, x, y = self._prepare()
163        getattr(self.pc, meth)(x,y)
164        if 'reset' not in self._getCtx().log:
165            assert self._getCtx().log['setUp'] == 1
166            assert self._getCtx().log[meth] == 1
167        else:
168            nreset = self._getCtx().log['reset']
169            nsetup = self._getCtx().log['setUp']
170            nmeth  = self._getCtx().log[meth]
171            assert (nreset == nsetup)
172            assert (nreset == nmeth)
173        if isinstance(self._getCtx().impl, MyPCNone):
174            self.assertTrue(y.equal(x))
175    def testApply(self):
176        self._applyMeth('apply')
177    def testApplySymmetricLeft(self):
178        self._applyMeth('applySymmetricLeft')
179    def testApplySymmetricRight(self):
180        self._applyMeth('applySymmetricRight')
181    def testApplyTranspose(self):
182        self._applyMeth('applyTranspose')
183    ## def testApplyRichardson(self):
184    ##     x, y = self._prepare()
185    ##     w = x.duplicate()
186    ##     tols = 0,0,0,0
187    ##     self.pc.applyRichardson(x,y,w,tols)
188    ##     assert self._getCtx().log['setUp'] == 1
189    ##     assert self._getCtx().log['applyRichardson'] == 1
190
191    ## def testView(self):
192    ##     vw = PETSc.ViewerString(100, self.pc.comm)
193    ##     self.pc.view(vw)
194    ##     s = vw.getString()
195    ##     assert 'python' in s
196    ##     module = __name__
197    ##     factory = 'self._getCtx()'
198    ##     assert '.'.join([module, factory]) in s
199
200    def testResetAndApply(self):
201        self.pc.reset()
202        self.testApply()
203        self.pc.reset()
204        self.testApply()
205        self.pc.reset()
206
207    def testKSPSolve(self):
208        A, x, y = self._prepare()
209        ksp = PETSc.KSP().create(self.pc.comm)
210        ksp.setType(PETSc.KSP.Type.PREONLY)
211        assert self.pc.getRefCount() == 1
212        ksp.setPC(self.pc)
213        assert self.pc.getRefCount() == 2
214        # normal ksp solve, twice
215        ksp.solve(x,y)
216        assert self._getCtx().log['setUp'    ] == 1
217        assert self._getCtx().log['apply'    ] == 1
218        assert self._getCtx().log['preSolve' ] == 1
219        assert self._getCtx().log['postSolve'] == 1
220        ksp.solve(x,y)
221        assert self._getCtx().log['setUp'    ] == 1
222        assert self._getCtx().log['apply'    ] == 2
223        assert self._getCtx().log['preSolve' ] == 2
224        assert self._getCtx().log['postSolve'] == 2
225        # transpose ksp solve, twice
226        ksp.solveTranspose(x,y)
227        assert self._getCtx().log['setUp'         ] == 1
228        assert self._getCtx().log['applyTranspose'] == 1
229        ksp.solveTranspose(x,y)
230        assert self._getCtx().log['setUp'         ] == 1
231        assert self._getCtx().log['applyTranspose'] == 2
232        del ksp # ksp.destroy()
233        assert self.pc.getRefCount() == 1
234
235    def testGetSetContext(self):
236        ctx = self.pc.getPythonContext()
237        self.pc.setPythonContext(ctx)
238        self.assertEqual(getrefcount(ctx), 3)
239        del ctx
240
241
242class TestPCPYTHON2(TestPCPYTHON):
243    def setUp(self):
244        OptDB = PETSc.Options(self.PC_PREFIX)
245        OptDB['impl'] = 'MyPCJacobi'
246        super(TestPCPYTHON2, self).setUp()
247        clsname = type(self._getCtx().impl).__name__
248        assert clsname == OptDB['impl']
249        del OptDB['impl']
250
251class TestPCPYTHON3(TestPCPYTHON):
252    def setUp(self):
253        pc = self.pc = PETSc.PC()
254        ctx = PC_PYTHON_CLASS()
255        pc.createPython(ctx, comm=PETSc.COMM_SELF)
256        self.pc.prefix = self.PC_PREFIX
257        self.pc.setFromOptions()
258        assert self._getCtx().log['create'] == 1
259        assert self._getCtx().log['setFromOptions'] == 1
260
261class TestPCPYTHON4(TestPCPYTHON3):
262    def setUp(self):
263        OptDB = PETSc.Options(self.PC_PREFIX)
264        OptDB['impl'] = 'MyPCJacobi'
265        super(TestPCPYTHON4, self).setUp()
266        clsname = type(self._getCtx().impl).__name__
267        assert clsname == OptDB['impl']
268        del OptDB['impl']
269
270# --------------------------------------------------------------------
271
272if __name__ == '__main__':
273    unittest.main()
274