xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision fffbe07892e3829f28d4ec4d935109d0474c54e5)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class BaseMyPC(object):
10    def setup(self, pc):
11        pass
12    def reset(self, pc):
13        pass
14    def apply(self, pc, x, y):
15        raise NotImplementedError
16    def applyT(self, pc, x, y):
17        self.apply(pc, x, y)
18    def applyS(self, pc, x, y):
19        self.apply(pc, x, y)
20    def applySL(self, pc, x, y):
21        self.applyS(pc, x, y)
22    def applySR(self, pc, x, y):
23        self.applyS(pc, x, y)
24    def applyRich(self, pc, x, y, w, tols):
25        self.apply(pc, x, y)
26    def applyM(self, pc, x, y):
27        raise NotImplementedError
28
29class MyPCNone(BaseMyPC):
30    def apply(self, pc, x, y):
31        x.copy(y)
32    def applyM(self, pc, x, y):
33        x.copy(y)
34
35class MyPCJacobi(BaseMyPC):
36    def setup(self, pc):
37        A, P = pc.getOperators()
38        self.diag = P.getDiagonal()
39        self.diag.reciprocal()
40    def reset(self, pc):
41        self.diag.destroy()
42        del self.diag
43    def apply(self, pc, x, y):
44        y.pointwiseMult(self.diag, x)
45    def applyS(self, pc, x, y):
46        self.diag.copy(y)
47        y.sqrtabs()
48        y.pointwiseMult(y, x)
49    def applyM(self, pc, x, y):
50        x.copy(y)
51        y.diagonalScale(L=self.diag)
52
53class PC_PYTHON_CLASS(object):
54
55    def __init__(self):
56        self.impl = None
57        self.log = {}
58    def _log(self, method, *args):
59        self.log.setdefault(method, 0)
60        self.log[method] += 1
61    def create(self, pc):
62        self._log('create', pc)
63    def destroy(self, pc):
64        self._log('destroy')
65        self.impl = None
66    def reset(self, pc):
67        self._log('reset', pc)
68    def view(self, pc, vw):
69        self._log('view', pc, vw)
70        assert isinstance(pc, PETSc.PC)
71        assert isinstance(vw, PETSc.Viewer)
72        pass
73    def setFromOptions(self, pc):
74        self._log('setFromOptions', pc)
75        assert isinstance(pc, PETSc.PC)
76        OptDB = PETSc.Options(pc)
77        impl =  OptDB.getString('impl','MyPCNone')
78        klass = globals()[impl]
79        self.impl = klass()
80    def setUp(self, pc):
81        self._log('setUp', pc)
82        assert isinstance(pc, PETSc.PC)
83        self.impl.setup(pc)
84    def preSolve(self, pc, ksp, b, x):
85        self._log('preSolve', pc, ksp, b, x)
86    def postSolve(self, pc, ksp, b, x):
87        self._log('postSolve', pc, ksp, b, x)
88    def apply(self, pc, x, y):
89        self._log('apply', pc, x, y)
90        assert isinstance(pc, PETSc.PC)
91        assert isinstance(x,  PETSc.Vec)
92        assert isinstance(y,  PETSc.Vec)
93        self.impl.apply(pc, x, y)
94    def applySymmetricLeft(self, pc, x, y):
95        self._log('applySymmetricLeft', pc, x, y)
96        assert isinstance(pc, PETSc.PC)
97        assert isinstance(x,  PETSc.Vec)
98        assert isinstance(y,  PETSc.Vec)
99        self.impl.applySL(pc, x, y)
100    def applySymmetricRight(self, pc, x, y):
101        self._log('applySymmetricRight', pc, x, y)
102        assert isinstance(pc, PETSc.PC)
103        assert isinstance(x,  PETSc.Vec)
104        assert isinstance(y,  PETSc.Vec)
105        self.impl.applySR(pc, x, y)
106    def applyTranspose(self, pc, x, y):
107        self._log('applyTranspose', pc, x, y)
108        assert isinstance(pc, PETSc.PC)
109        assert isinstance(x,  PETSc.Vec)
110        assert isinstance(y,  PETSc.Vec)
111        self.impl.applyT(pc, x, y)
112    def matApply(self, pc, x, y):
113        self._log('matApply', pc, x, y)
114        assert isinstance(pc, PETSc.PC)
115        assert isinstance(x,  PETSc.Mat)
116        assert isinstance(y,  PETSc.Mat)
117        self.impl.applyM(pc, x, y)
118    def applyRichardson(self, pc, x, y, w, tols):
119        self._log('applyRichardson', pc, x, y, w, tols)
120        assert isinstance(pc, PETSc.PC)
121        assert isinstance(x,  PETSc.Vec)
122        assert isinstance(y,  PETSc.Vec)
123        assert isinstance(w,  PETSc.Vec)
124        assert isinstance(tols,  tuple)
125        assert len(tols) == 4
126        self.impl.applyRich(pc, x, y, w, tols)
127
128
129class TestPCPYTHON(unittest.TestCase):
130
131    PC_TYPE = PETSc.PC.Type.PYTHON
132
133    PC_PREFIX = 'test-'
134
135    def setUp(self):
136        pc = self.pc = PETSc.PC()
137        pc.create(PETSc.COMM_SELF)
138        pc.setType(self.PC_TYPE)
139        module = __name__
140        factory = 'PC_PYTHON_CLASS'
141        self.pc.prefix = self.PC_PREFIX
142        OptDB = PETSc.Options(self.pc)
143        assert OptDB.prefix == self.pc.prefix
144        OptDB['pc_python_type'] = '%s.%s' % (module, factory)
145        self.pc.setFromOptions()
146        del OptDB['pc_python_type']
147        assert self._getCtx().log['create'] == 1
148        assert self._getCtx().log['setFromOptions'] == 1
149        ctx = self._getCtx()
150        self.assertEqual(getrefcount(ctx), 3)
151
152    def tearDown(self):
153        ctx = self._getCtx()
154        self.pc.destroy() # XXX
155        self.pc = None
156        assert ctx.log['destroy'] == 1
157        self.assertEqual(getrefcount(ctx), 2)
158
159    def _prepare(self):
160        A = PETSc.Mat().createAIJ([3,3], comm=PETSc.COMM_SELF)
161        A.setUp()
162        A.assemble()
163        A.shift(10)
164        x, y = A.createVecs()
165        x.setRandom()
166        self.pc.setOperators(A, A)
167        X = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
168        X.assemble()
169        Y = PETSc.Mat().createDense([3,5], comm=PETSc.COMM_SELF).setUp()
170        Y.assemble()
171        assert (A,A) == self.pc.getOperators()
172        return A, x, y, X, Y
173
174    def _getCtx(self):
175        return self.pc.getPythonContext()
176
177    def _applyMeth(self, meth):
178        A, x, y, X, Y = self._prepare()
179        if meth == 'matApply':
180            getattr(self.pc, meth)(X,Y)
181            x.copy(y)
182        else:
183            getattr(self.pc, meth)(x,y)
184            X.copy(Y)
185        if 'reset' not in self._getCtx().log:
186            assert self._getCtx().log['setUp'] == 1
187            assert self._getCtx().log[meth] == 1
188        else:
189            nreset = self._getCtx().log['reset']
190            nsetup = self._getCtx().log['setUp']
191            nmeth  = self._getCtx().log[meth]
192            assert (nreset == nsetup)
193            assert (nreset == nmeth)
194        if isinstance(self._getCtx().impl, MyPCNone):
195            self.assertTrue(y.equal(x))
196            self.assertTrue(Y.equal(X))
197    def testApply(self):
198        self._applyMeth('apply')
199    def testApplySymmetricLeft(self):
200        self._applyMeth('applySymmetricLeft')
201    def testApplySymmetricRight(self):
202        self._applyMeth('applySymmetricRight')
203    def testApplyTranspose(self):
204        self._applyMeth('applyTranspose')
205    def testApplyMat(self):
206        self._applyMeth('matApply')
207    ## def testApplyRichardson(self):
208    ##     x, y = self._prepare()
209    ##     w = x.duplicate()
210    ##     tols = 0,0,0,0
211    ##     self.pc.applyRichardson(x,y,w,tols)
212    ##     assert self._getCtx().log['setUp'] == 1
213    ##     assert self._getCtx().log['applyRichardson'] == 1
214
215    ## def testView(self):
216    ##     vw = PETSc.ViewerString(100, self.pc.comm)
217    ##     self.pc.view(vw)
218    ##     s = vw.getString()
219    ##     assert 'python' in s
220    ##     module = __name__
221    ##     factory = 'self._getCtx()'
222    ##     assert '.'.join([module, factory]) in s
223
224    def testResetAndApply(self):
225        self.pc.reset()
226        self.testApply()
227        self.pc.reset()
228        self.testApply()
229        self.pc.reset()
230
231    def testKSPSolve(self):
232        A, x, y, _, _ = self._prepare()
233        ksp = PETSc.KSP().create(self.pc.comm)
234        ksp.setType(PETSc.KSP.Type.PREONLY)
235        assert self.pc.getRefCount() == 1
236        ksp.setPC(self.pc)
237        assert self.pc.getRefCount() == 2
238        # normal ksp solve, twice
239        ksp.solve(x,y)
240        assert self._getCtx().log['setUp'    ] == 1
241        assert self._getCtx().log['apply'    ] == 1
242        assert self._getCtx().log['preSolve' ] == 1
243        assert self._getCtx().log['postSolve'] == 1
244        ksp.solve(x,y)
245        assert self._getCtx().log['setUp'    ] == 1
246        assert self._getCtx().log['apply'    ] == 2
247        assert self._getCtx().log['preSolve' ] == 2
248        assert self._getCtx().log['postSolve'] == 2
249        # transpose ksp solve, twice
250        ksp.solveTranspose(x,y)
251        assert self._getCtx().log['setUp'         ] == 1
252        assert self._getCtx().log['applyTranspose'] == 1
253        ksp.solveTranspose(x,y)
254        assert self._getCtx().log['setUp'         ] == 1
255        assert self._getCtx().log['applyTranspose'] == 2
256        del ksp # ksp.destroy()
257        assert self.pc.getRefCount() == 1
258
259    def testGetSetContext(self):
260        ctx = self.pc.getPythonContext()
261        self.pc.setPythonContext(ctx)
262        self.assertEqual(getrefcount(ctx), 3)
263        del ctx
264
265
266class TestPCPYTHON2(TestPCPYTHON):
267    def setUp(self):
268        OptDB = PETSc.Options(self.PC_PREFIX)
269        OptDB['impl'] = 'MyPCJacobi'
270        super(TestPCPYTHON2, self).setUp()
271        clsname = type(self._getCtx().impl).__name__
272        assert clsname == OptDB['impl']
273        del OptDB['impl']
274
275class TestPCPYTHON3(TestPCPYTHON):
276    def setUp(self):
277        pc = self.pc = PETSc.PC()
278        ctx = PC_PYTHON_CLASS()
279        pc.createPython(ctx, comm=PETSc.COMM_SELF)
280        self.pc.prefix = self.PC_PREFIX
281        self.pc.setFromOptions()
282        assert self._getCtx().log['create'] == 1
283        assert self._getCtx().log['setFromOptions'] == 1
284
285class TestPCPYTHON4(TestPCPYTHON3):
286    def setUp(self):
287        OptDB = PETSc.Options(self.PC_PREFIX)
288        OptDB['impl'] = 'MyPCJacobi'
289        super(TestPCPYTHON4, self).setUp()
290        clsname = type(self._getCtx().impl).__name__
291        assert clsname == OptDB['impl']
292        del OptDB['impl']
293
294# --------------------------------------------------------------------
295
296if __name__ == '__main__':
297    unittest.main()
298