xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseMyPC(object):
10    def setup(self, pc):
11        pass
12    def reset(self, pc):
13        pass
14    def apply(self, pc, x, y):
15        raise NotImplementedError
16    def applyT(self, pc, x, y):
17        self.apply(pc, x, y)
18    def applyS(self, pc, x, y):
19        self.apply(pc, x, y)
20    def applySL(self, pc, x, y):
21        self.applyS(pc, x, y)
22    def applySR(self, pc, x, y):
23        self.applyS(pc, x, y)
24    def applyRich(self, pc, x, y, w, tols):
25        self.apply(pc, x, y)
26    def applyM(self, pc, x, y):
27        raise NotImplementedError
28
29class MyPCNone(BaseMyPC):
30    def apply(self, pc, x, y):
31        x.copy(y)
32    def applyM(self, pc, x, y):
33        x.copy(y)
34
35class MyPCJacobi(BaseMyPC):
36    def setup(self, pc):
37        A, P = pc.getOperators()
38        self.diag = P.getDiagonal()
39        self.diag.reciprocal()
40    def reset(self, pc):
41        self.diag.destroy()
42        del self.diag
43    def apply(self, pc, x, y):
44        y.pointwiseMult(self.diag, x)
45    def applyS(self, pc, x, y):
46        self.diag.copy(y)
47        y.sqrtabs()
48        y.pointwiseMult(y, x)
49    def applyM(self, pc, x, y):
50        x.copy(y)
51        y.diagonalScale(L=self.diag)
52
53class PC_PYTHON_CLASS(object):
54
55    def __init__(self):
56        self.impl = None
57        self.log = {}
58    def _log(self, method, *args):
59        self.log.setdefault(method, 0)
60        self.log[method] += 1
61    def create(self, pc):
62        self._log('create', pc)
63    def destroy(self, pc):
64        self._log('destroy')
65        self.impl = None
66    def reset(self, pc):
67        self._log('reset', pc)
68    def view(self, pc, vw):
69        self._log('view', pc, vw)
70        assert isinstance(pc, PETSc.PC)
71        assert isinstance(vw, PETSc.Viewer)
72        pass
73    def setFromOptions(self, pc):
74        self._log('setFromOptions', pc)
75        assert isinstance(pc, PETSc.PC)
76        OptDB = PETSc.Options(pc)
77        impl =  OptDB.getString('impl','MyPCNone')
78        klass = globals()[impl]
79        self.impl = klass()
80    def setUp(self, pc):
81        self._log('setUp', pc)
82        assert isinstance(pc, PETSc.PC)
83        self.impl.setup(pc)
84    def preSolve(self, pc, ksp, b, x):
85        self._log('preSolve', pc, ksp, b, x)
86    def postSolve(self, pc, ksp, b, x):
87        self._log('postSolve', pc, ksp, b, x)
88    def apply(self, pc, x, y):
89        self._log('apply', pc, x, y)
90        assert isinstance(pc, PETSc.PC)
91        assert isinstance(x,  PETSc.Vec)
92        assert isinstance(y,  PETSc.Vec)
93        self.impl.apply(pc, x, y)
94    def applySymmetricLeft(self, pc, x, y):
95        self._log('applySymmetricLeft', pc, x, y)
96        assert isinstance(pc, PETSc.PC)
97        assert isinstance(x,  PETSc.Vec)
98        assert isinstance(y,  PETSc.Vec)
99        self.impl.applySL(pc, x, y)
100    def applySymmetricRight(self, pc, x, y):
101        self._log('applySymmetricRight', pc, x, y)
102        assert isinstance(pc, PETSc.PC)
103        assert isinstance(x,  PETSc.Vec)
104        assert isinstance(y,  PETSc.Vec)
105        self.impl.applySR(pc, x, y)
106    def applyTranspose(self, pc, x, y):
107        self._log('applyTranspose', pc, x, y)
108        assert isinstance(pc, PETSc.PC)
109        assert isinstance(x,  PETSc.Vec)
110        assert isinstance(y,  PETSc.Vec)
111        self.impl.applyT(pc, x, y)
112    def matApply(self, pc, x, y):
113        self._log('matApply', pc, x, y)
114        assert isinstance(pc, PETSc.PC)
115        assert isinstance(x,  PETSc.Mat)
116        assert isinstance(y,  PETSc.Mat)
117        self.impl.applyM(pc, x, y)
118    def applyRichardson(self, pc, x, y, w, tols):
119        self._log('applyRichardson', pc, x, y, w, tols)
120        assert isinstance(pc, PETSc.PC)
121        assert isinstance(x,  PETSc.Vec)
122        assert isinstance(y,  PETSc.Vec)
123        assert isinstance(w,  PETSc.Vec)
124        assert isinstance(tols,  tuple)
125        assert len(tols) == 4
126        self.impl.applyRich(pc, x, y, w, tols)
127
128
129class TestPCPYTHON(unittest.TestCase):
130
131    PC_TYPE = PETSc.PC.Type.PYTHON
132    PC_PREFIX = 'test-'
133
134    def setUp(self):
135        pc = self.pc = PETSc.PC()
136        pc.create(PETSc.COMM_SELF)
137        pc.setType(self.PC_TYPE)
138        module = __name__
139        factory = 'PC_PYTHON_CLASS'
140        self.pc.prefix = self.PC_PREFIX
141        OptDB = PETSc.Options(self.pc)
142        assert OptDB.prefix == self.pc.prefix
143        OptDB['pc_python_type'] = '%s.%s' % (module, factory)
144        self.pc.setFromOptions()
145        del OptDB['pc_python_type']
146        assert self._getCtx().log['create'] == 1
147        assert self._getCtx().log['setFromOptions'] == 1
148        ctx = self._getCtx()
149        self.assertEqual(getrefcount(ctx), 3)
150
151    def testGetType(self):
152        ctx = self.pc.getPythonContext()
153        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
154        self.assertTrue(self.pc.getPythonType() == pytype)
155
156    def tearDown(self):
157        ctx = self._getCtx()
158        self.pc.destroy() # XXX
159        self.pc = None
160        PETSc.garbage_cleanup()
161        assert ctx.log['destroy'] == 1
162        self.assertEqual(getrefcount(ctx), 2)
163
164    def _prepare(self):
165        A = PETSc.Mat().createAIJ([3,3], comm=PETSc.COMM_SELF)
166        A.setUp()
167        A.assemble()
168        A.shift(10)
169        x, y = A.createVecs()
170        x.setRandom()
171        self.pc.setOperators(A, A)
172        X = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
173        X.assemble()
174        Y = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
175        Y.assemble()
176        assert (A,A) == self.pc.getOperators()
177        return A, x, y, X, Y
178
179    def _getCtx(self):
180        return self.pc.getPythonContext()
181
182    def _applyMeth(self, meth):
183        A, x, y, X, Y = self._prepare()
184        if meth == 'matApply':
185            getattr(self.pc, meth)(X,Y)
186            x.copy(y)
187        else:
188            getattr(self.pc, meth)(x,y)
189            X.copy(Y)
190        if 'reset' not in self._getCtx().log:
191            assert self._getCtx().log['setUp'] == 1
192            assert self._getCtx().log[meth] == 1
193        else:
194            nreset = self._getCtx().log['reset']
195            nsetup = self._getCtx().log['setUp']
196            nmeth  = self._getCtx().log[meth]
197            assert (nreset == nsetup)
198            assert (nreset == nmeth)
199        if isinstance(self._getCtx().impl, MyPCNone):
200            self.assertTrue(y.equal(x))
201            self.assertTrue(Y.equal(X))
202    def testApply(self):
203        self._applyMeth('apply')
204    def testApplySymmetricLeft(self):
205        self._applyMeth('applySymmetricLeft')
206    def testApplySymmetricRight(self):
207        self._applyMeth('applySymmetricRight')
208    def testApplyTranspose(self):
209        self._applyMeth('applyTranspose')
210    def testApplyMat(self):
211        self._applyMeth('matApply')
212    ## def testApplyRichardson(self):
213    ##     x, y = self._prepare()
214    ##     w = x.duplicate()
215    ##     tols = 0,0,0,0
216    ##     self.pc.applyRichardson(x,y,w,tols)
217    ##     assert self._getCtx().log['setUp'] == 1
218    ##     assert self._getCtx().log['applyRichardson'] == 1
219
220    ## def testView(self):
221    ##     vw = PETSc.ViewerString(100, self.pc.comm)
222    ##     self.pc.view(vw)
223    ##     s = vw.getString()
224    ##     assert 'python' in s
225    ##     module = __name__
226    ##     factory = 'self._getCtx()'
227    ##     assert '.'.join([module, factory]) in s
228
229    def testResetAndApply(self):
230        self.pc.reset()
231        self.testApply()
232        self.pc.reset()
233        self.testApply()
234        self.pc.reset()
235
236    def testKSPSolve(self):
237        A, x, y, _, _ = self._prepare()
238        ksp = PETSc.KSP().create(self.pc.comm)
239        ksp.setType(PETSc.KSP.Type.PREONLY)
240        assert self.pc.getRefCount() == 1
241        ksp.setPC(self.pc)
242        assert self.pc.getRefCount() == 2
243        # normal ksp solve, twice
244        ksp.solve(x,y)
245        assert self._getCtx().log['setUp'    ] == 1
246        assert self._getCtx().log['apply'    ] == 1
247        assert self._getCtx().log['preSolve' ] == 1
248        assert self._getCtx().log['postSolve'] == 1
249        ksp.solve(x,y)
250        assert self._getCtx().log['setUp'    ] == 1
251        assert self._getCtx().log['apply'    ] == 2
252        assert self._getCtx().log['preSolve' ] == 2
253        assert self._getCtx().log['postSolve'] == 2
254        # transpose ksp solve, twice
255        ksp.solveTranspose(x,y)
256        assert self._getCtx().log['setUp'         ] == 1
257        assert self._getCtx().log['applyTranspose'] == 1
258        ksp.solveTranspose(x,y)
259        assert self._getCtx().log['setUp'         ] == 1
260        assert self._getCtx().log['applyTranspose'] == 2
261        del ksp # ksp.destroy()
262        PETSc.garbage_cleanup()
263        self.assertEqual(self.pc.getRefCount(), 1)
264
265    def testGetSetContext(self):
266        ctx = self.pc.getPythonContext()
267        self.pc.setPythonContext(ctx)
268        self.assertEqual(getrefcount(ctx), 3)
269        del ctx
270
271
272class TestPCPYTHON2(TestPCPYTHON):
273    def setUp(self):
274        OptDB = PETSc.Options(self.PC_PREFIX)
275        OptDB['impl'] = 'MyPCJacobi'
276        super(TestPCPYTHON2, self).setUp()
277        clsname = type(self._getCtx().impl).__name__
278        assert clsname == OptDB['impl']
279        del OptDB['impl']
280
281class TestPCPYTHON3(TestPCPYTHON):
282    def setUp(self):
283        pc = self.pc = PETSc.PC()
284        ctx = PC_PYTHON_CLASS()
285        pc.createPython(ctx, comm=PETSc.COMM_SELF)
286        self.pc.prefix = self.PC_PREFIX
287        self.pc.setFromOptions()
288        assert self._getCtx().log['create'] == 1
289        assert self._getCtx().log['setFromOptions'] == 1
290
291class TestPCPYTHON4(TestPCPYTHON3):
292    def setUp(self):
293        OptDB = PETSc.Options(self.PC_PREFIX)
294        OptDB['impl'] = 'MyPCJacobi'
295        super(TestPCPYTHON4, self).setUp()
296        clsname = type(self._getCtx().impl).__name__
297        assert clsname == OptDB['impl']
298        del OptDB['impl']
299
300# --------------------------------------------------------------------
301
302if __name__ == '__main__':
303    unittest.main()
304