1from petsc4py import PETSc 2import unittest 3import random 4 5# -------------------------------------------------------------------- 6 7 8class BaseTestIS: 9 TYPE = None 10 11 def tearDown(self): 12 self.iset = None 13 PETSc.garbage_cleanup() 14 15 def testGetType(self): 16 istype = self.iset.getType() 17 self.assertEqual(istype, self.TYPE) 18 19 def testGetSize(self): 20 lsize = self.iset.getLocalSize() 21 gsize = self.iset.getSize() 22 self.assertTrue(lsize <= gsize) 23 24 def testDuplicate(self): 25 iset = self.iset.duplicate() 26 self.assertTrue(self.iset.equal(iset)) 27 del iset 28 29 def testCopy(self): 30 iset = self.iset.copy() 31 self.assertTrue(self.iset.equal(iset)) 32 del iset 33 34 def testEqual(self): 35 self.assertTrue(self.iset.equal(self.iset)) 36 iset = self.iset.duplicate() 37 self.assertTrue(self.iset.equal(iset)) 38 del iset 39 40 def testSort(self): 41 self.iset.sort() 42 self.assertTrue(self.iset.isSorted()) 43 44 def testDifference(self): 45 iset = self.iset.difference(self.iset) 46 self.assertEqual(iset.getLocalSize(), 0) 47 del iset 48 49 def testComplement(self): 50 self.iset.sort() 51 nmin = self.iset.getIndices().min() 52 nmax = self.iset.getIndices().max() 53 iset = self.iset.complement(nmin, nmax + 1) 54 iset.complement(nmin, nmax + 1) 55 del iset 56 57 def testSum(self): 58 if self.iset.getComm().getSize() > 1: 59 return 60 self.iset.sort() 61 iset = self.iset.duplicate() 62 iset.sum(self.iset) 63 self.assertTrue(self.iset.equal(iset)) 64 del iset 65 66 def testExpand(self): 67 iset = self.iset.expand(self.iset) 68 if self.iset.type == iset.type: 69 self.assertTrue(self.iset.equal(iset)) 70 del iset 71 72 def testRenumber(self): 73 (n1, is1) = self.iset.renumber() 74 (n2, is2) = self.iset.renumber(self.iset) 75 del is1 76 del is2 77 78 def testProperties(self): 79 proplist = [ 80 'sizes', 81 'size', 82 'local_size', 83 'indices', 84 'permutation', 85 'identity', 86 'sorted', 87 ] 88 for prop in proplist: 89 self.assertTrue(hasattr(self.iset, prop)) 90 91 def testArray(self): 92 import numpy 93 94 refs = self.iset.getRefCount() 95 arr1 = numpy.asarray(self.iset) 96 self.assertEqual(self.iset.getRefCount(), refs + 1) 97 arr2 = self.iset.array 98 self.assertEqual(self.iset.getRefCount(), refs + 2) 99 self.assertTrue((arr1 == arr2).all()) 100 del arr2 101 self.assertEqual(self.iset.getRefCount(), refs + 1) 102 del arr1 103 self.assertEqual(self.iset.getRefCount(), refs) 104 105 106# -------------------------------------------------------------------- 107 108 109class TestISGeneral(BaseTestIS, unittest.TestCase): 110 TYPE = PETSc.IS.Type.GENERAL 111 112 def setUp(self): 113 self.idx = list(range(10)) 114 random.shuffle(self.idx) 115 self.iset = PETSc.IS().createGeneral(self.idx) 116 117 def testGetIndices(self): 118 idx = self.iset.getIndices() 119 self.assertEqual(self.idx, list(idx)) 120 121 122class TestISStride(BaseTestIS, unittest.TestCase): 123 TYPE = PETSc.IS.Type.STRIDE 124 125 def setUp(self): 126 self.info = (10, 7, 3) 127 size, start, step = self.info 128 self.iset = PETSc.IS().createStride(size, start, step) 129 130 def testGetIndices(self): 131 size, start, step = self.info 132 indices = [start + i * step for i in range(size)] 133 self.assertEqual(list(self.iset.getIndices()), indices) 134 135 def testToGeneral(self): 136 self.iset.toGeneral() 137 self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL) 138 139 140class TestISBlock(BaseTestIS, unittest.TestCase): 141 TYPE = PETSc.IS.Type.BLOCK 142 143 def setUp(self): 144 self.bsize = 3 145 self.index = list(range(0, 10, 2)) 146 random.shuffle(self.index) 147 self.iset = PETSc.IS().createBlock(self.bsize, self.index) 148 self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK) 149 150 def testGetSize(self): 151 lsize = self.iset.getLocalSize() 152 self.assertEqual(lsize / self.bsize, len(self.index)) 153 154 def testGetBlockSize(self): 155 bs = self.iset.getBlockSize() 156 self.assertEqual(bs, self.bsize) 157 158 def testGetBlockIndices(self): 159 index = list(self.iset.getBlockIndices()) 160 self.assertEqual(index, self.index) 161 162 def testGetIndices(self): 163 bs = self.bsize 164 idx = [] 165 for i in self.iset.getBlockIndices(): 166 for j in range(bs): 167 idx.append(i * bs + j) 168 index = list(self.iset.getIndices()) 169 self.assertEqual(index, idx) 170 171 172# -------------------------------------------------------------------- 173 174if __name__ == '__main__': 175 unittest.main() 176