xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision 9e48eb9f905fbc01704defd592b7e78e540d2a26)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class BaseMyPC:
11    def setup(self, pc):
12        pass
13
14    def reset(self, pc):
15        pass
16
17    def apply(self, pc, x, y):
18        raise NotImplementedError
19
20    def applyT(self, pc, x, y):
21        self.apply(pc, x, y)
22
23    def applyS(self, pc, x, y):
24        self.apply(pc, x, y)
25
26    def applySL(self, pc, x, y):
27        self.applyS(pc, x, y)
28
29    def applySR(self, pc, x, y):
30        self.applyS(pc, x, y)
31
32    def applyRich(self, pc, x, y, w, tols):
33        self.apply(pc, x, y)
34
35    def applyM(self, pc, x, y):
36        raise NotImplementedError
37
38
39class MyPCNone(BaseMyPC):
40    def apply(self, pc, x, y):
41        x.copy(y)
42
43    def applyM(self, pc, x, y):
44        x.copy(y)
45
46
47class MyPCJacobi(BaseMyPC):
48    def setup(self, pc):
49        A, P = pc.getOperators()
50        self.diag = P.getDiagonal()
51        self.diag.reciprocal()
52
53    def reset(self, pc):
54        self.diag.destroy()
55        del self.diag
56
57    def apply(self, pc, x, y):
58        y.pointwiseMult(self.diag, x)
59
60    def applyS(self, pc, x, y):
61        self.diag.copy(y)
62        y.sqrtabs()
63        y.pointwiseMult(y, x)
64
65    def applyM(self, pc, x, y):
66        x.copy(y)
67        y.diagonalScale(L=self.diag)
68
69
70class PC_PYTHON_CLASS:
71    def __init__(self):
72        self.impl = None
73        self.log = {}
74
75    def _log(self, method, *args):
76        self.log.setdefault(method, 0)
77        self.log[method] += 1
78
79    def create(self, pc):
80        self._log('create', pc)
81
82    def destroy(self, pc):
83        self._log('destroy')
84        self.impl = None
85
86    def reset(self, pc):
87        self._log('reset', pc)
88
89    def view(self, pc, vw):
90        self._log('view', pc, vw)
91
92    def setFromOptions(self, pc):
93        self._log('setFromOptions', pc)
94        OptDB = PETSc.Options(pc)
95        impl = OptDB.getString('impl', 'MyPCNone')
96        klass = globals()[impl]
97        self.impl = klass()
98
99    def setUp(self, pc):
100        self._log('setUp', pc)
101        self.impl.setup(pc)
102
103    def preSolve(self, pc, ksp, b, x):
104        self._log('preSolve', pc, ksp, b, x)
105
106    def postSolve(self, pc, ksp, b, x):
107        self._log('postSolve', pc, ksp, b, x)
108
109    def apply(self, pc, x, y):
110        self._log('apply', pc, x, y)
111        self.impl.apply(pc, x, y)
112
113    def applySymmetricLeft(self, pc, x, y):
114        self._log('applySymmetricLeft', pc, x, y)
115        self.impl.applySL(pc, x, y)
116
117    def applySymmetricRight(self, pc, x, y):
118        self._log('applySymmetricRight', pc, x, y)
119        self.impl.applySR(pc, x, y)
120
121    def applyTranspose(self, pc, x, y):
122        self._log('applyTranspose', pc, x, y)
123        self.impl.applyT(pc, x, y)
124
125    def matApply(self, pc, x, y):
126        self._log('matApply', pc, x, y)
127        self.impl.applyM(pc, x, y)
128
129    def applyRichardson(self, pc, x, y, w, tols):
130        self._log('applyRichardson', pc, x, y, w, tols)
131        self.impl.applyRich(pc, x, y, w, tols)
132
133
134class TestPCPYTHON(unittest.TestCase):
135    PC_TYPE = PETSc.PC.Type.PYTHON
136    PC_PREFIX = 'test-'
137
138    def setUp(self):
139        pc = self.pc = PETSc.PC()
140        pc.create(PETSc.COMM_SELF)
141        pc.setType(self.PC_TYPE)
142        module = __name__
143        factory = 'PC_PYTHON_CLASS'
144        self.pc.prefix = self.PC_PREFIX
145        OptDB = PETSc.Options(self.pc)
146        self.assertTrue(OptDB.prefix == self.pc.prefix)
147        OptDB['pc_python_type'] = f'{module}.{factory}'
148        self.pc.setFromOptions()
149        del OptDB['pc_python_type']
150        self.assertTrue(self._getCtx().log['create'] == 1)
151        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
152        ctx = self._getCtx()
153        self.assertEqual(getrefcount(ctx), 3)
154
155    def testGetType(self):
156        ctx = self.pc.getPythonContext()
157        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
158        self.assertTrue(self.pc.getPythonType() == pytype)
159
160    def tearDown(self):
161        ctx = self._getCtx()
162        self.pc.destroy()  # XXX
163        self.pc = None
164        PETSc.garbage_cleanup()
165        self.assertTrue(ctx.log['destroy'] == 1)
166        self.assertEqual(getrefcount(ctx), 2)
167
168    def _prepare(self):
169        A = PETSc.Mat().createAIJ([3, 3], comm=PETSc.COMM_SELF)
170        A.setUp()
171        A.assemble()
172        A.shift(10)
173        x, y = A.createVecs()
174        x.setRandom()
175        self.pc.setOperators(A, A)
176        X = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
177        X.assemble()
178        Y = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
179        Y.assemble()
180        self.assertTrue((A, A) == self.pc.getOperators())
181        return A, x, y, X, Y
182
183    def _getCtx(self):
184        return self.pc.getPythonContext()
185
186    def _applyMeth(self, meth):
187        A, x, y, X, Y = self._prepare()
188        if meth == 'matApply':
189            getattr(self.pc, meth)(X, Y)
190            x.copy(y)
191        else:
192            getattr(self.pc, meth)(x, y)
193            X.copy(Y)
194        if 'reset' not in self._getCtx().log:
195            self.assertTrue(self._getCtx().log['setUp'] == 1)
196            self.assertTrue(self._getCtx().log[meth] == 1)
197        else:
198            nreset = self._getCtx().log['reset']
199            nsetup = self._getCtx().log['setUp']
200            nmeth = self._getCtx().log[meth]
201            self.assertTrue(nreset == nsetup)
202            self.assertTrue(nreset == nmeth)
203        if isinstance(self._getCtx().impl, MyPCNone):
204            self.assertTrue(y.equal(x))
205            self.assertTrue(Y.equal(X))
206
207    def testApply(self):
208        self._applyMeth('apply')
209
210    def testApplySymmetricLeft(self):
211        self._applyMeth('applySymmetricLeft')
212
213    def testApplySymmetricRight(self):
214        self._applyMeth('applySymmetricRight')
215
216    def testApplyTranspose(self):
217        self._applyMeth('applyTranspose')
218
219    def testApplyMat(self):
220        self._applyMeth('matApply')
221
222    ## def testApplyRichardson(self):
223    ##     x, y = self._prepare()
224    ##     w = x.duplicate()
225    ##     tols = 0,0,0,0
226    ##     self.pc.applyRichardson(x,y,w,tols)
227    ##     assert self._getCtx().log['setUp'] == 1
228    ##     assert self._getCtx().log['applyRichardson'] == 1
229
230    ## def testView(self):
231    ##     vw = PETSc.ViewerString(100, self.pc.comm)
232    ##     self.pc.view(vw)
233    ##     s = vw.getString()
234    ##     assert 'python' in s
235    ##     module = __name__
236    ##     factory = 'self._getCtx()'
237    ##     assert '.'.join([module, factory]) in s
238
239    def testResetAndApply(self):
240        self.pc.reset()
241        self.testApply()
242        self.pc.reset()
243        self.testApply()
244        self.pc.reset()
245
246    def testKSPSolve(self):
247        A, x, y, _, _ = self._prepare()
248        ksp = PETSc.KSP().create(self.pc.comm)
249        ksp.setType(PETSc.KSP.Type.PREONLY)
250        self.assertTrue(self.pc.getRefCount() == 1)
251        ksp.setPC(self.pc)
252        self.assertTrue(self.pc.getRefCount() == 2)
253        # normal ksp solve, twice
254        ksp.solve(x, y)
255        self.assertTrue(self._getCtx().log['setUp'] == 1)
256        self.assertTrue(self._getCtx().log['apply'] == 1)
257        self.assertTrue(self._getCtx().log['preSolve'] == 1)
258        self.assertTrue(self._getCtx().log['postSolve'] == 1)
259        ksp.solve(x, y)
260        self.assertTrue(self._getCtx().log['setUp'] == 1)
261        self.assertTrue(self._getCtx().log['apply'] == 2)
262        self.assertTrue(self._getCtx().log['preSolve'] == 2)
263        self.assertTrue(self._getCtx().log['postSolve'] == 2)
264        # transpose ksp solve, twice
265        ksp.solveTranspose(x, y)
266        self.assertTrue(self._getCtx().log['setUp'] == 1)
267        self.assertTrue(self._getCtx().log['applyTranspose'] == 1)
268        ksp.solveTranspose(x, y)
269        self.assertTrue(self._getCtx().log['setUp'] == 1)
270        self.assertTrue(self._getCtx().log['applyTranspose'] == 2)
271        del ksp  # ksp.destroy()
272        PETSc.garbage_cleanup()
273        self.assertEqual(self.pc.getRefCount(), 1)
274
275    def testGetSetContext(self):
276        ctx = self.pc.getPythonContext()
277        self.pc.setPythonContext(ctx)
278        self.assertEqual(getrefcount(ctx), 3)
279        del ctx
280
281
282class TestPCPYTHON2(TestPCPYTHON):
283    def setUp(self):
284        OptDB = PETSc.Options(self.PC_PREFIX)
285        OptDB['impl'] = 'MyPCJacobi'
286        super().setUp()
287        clsname = type(self._getCtx().impl).__name__
288        self.assertTrue(clsname == OptDB['impl'])
289        del OptDB['impl']
290
291
292class TestPCPYTHON3(TestPCPYTHON):
293    def setUp(self):
294        pc = self.pc = PETSc.PC()
295        ctx = PC_PYTHON_CLASS()
296        pc.createPython(ctx, comm=PETSc.COMM_SELF)
297        self.pc.prefix = self.PC_PREFIX
298        self.pc.setFromOptions()
299        self.assertTrue(self._getCtx().log['create'] == 1)
300        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
301
302
303class TestPCPYTHON4(TestPCPYTHON3):
304    def setUp(self):
305        OptDB = PETSc.Options(self.PC_PREFIX)
306        OptDB['impl'] = 'MyPCJacobi'
307        super().setUp()
308        clsname = type(self._getCtx().impl).__name__
309        self.assertTrue(clsname == OptDB['impl'])
310        del OptDB['impl']
311
312
313# --------------------------------------------------------------------
314
315if __name__ == '__main__':
316    unittest.main()
317