xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision b69d2765e9c8cde17308f61ee677dcd992b8a9cf)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9
10class BaseMyPC:
11    def setup(self, pc):
12        pass
13
14    def reset(self, pc):
15        pass
16
17    def apply(self, pc, x, y):
18        raise NotImplementedError
19
20    def applyT(self, pc, x, y):
21        self.apply(pc, x, y)
22
23    def applyS(self, pc, x, y):
24        self.apply(pc, x, y)
25
26    def applySL(self, pc, x, y):
27        self.applyS(pc, x, y)
28
29    def applySR(self, pc, x, y):
30        self.applyS(pc, x, y)
31
32    def applyRich(self, pc, x, y, w, tols):
33        self.apply(pc, x, y)
34
35    def applyM(self, pc, x, y):
36        raise NotImplementedError
37
38
39class MyPCNone(BaseMyPC):
40    def apply(self, pc, x, y):
41        x.copy(y)
42
43    def applyM(self, pc, x, y):
44        x.copy(y)
45
46
47class MyPCJacobi(BaseMyPC):
48    def setup(self, pc):
49        A, P = pc.getOperators()
50        self.diag = P.getDiagonal()
51        self.diag.reciprocal()
52
53    def reset(self, pc):
54        self.diag.destroy()
55        del self.diag
56
57    def apply(self, pc, x, y):
58        y.pointwiseMult(self.diag, x)
59
60    def applyS(self, pc, x, y):
61        self.diag.copy(y)
62        y.sqrtabs()
63        y.pointwiseMult(y, x)
64
65    def applyM(self, pc, x, y):
66        x.copy(y)
67        y.diagonalScale(L=self.diag)
68
69
70class PC_PYTHON_CLASS:
71    def __init__(self):
72        self.impl = None
73        self.log = {}
74
75    def _log(self, method, *args):
76        self.log.setdefault(method, 0)
77        self.log[method] += 1
78
79    def create(self, pc):
80        self._log('create', pc)
81
82    def destroy(self, pc):
83        self._log('destroy')
84        self.impl = None
85
86    def reset(self, pc):
87        self._log('reset', pc)
88
89    def view(self, pc, vw):
90        self._log('view', pc, vw)
91
92    def setFromOptions(self, pc):
93        self._log('setFromOptions', pc)
94        OptDB = PETSc.Options(pc)
95        impl = OptDB.getString('impl', 'MyPCNone')
96        klass = globals()[impl]
97        self.impl = klass()
98
99    def setUp(self, pc):
100        self._log('setUp', pc)
101        self.impl.setup(pc)
102
103    def preSolve(self, pc, ksp, b, x):
104        self._log('preSolve', pc, ksp, b, x)
105
106    def postSolve(self, pc, ksp, b, x):
107        self._log('postSolve', pc, ksp, b, x)
108
109    def apply(self, pc, x, y):
110        self._log('apply', pc, x, y)
111        self.impl.apply(pc, x, y)
112
113    def applySymmetricLeft(self, pc, x, y):
114        self._log('applySymmetricLeft', pc, x, y)
115        self.impl.applySL(pc, x, y)
116
117    def applySymmetricRight(self, pc, x, y):
118        self._log('applySymmetricRight', pc, x, y)
119        self.impl.applySR(pc, x, y)
120
121    def applyTranspose(self, pc, x, y):
122        self._log('applyTranspose', pc, x, y)
123        self.impl.applyT(pc, x, y)
124
125    def matApply(self, pc, x, y):
126        self._log('matApply', pc, x, y)
127        self.impl.applyM(pc, x, y)
128
129    def applyRichardson(self, pc, x, y, w, tols):
130        self._log('applyRichardson', pc, x, y, w, tols)
131        self.impl.applyRich(pc, x, y, w, tols)
132
133
134class TestPCPYTHON(unittest.TestCase):
135    PC_TYPE = PETSc.PC.Type.PYTHON
136    PC_PREFIX = 'test-'
137
138    def setUp(self):
139        pc = self.pc = PETSc.PC()
140        pc.create(PETSc.COMM_SELF)
141        pc.setType(self.PC_TYPE)
142        module = __name__
143        factory = 'PC_PYTHON_CLASS'
144        self.pc.prefix = self.PC_PREFIX
145        OptDB = PETSc.Options(self.pc)
146        self.assertTrue(OptDB.prefix == self.pc.prefix)
147        OptDB['pc_python_type'] = f'{module}.{factory}'
148        self.pc.setFromOptions()
149        del OptDB['pc_python_type']
150        self.assertTrue(self._getCtx().log['create'] == 1)
151        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
152        self.assertEqual(getrefcount(self._getCtx()), 2)
153
154    def testGetType(self):
155        ctx = self.pc.getPythonContext()
156        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
157        self.assertTrue(self.pc.getPythonType() == pytype)
158
159    def tearDown(self):
160        ctx = self._getCtx()
161        self.pc.destroy()  # XXX
162        self.pc = None
163        PETSc.garbage_cleanup()
164        self.assertTrue(ctx.log['destroy'] == 1)
165
166    def _prepare(self):
167        A = PETSc.Mat().createAIJ([3, 3], comm=PETSc.COMM_SELF)
168        A.setUp()
169        A.assemble()
170        A.shift(10)
171        x, y = A.createVecs()
172        x.setRandom()
173        self.pc.setOperators(A, A)
174        X = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
175        X.assemble()
176        Y = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
177        Y.assemble()
178        self.assertTrue((A, A) == self.pc.getOperators())
179        return A, x, y, X, Y
180
181    def _getCtx(self):
182        return self.pc.getPythonContext()
183
184    def _applyMeth(self, meth):
185        A, x, y, X, Y = self._prepare()
186        if meth == 'matApply':
187            getattr(self.pc, meth)(X, Y)
188            x.copy(y)
189        else:
190            getattr(self.pc, meth)(x, y)
191            X.copy(Y)
192        if 'reset' not in self._getCtx().log:
193            self.assertTrue(self._getCtx().log['setUp'] == 1)
194            self.assertTrue(self._getCtx().log[meth] == 1)
195        else:
196            nreset = self._getCtx().log['reset']
197            nsetup = self._getCtx().log['setUp']
198            nmeth = self._getCtx().log[meth]
199            self.assertTrue(nreset == nsetup)
200            self.assertTrue(nreset == nmeth)
201        if isinstance(self._getCtx().impl, MyPCNone):
202            self.assertTrue(y.equal(x))
203            self.assertTrue(Y.equal(X))
204
205    def testApply(self):
206        self._applyMeth('apply')
207
208    def testApplySymmetricLeft(self):
209        self._applyMeth('applySymmetricLeft')
210
211    def testApplySymmetricRight(self):
212        self._applyMeth('applySymmetricRight')
213
214    def testApplyTranspose(self):
215        self._applyMeth('applyTranspose')
216
217    def testApplyMat(self):
218        self._applyMeth('matApply')
219
220    ## def testApplyRichardson(self):
221    ##     x, y = self._prepare()
222    ##     w = x.duplicate()
223    ##     tols = 0,0,0,0
224    ##     self.pc.applyRichardson(x,y,w,tols)
225    ##     assert self._getCtx().log['setUp'] == 1
226    ##     assert self._getCtx().log['applyRichardson'] == 1
227
228    ## def testView(self):
229    ##     vw = PETSc.ViewerString(100, self.pc.comm)
230    ##     self.pc.view(vw)
231    ##     s = vw.getString()
232    ##     assert 'python' in s
233    ##     module = __name__
234    ##     factory = 'self._getCtx()'
235    ##     assert '.'.join([module, factory]) in s
236
237    def testResetAndApply(self):
238        self.pc.reset()
239        self.testApply()
240        self.pc.reset()
241        self.testApply()
242        self.pc.reset()
243
244    def testKSPSolve(self):
245        A, x, y, _, _ = self._prepare()
246        ksp = PETSc.KSP().create(self.pc.comm)
247        ksp.setType(PETSc.KSP.Type.PREONLY)
248        self.assertTrue(self.pc.getRefCount() == 1)
249        ksp.setPC(self.pc)
250        self.assertTrue(self.pc.getRefCount() == 2)
251        # normal ksp solve, twice
252        ksp.solve(x, y)
253        self.assertTrue(self._getCtx().log['setUp'] == 1)
254        self.assertTrue(self._getCtx().log['apply'] == 1)
255        self.assertTrue(self._getCtx().log['preSolve'] == 1)
256        self.assertTrue(self._getCtx().log['postSolve'] == 1)
257        ksp.solve(x, y)
258        self.assertTrue(self._getCtx().log['setUp'] == 1)
259        self.assertTrue(self._getCtx().log['apply'] == 2)
260        self.assertTrue(self._getCtx().log['preSolve'] == 2)
261        self.assertTrue(self._getCtx().log['postSolve'] == 2)
262        # transpose ksp solve, twice
263        ksp.solveTranspose(x, y)
264        self.assertTrue(self._getCtx().log['setUp'] == 1)
265        self.assertTrue(self._getCtx().log['applyTranspose'] == 1)
266        ksp.solveTranspose(x, y)
267        self.assertTrue(self._getCtx().log['setUp'] == 1)
268        self.assertTrue(self._getCtx().log['applyTranspose'] == 2)
269        del ksp  # ksp.destroy()
270        PETSc.garbage_cleanup()
271        self.assertEqual(self.pc.getRefCount(), 1)
272
273    def testGetSetContext(self):
274        self.pc.setPythonContext(self._getCtx())
275        self.assertEqual(getrefcount(self.pc.getPythonContext()), 2)
276
277
278class TestPCPYTHON2(TestPCPYTHON):
279    def setUp(self):
280        OptDB = PETSc.Options(self.PC_PREFIX)
281        OptDB['impl'] = 'MyPCJacobi'
282        super().setUp()
283        clsname = type(self._getCtx().impl).__name__
284        self.assertTrue(clsname == OptDB['impl'])
285        del OptDB['impl']
286
287
288class TestPCPYTHON3(TestPCPYTHON):
289    def setUp(self):
290        pc = self.pc = PETSc.PC()
291        ctx = PC_PYTHON_CLASS()
292        pc.createPython(ctx, comm=PETSc.COMM_SELF)
293        self.pc.prefix = self.PC_PREFIX
294        self.pc.setFromOptions()
295        self.assertTrue(self._getCtx().log['create'] == 1)
296        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
297
298
299class TestPCPYTHON4(TestPCPYTHON3):
300    def setUp(self):
301        OptDB = PETSc.Options(self.PC_PREFIX)
302        OptDB['impl'] = 'MyPCJacobi'
303        super().setUp()
304        clsname = type(self._getCtx().impl).__name__
305        self.assertTrue(clsname == OptDB['impl'])
306        del OptDB['impl']
307
308
309# --------------------------------------------------------------------
310
311if __name__ == '__main__':
312    unittest.main()
313