xref: /petsc/src/binding/petsc4py/test/test_object.py (revision b698fc57f0bea7237255b29c1b77df0acc362ffd)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6class BaseTestObject(object):
7
8    CLASS, FACTORY = None, None
9    TARGS, KARGS = (), {}
10    BUILD = None
11    def setUp(self):
12        self.obj = self.CLASS()
13        getattr(self.obj,self.FACTORY)(*self.TARGS, **self.KARGS)
14        if not self.obj: self.obj.create()
15
16    def tearDown(self):
17        self.obj = None
18
19    def testTypeRegistry(self):
20        type_reg = PETSc.__type_registry__
21        classid = self.obj.getClassId()
22        typeobj = self.CLASS
23        if isinstance(self.obj, PETSc.DMDA):
24            typeobj = PETSc.DM
25        self.assertTrue(type_reg[classid] is typeobj )
26
27    def testLogClass(self):
28        name = self.CLASS.__name__
29        if name == 'DMDA': name = 'DM'
30        logcls = PETSc.Log.Class(name)
31        classid = self.obj.getClassId()
32        self.assertEqual(logcls.id, classid)
33
34    def testClass(self):
35        self.assertTrue(isinstance(self.obj, self.CLASS))
36        self.assertTrue(type(self.obj) is self.CLASS)
37
38    def testNonZero(self):
39        self.assertTrue(bool(self.obj))
40
41    def testDestroy(self):
42        self.assertTrue(bool(self.obj))
43        self.obj.destroy()
44        self.assertFalse(bool(self.obj))
45        ## self.assertRaises(PETSc.Error, self.obj.destroy)
46        ## self.assertTrue(self.obj.this is this)
47
48    def testOptions(self):
49        self.assertFalse(self.obj.getOptionsPrefix())
50        prefix1 = 'my_'
51        self.obj.setOptionsPrefix(prefix1)
52        self.assertEqual(self.obj.getOptionsPrefix(), prefix1)
53        prefix2 = 'opt_'
54        self.obj.setOptionsPrefix(prefix2)
55        self.assertEqual(self.obj.getOptionsPrefix(), prefix2)
56        ## self.obj.appendOptionsPrefix(prefix1)
57        ## self.assertEqual(self.obj.getOptionsPrefix(),
58        ##                  prefix2 + prefix1)
59        ## self.obj.prependOptionsPrefix(prefix1)
60        ## self.assertEqual(self.obj.getOptionsPrefix(),
61        ##                  prefix1 + prefix2 + prefix1)
62        self.obj.setFromOptions()
63
64    def testName(self):
65        oldname = self.obj.getName()
66        newname = '%s-%s' %(oldname, oldname)
67        self.obj.setName(newname)
68        self.assertEqual(self.obj.getName(), newname)
69        self.obj.setName(oldname)
70        self.assertEqual(self.obj.getName(), oldname)
71
72    def testComm(self):
73        comm = self.obj.getComm()
74        self.assertTrue(isinstance(comm, PETSc.Comm))
75        self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD])
76
77    def testRefCount(self):
78        self.assertEqual(self.obj.getRefCount(), 1)
79        self.obj.incRef()
80        self.assertEqual(self.obj.getRefCount(), 2)
81        self.obj.incRef()
82        self.assertEqual(self.obj.getRefCount(), 3)
83        self.obj.decRef()
84        self.assertEqual(self.obj.getRefCount(), 2)
85        self.obj.decRef()
86        self.assertEqual(self.obj.getRefCount(), 1)
87        self.obj.decRef()
88        self.assertFalse(bool(self.obj))
89
90    def testHandle(self):
91        self.assertTrue(self.obj.handle)
92        self.assertTrue(self.obj.fortran)
93        h, f = self.obj.handle, self.obj.fortran
94        if (h>0 and f>0) or (h<0 and f<0):
95            self.assertEqual(h, f)
96        self.obj.destroy()
97        self.assertFalse(self.obj.handle)
98        self.assertFalse(self.obj.fortran)
99
100    def testComposeQuery(self):
101        import copy
102        try:
103            myobj = copy.deepcopy(self.obj)
104        except NotImplementedError:
105            return
106        self.assertEqual(myobj.getRefCount(), 1)
107        self.obj.compose('myobj', myobj)
108        self.assertTrue(type(self.obj.query('myobj')) is self.CLASS)
109        self.assertEqual(self.obj.query('myobj'), myobj)
110        self.assertEqual(myobj.getRefCount(), 2)
111        self.obj.compose('myobj', None)
112        self.assertEqual(myobj.getRefCount(), 1)
113        self.assertEqual(self.obj.query('myobj'), None)
114        myobj.destroy()
115
116    def testProperties(self):
117        self.assertEqual(self.obj.getClassId(),   self.obj.classid)
118        self.assertEqual(self.obj.getClassName(), self.obj.klass)
119        self.assertEqual(self.obj.getType(),      self.obj.type)
120        self.assertEqual(self.obj.getName(),      self.obj.name)
121        self.assertEqual(self.obj.getComm(),      self.obj.comm)
122        self.assertEqual(self.obj.getRefCount(),  self.obj.refcount)
123
124    def testShallowCopy(self):
125        import copy
126        rc = self.obj.getRefCount()
127        obj = copy.copy(self.obj)
128        self.assertTrue(obj is not self.obj)
129        self.assertTrue(obj == self.obj)
130        self.assertTrue(type(obj) is type(self.obj))
131        self.assertEqual(obj.getRefCount(), rc+1)
132        del obj
133        self.assertEqual(self.obj.getRefCount(), rc)
134
135    def testDeepCopy(self):
136        import copy
137        rc = self.obj.getRefCount()
138        try:
139            obj = copy.deepcopy(self.obj)
140        except NotImplementedError:
141            return
142        self.assertTrue(obj is not self.obj)
143        self.assertTrue(obj != self.obj)
144        self.assertTrue(type(obj) is type(self.obj))
145        self.assertEqual(self.obj.getRefCount(), rc)
146        self.assertEqual(obj.getRefCount(), 1)
147        del obj
148
149# --------------------------------------------------------------------
150
151class TestObjectRandom(BaseTestObject, unittest.TestCase):
152    CLASS = PETSc.Random
153    FACTORY = 'create'
154
155class TestObjectViewer(BaseTestObject, unittest.TestCase):
156    CLASS = PETSc.Viewer
157    FACTORY = 'create'
158
159class TestObjectIS(BaseTestObject, unittest.TestCase):
160    CLASS  = PETSc.IS
161    FACTORY = 'createGeneral'
162    TARGS = ([],)
163
164class TestObjectLGMap(BaseTestObject, unittest.TestCase):
165    CLASS = PETSc.LGMap
166    FACTORY = 'create'
167    TARGS = ([],)
168
169class TestObjectAO(BaseTestObject, unittest.TestCase):
170    CLASS  = PETSc.AO
171    FACTORY = 'createMapping'
172    TARGS = ([], [])
173
174class TestObjectDMDA(BaseTestObject, unittest.TestCase):
175    CLASS  = PETSc.DMDA
176    FACTORY = 'create'
177    TARGS = ([3,3,3],)
178
179class TestObjectDS(BaseTestObject, unittest.TestCase):
180    CLASS  = PETSc.DS
181    FACTORY = 'create'
182
183class TestObjectVec(BaseTestObject, unittest.TestCase):
184    CLASS   = PETSc.Vec
185    FACTORY = 'createSeq'
186    TARGS   = (0,)
187
188    def setUp(self):
189        BaseTestObject.setUp(self)
190        self.obj.assemble()
191
192class TestObjectMat(BaseTestObject, unittest.TestCase):
193    CLASS  = PETSc.Mat
194    FACTORY = 'createAIJ'
195    TARGS = (0,)
196    KARGS   = {'nnz':0, 'comm': PETSc.COMM_SELF}
197
198    def setUp(self):
199        BaseTestObject.setUp(self)
200        self.obj.assemble()
201
202class TestObjectNullSpace(BaseTestObject, unittest.TestCase):
203    CLASS  = PETSc.NullSpace
204    FACTORY = 'create'
205    TARGS = (True, [])
206
207class TestObjectKSP(BaseTestObject, unittest.TestCase):
208    CLASS = PETSc.KSP
209    FACTORY = 'create'
210
211class TestObjectPC(BaseTestObject, unittest.TestCase):
212    CLASS = PETSc.PC
213    FACTORY = 'create'
214
215class TestObjectSNES(BaseTestObject, unittest.TestCase):
216    CLASS = PETSc.SNES
217    FACTORY = 'create'
218
219class TestObjectTS(BaseTestObject, unittest.TestCase):
220    CLASS  = PETSc.TS
221    FACTORY = 'create'
222    def setUp(self):
223        super(TestObjectTS, self).setUp()
224        self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR)
225        self.obj.setType(PETSc.TS.Type.BEULER)
226
227class TestObjectTAO(BaseTestObject, unittest.TestCase):
228    CLASS  = PETSc.TAO
229    FACTORY = 'create'
230
231class TestObjectAOBasic(BaseTestObject, unittest.TestCase):
232    CLASS  = PETSc.AO
233    FACTORY = 'createBasic'
234    TARGS = ([], [])
235
236class TestObjectAOMapping(BaseTestObject, unittest.TestCase):
237    CLASS  = PETSc.AO
238    FACTORY = 'createMapping'
239    TARGS = ([], [])
240
241# class TestObjectFE(BaseTestObject, unittest.TestCase):
242#     CLASS  = PETSc.FE
243#     FACTORY = 'create'
244#
245# class TestObjectQuad(BaseTestObject, unittest.TestCase):
246#     CLASS  = PETSc.Quad
247#     FACTORY = 'create'
248
249class TestObjectDMLabel(BaseTestObject, unittest.TestCase):
250    CLASS  = PETSc.DMLabel
251    FACTORY = 'create'
252    TARGS = ("test",)
253
254# --------------------------------------------------------------------
255
256import numpy
257if numpy.iscomplexobj(PETSc.ScalarType()):
258    del TestObjectTAO
259
260if __name__ == '__main__':
261    unittest.main()
262