xref: /petsc/src/binding/petsc4py/test/test_object.py (revision a32e9c995d3c9cc14233efbb30d372fdb63ce962)
1import unittest
2
3from petsc4py import PETSc
4
5# --------------------------------------------------------------------
6
7class BaseTestObject(object):
8
9    CLASS, FACTORY = None, None
10    TARGS, KARGS = (), {}
11    BUILD = None
12    def setUp(self):
13        self.obj = self.CLASS()
14        getattr(self.obj,self.FACTORY)(*self.TARGS, **self.KARGS)
15        if not self.obj: self.obj.create()
16
17    def tearDown(self):
18        self.obj = None
19        PETSc.garbage_cleanup()
20
21    def testTypeRegistry(self):
22        type_reg = PETSc.__type_registry__
23        classid = self.obj.getClassId()
24        typeobj = self.CLASS
25        if isinstance(self.obj, PETSc.DMDA):
26            typeobj = PETSc.DM
27        self.assertTrue(type_reg[classid] is typeobj )
28
29    def testLogClass(self):
30        name = self.CLASS.__name__
31        if name == 'DMDA': name = 'DM'
32        logcls = PETSc.Log.Class(name)
33        classid = self.obj.getClassId()
34        self.assertEqual(logcls.id, classid)
35
36    def testClass(self):
37        self.assertTrue(isinstance(self.obj, self.CLASS))
38        self.assertTrue(type(self.obj) is self.CLASS)
39
40    def testNonZero(self):
41        self.assertTrue(bool(self.obj))
42
43    def testDestroy(self):
44        self.assertTrue(bool(self.obj))
45        self.obj.destroy()
46        self.assertFalse(bool(self.obj))
47        ## self.assertRaises(PETSc.Error, self.obj.destroy)
48        ## self.assertTrue(self.obj.this is this)
49
50    def testOptions(self):
51        self.assertFalse(self.obj.getOptionsPrefix())
52        prefix1 = 'my_'
53        self.obj.setOptionsPrefix(prefix1)
54        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
55        prefix2 = 'opt_'
56        self.obj.setOptionsPrefix(prefix2)
57        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
58        ## self.obj.appendOptionsPrefix(prefix1)
59        ## self.assertEqual(self.obj.getOptionsPrefix(),
60        ##                  prefix2 + prefix1)
61        ## self.obj.prependOptionsPrefix(prefix1)
62        ## self.assertEqual(self.obj.getOptionsPrefix(),
63        ##                  prefix1 + prefix2 + prefix1)
64        self.obj.setFromOptions()
65
66    def testName(self):
67        oldname = self.obj.getName()
68        newname = '%s-%s' %(oldname, oldname)
69        self.obj.setName(newname)
70        self.assertEqual(self.obj.getName(), newname)
71        self.obj.setName(oldname)
72        self.assertEqual(self.obj.getName(), oldname)
73
74    def testComm(self):
75        comm = self.obj.getComm()
76        self.assertTrue(isinstance(comm, PETSc.Comm))
77        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
78
79    def testRefCount(self):
80        self.assertEqual(self.obj.getRefCount(), 1)
81        self.obj.incRef()
82        self.assertEqual(self.obj.getRefCount(), 2)
83        self.obj.incRef()
84        self.assertEqual(self.obj.getRefCount(), 3)
85        self.obj.decRef()
86        self.assertEqual(self.obj.getRefCount(), 2)
87        self.obj.decRef()
88        self.assertEqual(self.obj.getRefCount(), 1)
89        self.obj.decRef()
90        self.assertFalse(bool(self.obj))
91
92    def testHandle(self):
93        self.assertTrue(self.obj.handle)
94        self.assertTrue(self.obj.fortran)
95        h, f = self.obj.handle, self.obj.fortran
96        if (h>0 and f>0) or (h<0 and f<0):
97            self.assertEqual(h, f)
98        self.obj.destroy()
99        self.assertFalse(self.obj.handle)
100        self.assertFalse(self.obj.fortran)
101
102    def testComposeQuery(self):
103        import copy
104        try:
105            myobj = copy.deepcopy(self.obj)
106        except NotImplementedError:
107            return
108        self.assertEqual(myobj.getRefCount(), 1)
109        self.obj.compose('myobj', myobj)
110        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
111        self.assertEqual(self.obj.query('myobj'), myobj)
112        self.assertEqual(myobj.getRefCount(), 2)
113        self.obj.compose('myobj', None)
114        self.assertEqual(myobj.getRefCount(), 1)
115        self.assertEqual(self.obj.query('myobj'), None)
116        myobj.destroy()
117
118    def testProperties(self):
119        self.assertEqual(self.obj.getClassId(),   self.obj.classid)
120        self.assertEqual(self.obj.getClassName(), self.obj.klass)
121        self.assertEqual(self.obj.getType(),      self.obj.type)
122        self.assertEqual(self.obj.getName(),      self.obj.name)
123        self.assertEqual(self.obj.getComm(),      self.obj.comm)
124        self.assertEqual(self.obj.getRefCount(),  self.obj.refcount)
125
126    def testShallowCopy(self):
127        import copy
128        rc = self.obj.getRefCount()
129        obj = copy.copy(self.obj)
130        self.assertTrue(obj is not self.obj)
131        self.assertTrue(obj == self.obj)
132        self.assertTrue(type(obj) is type(self.obj))
133        self.assertEqual(obj.getRefCount(), rc+1)
134        del obj
135        self.assertEqual(self.obj.getRefCount(), rc)
136
137    def testDeepCopy(self):
138        import copy
139        rc = self.obj.getRefCount()
140        try:
141            obj = copy.deepcopy(self.obj)
142        except NotImplementedError:
143            return
144        self.assertTrue(obj is not self.obj)
145        self.assertTrue(obj != self.obj)
146        self.assertTrue(type(obj) is type(self.obj))
147        self.assertEqual(self.obj.getRefCount(), rc)
148        self.assertEqual(obj.getRefCount(), 1)
149        del obj
150
151    def testStateInspection(self):
152        state = self.obj.stateGet()
153        self.obj.stateIncrease()
154        self.assertTrue(state < self.obj.stateGet())
155        self.obj.stateSet(0)
156        self.assertTrue(self.obj.stateGet() == 0)
157        self.obj.stateSet(state)
158        self.assertTrue(self.obj.stateGet() == state)
159
160
161# --------------------------------------------------------------------
162
163class TestObjectRandom(BaseTestObject, unittest.TestCase):
164    CLASS = PETSc.Random
165    FACTORY = 'create'
166
167class TestObjectViewer(BaseTestObject, unittest.TestCase):
168    CLASS = PETSc.Viewer
169    FACTORY = 'create'
170
171class TestObjectIS(BaseTestObject, unittest.TestCase):
172    CLASS  = PETSc.IS
173    FACTORY = 'createGeneral'
174    TARGS = ([],)
175
176class TestObjectLGMap(BaseTestObject, unittest.TestCase):
177    CLASS = PETSc.LGMap
178    FACTORY = 'create'
179    TARGS = ([],)
180
181class TestObjectAO(BaseTestObject, unittest.TestCase):
182    CLASS  = PETSc.AO
183    FACTORY = 'createMapping'
184    TARGS = ([], [])
185
186class TestObjectDMDA(BaseTestObject, unittest.TestCase):
187    CLASS  = PETSc.DMDA
188    FACTORY = 'create'
189    TARGS = ([3,3,3],)
190
191class TestObjectDS(BaseTestObject, unittest.TestCase):
192    CLASS  = PETSc.DS
193    FACTORY = 'create'
194
195class TestObjectVec(BaseTestObject, unittest.TestCase):
196    CLASS   = PETSc.Vec
197    FACTORY = 'createSeq'
198    TARGS   = (0,)
199
200    def setUp(self):
201        BaseTestObject.setUp(self)
202        self.obj.assemble()
203
204class TestObjectMat(BaseTestObject, unittest.TestCase):
205    CLASS  = PETSc.Mat
206    FACTORY = 'createAIJ'
207    TARGS = (0,)
208    KARGS   = {'nnz':0, 'comm': PETSc.COMM_SELF}
209
210    def setUp(self):
211        BaseTestObject.setUp(self)
212        self.obj.assemble()
213
214class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
215    CLASS  = PETSc.MatPartitioning
216    FACTORY = 'create'
217
218class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
219    CLASS  = PETSc.NullSpace
220    FACTORY = 'create'
221    TARGS = (True, [])
222
223class TestObjectKSP(BaseTestObject, unittest.TestCase):
224    CLASS = PETSc.KSP
225    FACTORY = 'create'
226
227class TestObjectPC(BaseTestObject, unittest.TestCase):
228    CLASS = PETSc.PC
229    FACTORY = 'create'
230
231class TestObjectSNES(BaseTestObject, unittest.TestCase):
232    CLASS = PETSc.SNES
233    FACTORY = 'create'
234
235class TestObjectTS(BaseTestObject, unittest.TestCase):
236    CLASS  = PETSc.TS
237    FACTORY = 'create'
238    def setUp(self):
239        super(TestObjectTS, self).setUp()
240        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
241        self.obj.setType(PETSc.TS.Type.BEULER)
242
243class TestObjectTAO(BaseTestObject, unittest.TestCase):
244    CLASS  = PETSc.TAO
245    FACTORY = 'create'
246
247class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
248    CLASS  = PETSc.AO
249    FACTORY = 'createBasic'
250    TARGS = ([], [])
251
252class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
253    CLASS  = PETSc.AO
254    FACTORY = 'createMapping'
255    TARGS = ([], [])
256
257# class TestObjectFE(BaseTestObject, unittest.TestCase):
258#     CLASS  = PETSc.FE
259#     FACTORY = 'create'
260#
261# class TestObjectQuad(BaseTestObject, unittest.TestCase):
262#     CLASS  = PETSc.Quad
263#     FACTORY = 'create'
264
265class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
266    CLASS  = PETSc.DMLabel
267    FACTORY = 'create'
268    TARGS = ("test",)
269
270class TestObjectSpace(BaseTestObject, unittest.TestCase):
271    CLASS  = PETSc.Space
272    FACTORY = 'create'
273
274class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
275    CLASS  = PETSc.DualSpace
276    FACTORY = 'create'
277
278# --------------------------------------------------------------------
279
280import numpy
281
282if numpy.iscomplexobj(PETSc.ScalarType()):
283    del TestObjectTAO
284
285if __name__ == '__main__':
286    unittest.main()
287