1import unittest 2 3from petsc4py import PETSc 4import numpy 5 6# -------------------------------------------------------------------- 7 8 9class BaseTestObject: 10 CLASS, FACTORY = None, None 11 TARGS, KARGS = (), {} 12 BUILD = None 13 14 def setUp(self): 15 self.obj = self.CLASS() 16 getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS) 17 if not self.obj: 18 self.obj.create() 19 20 def tearDown(self): 21 self.obj = None 22 PETSc.garbage_cleanup() 23 24 def testTypeRegistry(self): 25 type_reg = PETSc.__type_registry__ 26 classid = self.obj.getClassId() 27 typeobj = self.CLASS 28 if isinstance(self.obj, PETSc.DMDA): 29 typeobj = PETSc.DM 30 self.assertTrue(type_reg[classid] is typeobj) 31 32 def testLogClass(self): 33 name = self.CLASS.__name__ 34 if name == 'DMDA': 35 name = 'DM' 36 logcls = PETSc.Log.Class(name) 37 classid = self.obj.getClassId() 38 self.assertEqual(logcls.id, classid) 39 40 def testClass(self): 41 self.assertTrue(isinstance(self.obj, self.CLASS)) 42 self.assertTrue(type(self.obj) is self.CLASS) 43 44 def testId(self): 45 oid = self.obj.getId() 46 self.assertTrue(oid > 0) 47 self.assertEqual(self.obj.id, oid) 48 49 def testNonZero(self): 50 self.assertTrue(bool(self.obj)) 51 52 def testDestroy(self): 53 self.assertTrue(bool(self.obj)) 54 self.obj.destroy() 55 self.assertFalse(bool(self.obj)) 56 ## self.assertRaises(PETSc.Error, self.obj.destroy) 57 ## self.assertTrue(self.obj.this is this) 58 59 def testOptions(self): 60 self.assertFalse(self.obj.getOptionsPrefix()) 61 prefix1 = 'my_' 62 self.obj.setOptionsPrefix(prefix1) 63 self.assertEqual(self.obj.getOptionsPrefix(), prefix1) 64 prefix2 = 'opt_' 65 self.obj.setOptionsPrefix(prefix2) 66 self.assertEqual(self.obj.getOptionsPrefix(), prefix2) 67 self.obj.appendOptionsPrefix(prefix1) 68 self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1) 69 self.obj.setOptionsPrefix(None) 70 self.assertEqual(self.obj.getOptionsPrefix(), None) 71 if not self.obj.getType() or not 'da' == str(self.obj.getType()): 72 self.obj.setFromOptions() 73 74 def opts_handler(obj): 75 n = obj.getAttr('opts_handler_called') 76 obj.setAttr('opts_handler_called', n + 1) 77 self.assertEqual(type(self.obj), type(obj)) 78 self.assertEqual(self.obj.klass, obj.klass) 79 self.assertEqual(self.obj.type, obj.type) 80 self.assertEqual(self.obj.id, obj.id) 81 82 for _ in range(2): 83 self.obj.setAttr('opts_handler_called', 0) 84 self.obj.setOptionsHandler(opts_handler) 85 if not self.obj.getType() or not 'da' == str(self.obj.getType()): 86 self.obj.setFromOptions() 87 missing = [ 88 'AO', 89 'DMLabel', 90 'PetscDualSpace', 91 'IS', 92 'ISLocalToGlobalMapping', 93 'MatPartitioning', 94 'MatNullSpace', 95 'PetscRandom', 96 'PetscViewer', 97 ] 98 if self.obj.klass not in missing: 99 self.assertTrue(self.obj.getAttr('opts_handler_called') == 1) 100 101 if not self.obj.getType() or not 'da' == str(self.obj.getType()): 102 self.obj.setAttr('opts_handler_called', 0) 103 self.obj.setOptionsHandler(None) 104 self.obj.setFromOptions() 105 self.assertFalse(self.obj.getAttr('opts_handler_called')) 106 107 self.obj.destroyOptionsHandlers() 108 self.obj.setFromOptions() 109 self.assertFalse(self.obj.getAttr('opts_handler_called')) 110 111 def testName(self): 112 oldname = self.obj.getName() 113 newname = f'{oldname}-{oldname}' 114 self.obj.setName(newname) 115 self.assertEqual(self.obj.getName(), newname) 116 self.obj.setName(oldname) 117 self.assertEqual(self.obj.getName(), oldname) 118 119 def testComm(self): 120 comm = self.obj.getComm() 121 self.assertTrue(isinstance(comm, PETSc.Comm)) 122 self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD]) 123 124 def testRefCount(self): 125 self.assertEqual(self.obj.getRefCount(), 1) 126 self.obj.incRef() 127 self.assertEqual(self.obj.getRefCount(), 2) 128 self.obj.incRef() 129 self.assertEqual(self.obj.getRefCount(), 3) 130 self.obj.decRef() 131 self.assertEqual(self.obj.getRefCount(), 2) 132 self.obj.decRef() 133 self.assertEqual(self.obj.getRefCount(), 1) 134 self.obj.decRef() 135 self.assertFalse(bool(self.obj)) 136 137 def testHandle(self): 138 self.assertTrue(self.obj.handle) 139 self.assertTrue(self.obj.fortran) 140 h, f = self.obj.handle, self.obj.fortran 141 if (h > 0 and f > 0) or (h < 0 and f < 0): 142 self.assertEqual(h, f) 143 self.obj.destroy() 144 self.assertFalse(self.obj.handle) 145 self.assertFalse(self.obj.fortran) 146 147 def testComposeQuery(self): 148 import copy 149 150 try: 151 myobj = copy.deepcopy(self.obj) 152 except NotImplementedError: 153 return 154 self.assertEqual(myobj.getRefCount(), 1) 155 self.obj.compose('myobj', myobj) 156 self.assertTrue(type(self.obj.query('myobj')) is self.CLASS) 157 self.assertEqual(self.obj.query('myobj'), myobj) 158 self.assertEqual(myobj.getRefCount(), 2) 159 self.obj.compose('myobj', None) 160 self.assertEqual(myobj.getRefCount(), 1) 161 self.assertEqual(self.obj.query('myobj'), None) 162 myobj.destroy() 163 164 def testProperties(self): 165 self.assertEqual(self.obj.getClassId(), self.obj.classid) 166 self.assertEqual(self.obj.getClassName(), self.obj.klass) 167 self.assertEqual(self.obj.getType(), self.obj.type) 168 self.assertEqual(self.obj.getName(), self.obj.name) 169 self.assertEqual(self.obj.getComm(), self.obj.comm) 170 self.assertEqual(self.obj.getRefCount(), self.obj.refcount) 171 172 def testShallowCopy(self): 173 import copy 174 175 rc = self.obj.getRefCount() 176 obj = copy.copy(self.obj) 177 self.assertTrue(obj is not self.obj) 178 self.assertTrue(obj == self.obj) 179 self.assertTrue(isinstance(obj, type(self.obj))) 180 self.assertEqual(obj.getRefCount(), rc + 1) 181 del obj 182 self.assertEqual(self.obj.getRefCount(), rc) 183 184 def testDeepCopy(self): 185 import copy 186 187 rc = self.obj.getRefCount() 188 try: 189 obj = copy.deepcopy(self.obj) 190 except NotImplementedError: 191 return 192 self.assertTrue(obj is not self.obj) 193 self.assertTrue(obj != self.obj) 194 self.assertTrue(isinstance(obj, type(self.obj))) 195 self.assertEqual(self.obj.getRefCount(), rc) 196 self.assertEqual(obj.getRefCount(), 1) 197 del obj 198 199 def testStateInspection(self): 200 state = self.obj.stateGet() 201 self.obj.stateIncrease() 202 self.assertTrue(state < self.obj.stateGet()) 203 self.obj.stateSet(0) 204 self.assertTrue(self.obj.stateGet() == 0) 205 self.obj.stateSet(state) 206 self.assertTrue(self.obj.stateGet() == state) 207 208 209# -------------------------------------------------------------------- 210 211 212class TestObjectRandom(BaseTestObject, unittest.TestCase): 213 CLASS = PETSc.Random 214 FACTORY = 'create' 215 216 217class TestObjectViewer(BaseTestObject, unittest.TestCase): 218 CLASS = PETSc.Viewer 219 FACTORY = 'create' 220 221 222class TestObjectIS(BaseTestObject, unittest.TestCase): 223 CLASS = PETSc.IS 224 FACTORY = 'createGeneral' 225 TARGS = ([],) 226 227 228class TestObjectLGMap(BaseTestObject, unittest.TestCase): 229 CLASS = PETSc.LGMap 230 FACTORY = 'create' 231 TARGS = ([],) 232 233 234class TestObjectAO(BaseTestObject, unittest.TestCase): 235 CLASS = PETSc.AO 236 FACTORY = 'createMapping' 237 TARGS = ([], []) 238 239 240class TestObjectDMDA(BaseTestObject, unittest.TestCase): 241 CLASS = PETSc.DMDA 242 FACTORY = 'create' 243 TARGS = ([3, 3, 3],) 244 245 246class TestObjectDS(BaseTestObject, unittest.TestCase): 247 CLASS = PETSc.DS 248 FACTORY = 'create' 249 250 251class TestObjectVec(BaseTestObject, unittest.TestCase): 252 CLASS = PETSc.Vec 253 FACTORY = 'createSeq' 254 TARGS = (0,) 255 256 def setUp(self): 257 BaseTestObject.setUp(self) 258 self.obj.assemble() 259 260 261class TestObjectMat(BaseTestObject, unittest.TestCase): 262 CLASS = PETSc.Mat 263 FACTORY = 'createAIJ' 264 TARGS = (0,) 265 KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF} 266 267 def setUp(self): 268 BaseTestObject.setUp(self) 269 self.obj.assemble() 270 271 272class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase): 273 CLASS = PETSc.MatPartitioning 274 FACTORY = 'create' 275 276 277class TestObjectNullSpace(BaseTestObject, unittest.TestCase): 278 CLASS = PETSc.NullSpace 279 FACTORY = 'create' 280 TARGS = (True, []) 281 282 283class TestObjectKSP(BaseTestObject, unittest.TestCase): 284 CLASS = PETSc.KSP 285 FACTORY = 'create' 286 287 288class TestObjectPC(BaseTestObject, unittest.TestCase): 289 CLASS = PETSc.PC 290 FACTORY = 'create' 291 292 293class TestObjectSNES(BaseTestObject, unittest.TestCase): 294 CLASS = PETSc.SNES 295 FACTORY = 'create' 296 297 298class TestObjectTS(BaseTestObject, unittest.TestCase): 299 CLASS = PETSc.TS 300 FACTORY = 'create' 301 302 def setUp(self): 303 super().setUp() 304 self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR) 305 self.obj.setType(PETSc.TS.Type.BEULER) 306 307 308class TestObjectTAO(BaseTestObject, unittest.TestCase): 309 CLASS = PETSc.TAO 310 FACTORY = 'create' 311 312 313class TestObjectAOBasic(BaseTestObject, unittest.TestCase): 314 CLASS = PETSc.AO 315 FACTORY = 'createBasic' 316 TARGS = ([], []) 317 318 319class TestObjectAOMapping(BaseTestObject, unittest.TestCase): 320 CLASS = PETSc.AO 321 FACTORY = 'createMapping' 322 TARGS = ([], []) 323 324 325# class TestObjectFE(BaseTestObject, unittest.TestCase): 326# CLASS = PETSc.FE 327# FACTORY = 'create' 328# 329# class TestObjectQuad(BaseTestObject, unittest.TestCase): 330# CLASS = PETSc.Quad 331# FACTORY = 'create' 332 333 334class TestObjectDMLabel(BaseTestObject, unittest.TestCase): 335 CLASS = PETSc.DMLabel 336 FACTORY = 'create' 337 TARGS = ('test',) 338 339 340class TestObjectSpace(BaseTestObject, unittest.TestCase): 341 CLASS = PETSc.Space 342 FACTORY = 'create' 343 344 345class TestObjectDualSpace(BaseTestObject, unittest.TestCase): 346 CLASS = PETSc.DualSpace 347 FACTORY = 'create' 348 349 350# -------------------------------------------------------------------- 351 352if numpy.iscomplexobj(PETSc.ScalarType()): 353 del TestObjectTAO 354 355if __name__ == '__main__': 356 unittest.main() 357