1import unittest 2 3from petsc4py import PETSc 4import numpy 5 6# -------------------------------------------------------------------- 7 8 9class BaseTestObject: 10 CLASS, FACTORY = None, None 11 TARGS, KARGS = (), {} 12 BUILD = None 13 14 def setUp(self): 15 self.obj = self.CLASS() 16 getattr(self.obj, self.FACTORY)(*self.TARGS, **self.KARGS) 17 if not self.obj: 18 self.obj.create() 19 20 def tearDown(self): 21 self.obj = None 22 PETSc.garbage_cleanup() 23 24 def testTypeRegistry(self): 25 type_reg = PETSc.__type_registry__ 26 classid = self.obj.getClassId() 27 typeobj = self.CLASS 28 if isinstance(self.obj, PETSc.DMDA): 29 typeobj = PETSc.DM 30 self.assertTrue(type_reg[classid] is typeobj) 31 32 def testLogClass(self): 33 name = self.CLASS.__name__ 34 if name == 'DMDA': 35 name = 'DM' 36 logcls = PETSc.Log.Class(name) 37 classid = self.obj.getClassId() 38 self.assertEqual(logcls.id, classid) 39 40 def testClass(self): 41 self.assertTrue(isinstance(self.obj, self.CLASS)) 42 self.assertTrue(type(self.obj) is self.CLASS) 43 44 def testId(self): 45 oid = self.obj.getId() 46 self.assertTrue(oid > 0) 47 self.assertEqual(self.obj.id, oid) 48 49 def testNonZero(self): 50 self.assertTrue(bool(self.obj)) 51 52 def testDestroy(self): 53 self.assertTrue(bool(self.obj)) 54 self.obj.destroy() 55 self.assertFalse(bool(self.obj)) 56 ## self.assertRaises(PETSc.Error, self.obj.destroy) 57 ## self.assertTrue(self.obj.this is this) 58 59 def testOptions(self): 60 self.assertFalse(self.obj.getOptionsPrefix()) 61 prefix1 = 'my_' 62 self.obj.setOptionsPrefix(prefix1) 63 self.assertEqual(self.obj.getOptionsPrefix(), prefix1) 64 prefix2 = 'opt_' 65 self.obj.setOptionsPrefix(prefix2) 66 self.assertEqual(self.obj.getOptionsPrefix(), prefix2) 67 self.obj.appendOptionsPrefix(prefix1) 68 self.assertEqual(self.obj.getOptionsPrefix(), prefix2 + prefix1) 69 self.obj.setOptionsPrefix(None) 70 self.assertEqual(self.obj.getOptionsPrefix(), None) 71 self.obj.setFromOptions() 72 73 def opts_handler(obj): 74 n = obj.getAttr('opts_handler_called') 75 obj.setAttr('opts_handler_called', n + 1) 76 self.assertEqual(type(self.obj), type(obj)) 77 self.assertEqual(self.obj.klass, obj.klass) 78 self.assertEqual(self.obj.type, obj.type) 79 self.assertEqual(self.obj.id, obj.id) 80 81 for _ in range(2): 82 self.obj.setAttr('opts_handler_called', 0) 83 self.obj.setOptionsHandler(opts_handler) 84 self.obj.setFromOptions() 85 missing = [ 86 'AO', 87 'DMLabel', 88 'PetscDualSpace', 89 'IS', 90 'ISLocalToGlobalMapping', 91 'MatPartitioning', 92 'MatNullSpace', 93 'PetscRandom', 94 'PetscViewer', 95 ] 96 if self.obj.klass not in missing: 97 self.assertTrue(self.obj.getAttr('opts_handler_called') == 1) 98 99 self.obj.setAttr('opts_handler_called', 0) 100 self.obj.setOptionsHandler(None) 101 self.obj.setFromOptions() 102 self.assertFalse(self.obj.getAttr('opts_handler_called')) 103 104 self.obj.destroyOptionsHandlers() 105 self.obj.setFromOptions() 106 self.assertFalse(self.obj.getAttr('opts_handler_called')) 107 108 def testName(self): 109 oldname = self.obj.getName() 110 newname = f'{oldname}-{oldname}' 111 self.obj.setName(newname) 112 self.assertEqual(self.obj.getName(), newname) 113 self.obj.setName(oldname) 114 self.assertEqual(self.obj.getName(), oldname) 115 116 def testComm(self): 117 comm = self.obj.getComm() 118 self.assertTrue(isinstance(comm, PETSc.Comm)) 119 self.assertTrue(comm in [PETSc.COMM_SELF, PETSc.COMM_WORLD]) 120 121 def testRefCount(self): 122 self.assertEqual(self.obj.getRefCount(), 1) 123 self.obj.incRef() 124 self.assertEqual(self.obj.getRefCount(), 2) 125 self.obj.incRef() 126 self.assertEqual(self.obj.getRefCount(), 3) 127 self.obj.decRef() 128 self.assertEqual(self.obj.getRefCount(), 2) 129 self.obj.decRef() 130 self.assertEqual(self.obj.getRefCount(), 1) 131 self.obj.decRef() 132 self.assertFalse(bool(self.obj)) 133 134 def testHandle(self): 135 self.assertTrue(self.obj.handle) 136 self.assertTrue(self.obj.fortran) 137 h, f = self.obj.handle, self.obj.fortran 138 if (h > 0 and f > 0) or (h < 0 and f < 0): 139 self.assertEqual(h, f) 140 self.obj.destroy() 141 self.assertFalse(self.obj.handle) 142 self.assertFalse(self.obj.fortran) 143 144 def testComposeQuery(self): 145 import copy 146 147 try: 148 myobj = copy.deepcopy(self.obj) 149 except NotImplementedError: 150 return 151 self.assertEqual(myobj.getRefCount(), 1) 152 self.obj.compose('myobj', myobj) 153 self.assertTrue(type(self.obj.query('myobj')) is self.CLASS) 154 self.assertEqual(self.obj.query('myobj'), myobj) 155 self.assertEqual(myobj.getRefCount(), 2) 156 self.obj.compose('myobj', None) 157 self.assertEqual(myobj.getRefCount(), 1) 158 self.assertEqual(self.obj.query('myobj'), None) 159 myobj.destroy() 160 161 def testProperties(self): 162 self.assertEqual(self.obj.getClassId(), self.obj.classid) 163 self.assertEqual(self.obj.getClassName(), self.obj.klass) 164 self.assertEqual(self.obj.getType(), self.obj.type) 165 self.assertEqual(self.obj.getName(), self.obj.name) 166 self.assertEqual(self.obj.getComm(), self.obj.comm) 167 self.assertEqual(self.obj.getRefCount(), self.obj.refcount) 168 169 def testShallowCopy(self): 170 import copy 171 172 rc = self.obj.getRefCount() 173 obj = copy.copy(self.obj) 174 self.assertTrue(obj is not self.obj) 175 self.assertTrue(obj == self.obj) 176 self.assertTrue(isinstance(obj, type(self.obj))) 177 self.assertEqual(obj.getRefCount(), rc + 1) 178 del obj 179 self.assertEqual(self.obj.getRefCount(), rc) 180 181 def testDeepCopy(self): 182 import copy 183 184 rc = self.obj.getRefCount() 185 try: 186 obj = copy.deepcopy(self.obj) 187 except NotImplementedError: 188 return 189 self.assertTrue(obj is not self.obj) 190 self.assertTrue(obj != self.obj) 191 self.assertTrue(isinstance(obj, type(self.obj))) 192 self.assertEqual(self.obj.getRefCount(), rc) 193 self.assertEqual(obj.getRefCount(), 1) 194 del obj 195 196 def testStateInspection(self): 197 state = self.obj.stateGet() 198 self.obj.stateIncrease() 199 self.assertTrue(state < self.obj.stateGet()) 200 self.obj.stateSet(0) 201 self.assertTrue(self.obj.stateGet() == 0) 202 self.obj.stateSet(state) 203 self.assertTrue(self.obj.stateGet() == state) 204 205 206# -------------------------------------------------------------------- 207 208 209class TestObjectRandom(BaseTestObject, unittest.TestCase): 210 CLASS = PETSc.Random 211 FACTORY = 'create' 212 213 214class TestObjectViewer(BaseTestObject, unittest.TestCase): 215 CLASS = PETSc.Viewer 216 FACTORY = 'create' 217 218 219class TestObjectIS(BaseTestObject, unittest.TestCase): 220 CLASS = PETSc.IS 221 FACTORY = 'createGeneral' 222 TARGS = ([],) 223 224 225class TestObjectLGMap(BaseTestObject, unittest.TestCase): 226 CLASS = PETSc.LGMap 227 FACTORY = 'create' 228 TARGS = ([],) 229 230 231class TestObjectAO(BaseTestObject, unittest.TestCase): 232 CLASS = PETSc.AO 233 FACTORY = 'createMapping' 234 TARGS = ([], []) 235 236 237class TestObjectDMDA(BaseTestObject, unittest.TestCase): 238 CLASS = PETSc.DMDA 239 FACTORY = 'create' 240 TARGS = ([3, 3, 3],) 241 242 243class TestObjectDS(BaseTestObject, unittest.TestCase): 244 CLASS = PETSc.DS 245 FACTORY = 'create' 246 247 248class TestObjectVec(BaseTestObject, unittest.TestCase): 249 CLASS = PETSc.Vec 250 FACTORY = 'createSeq' 251 TARGS = (0,) 252 253 def setUp(self): 254 BaseTestObject.setUp(self) 255 self.obj.assemble() 256 257 258class TestObjectMat(BaseTestObject, unittest.TestCase): 259 CLASS = PETSc.Mat 260 FACTORY = 'createAIJ' 261 TARGS = (0,) 262 KARGS = {'nnz': 0, 'comm': PETSc.COMM_SELF} 263 264 def setUp(self): 265 BaseTestObject.setUp(self) 266 self.obj.assemble() 267 268 269class TestObjectMatPartitioning(BaseTestObject, unittest.TestCase): 270 CLASS = PETSc.MatPartitioning 271 FACTORY = 'create' 272 273 274class TestObjectNullSpace(BaseTestObject, unittest.TestCase): 275 CLASS = PETSc.NullSpace 276 FACTORY = 'create' 277 TARGS = (True, []) 278 279 280class TestObjectKSP(BaseTestObject, unittest.TestCase): 281 CLASS = PETSc.KSP 282 FACTORY = 'create' 283 284 285class TestObjectPC(BaseTestObject, unittest.TestCase): 286 CLASS = PETSc.PC 287 FACTORY = 'create' 288 289 290class TestObjectSNES(BaseTestObject, unittest.TestCase): 291 CLASS = PETSc.SNES 292 FACTORY = 'create' 293 294 295class TestObjectTS(BaseTestObject, unittest.TestCase): 296 CLASS = PETSc.TS 297 FACTORY = 'create' 298 299 def setUp(self): 300 super().setUp() 301 self.obj.setProblemType(PETSc.TS.ProblemType.NONLINEAR) 302 self.obj.setType(PETSc.TS.Type.BEULER) 303 304 305class TestObjectTAO(BaseTestObject, unittest.TestCase): 306 CLASS = PETSc.TAO 307 FACTORY = 'create' 308 309 310class TestObjectAOBasic(BaseTestObject, unittest.TestCase): 311 CLASS = PETSc.AO 312 FACTORY = 'createBasic' 313 TARGS = ([], []) 314 315 316class TestObjectAOMapping(BaseTestObject, unittest.TestCase): 317 CLASS = PETSc.AO 318 FACTORY = 'createMapping' 319 TARGS = ([], []) 320 321 322# class TestObjectFE(BaseTestObject, unittest.TestCase): 323# CLASS = PETSc.FE 324# FACTORY = 'create' 325# 326# class TestObjectQuad(BaseTestObject, unittest.TestCase): 327# CLASS = PETSc.Quad 328# FACTORY = 'create' 329 330 331class TestObjectDMLabel(BaseTestObject, unittest.TestCase): 332 CLASS = PETSc.DMLabel 333 FACTORY = 'create' 334 TARGS = ('test',) 335 336 337class TestObjectSpace(BaseTestObject, unittest.TestCase): 338 CLASS = PETSc.Space 339 FACTORY = 'create' 340 341 342class TestObjectDualSpace(BaseTestObject, unittest.TestCase): 343 CLASS = PETSc.DualSpace 344 FACTORY = 'create' 345 346 347# -------------------------------------------------------------------- 348 349if numpy.iscomplexobj(PETSc.ScalarType()): 350 del TestObjectTAO 351 352if __name__ == '__main__': 353 unittest.main() 354