1from petsc4py import PETSc 2import unittest 3import random 4 5# -------------------------------------------------------------------- 6 7class BaseTestIS(object): 8 9 TYPE = None 10 11 def tearDown(self): 12 self.iset = None 13 PETSc.garbage_cleanup() 14 15 def testGetType(self): 16 istype = self.iset.getType() 17 self.assertEqual(istype, self.TYPE) 18 19 def testGetSize(self): 20 lsize = self.iset.getLocalSize() 21 gsize = self.iset.getSize() 22 self.assertTrue(lsize <= gsize) 23 24 def testDuplicate(self): 25 iset = self.iset.duplicate() 26 self.assertTrue(self.iset.equal(iset)) 27 del iset 28 29 def testCopy(self): 30 iset = self.iset.copy() 31 self.assertTrue(self.iset.equal(iset)) 32 del iset 33 34 def testEqual(self): 35 self.assertTrue(self.iset.equal(self.iset)) 36 iset = self.iset.duplicate() 37 self.assertTrue(self.iset.equal(iset)) 38 del iset 39 40 def testSort(self): 41 self.iset.sort() 42 self.assertTrue(self.iset.isSorted()) 43 44 def testDifference(self): 45 iset = self.iset.difference(self.iset) 46 self.assertEqual(iset.getLocalSize(), 0) 47 del iset 48 49 def testComplement(self): 50 self.iset.sort() 51 nmin = self.iset.getIndices().min() 52 nmax = self.iset.getIndices().max() 53 iset = self.iset.complement(nmin, nmax+1) 54 iset.complement(nmin, nmax+1) 55 del iset 56 57 def testSum(self): 58 if self.iset.getComm().getSize() > 1: 59 return 60 self.iset.sort() 61 iset = self.iset.duplicate() 62 iset.sum(self.iset) 63 self.assertTrue(self.iset.equal(iset)) 64 del iset 65 66 def testExpand(self): 67 iset = self.iset.expand(self.iset) 68 if self.iset.type == iset.type: 69 self.assertTrue(self.iset.equal(iset)) 70 del iset 71 72 def testRenumber(self): 73 (n1,is1) = self.iset.renumber() 74 (n2,is2) = self.iset.renumber(self.iset) 75 del is1 76 del is2 77 78 def testProperties(self): 79 proplist = ['sizes', 'size', 'local_size', 'indices', 80 'permutation', 'identity', 'sorted'] 81 for prop in proplist: 82 self.assertTrue(hasattr(self.iset, prop)) 83 84 def testArray(self): 85 import numpy 86 refs = self.iset.getRefCount() 87 arr1 = numpy.asarray(self.iset) 88 self.assertEqual(self.iset.getRefCount(), refs+1) 89 arr2 = self.iset.array 90 self.assertEqual(self.iset.getRefCount(), refs+2) 91 self.assertTrue((arr1 == arr2).all()) 92 del arr2 93 self.assertEqual(self.iset.getRefCount(), refs+1) 94 del arr1 95 self.assertEqual(self.iset.getRefCount(), refs) 96 97 98# -------------------------------------------------------------------- 99 100class TestISGeneral(BaseTestIS, unittest.TestCase): 101 102 TYPE = PETSc.IS.Type.GENERAL 103 104 def setUp(self): 105 self.idx = list(range(10)) 106 random.shuffle(self.idx) 107 self.iset = PETSc.IS().createGeneral(self.idx) 108 109 def testGetIndices(self): 110 idx = self.iset.getIndices() 111 self.assertEqual(self.idx, list(idx)) 112 113 114class TestISStride(BaseTestIS, unittest.TestCase): 115 116 TYPE = PETSc.IS.Type.STRIDE 117 118 def setUp(self): 119 self.info = (10, 7, 3) 120 size, start, step = self.info 121 self.iset = PETSc.IS().createStride(size, start, step) 122 123 def testGetIndices(self): 124 size, start, step = self.info 125 indices = [start+i*step for i in range(size)] 126 self.assertEqual(list(self.iset.getIndices()), indices) 127 128 def testToGeneral(self): 129 self.iset.toGeneral() 130 self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL) 131 132 133class TestISBlock(BaseTestIS, unittest.TestCase): 134 135 TYPE = PETSc.IS.Type.BLOCK 136 137 def setUp(self): 138 self.bsize = 3 139 self.index = list(range(0,10,2)) 140 random.shuffle(self.index) 141 self.iset = PETSc.IS().createBlock(self.bsize, self.index) 142 self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK) 143 144 def testGetSize(self): 145 lsize = self.iset.getLocalSize() 146 self.assertEqual(lsize/self.bsize, len(self.index)) 147 148 def testGetBlockSize(self): 149 bs = self.iset.getBlockSize() 150 self.assertEqual(bs, self.bsize) 151 152 def testGetBlockIndices(self): 153 index = list(self.iset.getBlockIndices()) 154 self.assertEqual(index, self.index) 155 156 def testGetIndices(self): 157 bs = self.bsize 158 idx = [] 159 for i in self.iset.getBlockIndices(): 160 for j in range(bs): 161 idx.append(i*bs+j) 162 index = list(self.iset.getIndices()) 163 #self.assertEqual(index, idx) 164 165 166# -------------------------------------------------------------------- 167 168if __name__ == '__main__': 169 unittest.main() 170