xref: /petsc/src/binding/petsc4py/test/test_object.py (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1import unittest
2
3from petsc4py import PETSc
4
5# --------------------------------------------------------------------
6
7class BaseTestObject(object):
8
9    CLASS, FACTORY = None, None
10    TARGS, KARGS = (), {}
11    BUILD = None
12    def setUp(self):
13        self.obj = self.CLASS()
14        getattr(self.obj,self.FACTORY)(*self.TARGS, **self.KARGS)
15        if not self.obj: self.obj.create()
16
17    def tearDown(self):
18        self.obj = None
19
20    def testTypeRegistry(self):
21        type_reg = PETSc.__type_registry__
22        classid = self.obj.getClassId()
23        typeobj = self.CLASS
24        if isinstance(self.obj, PETSc.DMDA):
25            typeobj = PETSc.DM
26        self.assertTrue(type_reg[classid] is typeobj )
27
28    def testLogClass(self):
29        name = self.CLASS.__name__
30        if name == 'DMDA': name = 'DM'
31        logcls = PETSc.Log.Class(name)
32        classid = self.obj.getClassId()
33        self.assertEqual(logcls.id, classid)
34
35    def testClass(self):
36        self.assertTrue(isinstance(self.obj, self.CLASS))
37        self.assertTrue(type(self.obj) is self.CLASS)
38
39    def testNonZero(self):
40        self.assertTrue(bool(self.obj))
41
42    def testDestroy(self):
43        self.assertTrue(bool(self.obj))
44        self.obj.destroy()
45        self.assertFalse(bool(self.obj))
46        ## self.assertRaises(PETSc.Error, self.obj.destroy)
47        ## self.assertTrue(self.obj.this is this)
48
49    def testOptions(self):
50        self.assertFalse(self.obj.getOptionsPrefix())
51        prefix1 = 'my_'
52        self.obj.setOptionsPrefix(prefix1)
53        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
54        prefix2 = 'opt_'
55        self.obj.setOptionsPrefix(prefix2)
56        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
57        ## self.obj.appendOptionsPrefix(prefix1)
58        ## self.assertEqual(self.obj.getOptionsPrefix(),
59        ##                  prefix2 + prefix1)
60        ## self.obj.prependOptionsPrefix(prefix1)
61        ## self.assertEqual(self.obj.getOptionsPrefix(),
62        ##                  prefix1 + prefix2 + prefix1)
63        self.obj.setFromOptions()
64
65    def testName(self):
66        oldname = self.obj.getName()
67        newname = '%s-%s' %(oldname, oldname)
68        self.obj.setName(newname)
69        self.assertEqual(self.obj.getName(), newname)
70        self.obj.setName(oldname)
71        self.assertEqual(self.obj.getName(), oldname)
72
73    def testComm(self):
74        comm = self.obj.getComm()
75        self.assertTrue(isinstance(comm, PETSc.Comm))
76        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
77
78    def testRefCount(self):
79        self.assertEqual(self.obj.getRefCount(), 1)
80        self.obj.incRef()
81        self.assertEqual(self.obj.getRefCount(), 2)
82        self.obj.incRef()
83        self.assertEqual(self.obj.getRefCount(), 3)
84        self.obj.decRef()
85        self.assertEqual(self.obj.getRefCount(), 2)
86        self.obj.decRef()
87        self.assertEqual(self.obj.getRefCount(), 1)
88        self.obj.decRef()
89        self.assertFalse(bool(self.obj))
90
91    def testHandle(self):
92        self.assertTrue(self.obj.handle)
93        self.assertTrue(self.obj.fortran)
94        h, f = self.obj.handle, self.obj.fortran
95        if (h>0 and f>0) or (h<0 and f<0):
96            self.assertEqual(h, f)
97        self.obj.destroy()
98        self.assertFalse(self.obj.handle)
99        self.assertFalse(self.obj.fortran)
100
101    def testComposeQuery(self):
102        import copy
103        try:
104            myobj = copy.deepcopy(self.obj)
105        except NotImplementedError:
106            return
107        self.assertEqual(myobj.getRefCount(), 1)
108        self.obj.compose('myobj', myobj)
109        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
110        self.assertEqual(self.obj.query('myobj'), myobj)
111        self.assertEqual(myobj.getRefCount(), 2)
112        self.obj.compose('myobj', None)
113        self.assertEqual(myobj.getRefCount(), 1)
114        self.assertEqual(self.obj.query('myobj'), None)
115        myobj.destroy()
116
117    def testProperties(self):
118        self.assertEqual(self.obj.getClassId(),   self.obj.classid)
119        self.assertEqual(self.obj.getClassName(), self.obj.klass)
120        self.assertEqual(self.obj.getType(),      self.obj.type)
121        self.assertEqual(self.obj.getName(),      self.obj.name)
122        self.assertEqual(self.obj.getComm(),      self.obj.comm)
123        self.assertEqual(self.obj.getRefCount(),  self.obj.refcount)
124
125    def testShallowCopy(self):
126        import copy
127        rc = self.obj.getRefCount()
128        obj = copy.copy(self.obj)
129        self.assertTrue(obj is not self.obj)
130        self.assertTrue(obj == self.obj)
131        self.assertTrue(type(obj) is type(self.obj))
132        self.assertEqual(obj.getRefCount(), rc+1)
133        del obj
134        self.assertEqual(self.obj.getRefCount(), rc)
135
136    def testDeepCopy(self):
137        import copy
138        rc = self.obj.getRefCount()
139        try:
140            obj = copy.deepcopy(self.obj)
141        except NotImplementedError:
142            return
143        self.assertTrue(obj is not self.obj)
144        self.assertTrue(obj != self.obj)
145        self.assertTrue(type(obj) is type(self.obj))
146        self.assertEqual(self.obj.getRefCount(), rc)
147        self.assertEqual(obj.getRefCount(), 1)
148        del obj
149
150    def testStateInspection(self):
151        state = self.obj.stateGet()
152        self.obj.stateIncrease()
153        self.assertTrue(state < self.obj.stateGet())
154        self.obj.stateSet(0)
155        self.assertTrue(self.obj.stateGet() == 0)
156        self.obj.stateSet(state)
157        self.assertTrue(self.obj.stateGet() == state)
158
159
160# --------------------------------------------------------------------
161
162class TestObjectRandom(BaseTestObject, unittest.TestCase):
163    CLASS = PETSc.Random
164    FACTORY = 'create'
165
166class TestObjectViewer(BaseTestObject, unittest.TestCase):
167    CLASS = PETSc.Viewer
168    FACTORY = 'create'
169
170class TestObjectIS(BaseTestObject, unittest.TestCase):
171    CLASS  = PETSc.IS
172    FACTORY = 'createGeneral'
173    TARGS = ([],)
174
175class TestObjectLGMap(BaseTestObject, unittest.TestCase):
176    CLASS = PETSc.LGMap
177    FACTORY = 'create'
178    TARGS = ([],)
179
180class TestObjectAO(BaseTestObject, unittest.TestCase):
181    CLASS  = PETSc.AO
182    FACTORY = 'createMapping'
183    TARGS = ([], [])
184
185class TestObjectDMDA(BaseTestObject, unittest.TestCase):
186    CLASS  = PETSc.DMDA
187    FACTORY = 'create'
188    TARGS = ([3,3,3],)
189
190class TestObjectDS(BaseTestObject, unittest.TestCase):
191    CLASS  = PETSc.DS
192    FACTORY = 'create'
193
194class TestObjectVec(BaseTestObject, unittest.TestCase):
195    CLASS   = PETSc.Vec
196    FACTORY = 'createSeq'
197    TARGS   = (0,)
198
199    def setUp(self):
200        BaseTestObject.setUp(self)
201        self.obj.assemble()
202
203class TestObjectMat(BaseTestObject, unittest.TestCase):
204    CLASS  = PETSc.Mat
205    FACTORY = 'createAIJ'
206    TARGS = (0,)
207    KARGS   = {'nnz':0, 'comm': PETSc.COMM_SELF}
208
209    def setUp(self):
210        BaseTestObject.setUp(self)
211        self.obj.assemble()
212
213class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase):
214    CLASS  = PETSc.MatPartitioning
215    FACTORY = 'create'
216
217class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
218    CLASS  = PETSc.NullSpace
219    FACTORY = 'create'
220    TARGS = (True, [])
221
222class TestObjectKSP(BaseTestObject, unittest.TestCase):
223    CLASS = PETSc.KSP
224    FACTORY = 'create'
225
226class TestObjectPC(BaseTestObject, unittest.TestCase):
227    CLASS = PETSc.PC
228    FACTORY = 'create'
229
230class TestObjectSNES(BaseTestObject, unittest.TestCase):
231    CLASS = PETSc.SNES
232    FACTORY = 'create'
233
234class TestObjectTS(BaseTestObject, unittest.TestCase):
235    CLASS  = PETSc.TS
236    FACTORY = 'create'
237    def setUp(self):
238        super(TestObjectTS, self).setUp()
239        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
240        self.obj.setType(PETSc.TS.Type.BEULER)
241
242class TestObjectTAO(BaseTestObject, unittest.TestCase):
243    CLASS  = PETSc.TAO
244    FACTORY = 'create'
245
246class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
247    CLASS  = PETSc.AO
248    FACTORY = 'createBasic'
249    TARGS = ([], [])
250
251class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
252    CLASS  = PETSc.AO
253    FACTORY = 'createMapping'
254    TARGS = ([], [])
255
256# class TestObjectFE(BaseTestObject, unittest.TestCase):
257#     CLASS  = PETSc.FE
258#     FACTORY = 'create'
259#
260# class TestObjectQuad(BaseTestObject, unittest.TestCase):
261#     CLASS  = PETSc.Quad
262#     FACTORY = 'create'
263
264class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
265    CLASS  = PETSc.DMLabel
266    FACTORY = 'create'
267    TARGS = ("test",)
268
269class TestObjectSpace(BaseTestObject, unittest.TestCase):
270    CLASS  = PETSc.Space
271    FACTORY = 'create'
272
273class TestObjectDualSpace(BaseTestObject, unittest.TestCase):
274    CLASS  = PETSc.DualSpace
275    FACTORY = 'create'
276
277# --------------------------------------------------------------------
278
279import numpy
280
281if numpy.iscomplexobj(PETSc.ScalarType()):
282    del TestObjectTAO
283
284if __name__ == '__main__':
285    unittest.main()
286