xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision 11486bccf1aeea1ca5536228f99d437b39bdaca6)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseMyPC(object):
10    def setup(self, pc):
11        pass
12    def reset(self, pc):
13        pass
14    def apply(self, pc, x, y):
15        raise NotImplementedError
16    def applyT(self, pc, x, y):
17        self.apply(pc, x, y)
18    def applyS(self, pc, x, y):
19        self.apply(pc, x, y)
20    def applySL(self, pc, x, y):
21        self.applyS(pc, x, y)
22    def applySR(self, pc, x, y):
23        self.applyS(pc, x, y)
24    def applyRich(self, pc, x, y, w, tols):
25        self.apply(pc, x, y)
26    def applyM(self, pc, x, y):
27        raise NotImplementedError
28
29class MyPCNone(BaseMyPC):
30    def apply(self, pc, x, y):
31        x.copy(y)
32    def applyM(self, pc, x, y):
33        x.copy(y)
34
35class MyPCJacobi(BaseMyPC):
36    def setup(self, pc):
37        A, P = pc.getOperators()
38        self.diag = P.getDiagonal()
39        self.diag.reciprocal()
40    def reset(self, pc):
41        self.diag.destroy()
42        del self.diag
43    def apply(self, pc, x, y):
44        y.pointwiseMult(self.diag, x)
45    def applyS(self, pc, x, y):
46        self.diag.copy(y)
47        y.sqrtabs()
48        y.pointwiseMult(y, x)
49    def applyM(self, pc, x, y):
50        x.copy(y)
51        y.diagonalScale(L=self.diag)
52
53class PC_PYTHON_CLASS(object):
54
55    def __init__(self):
56        self.impl = None
57        self.log = {}
58    def _log(self, method, *args):
59        self.log.setdefault(method, 0)
60        self.log[method] += 1
61    def create(self, pc):
62        self._log('create', pc)
63    def destroy(self, pc):
64        self._log('destroy')
65        self.impl = None
66    def reset(self, pc):
67        self._log('reset', pc)
68    def view(self, pc, vw):
69        self._log('view', pc, vw)
70        assert isinstance(pc, PETSc.PC)
71        assert isinstance(vw, PETSc.Viewer)
72        pass
73    def setFromOptions(self, pc):
74        self._log('setFromOptions', pc)
75        assert isinstance(pc, PETSc.PC)
76        OptDB = PETSc.Options(pc)
77        impl =  OptDB.getString('impl','MyPCNone')
78        klass = globals()[impl]
79        self.impl = klass()
80    def setUp(self, pc):
81        self._log('setUp', pc)
82        assert isinstance(pc, PETSc.PC)
83        self.impl.setup(pc)
84    def preSolve(self, pc, ksp, b, x):
85        self._log('preSolve', pc, ksp, b, x)
86    def postSolve(self, pc, ksp, b, x):
87        self._log('postSolve', pc, ksp, b, x)
88    def apply(self, pc, x, y):
89        self._log('apply', pc, x, y)
90        assert isinstance(pc, PETSc.PC)
91        assert isinstance(x,  PETSc.Vec)
92        assert isinstance(y,  PETSc.Vec)
93        self.impl.apply(pc, x, y)
94    def applySymmetricLeft(self, pc, x, y):
95        self._log('applySymmetricLeft', pc, x, y)
96        assert isinstance(pc, PETSc.PC)
97        assert isinstance(x,  PETSc.Vec)
98        assert isinstance(y,  PETSc.Vec)
99        self.impl.applySL(pc, x, y)
100    def applySymmetricRight(self, pc, x, y):
101        self._log('applySymmetricRight', pc, x, y)
102        assert isinstance(pc, PETSc.PC)
103        assert isinstance(x,  PETSc.Vec)
104        assert isinstance(y,  PETSc.Vec)
105        self.impl.applySR(pc, x, y)
106    def applyTranspose(self, pc, x, y):
107        self._log('applyTranspose', pc, x, y)
108        assert isinstance(pc, PETSc.PC)
109        assert isinstance(x,  PETSc.Vec)
110        assert isinstance(y,  PETSc.Vec)
111        self.impl.applyT(pc, x, y)
112    def matApply(self, pc, x, y):
113        self._log('matApply', pc, x, y)
114        assert isinstance(pc, PETSc.PC)
115        assert isinstance(x,  PETSc.Mat)
116        assert isinstance(y,  PETSc.Mat)
117        self.impl.applyM(pc, x, y)
118    def applyRichardson(self, pc, x, y, w, tols):
119        self._log('applyRichardson', pc, x, y, w, tols)
120        assert isinstance(pc, PETSc.PC)
121        assert isinstance(x,  PETSc.Vec)
122        assert isinstance(y,  PETSc.Vec)
123        assert isinstance(w,  PETSc.Vec)
124        assert isinstance(tols,  tuple)
125        assert len(tols) == 4
126        self.impl.applyRich(pc, x, y, w, tols)
127
128
129class TestPCPYTHON(unittest.TestCase):
130
131    PC_TYPE = PETSc.PC.Type.PYTHON
132    PC_PREFIX = 'test-'
133
134    def setUp(self):
135        pc = self.pc = PETSc.PC()
136        pc.create(PETSc.COMM_SELF)
137        pc.setType(self.PC_TYPE)
138        module = __name__
139        factory = 'PC_PYTHON_CLASS'
140        self.pc.prefix = self.PC_PREFIX
141        OptDB = PETSc.Options(self.pc)
142        assert OptDB.prefix == self.pc.prefix
143        OptDB['pc_python_type'] = '%s.%s' % (module, factory)
144        self.pc.setFromOptions()
145        del OptDB['pc_python_type']
146        assert self._getCtx().log['create'] == 1
147        assert self._getCtx().log['setFromOptions'] == 1
148        ctx = self._getCtx()
149        self.assertEqual(getrefcount(ctx), 3)
150
151    def testGetType(self):
152        ctx = self.pc.getPythonContext()
153        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
154        self.assertTrue(self.pc.getPythonType() == pytype)
155
156    def tearDown(self):
157        ctx = self._getCtx()
158        self.pc.destroy() # XXX
159        self.pc = None
160        assert ctx.log['destroy'] == 1
161        self.assertEqual(getrefcount(ctx), 2)
162
163    def _prepare(self):
164        A = PETSc.Mat().createAIJ([3,3], comm=PETSc.COMM_SELF)
165        A.setUp()
166        A.assemble()
167        A.shift(10)
168        x, y = A.createVecs()
169        x.setRandom()
170        self.pc.setOperators(A, A)
171        X = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
172        X.assemble()
173        Y = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
174        Y.assemble()
175        assert (A,A) == self.pc.getOperators()
176        return A, x, y, X, Y
177
178    def _getCtx(self):
179        return self.pc.getPythonContext()
180
181    def _applyMeth(self, meth):
182        A, x, y, X, Y = self._prepare()
183        if meth == 'matApply':
184            getattr(self.pc, meth)(X,Y)
185            x.copy(y)
186        else:
187            getattr(self.pc, meth)(x,y)
188            X.copy(Y)
189        if 'reset' not in self._getCtx().log:
190            assert self._getCtx().log['setUp'] == 1
191            assert self._getCtx().log[meth] == 1
192        else:
193            nreset = self._getCtx().log['reset']
194            nsetup = self._getCtx().log['setUp']
195            nmeth  = self._getCtx().log[meth]
196            assert (nreset == nsetup)
197            assert (nreset == nmeth)
198        if isinstance(self._getCtx().impl, MyPCNone):
199            self.assertTrue(y.equal(x))
200            self.assertTrue(Y.equal(X))
201    def testApply(self):
202        self._applyMeth('apply')
203    def testApplySymmetricLeft(self):
204        self._applyMeth('applySymmetricLeft')
205    def testApplySymmetricRight(self):
206        self._applyMeth('applySymmetricRight')
207    def testApplyTranspose(self):
208        self._applyMeth('applyTranspose')
209    def testApplyMat(self):
210        self._applyMeth('matApply')
211    ## def testApplyRichardson(self):
212    ##     x, y = self._prepare()
213    ##     w = x.duplicate()
214    ##     tols = 0,0,0,0
215    ##     self.pc.applyRichardson(x,y,w,tols)
216    ##     assert self._getCtx().log['setUp'] == 1
217    ##     assert self._getCtx().log['applyRichardson'] == 1
218
219    ## def testView(self):
220    ##     vw = PETSc.ViewerString(100, self.pc.comm)
221    ##     self.pc.view(vw)
222    ##     s = vw.getString()
223    ##     assert 'python' in s
224    ##     module = __name__
225    ##     factory = 'self._getCtx()'
226    ##     assert '.'.join([module, factory]) in s
227
228    def testResetAndApply(self):
229        self.pc.reset()
230        self.testApply()
231        self.pc.reset()
232        self.testApply()
233        self.pc.reset()
234
235    def testKSPSolve(self):
236        A, x, y, _, _ = self._prepare()
237        ksp = PETSc.KSP().create(self.pc.comm)
238        ksp.setType(PETSc.KSP.Type.PREONLY)
239        assert self.pc.getRefCount() == 1
240        ksp.setPC(self.pc)
241        assert self.pc.getRefCount() == 2
242        # normal ksp solve, twice
243        ksp.solve(x,y)
244        assert self._getCtx().log['setUp'    ] == 1
245        assert self._getCtx().log['apply'    ] == 1
246        assert self._getCtx().log['preSolve' ] == 1
247        assert self._getCtx().log['postSolve'] == 1
248        ksp.solve(x,y)
249        assert self._getCtx().log['setUp'    ] == 1
250        assert self._getCtx().log['apply'    ] == 2
251        assert self._getCtx().log['preSolve' ] == 2
252        assert self._getCtx().log['postSolve'] == 2
253        # transpose ksp solve, twice
254        ksp.solveTranspose(x,y)
255        assert self._getCtx().log['setUp'         ] == 1
256        assert self._getCtx().log['applyTranspose'] == 1
257        ksp.solveTranspose(x,y)
258        assert self._getCtx().log['setUp'         ] == 1
259        assert self._getCtx().log['applyTranspose'] == 2
260        del ksp # ksp.destroy()
261        assert self.pc.getRefCount() == 1
262
263    def testGetSetContext(self):
264        ctx = self.pc.getPythonContext()
265        self.pc.setPythonContext(ctx)
266        self.assertEqual(getrefcount(ctx), 3)
267        del ctx
268
269
270class TestPCPYTHON2(TestPCPYTHON):
271    def setUp(self):
272        OptDB = PETSc.Options(self.PC_PREFIX)
273        OptDB['impl'] = 'MyPCJacobi'
274        super(TestPCPYTHON2, self).setUp()
275        clsname = type(self._getCtx().impl).__name__
276        assert clsname == OptDB['impl']
277        del OptDB['impl']
278
279class TestPCPYTHON3(TestPCPYTHON):
280    def setUp(self):
281        pc = self.pc = PETSc.PC()
282        ctx = PC_PYTHON_CLASS()
283        pc.createPython(ctx, comm=PETSc.COMM_SELF)
284        self.pc.prefix = self.PC_PREFIX
285        self.pc.setFromOptions()
286        assert self._getCtx().log['create'] == 1
287        assert self._getCtx().log['setFromOptions'] == 1
288
289class TestPCPYTHON4(TestPCPYTHON3):
290    def setUp(self):
291        OptDB = PETSc.Options(self.PC_PREFIX)
292        OptDB['impl'] = 'MyPCJacobi'
293        super(TestPCPYTHON4, self).setUp()
294        clsname = type(self._getCtx().impl).__name__
295        assert clsname == OptDB['impl']
296        del OptDB['impl']
297
298# --------------------------------------------------------------------
299
300if __name__ == '__main__':
301    unittest.main()
302