xref: /petsc/src/binding/petsc4py/test/test_object.py (revision 34c645fd3b0199e05bec2fcc32d3597bfeb7f4f2)
1import unittest
2
3from petsc4py import PETSc
4import numpy
5
6# --------------------------------------------------------------------
7
8
9class BaseTestObject:
10    CLASS, FACTORY = None, None
11    TARGS, KARGS = (), {}
12    BUILD = None
13
14    def setUp(self):
15        self.obj = self.CLASS()
16        getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS)
17        if not self.obj:
18            self.obj.create()
19
20    def tearDown(self):
21        self.obj = None
22        PETSc.garbage_cleanup()
23
24    def testTypeRegistry(self):
25        type_reg = PETSc.__type_registry__
26        classid = self.obj.getClassId()
27        typeobj = self.CLASS
28        if isinstance(self.obj, PETSc.DMDA):
29            typeobj = PETSc.DM
30        self.assertTrue(type_reg[classid] is typeobj)
31
32    def testLogClass(self):
33        name = self.CLASS.__name__
34        if name == 'DMDA':
35            name = 'DM'
36        logcls = PETSc.Log.Class(name)
37        classid = self.obj.getClassId()
38        self.assertEqual(logcls.id, classid)
39
40    def testClass(self):
41        self.assertTrue(isinstance(self.obj, self.CLASS))
42        self.assertTrue(type(self.obj) is self.CLASS)
43
44    def testId(self):
45        oid = self.obj.getId()
46        self.assertTrue(oid > 0)
47        self.assertEqual(self.obj.id, oid)
48
49    def testNonZero(self):
50        self.assertTrue(bool(self.obj))
51
52    def testDestroy(self):
53        self.assertTrue(bool(self.obj))
54        self.obj.destroy()
55        self.assertFalse(bool(self.obj))
56        ## self.assertRaises(PETSc.Error, self.obj.destroy)
57        ## self.assertTrue(self.obj.this is this)
58
59    def testOptions(self):
60        self.assertFalse(self.obj.getOptionsPrefix())
61        prefix1 = 'my_'
62        self.obj.setOptionsPrefix(prefix1)
63        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
64        prefix2 = 'opt_'
65        self.obj.setOptionsPrefix(prefix2)
66        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
67        self.obj.appendOptionsPrefix(prefix1)
68        self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1)
69        self.obj.setOptionsPrefix(None)
70        self.assertEqual(self.obj.getOptionsPrefix(), None)
71        self.obj.setFromOptions()
72
73        def opts_handler(obj):
74            n = obj.getAttr('opts_handler_called')
75            obj.setAttr('opts_handler_called', n + 1)
76            self.assertEqual(type(self.obj), type(obj))
77            self.assertEqual(self.obj.klass, obj.klass)
78            self.assertEqual(self.obj.type, obj.type)
79            self.assertEqual(self.obj.id, obj.id)
80
81        for _ in range(2):
82            self.obj.setAttr('opts_handler_called', 0)
83            self.obj.setOptionsHandler(opts_handler)
84            self.obj.setFromOptions()
85            missing = [
86                'AO',
87                'DMLabel',
88                'PetscDualSpace',
89                'IS',
90                'ISLocalToGlobalMapping',
91                'MatPartitioning',
92                'MatNullSpace',
93                'PetscRandom',
94                'PetscViewer',
95            ]
96            if self.obj.klass not in missing:
97                self.assertTrue(self.obj.getAttr('opts_handler_called') == 1)
98
99        self.obj.setAttr('opts_handler_called', 0)
100        self.obj.setOptionsHandler(None)
101        self.obj.setFromOptions()
102        self.assertFalse(self.obj.getAttr('opts_handler_called'))
103
104        self.obj.destroyOptionsHandlers()
105        self.obj.setFromOptions()
106        self.assertFalse(self.obj.getAttr('opts_handler_called'))
107
108    def testName(self):
109        oldname = self.obj.getName()
110        newname = f'{oldname}-{oldname}'
111        self.obj.setName(newname)
112        self.assertEqual(self.obj.getName(), newname)
113        self.obj.setName(oldname)
114        self.assertEqual(self.obj.getName(), oldname)
115
116    def testComm(self):
117        comm = self.obj.getComm()
118        self.assertTrue(isinstance(comm, PETSc.Comm))
119        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
120
121    def testRefCount(self):
122        self.assertEqual(self.obj.getRefCount(), 1)
123        self.obj.incRef()
124        self.assertEqual(self.obj.getRefCount(), 2)
125        self.obj.incRef()
126        self.assertEqual(self.obj.getRefCount(), 3)
127        self.obj.decRef()
128        self.assertEqual(self.obj.getRefCount(), 2)
129        self.obj.decRef()
130        self.assertEqual(self.obj.getRefCount(), 1)
131        self.obj.decRef()
132        self.assertFalse(bool(self.obj))
133
134    def testHandle(self):
135        self.assertTrue(self.obj.handle)
136        self.assertTrue(self.obj.fortran)
137        h, f = self.obj.handle, self.obj.fortran
138        if (h > 0 and f > 0) or (h < 0 and f < 0):
139            self.assertEqual(h, f)
140        self.obj.destroy()
141        self.assertFalse(self.obj.handle)
142        self.assertFalse(self.obj.fortran)
143
144    def testComposeQuery(self):
145        import copy
146
147        try:
148            myobj = copy.deepcopy(self.obj)
149        except NotImplementedError:
150            return
151        self.assertEqual(myobj.getRefCount(), 1)
152        self.obj.compose('myobj', myobj)
153        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
154        self.assertEqual(self.obj.query('myobj'), myobj)
155        self.assertEqual(myobj.getRefCount(), 2)
156        self.obj.compose('myobj', None)
157        self.assertEqual(myobj.getRefCount(), 1)
158        self.assertEqual(self.obj.query('myobj'), None)
159        myobj.destroy()
160
161    def testProperties(self):
162        self.assertEqual(self.obj.getClassId(), self.obj.classid)
163        self.assertEqual(self.obj.getClassName(), self.obj.klass)
164        self.assertEqual(self.obj.getType(), self.obj.type)
165        self.assertEqual(self.obj.getName(), self.obj.name)
166        self.assertEqual(self.obj.getComm(), self.obj.comm)
167        self.assertEqual(self.obj.getRefCount(), self.obj.refcount)
168
169    def testShallowCopy(self):
170        import copy
171
172        rc = self.obj.getRefCount()
173        obj = copy.copy(self.obj)
174        self.assertTrue(obj is not self.obj)
175        self.assertTrue(obj == self.obj)
176        self.assertTrue(isinstance(obj, type(self.obj)))
177        self.assertEqual(obj.getRefCount(), rc + 1)
178        del obj
179        self.assertEqual(self.obj.getRefCount(), rc)
180
181    def testDeepCopy(self):
182        import copy
183
184        rc = self.obj.getRefCount()
185        try:
186            obj = copy.deepcopy(self.obj)
187        except NotImplementedError:
188            return
189        self.assertTrue(obj is not self.obj)
190        self.assertTrue(obj != self.obj)
191        self.assertTrue(isinstance(obj, type(self.obj)))
192        self.assertEqual(self.obj.getRefCount(), rc)
193        self.assertEqual(obj.getRefCount(), 1)
194        del obj
195
196    def testStateInspection(self):
197        state = self.obj.stateGet()
198        self.obj.stateIncrease()
199        self.assertTrue(state < self.obj.stateGet())
200        self.obj.stateSet(0)
201        self.assertTrue(self.obj.stateGet() == 0)
202        self.obj.stateSet(state)
203        self.assertTrue(self.obj.stateGet() == state)
204
205
206# --------------------------------------------------------------------
207
208
209class TestObjectRandom(BaseTestObject, unittest.TestCase):
210    CLASS = PETSc.Random
211    FACTORY = 'create'
212
213
214class TestObjectViewer(BaseTestObject, unittest.TestCase):
215    CLASS = PETSc.Viewer
216    FACTORY = 'create'
217
218
219class TestObjectIS(BaseTestObject, unittest.TestCase):
220    CLASS = PETSc.IS
221    FACTORY = 'createGeneral'
222    TARGS = ([],)
223
224
225class TestObjectLGMap(BaseTestObject, unittest.TestCase):
226    CLASS = PETSc.LGMap
227    FACTORY = 'create'
228    TARGS = ([],)
229
230
231class TestObjectAO(BaseTestObject, unittest.TestCase):
232    CLASS = PETSc.AO
233    FACTORY = 'createMapping'
234    TARGS = ([], [])
235
236
237class TestObjectDMDA(BaseTestObject, unittest.TestCase):
238    CLASS = PETSc.DMDA
239    FACTORY = 'create'
240    TARGS = ([3, 3, 3],)
241
242
243class TestObjectDS(BaseTestObject, unittest.TestCase):
244    CLASS = PETSc.DS
245    FACTORY = 'create'
246
247
248class TestObjectVec(BaseTestObject, unittest.TestCase):
249    CLASS = PETSc.Vec
250    FACTORY = 'createSeq'
251    TARGS = (0,)
252
253    def setUp(self):
254        BaseTestObject.setUp(self)
255        self.obj.assemble()
256
257
258class TestObjectMat(BaseTestObject, unittest.TestCase):
259    CLASS = PETSc.Mat
260    FACTORY = 'createAIJ'
261    TARGS = (0,)
262    KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF}
263
264    def setUp(self):
265        BaseTestObject.setUp(self)
266        self.obj.assemble()
267
268
269class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
270    CLASS = PETSc.MatPartitioning
271    FACTORY = 'create'
272
273
274class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
275    CLASS = PETSc.NullSpace
276    FACTORY = 'create'
277    TARGS = (True, [])
278
279
280class TestObjectKSP(BaseTestObject, unittest.TestCase):
281    CLASS = PETSc.KSP
282    FACTORY = 'create'
283
284
285class TestObjectPC(BaseTestObject, unittest.TestCase):
286    CLASS = PETSc.PC
287    FACTORY = 'create'
288
289
290class TestObjectSNES(BaseTestObject, unittest.TestCase):
291    CLASS = PETSc.SNES
292    FACTORY = 'create'
293
294
295class TestObjectTS(BaseTestObject, unittest.TestCase):
296    CLASS = PETSc.TS
297    FACTORY = 'create'
298
299    def setUp(self):
300        super().setUp()
301        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
302        self.obj.setType(PETSc.TS.Type.BEULER)
303
304
305class TestObjectTAO(BaseTestObject, unittest.TestCase):
306    CLASS = PETSc.TAO
307    FACTORY = 'create'
308
309
310class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
311    CLASS = PETSc.AO
312    FACTORY = 'createBasic'
313    TARGS = ([], [])
314
315
316class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
317    CLASS = PETSc.AO
318    FACTORY = 'createMapping'
319    TARGS = ([], [])
320
321
322# class TestObjectFE(BaseTestObject, unittest.TestCase):
323#     CLASS  = PETSc.FE
324#     FACTORY = 'create'
325#
326# class TestObjectQuad(BaseTestObject, unittest.TestCase):
327#     CLASS  = PETSc.Quad
328#     FACTORY = 'create'
329
330
331class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
332    CLASS = PETSc.DMLabel
333    FACTORY = 'create'
334    TARGS = ('test',)
335
336
337class TestObjectSpace(BaseTestObject, unittest.TestCase):
338    CLASS = PETSc.Space
339    FACTORY = 'create'
340
341
342class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
343    CLASS = PETSc.DualSpace
344    FACTORY = 'create'
345
346
347# --------------------------------------------------------------------
348
349if numpy.iscomplexobj(PETSc.ScalarType()):
350    del TestObjectTAO
351
352if __name__ == '__main__':
353    unittest.main()
354