xref: /petsc/src/binding/petsc4py/test/test_object.py (revision bef158480efac06de457f7a665168877ab3c2fd7)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6class BaseTestObject(object):
7
8    CLASS, FACTORY = None, None
9    TARGS, KARGS = (), {}
10    BUILD = None
11    def setUp(self):
12        self.obj = self.CLASS()
13        getattr(self.obj,self.FACTORY)(*self.TARGS, **self.KARGS)
14        if not self.obj: self.obj.create()
15
16    def tearDown(self):
17        self.obj = None
18
19    def testTypeRegistry(self):
20        type_reg = PETSc.__type_registry__
21        classid = self.obj.getClassId()
22        typeobj = self.CLASS
23        if isinstance(self.obj, PETSc.DMDA):
24            typeobj = PETSc.DM
25        self.assertTrue(type_reg[classid] is typeobj )
26
27    def testLogClass(self):
28        name = self.CLASS.__name__
29        if name == 'DMDA': name = 'DM'
30        logcls = PETSc.Log.Class(name)
31        classid = self.obj.getClassId()
32        self.assertEqual(logcls.id, classid)
33
34    def testClass(self):
35        self.assertTrue(isinstance(self.obj, self.CLASS))
36        self.assertTrue(type(self.obj) is self.CLASS)
37
38    def testNonZero(self):
39        self.assertTrue(bool(self.obj))
40
41    def testDestroy(self):
42        self.assertTrue(bool(self.obj))
43        self.obj.destroy()
44        self.assertFalse(bool(self.obj))
45        ## self.assertRaises(PETSc.Error, self.obj.destroy)
46        ## self.assertTrue(self.obj.this is this)
47
48    def testOptions(self):
49        self.assertFalse(self.obj.getOptionsPrefix())
50        prefix1 = 'my_'
51        self.obj.setOptionsPrefix(prefix1)
52        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
53        prefix2 = 'opt_'
54        self.obj.setOptionsPrefix(prefix2)
55        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
56        ## self.obj.appendOptionsPrefix(prefix1)
57        ## self.assertEqual(self.obj.getOptionsPrefix(),
58        ##                  prefix2 + prefix1)
59        ## self.obj.prependOptionsPrefix(prefix1)
60        ## self.assertEqual(self.obj.getOptionsPrefix(),
61        ##                  prefix1 + prefix2 + prefix1)
62        self.obj.setFromOptions()
63
64    def testName(self):
65        oldname = self.obj.getName()
66        newname = '%s-%s' %(oldname, oldname)
67        self.obj.setName(newname)
68        self.assertEqual(self.obj.getName(), newname)
69        self.obj.setName(oldname)
70        self.assertEqual(self.obj.getName(), oldname)
71
72    def testComm(self):
73        comm = self.obj.getComm()
74        self.assertTrue(isinstance(comm, PETSc.Comm))
75        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
76
77    def testRefCount(self):
78        self.assertEqual(self.obj.getRefCount(), 1)
79        self.obj.incRef()
80        self.assertEqual(self.obj.getRefCount(), 2)
81        self.obj.incRef()
82        self.assertEqual(self.obj.getRefCount(), 3)
83        self.obj.decRef()
84        self.assertEqual(self.obj.getRefCount(), 2)
85        self.obj.decRef()
86        self.assertEqual(self.obj.getRefCount(), 1)
87        self.obj.decRef()
88        self.assertFalse(bool(self.obj))
89
90    def testHandle(self):
91        self.assertTrue(self.obj.handle)
92        self.assertTrue(self.obj.fortran)
93        h, f = self.obj.handle, self.obj.fortran
94        if (h>0 and f>0) or (h<0 and f<0):
95            self.assertEqual(h, f)
96        self.obj.destroy()
97        self.assertFalse(self.obj.handle)
98        self.assertFalse(self.obj.fortran)
99
100    def testComposeQuery(self):
101        import copy
102        try:
103            myobj = copy.deepcopy(self.obj)
104        except NotImplementedError:
105            return
106        self.assertEqual(myobj.getRefCount(), 1)
107        self.obj.compose('myobj', myobj)
108        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
109        self.assertEqual(self.obj.query('myobj'), myobj)
110        self.assertEqual(myobj.getRefCount(), 2)
111        self.obj.compose('myobj', None)
112        self.assertEqual(myobj.getRefCount(), 1)
113        self.assertEqual(self.obj.query('myobj'), None)
114        myobj.destroy()
115
116    def testProperties(self):
117        self.assertEqual(self.obj.getClassId(),   self.obj.classid)
118        self.assertEqual(self.obj.getClassName(), self.obj.klass)
119        self.assertEqual(self.obj.getType(),      self.obj.type)
120        self.assertEqual(self.obj.getName(),      self.obj.name)
121        self.assertEqual(self.obj.getComm(),      self.obj.comm)
122        self.assertEqual(self.obj.getRefCount(),  self.obj.refcount)
123
124    def testShallowCopy(self):
125        import copy
126        rc = self.obj.getRefCount()
127        obj = copy.copy(self.obj)
128        self.assertTrue(obj is not self.obj)
129        self.assertTrue(obj == self.obj)
130        self.assertTrue(type(obj) is type(self.obj))
131        self.assertEqual(obj.getRefCount(), rc+1)
132        del obj
133        self.assertEqual(self.obj.getRefCount(), rc)
134
135    def testDeepCopy(self):
136        import copy
137        rc = self.obj.getRefCount()
138        try:
139            obj = copy.deepcopy(self.obj)
140        except NotImplementedError:
141            return
142        self.assertTrue(obj is not self.obj)
143        self.assertTrue(obj != self.obj)
144        self.assertTrue(type(obj) is type(self.obj))
145        self.assertEqual(self.obj.getRefCount(), rc)
146        self.assertEqual(obj.getRefCount(), 1)
147        del obj
148
149# --------------------------------------------------------------------
150
151class TestObjectRandom(BaseTestObject, unittest.TestCase):
152    CLASS = PETSc.Random
153    FACTORY = 'create'
154
155class TestObjectViewer(BaseTestObject, unittest.TestCase):
156    CLASS = PETSc.Viewer
157    FACTORY = 'create'
158
159class TestObjectIS(BaseTestObject, unittest.TestCase):
160    CLASS  = PETSc.IS
161    FACTORY = 'createGeneral'
162    TARGS = ([],)
163
164class TestObjectLGMap(BaseTestObject, unittest.TestCase):
165    CLASS = PETSc.LGMap
166    FACTORY = 'create'
167    TARGS = ([],)
168
169class TestObjectAO(BaseTestObject, unittest.TestCase):
170    CLASS  = PETSc.AO
171    FACTORY = 'createMapping'
172    TARGS = ([], [])
173
174class TestObjectDMDA(BaseTestObject, unittest.TestCase):
175    CLASS  = PETSc.DMDA
176    FACTORY = 'create'
177    TARGS = ([3,3,3],)
178
179class TestObjectDS(BaseTestObject, unittest.TestCase):
180    CLASS  = PETSc.DS
181    FACTORY = 'create'
182
183class TestObjectVec(BaseTestObject, unittest.TestCase):
184    CLASS   = PETSc.Vec
185    FACTORY = 'createSeq'
186    TARGS   = (0,)
187
188    def setUp(self):
189        BaseTestObject.setUp(self)
190        self.obj.assemble()
191
192class TestObjectScatter(BaseTestObject, unittest.TestCase):
193    CLASS  = PETSc.Scatter
194    FACTORY = 'create'
195    def setUp(self):
196        v1, v2 = PETSc.Vec().createSeq(0), PETSc.Vec().createSeq(0)
197        i1, i2 = PETSc.IS().createGeneral([]), PETSc.IS().createGeneral([])
198        self.obj = PETSc.Scatter().create(v1, i1, v2, i2)
199        del v1, v2, i1, i2
200
201class TestObjectMat(BaseTestObject, unittest.TestCase):
202    CLASS  = PETSc.Mat
203    FACTORY = 'createAIJ'
204    TARGS = (0,)
205    KARGS   = {'nnz':0, 'comm': PETSc.COMM_SELF}
206
207    def setUp(self):
208        BaseTestObject.setUp(self)
209        self.obj.assemble()
210
211class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
212    CLASS  = PETSc.NullSpace
213    FACTORY = 'create'
214    TARGS = (True, [])
215
216class TestObjectKSP(BaseTestObject, unittest.TestCase):
217    CLASS = PETSc.KSP
218    FACTORY = 'create'
219
220class TestObjectPC(BaseTestObject, unittest.TestCase):
221    CLASS = PETSc.PC
222    FACTORY = 'create'
223
224class TestObjectSNES(BaseTestObject, unittest.TestCase):
225    CLASS = PETSc.SNES
226    FACTORY = 'create'
227
228class TestObjectTS(BaseTestObject, unittest.TestCase):
229    CLASS  = PETSc.TS
230    FACTORY = 'create'
231    def setUp(self):
232        super(TestObjectTS, self).setUp()
233        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
234        self.obj.setType(PETSc.TS.Type.BEULER)
235
236class TestObjectTAO(BaseTestObject, unittest.TestCase):
237    CLASS  = PETSc.TAO
238    FACTORY = 'create'
239
240class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
241    CLASS  = PETSc.AO
242    FACTORY = 'createBasic'
243    TARGS = ([], [])
244
245class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
246    CLASS  = PETSc.AO
247    FACTORY = 'createMapping'
248    TARGS = ([], [])
249
250# class TestObjectFE(BaseTestObject, unittest.TestCase):
251#     CLASS  = PETSc.FE
252#     FACTORY = 'create'
253#
254# class TestObjectQuad(BaseTestObject, unittest.TestCase):
255#     CLASS  = PETSc.Quad
256#     FACTORY = 'create'
257
258class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
259    CLASS  = PETSc.DMLabel
260    FACTORY = 'create'
261    TARGS = ("test",)
262
263# --------------------------------------------------------------------
264
265import numpy
266if numpy.iscomplexobj(PETSc.ScalarType()):
267    del TestObjectTAO
268
269if __name__ == '__main__':
270    unittest.main()
271