xref: /petsc/src/binding/petsc4py/test/test_object.py (revision 85f25e71687eef93ec659366612fd9d9ee335aac)
1import unittest
2
3from petsc4py import PETSc
4import numpy
5
6# --------------------------------------------------------------------
7
8
9class BaseTestObject:
10    CLASS, FACTORY = None, None
11    TARGS, KARGS = (), {}
12    BUILD = None
13
14    def setUp(self):
15        self.obj = self.CLASS()
16        getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS)
17        if not self.obj:
18            self.obj.create()
19
20    def tearDown(self):
21        self.obj = None
22        PETSc.garbage_cleanup()
23
24    def testTypeRegistry(self):
25        type_reg = PETSc.__type_registry__
26        classid = self.obj.getClassId()
27        typeobj = self.CLASS
28        if isinstance(self.obj, PETSc.DMDA):
29            typeobj = PETSc.DM
30        self.assertTrue(type_reg[classid] is typeobj)
31
32    def testLogClass(self):
33        name = self.CLASS.__name__
34        if name == 'DMDA':
35            name = 'DM'
36        logcls = PETSc.Log.Class(name)
37        classid = self.obj.getClassId()
38        self.assertEqual(logcls.id, classid)
39
40    def testClass(self):
41        self.assertTrue(isinstance(self.obj, self.CLASS))
42        self.assertTrue(type(self.obj) is self.CLASS)
43
44    def testNonZero(self):
45        self.assertTrue(bool(self.obj))
46
47    def testDestroy(self):
48        self.assertTrue(bool(self.obj))
49        self.obj.destroy()
50        self.assertFalse(bool(self.obj))
51        ## self.assertRaises(PETSc.Error, self.obj.destroy)
52        ## self.assertTrue(self.obj.this is this)
53
54    def testOptions(self):
55        self.assertFalse(self.obj.getOptionsPrefix())
56        prefix1 = 'my_'
57        self.obj.setOptionsPrefix(prefix1)
58        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
59        prefix2 = 'opt_'
60        self.obj.setOptionsPrefix(prefix2)
61        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
62        self.obj.appendOptionsPrefix(prefix1)
63        self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1)
64        self.obj.setOptionsPrefix(None)
65        self.assertEqual(self.obj.getOptionsPrefix(), None)
66        self.obj.setFromOptions()
67
68    def testName(self):
69        oldname = self.obj.getName()
70        newname = f'{oldname}-{oldname}'
71        self.obj.setName(newname)
72        self.assertEqual(self.obj.getName(), newname)
73        self.obj.setName(oldname)
74        self.assertEqual(self.obj.getName(), oldname)
75
76    def testComm(self):
77        comm = self.obj.getComm()
78        self.assertTrue(isinstance(comm, PETSc.Comm))
79        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
80
81    def testRefCount(self):
82        self.assertEqual(self.obj.getRefCount(), 1)
83        self.obj.incRef()
84        self.assertEqual(self.obj.getRefCount(), 2)
85        self.obj.incRef()
86        self.assertEqual(self.obj.getRefCount(), 3)
87        self.obj.decRef()
88        self.assertEqual(self.obj.getRefCount(), 2)
89        self.obj.decRef()
90        self.assertEqual(self.obj.getRefCount(), 1)
91        self.obj.decRef()
92        self.assertFalse(bool(self.obj))
93
94    def testHandle(self):
95        self.assertTrue(self.obj.handle)
96        self.assertTrue(self.obj.fortran)
97        h, f = self.obj.handle, self.obj.fortran
98        if (h > 0 and f > 0) or (h < 0 and f < 0):
99            self.assertEqual(h, f)
100        self.obj.destroy()
101        self.assertFalse(self.obj.handle)
102        self.assertFalse(self.obj.fortran)
103
104    def testComposeQuery(self):
105        import copy
106
107        try:
108            myobj = copy.deepcopy(self.obj)
109        except NotImplementedError:
110            return
111        self.assertEqual(myobj.getRefCount(), 1)
112        self.obj.compose('myobj', myobj)
113        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
114        self.assertEqual(self.obj.query('myobj'), myobj)
115        self.assertEqual(myobj.getRefCount(), 2)
116        self.obj.compose('myobj', None)
117        self.assertEqual(myobj.getRefCount(), 1)
118        self.assertEqual(self.obj.query('myobj'), None)
119        myobj.destroy()
120
121    def testProperties(self):
122        self.assertEqual(self.obj.getClassId(), self.obj.classid)
123        self.assertEqual(self.obj.getClassName(), self.obj.klass)
124        self.assertEqual(self.obj.getType(), self.obj.type)
125        self.assertEqual(self.obj.getName(), self.obj.name)
126        self.assertEqual(self.obj.getComm(), self.obj.comm)
127        self.assertEqual(self.obj.getRefCount(), self.obj.refcount)
128
129    def testShallowCopy(self):
130        import copy
131
132        rc = self.obj.getRefCount()
133        obj = copy.copy(self.obj)
134        self.assertTrue(obj is not self.obj)
135        self.assertTrue(obj == self.obj)
136        self.assertTrue(isinstance(obj, type(self.obj)))
137        self.assertEqual(obj.getRefCount(), rc + 1)
138        del obj
139        self.assertEqual(self.obj.getRefCount(), rc)
140
141    def testDeepCopy(self):
142        import copy
143
144        rc = self.obj.getRefCount()
145        try:
146            obj = copy.deepcopy(self.obj)
147        except NotImplementedError:
148            return
149        self.assertTrue(obj is not self.obj)
150        self.assertTrue(obj != self.obj)
151        self.assertTrue(isinstance(obj, type(self.obj)))
152        self.assertEqual(self.obj.getRefCount(), rc)
153        self.assertEqual(obj.getRefCount(), 1)
154        del obj
155
156    def testStateInspection(self):
157        state = self.obj.stateGet()
158        self.obj.stateIncrease()
159        self.assertTrue(state < self.obj.stateGet())
160        self.obj.stateSet(0)
161        self.assertTrue(self.obj.stateGet() == 0)
162        self.obj.stateSet(state)
163        self.assertTrue(self.obj.stateGet() == state)
164
165
166# --------------------------------------------------------------------
167
168
169class TestObjectRandom(BaseTestObject, unittest.TestCase):
170    CLASS = PETSc.Random
171    FACTORY = 'create'
172
173
174class TestObjectViewer(BaseTestObject, unittest.TestCase):
175    CLASS = PETSc.Viewer
176    FACTORY = 'create'
177
178
179class TestObjectIS(BaseTestObject, unittest.TestCase):
180    CLASS = PETSc.IS
181    FACTORY = 'createGeneral'
182    TARGS = ([],)
183
184
185class TestObjectLGMap(BaseTestObject, unittest.TestCase):
186    CLASS = PETSc.LGMap
187    FACTORY = 'create'
188    TARGS = ([],)
189
190
191class TestObjectAO(BaseTestObject, unittest.TestCase):
192    CLASS = PETSc.AO
193    FACTORY = 'createMapping'
194    TARGS = ([], [])
195
196
197class TestObjectDMDA(BaseTestObject, unittest.TestCase):
198    CLASS = PETSc.DMDA
199    FACTORY = 'create'
200    TARGS = ([3, 3, 3],)
201
202
203class TestObjectDS(BaseTestObject, unittest.TestCase):
204    CLASS = PETSc.DS
205    FACTORY = 'create'
206
207
208class TestObjectVec(BaseTestObject, unittest.TestCase):
209    CLASS = PETSc.Vec
210    FACTORY = 'createSeq'
211    TARGS = (0,)
212
213    def setUp(self):
214        BaseTestObject.setUp(self)
215        self.obj.assemble()
216
217
218class TestObjectMat(BaseTestObject, unittest.TestCase):
219    CLASS = PETSc.Mat
220    FACTORY = 'createAIJ'
221    TARGS = (0,)
222    KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF}
223
224    def setUp(self):
225        BaseTestObject.setUp(self)
226        self.obj.assemble()
227
228
229class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
230    CLASS = PETSc.MatPartitioning
231    FACTORY = 'create'
232
233
234class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
235    CLASS = PETSc.NullSpace
236    FACTORY = 'create'
237    TARGS = (True, [])
238
239
240class TestObjectKSP(BaseTestObject, unittest.TestCase):
241    CLASS = PETSc.KSP
242    FACTORY = 'create'
243
244
245class TestObjectPC(BaseTestObject, unittest.TestCase):
246    CLASS = PETSc.PC
247    FACTORY = 'create'
248
249
250class TestObjectSNES(BaseTestObject, unittest.TestCase):
251    CLASS = PETSc.SNES
252    FACTORY = 'create'
253
254
255class TestObjectTS(BaseTestObject, unittest.TestCase):
256    CLASS = PETSc.TS
257    FACTORY = 'create'
258
259    def setUp(self):
260        super().setUp()
261        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
262        self.obj.setType(PETSc.TS.Type.BEULER)
263
264
265class TestObjectTAO(BaseTestObject, unittest.TestCase):
266    CLASS = PETSc.TAO
267    FACTORY = 'create'
268
269
270class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
271    CLASS = PETSc.AO
272    FACTORY = 'createBasic'
273    TARGS = ([], [])
274
275
276class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
277    CLASS = PETSc.AO
278    FACTORY = 'createMapping'
279    TARGS = ([], [])
280
281
282# class TestObjectFE(BaseTestObject, unittest.TestCase):
283#     CLASS  = PETSc.FE
284#     FACTORY = 'create'
285#
286# class TestObjectQuad(BaseTestObject, unittest.TestCase):
287#     CLASS  = PETSc.Quad
288#     FACTORY = 'create'
289
290
291class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
292    CLASS = PETSc.DMLabel
293    FACTORY = 'create'
294    TARGS = ('test',)
295
296
297class TestObjectSpace(BaseTestObject, unittest.TestCase):
298    CLASS = PETSc.Space
299    FACTORY = 'create'
300
301
302class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
303    CLASS = PETSc.DualSpace
304    FACTORY = 'create'
305
306
307# --------------------------------------------------------------------
308
309if numpy.iscomplexobj(PETSc.ScalarType()):
310    del TestObjectTAO
311
312if __name__ == '__main__':
313    unittest.main()
314