xref: /petsc/src/binding/petsc4py/test/test_object.py (revision 36d43d94b6d42e888c89e2d3ed68780aaa9faca1)
1import unittest
2
3from petsc4py import PETSc
4import numpy
5
6# --------------------------------------------------------------------
7
8
9class BaseTestObject:
10    CLASS, FACTORY = None, None
11    TARGS, KARGS = (), {}
12    BUILD = None
13
14    def setUp(self):
15        self.obj = self.CLASS()
16        getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS)
17        if not self.obj:
18            self.obj.create()
19
20    def tearDown(self):
21        self.obj = None
22        PETSc.garbage_cleanup()
23
24    def testTypeRegistry(self):
25        type_reg = PETSc.__type_registry__
26        classid = self.obj.getClassId()
27        typeobj = self.CLASS
28        if isinstance(self.obj, PETSc.DMDA):
29            typeobj = PETSc.DM
30        self.assertTrue(type_reg[classid] is typeobj)
31
32    def testLogClass(self):
33        name = self.CLASS.__name__
34        if name == 'DMDA':
35            name = 'DM'
36        logcls = PETSc.Log.Class(name)
37        classid = self.obj.getClassId()
38        self.assertEqual(logcls.id, classid)
39
40    def testClass(self):
41        self.assertTrue(isinstance(self.obj, self.CLASS))
42        self.assertTrue(type(self.obj) is self.CLASS)
43
44    def testId(self):
45        oid = self.obj.getId()
46        self.assertTrue(oid > 0)
47        self.assertEqual(self.obj.id, oid)
48
49    def testNonZero(self):
50        self.assertTrue(bool(self.obj))
51
52    def testDestroy(self):
53        self.assertTrue(bool(self.obj))
54        self.obj.destroy()
55        self.assertFalse(bool(self.obj))
56        ## self.assertRaises(PETSc.Error, self.obj.destroy)
57        ## self.assertTrue(self.obj.this is this)
58
59    def testOptions(self):
60        self.assertFalse(self.obj.getOptionsPrefix())
61        prefix1 = 'my_'
62        self.obj.setOptionsPrefix(prefix1)
63        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
64        prefix2 = 'opt_'
65        self.obj.setOptionsPrefix(prefix2)
66        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
67        self.obj.appendOptionsPrefix(prefix1)
68        self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1)
69        self.obj.setOptionsPrefix(None)
70        self.assertEqual(self.obj.getOptionsPrefix(), None)
71        self.obj.setFromOptions()
72
73    def testName(self):
74        oldname = self.obj.getName()
75        newname = f'{oldname}-{oldname}'
76        self.obj.setName(newname)
77        self.assertEqual(self.obj.getName(), newname)
78        self.obj.setName(oldname)
79        self.assertEqual(self.obj.getName(), oldname)
80
81    def testComm(self):
82        comm = self.obj.getComm()
83        self.assertTrue(isinstance(comm, PETSc.Comm))
84        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
85
86    def testRefCount(self):
87        self.assertEqual(self.obj.getRefCount(), 1)
88        self.obj.incRef()
89        self.assertEqual(self.obj.getRefCount(), 2)
90        self.obj.incRef()
91        self.assertEqual(self.obj.getRefCount(), 3)
92        self.obj.decRef()
93        self.assertEqual(self.obj.getRefCount(), 2)
94        self.obj.decRef()
95        self.assertEqual(self.obj.getRefCount(), 1)
96        self.obj.decRef()
97        self.assertFalse(bool(self.obj))
98
99    def testHandle(self):
100        self.assertTrue(self.obj.handle)
101        self.assertTrue(self.obj.fortran)
102        h, f = self.obj.handle, self.obj.fortran
103        if (h > 0 and f > 0) or (h < 0 and f < 0):
104            self.assertEqual(h, f)
105        self.obj.destroy()
106        self.assertFalse(self.obj.handle)
107        self.assertFalse(self.obj.fortran)
108
109    def testComposeQuery(self):
110        import copy
111
112        try:
113            myobj = copy.deepcopy(self.obj)
114        except NotImplementedError:
115            return
116        self.assertEqual(myobj.getRefCount(), 1)
117        self.obj.compose('myobj', myobj)
118        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
119        self.assertEqual(self.obj.query('myobj'), myobj)
120        self.assertEqual(myobj.getRefCount(), 2)
121        self.obj.compose('myobj', None)
122        self.assertEqual(myobj.getRefCount(), 1)
123        self.assertEqual(self.obj.query('myobj'), None)
124        myobj.destroy()
125
126    def testProperties(self):
127        self.assertEqual(self.obj.getClassId(), self.obj.classid)
128        self.assertEqual(self.obj.getClassName(), self.obj.klass)
129        self.assertEqual(self.obj.getType(), self.obj.type)
130        self.assertEqual(self.obj.getName(), self.obj.name)
131        self.assertEqual(self.obj.getComm(), self.obj.comm)
132        self.assertEqual(self.obj.getRefCount(), self.obj.refcount)
133
134    def testShallowCopy(self):
135        import copy
136
137        rc = self.obj.getRefCount()
138        obj = copy.copy(self.obj)
139        self.assertTrue(obj is not self.obj)
140        self.assertTrue(obj == self.obj)
141        self.assertTrue(isinstance(obj, type(self.obj)))
142        self.assertEqual(obj.getRefCount(), rc + 1)
143        del obj
144        self.assertEqual(self.obj.getRefCount(), rc)
145
146    def testDeepCopy(self):
147        import copy
148
149        rc = self.obj.getRefCount()
150        try:
151            obj = copy.deepcopy(self.obj)
152        except NotImplementedError:
153            return
154        self.assertTrue(obj is not self.obj)
155        self.assertTrue(obj != self.obj)
156        self.assertTrue(isinstance(obj, type(self.obj)))
157        self.assertEqual(self.obj.getRefCount(), rc)
158        self.assertEqual(obj.getRefCount(), 1)
159        del obj
160
161    def testStateInspection(self):
162        state = self.obj.stateGet()
163        self.obj.stateIncrease()
164        self.assertTrue(state < self.obj.stateGet())
165        self.obj.stateSet(0)
166        self.assertTrue(self.obj.stateGet() == 0)
167        self.obj.stateSet(state)
168        self.assertTrue(self.obj.stateGet() == state)
169
170
171# --------------------------------------------------------------------
172
173
174class TestObjectRandom(BaseTestObject, unittest.TestCase):
175    CLASS = PETSc.Random
176    FACTORY = 'create'
177
178
179class TestObjectViewer(BaseTestObject, unittest.TestCase):
180    CLASS = PETSc.Viewer
181    FACTORY = 'create'
182
183
184class TestObjectIS(BaseTestObject, unittest.TestCase):
185    CLASS = PETSc.IS
186    FACTORY = 'createGeneral'
187    TARGS = ([],)
188
189
190class TestObjectLGMap(BaseTestObject, unittest.TestCase):
191    CLASS = PETSc.LGMap
192    FACTORY = 'create'
193    TARGS = ([],)
194
195
196class TestObjectAO(BaseTestObject, unittest.TestCase):
197    CLASS = PETSc.AO
198    FACTORY = 'createMapping'
199    TARGS = ([], [])
200
201
202class TestObjectDMDA(BaseTestObject, unittest.TestCase):
203    CLASS = PETSc.DMDA
204    FACTORY = 'create'
205    TARGS = ([3, 3, 3],)
206
207
208class TestObjectDS(BaseTestObject, unittest.TestCase):
209    CLASS = PETSc.DS
210    FACTORY = 'create'
211
212
213class TestObjectVec(BaseTestObject, unittest.TestCase):
214    CLASS = PETSc.Vec
215    FACTORY = 'createSeq'
216    TARGS = (0,)
217
218    def setUp(self):
219        BaseTestObject.setUp(self)
220        self.obj.assemble()
221
222
223class TestObjectMat(BaseTestObject, unittest.TestCase):
224    CLASS = PETSc.Mat
225    FACTORY = 'createAIJ'
226    TARGS = (0,)
227    KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF}
228
229    def setUp(self):
230        BaseTestObject.setUp(self)
231        self.obj.assemble()
232
233
234class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
235    CLASS = PETSc.MatPartitioning
236    FACTORY = 'create'
237
238
239class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
240    CLASS = PETSc.NullSpace
241    FACTORY = 'create'
242    TARGS = (True, [])
243
244
245class TestObjectKSP(BaseTestObject, unittest.TestCase):
246    CLASS = PETSc.KSP
247    FACTORY = 'create'
248
249
250class TestObjectPC(BaseTestObject, unittest.TestCase):
251    CLASS = PETSc.PC
252    FACTORY = 'create'
253
254
255class TestObjectSNES(BaseTestObject, unittest.TestCase):
256    CLASS = PETSc.SNES
257    FACTORY = 'create'
258
259
260class TestObjectTS(BaseTestObject, unittest.TestCase):
261    CLASS = PETSc.TS
262    FACTORY = 'create'
263
264    def setUp(self):
265        super().setUp()
266        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
267        self.obj.setType(PETSc.TS.Type.BEULER)
268
269
270class TestObjectTAO(BaseTestObject, unittest.TestCase):
271    CLASS = PETSc.TAO
272    FACTORY = 'create'
273
274
275class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
276    CLASS = PETSc.AO
277    FACTORY = 'createBasic'
278    TARGS = ([], [])
279
280
281class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
282    CLASS = PETSc.AO
283    FACTORY = 'createMapping'
284    TARGS = ([], [])
285
286
287# class TestObjectFE(BaseTestObject, unittest.TestCase):
288#     CLASS  = PETSc.FE
289#     FACTORY = 'create'
290#
291# class TestObjectQuad(BaseTestObject, unittest.TestCase):
292#     CLASS  = PETSc.Quad
293#     FACTORY = 'create'
294
295
296class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
297    CLASS = PETSc.DMLabel
298    FACTORY = 'create'
299    TARGS = ('test',)
300
301
302class TestObjectSpace(BaseTestObject, unittest.TestCase):
303    CLASS = PETSc.Space
304    FACTORY = 'create'
305
306
307class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
308    CLASS = PETSc.DualSpace
309    FACTORY = 'create'
310
311
312# --------------------------------------------------------------------
313
314if numpy.iscomplexobj(PETSc.ScalarType()):
315    del TestObjectTAO
316
317if __name__ == '__main__':
318    unittest.main()
319