1import unittest 2 3from petsc4py import PETSc 4import numpy 5 6# -------------------------------------------------------------------- 7 8 9class BaseTestObject: 10 CLASS, FACTORY = None, None 11 TARGS, KARGS = (), {} 12 BUILD = None 13 14 def setUp(self): 15 self.obj = self.CLASS() 16 getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS) 17 if not self.obj: 18 self.obj.create() 19 20 def tearDown(self): 21 self.obj = None 22 PETSc.garbage_cleanup() 23 24 def testTypeRegistry(self): 25 type_reg = PETSc.__type_registry__ 26 classid = self.obj.getClassId() 27 typeobj = self.CLASS 28 if isinstance(self.obj, PETSc.DMDA): 29 typeobj = PETSc.DM 30 self.assertTrue(type_reg[classid] is typeobj) 31 32 def testLogClass(self): 33 name = self.CLASS.__name__ 34 if name == 'DMDA': 35 name = 'DM' 36 logcls = PETSc.Log.Class(name) 37 classid = self.obj.getClassId() 38 self.assertEqual(logcls.id, classid) 39 40 def testClass(self): 41 self.assertTrue(isinstance(self.obj, self.CLASS)) 42 self.assertTrue(type(self.obj) is self.CLASS) 43 44 def testId(self): 45 oid = self.obj.getId() 46 self.assertTrue(oid > 0) 47 self.assertEqual(self.obj.id, oid) 48 49 def testNonZero(self): 50 self.assertTrue(bool(self.obj)) 51 52 def testDestroy(self): 53 self.assertTrue(bool(self.obj)) 54 self.obj.destroy() 55 self.assertFalse(bool(self.obj)) 56 ## self.assertRaises(PETSc.Error, self.obj.destroy) 57 ## self.assertTrue(self.obj.this is this) 58 59 def testOptions(self): 60 self.assertFalse(self.obj.getOptionsPrefix()) 61 prefix1 = 'my_' 62 self.obj.setOptionsPrefix(prefix1) 63 self.assertEqual(self.obj.getOptionsPrefix(), prefix1) 64 prefix2 = 'opt_' 65 self.obj.setOptionsPrefix(prefix2) 66 self.assertEqual(self.obj.getOptionsPrefix(), prefix2) 67 self.obj.appendOptionsPrefix(prefix1) 68 self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1) 69 self.obj.setOptionsPrefix(None) 70 self.assertEqual(self.obj.getOptionsPrefix(), None) 71 self.obj.setFromOptions() 72 73 def testName(self): 74 oldname = self.obj.getName() 75 newname = f'{oldname}-{oldname}' 76 self.obj.setName(newname) 77 self.assertEqual(self.obj.getName(), newname) 78 self.obj.setName(oldname) 79 self.assertEqual(self.obj.getName(), oldname) 80 81 def testComm(self): 82 comm = self.obj.getComm() 83 self.assertTrue(isinstance(comm, PETSc.Comm)) 84 self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD]) 85 86 def testRefCount(self): 87 self.assertEqual(self.obj.getRefCount(), 1) 88 self.obj.incRef() 89 self.assertEqual(self.obj.getRefCount(), 2) 90 self.obj.incRef() 91 self.assertEqual(self.obj.getRefCount(), 3) 92 self.obj.decRef() 93 self.assertEqual(self.obj.getRefCount(), 2) 94 self.obj.decRef() 95 self.assertEqual(self.obj.getRefCount(), 1) 96 self.obj.decRef() 97 self.assertFalse(bool(self.obj)) 98 99 def testHandle(self): 100 self.assertTrue(self.obj.handle) 101 self.assertTrue(self.obj.fortran) 102 h, f = self.obj.handle, self.obj.fortran 103 if (h > 0 and f > 0) or (h < 0 and f < 0): 104 self.assertEqual(h, f) 105 self.obj.destroy() 106 self.assertFalse(self.obj.handle) 107 self.assertFalse(self.obj.fortran) 108 109 def testComposeQuery(self): 110 import copy 111 112 try: 113 myobj = copy.deepcopy(self.obj) 114 except NotImplementedError: 115 return 116 self.assertEqual(myobj.getRefCount(), 1) 117 self.obj.compose('myobj', myobj) 118 self.assertTrue(type(self.obj.query('myobj')) is self.CLASS) 119 self.assertEqual(self.obj.query('myobj'), myobj) 120 self.assertEqual(myobj.getRefCount(), 2) 121 self.obj.compose('myobj', None) 122 self.assertEqual(myobj.getRefCount(), 1) 123 self.assertEqual(self.obj.query('myobj'), None) 124 myobj.destroy() 125 126 def testProperties(self): 127 self.assertEqual(self.obj.getClassId(), self.obj.classid) 128 self.assertEqual(self.obj.getClassName(), self.obj.klass) 129 self.assertEqual(self.obj.getType(), self.obj.type) 130 self.assertEqual(self.obj.getName(), self.obj.name) 131 self.assertEqual(self.obj.getComm(), self.obj.comm) 132 self.assertEqual(self.obj.getRefCount(), self.obj.refcount) 133 134 def testShallowCopy(self): 135 import copy 136 137 rc = self.obj.getRefCount() 138 obj = copy.copy(self.obj) 139 self.assertTrue(obj is not self.obj) 140 self.assertTrue(obj == self.obj) 141 self.assertTrue(isinstance(obj, type(self.obj))) 142 self.assertEqual(obj.getRefCount(), rc + 1) 143 del obj 144 self.assertEqual(self.obj.getRefCount(), rc) 145 146 def testDeepCopy(self): 147 import copy 148 149 rc = self.obj.getRefCount() 150 try: 151 obj = copy.deepcopy(self.obj) 152 except NotImplementedError: 153 return 154 self.assertTrue(obj is not self.obj) 155 self.assertTrue(obj != self.obj) 156 self.assertTrue(isinstance(obj, type(self.obj))) 157 self.assertEqual(self.obj.getRefCount(), rc) 158 self.assertEqual(obj.getRefCount(), 1) 159 del obj 160 161 def testStateInspection(self): 162 state = self.obj.stateGet() 163 self.obj.stateIncrease() 164 self.assertTrue(state < self.obj.stateGet()) 165 self.obj.stateSet(0) 166 self.assertTrue(self.obj.stateGet() == 0) 167 self.obj.stateSet(state) 168 self.assertTrue(self.obj.stateGet() == state) 169 170 171# -------------------------------------------------------------------- 172 173 174class TestObjectRandom(BaseTestObject, unittest.TestCase): 175 CLASS = PETSc.Random 176 FACTORY = 'create' 177 178 179class TestObjectViewer(BaseTestObject, unittest.TestCase): 180 CLASS = PETSc.Viewer 181 FACTORY = 'create' 182 183 184class TestObjectIS(BaseTestObject, unittest.TestCase): 185 CLASS = PETSc.IS 186 FACTORY = 'createGeneral' 187 TARGS = ([],) 188 189 190class TestObjectLGMap(BaseTestObject, unittest.TestCase): 191 CLASS = PETSc.LGMap 192 FACTORY = 'create' 193 TARGS = ([],) 194 195 196class TestObjectAO(BaseTestObject, unittest.TestCase): 197 CLASS = PETSc.AO 198 FACTORY = 'createMapping' 199 TARGS = ([], []) 200 201 202class TestObjectDMDA(BaseTestObject, unittest.TestCase): 203 CLASS = PETSc.DMDA 204 FACTORY = 'create' 205 TARGS = ([3, 3, 3],) 206 207 208class TestObjectDS(BaseTestObject, unittest.TestCase): 209 CLASS = PETSc.DS 210 FACTORY = 'create' 211 212 213class TestObjectVec(BaseTestObject, unittest.TestCase): 214 CLASS = PETSc.Vec 215 FACTORY = 'createSeq' 216 TARGS = (0,) 217 218 def setUp(self): 219 BaseTestObject.setUp(self) 220 self.obj.assemble() 221 222 223class TestObjectMat(BaseTestObject, unittest.TestCase): 224 CLASS = PETSc.Mat 225 FACTORY = 'createAIJ' 226 TARGS = (0,) 227 KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF} 228 229 def setUp(self): 230 BaseTestObject.setUp(self) 231 self.obj.assemble() 232 233 234class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase): 235 CLASS = PETSc.MatPartitioning 236 FACTORY = 'create' 237 238 239class TestObjectNullSpace(BaseTestObject, unittest.TestCase): 240 CLASS = PETSc.NullSpace 241 FACTORY = 'create' 242 TARGS = (True, []) 243 244 245class TestObjectKSP(BaseTestObject, unittest.TestCase): 246 CLASS = PETSc.KSP 247 FACTORY = 'create' 248 249 250class TestObjectPC(BaseTestObject, unittest.TestCase): 251 CLASS = PETSc.PC 252 FACTORY = 'create' 253 254 255class TestObjectSNES(BaseTestObject, unittest.TestCase): 256 CLASS = PETSc.SNES 257 FACTORY = 'create' 258 259 260class TestObjectTS(BaseTestObject, unittest.TestCase): 261 CLASS = PETSc.TS 262 FACTORY = 'create' 263 264 def setUp(self): 265 super().setUp() 266 self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR) 267 self.obj.setType(PETSc.TS.Type.BEULER) 268 269 270class TestObjectTAO(BaseTestObject, unittest.TestCase): 271 CLASS = PETSc.TAO 272 FACTORY = 'create' 273 274 275class TestObjectAOBasic(BaseTestObject, unittest.TestCase): 276 CLASS = PETSc.AO 277 FACTORY = 'createBasic' 278 TARGS = ([], []) 279 280 281class TestObjectAOMapping(BaseTestObject, unittest.TestCase): 282 CLASS = PETSc.AO 283 FACTORY = 'createMapping' 284 TARGS = ([], []) 285 286 287# class TestObjectFE(BaseTestObject, unittest.TestCase): 288# CLASS = PETSc.FE 289# FACTORY = 'create' 290# 291# class TestObjectQuad(BaseTestObject, unittest.TestCase): 292# CLASS = PETSc.Quad 293# FACTORY = 'create' 294 295 296class TestObjectDMLabel(BaseTestObject, unittest.TestCase): 297 CLASS = PETSc.DMLabel 298 FACTORY = 'create' 299 TARGS = ('test',) 300 301 302class TestObjectSpace(BaseTestObject, unittest.TestCase): 303 CLASS = PETSc.Space 304 FACTORY = 'create' 305 306 307class TestObjectDualSpace(BaseTestObject, unittest.TestCase): 308 CLASS = PETSc.DualSpace 309 FACTORY = 'create' 310 311 312# -------------------------------------------------------------------- 313 314if numpy.iscomplexobj(PETSc.ScalarType()): 315 del TestObjectTAO 316 317if __name__ == '__main__': 318 unittest.main() 319