xref: /petsc/src/binding/petsc4py/test/test_object.py (revision 8a7d4057d9226490dba4e1a062f54f84e7d90861) !
1import unittest
2
3from petsc4py import PETSc
4import numpy
5
6# --------------------------------------------------------------------
7
8
9class BaseTestObject:
10    CLASS, FACTORY = None, None
11    TARGS, KARGS = (), {}
12    BUILD = None
13
14    def setUp(self):
15        self.obj = self.CLASS()
16        getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS)
17        if not self.obj:
18            self.obj.create()
19
20    def tearDown(self):
21        self.obj = None
22        PETSc.garbage_cleanup()
23
24    def testTypeRegistry(self):
25        type_reg = PETSc.__type_registry__
26        classid = self.obj.getClassId()
27        typeobj = self.CLASS
28        if isinstance(self.obj, PETSc.DMDA):
29            typeobj = PETSc.DM
30        self.assertTrue(type_reg[classid] is typeobj)
31
32    def testLogClass(self):
33        name = self.CLASS.__name__
34        if name == 'DMDA':
35            name = 'DM'
36        logcls = PETSc.Log.Class(name)
37        classid = self.obj.getClassId()
38        self.assertEqual(logcls.id, classid)
39
40    def testClass(self):
41        self.assertTrue(isinstance(self.obj, self.CLASS))
42        self.assertTrue(type(self.obj) is self.CLASS)
43
44    def testId(self):
45        oid = self.obj.getId()
46        self.assertTrue(oid > 0)
47        self.assertEqual(self.obj.id, oid)
48
49    def testNonZero(self):
50        self.assertTrue(bool(self.obj))
51
52    def testDestroy(self):
53        self.assertTrue(bool(self.obj))
54        self.obj.destroy()
55        self.assertFalse(bool(self.obj))
56        ## self.assertRaises(PETSc.Error, self.obj.destroy)
57        ## self.assertTrue(self.obj.this is this)
58
59    def testOptions(self):
60        self.assertFalse(self.obj.getOptionsPrefix())
61        prefix1 = 'my_'
62        self.obj.setOptionsPrefix(prefix1)
63        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
64        prefix2 = 'opt_'
65        self.obj.setOptionsPrefix(prefix2)
66        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
67        self.obj.appendOptionsPrefix(prefix1)
68        self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1)
69        self.obj.setOptionsPrefix(None)
70        self.assertEqual(self.obj.getOptionsPrefix(), None)
71        if not self.obj.getType() or not 'da' == str(self.obj.getType()):
72            self.obj.setFromOptions()
73
74        def opts_handler(obj):
75            n = obj.getAttr('opts_handler_called')
76            obj.setAttr('opts_handler_called', n + 1)
77            self.assertEqual(type(self.obj), type(obj))
78            self.assertEqual(self.obj.klass, obj.klass)
79            self.assertEqual(self.obj.type, obj.type)
80            self.assertEqual(self.obj.id, obj.id)
81
82        for _ in range(2):
83            self.obj.setAttr('opts_handler_called', 0)
84            self.obj.setOptionsHandler(opts_handler)
85            if not self.obj.getType() or not 'da' == str(self.obj.getType()):
86                self.obj.setFromOptions()
87                missing = [
88                           'AO',
89                           'DMLabel',
90                           'PetscDualSpace',
91                           'IS',
92                           'ISLocalToGlobalMapping',
93                           'MatPartitioning',
94                           'MatNullSpace',
95                           'PetscRandom',
96                           'PetscViewer',
97                          ]
98                if self.obj.klass not in missing:
99                    self.assertTrue(self.obj.getAttr('opts_handler_called') == 1)
100
101        if not self.obj.getType() or not 'da' == str(self.obj.getType()):
102            self.obj.setAttr('opts_handler_called', 0)
103            self.obj.setOptionsHandler(None)
104            self.obj.setFromOptions()
105            self.assertFalse(self.obj.getAttr('opts_handler_called'))
106
107            self.obj.destroyOptionsHandlers()
108            self.obj.setFromOptions()
109            self.assertFalse(self.obj.getAttr('opts_handler_called'))
110
111    def testName(self):
112        oldname = self.obj.getName()
113        newname = f'{oldname}-{oldname}'
114        self.obj.setName(newname)
115        self.assertEqual(self.obj.getName(), newname)
116        self.obj.setName(oldname)
117        self.assertEqual(self.obj.getName(), oldname)
118
119    def testComm(self):
120        comm = self.obj.getComm()
121        self.assertTrue(isinstance(comm, PETSc.Comm))
122        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
123
124    def testRefCount(self):
125        self.assertEqual(self.obj.getRefCount(), 1)
126        self.obj.incRef()
127        self.assertEqual(self.obj.getRefCount(), 2)
128        self.obj.incRef()
129        self.assertEqual(self.obj.getRefCount(), 3)
130        self.obj.decRef()
131        self.assertEqual(self.obj.getRefCount(), 2)
132        self.obj.decRef()
133        self.assertEqual(self.obj.getRefCount(), 1)
134        self.obj.decRef()
135        self.assertFalse(bool(self.obj))
136
137    def testHandle(self):
138        self.assertTrue(self.obj.handle)
139        self.assertTrue(self.obj.fortran)
140        h, f = self.obj.handle, self.obj.fortran
141        if (h > 0 and f > 0) or (h < 0 and f < 0):
142            self.assertEqual(h, f)
143        self.obj.destroy()
144        self.assertFalse(self.obj.handle)
145        self.assertFalse(self.obj.fortran)
146
147    def testComposeQuery(self):
148        import copy
149
150        try:
151            myobj = copy.deepcopy(self.obj)
152        except NotImplementedError:
153            return
154        self.assertEqual(myobj.getRefCount(), 1)
155        self.obj.compose('myobj', myobj)
156        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
157        self.assertEqual(self.obj.query('myobj'), myobj)
158        self.assertEqual(myobj.getRefCount(), 2)
159        self.obj.compose('myobj', None)
160        self.assertEqual(myobj.getRefCount(), 1)
161        self.assertEqual(self.obj.query('myobj'), None)
162        myobj.destroy()
163
164    def testProperties(self):
165        self.assertEqual(self.obj.getClassId(), self.obj.classid)
166        self.assertEqual(self.obj.getClassName(), self.obj.klass)
167        self.assertEqual(self.obj.getType(), self.obj.type)
168        self.assertEqual(self.obj.getName(), self.obj.name)
169        self.assertEqual(self.obj.getComm(), self.obj.comm)
170        self.assertEqual(self.obj.getRefCount(), self.obj.refcount)
171
172    def testShallowCopy(self):
173        import copy
174
175        rc = self.obj.getRefCount()
176        obj = copy.copy(self.obj)
177        self.assertTrue(obj is not self.obj)
178        self.assertTrue(obj == self.obj)
179        self.assertTrue(isinstance(obj, type(self.obj)))
180        self.assertEqual(obj.getRefCount(), rc + 1)
181        del obj
182        self.assertEqual(self.obj.getRefCount(), rc)
183
184    def testDeepCopy(self):
185        import copy
186
187        rc = self.obj.getRefCount()
188        try:
189            obj = copy.deepcopy(self.obj)
190        except NotImplementedError:
191            return
192        self.assertTrue(obj is not self.obj)
193        self.assertTrue(obj != self.obj)
194        self.assertTrue(isinstance(obj, type(self.obj)))
195        self.assertEqual(self.obj.getRefCount(), rc)
196        self.assertEqual(obj.getRefCount(), 1)
197        del obj
198
199    def testStateInspection(self):
200        state = self.obj.stateGet()
201        self.obj.stateIncrease()
202        self.assertTrue(state < self.obj.stateGet())
203        self.obj.stateSet(0)
204        self.assertTrue(self.obj.stateGet() == 0)
205        self.obj.stateSet(state)
206        self.assertTrue(self.obj.stateGet() == state)
207
208
209# --------------------------------------------------------------------
210
211
212class TestObjectRandom(BaseTestObject, unittest.TestCase):
213    CLASS = PETSc.Random
214    FACTORY = 'create'
215
216
217class TestObjectViewer(BaseTestObject, unittest.TestCase):
218    CLASS = PETSc.Viewer
219    FACTORY = 'create'
220
221
222class TestObjectIS(BaseTestObject, unittest.TestCase):
223    CLASS = PETSc.IS
224    FACTORY = 'createGeneral'
225    TARGS = ([],)
226
227
228class TestObjectLGMap(BaseTestObject, unittest.TestCase):
229    CLASS = PETSc.LGMap
230    FACTORY = 'create'
231    TARGS = ([],)
232
233
234class TestObjectAO(BaseTestObject, unittest.TestCase):
235    CLASS = PETSc.AO
236    FACTORY = 'createMapping'
237    TARGS = ([], [])
238
239
240class TestObjectDMDA(BaseTestObject, unittest.TestCase):
241    CLASS = PETSc.DMDA
242    FACTORY = 'create'
243    TARGS = ([3, 3, 3],)
244
245
246class TestObjectDS(BaseTestObject, unittest.TestCase):
247    CLASS = PETSc.DS
248    FACTORY = 'create'
249
250
251class TestObjectVec(BaseTestObject, unittest.TestCase):
252    CLASS = PETSc.Vec
253    FACTORY = 'createSeq'
254    TARGS = (0,)
255
256    def setUp(self):
257        BaseTestObject.setUp(self)
258        self.obj.assemble()
259
260
261class TestObjectMat(BaseTestObject, unittest.TestCase):
262    CLASS = PETSc.Mat
263    FACTORY = 'createAIJ'
264    TARGS = (0,)
265    KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF}
266
267    def setUp(self):
268        BaseTestObject.setUp(self)
269        self.obj.assemble()
270
271
272class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
273    CLASS = PETSc.MatPartitioning
274    FACTORY = 'create'
275
276
277class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
278    CLASS = PETSc.NullSpace
279    FACTORY = 'create'
280    TARGS = (True, [])
281
282
283class TestObjectKSP(BaseTestObject, unittest.TestCase):
284    CLASS = PETSc.KSP
285    FACTORY = 'create'
286
287
288class TestObjectPC(BaseTestObject, unittest.TestCase):
289    CLASS = PETSc.PC
290    FACTORY = 'create'
291
292
293class TestObjectSNES(BaseTestObject, unittest.TestCase):
294    CLASS = PETSc.SNES
295    FACTORY = 'create'
296
297
298class TestObjectTS(BaseTestObject, unittest.TestCase):
299    CLASS = PETSc.TS
300    FACTORY = 'create'
301
302    def setUp(self):
303        super().setUp()
304        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
305        self.obj.setType(PETSc.TS.Type.BEULER)
306
307
308class TestObjectTAO(BaseTestObject, unittest.TestCase):
309    CLASS = PETSc.TAO
310    FACTORY = 'create'
311
312
313class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
314    CLASS = PETSc.AO
315    FACTORY = 'createBasic'
316    TARGS = ([], [])
317
318
319class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
320    CLASS = PETSc.AO
321    FACTORY = 'createMapping'
322    TARGS = ([], [])
323
324
325# class TestObjectFE(BaseTestObject, unittest.TestCase):
326#     CLASS  = PETSc.FE
327#     FACTORY = 'create'
328#
329# class TestObjectQuad(BaseTestObject, unittest.TestCase):
330#     CLASS  = PETSc.Quad
331#     FACTORY = 'create'
332
333
334class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
335    CLASS = PETSc.DMLabel
336    FACTORY = 'create'
337    TARGS = ('test',)
338
339
340class TestObjectSpace(BaseTestObject, unittest.TestCase):
341    CLASS = PETSc.Space
342    FACTORY = 'create'
343
344
345class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
346    CLASS = PETSc.DualSpace
347    FACTORY = 'create'
348
349
350# --------------------------------------------------------------------
351
352if numpy.iscomplexobj(PETSc.ScalarType()):
353    del TestObjectTAO
354
355if __name__ == '__main__':
356    unittest.main()
357