1from petsc4py import PETSc 2import unittest 3import random 4 5# -------------------------------------------------------------------- 6 7class BaseTestIS(object): 8 9 TYPE = None 10 11 def tearDown(self): 12 self.iset = None 13 14 def testGetType(self): 15 istype = self.iset.getType() 16 self.assertEqual(istype, self.TYPE) 17 18 def testGetSize(self): 19 lsize = self.iset.getLocalSize() 20 gsize = self.iset.getSize() 21 self.assertTrue(lsize <= gsize) 22 23 def testDuplicate(self): 24 iset = self.iset.duplicate() 25 self.assertTrue(self.iset.equal(iset)) 26 del iset 27 28 def testCopy(self): 29 iset = self.iset.copy() 30 self.assertTrue(self.iset.equal(iset)) 31 del iset 32 33 def testEqual(self): 34 self.assertTrue(self.iset.equal(self.iset)) 35 iset = self.iset.duplicate() 36 self.assertTrue(self.iset.equal(iset)) 37 del iset 38 39 def testSort(self): 40 self.iset.sort() 41 self.assertTrue(self.iset.isSorted()) 42 43 def testDifference(self): 44 iset = self.iset.difference(self.iset) 45 self.assertEqual(iset.getLocalSize(), 0) 46 del iset 47 48 def testComplement(self): 49 self.iset.sort() 50 nmin = self.iset.getIndices().min() 51 nmax = self.iset.getIndices().max() 52 iset = self.iset.complement(nmin, nmax+1) 53 iset.complement(nmin, nmax+1) 54 del iset 55 56 def testSum(self): 57 if self.iset.getComm().getSize() > 1: 58 return 59 self.iset.sort() 60 iset = self.iset.duplicate() 61 iset.sum(self.iset) 62 self.assertTrue(self.iset.equal(iset)) 63 del iset 64 65 def testExpand(self): 66 iset = self.iset.expand(self.iset) 67 if self.iset.type == iset.type: 68 self.assertTrue(self.iset.equal(iset)) 69 del iset 70 71 def testRenumber(self): 72 (n1,is1) = self.iset.renumber() 73 (n2,is2) = self.iset.renumber(self.iset) 74 del is1 75 del is2 76 77 def testProperties(self): 78 proplist = ['sizes', 'size', 'local_size', 'indices', 79 'permutation', 'identity', 'sorted'] 80 for prop in proplist: 81 self.assertTrue(hasattr(self.iset, prop)) 82 83 def testArray(self): 84 import numpy 85 refs = self.iset.getRefCount() 86 arr1 = numpy.asarray(self.iset) 87 self.assertEqual(self.iset.getRefCount(), refs+1) 88 arr2 = self.iset.array 89 self.assertEqual(self.iset.getRefCount(), refs+2) 90 self.assertTrue((arr1 == arr2).all()) 91 del arr2 92 self.assertEqual(self.iset.getRefCount(), refs+1) 93 del arr1 94 self.assertEqual(self.iset.getRefCount(), refs) 95 96 97# -------------------------------------------------------------------- 98 99class TestISGeneral(BaseTestIS, unittest.TestCase): 100 101 TYPE = PETSc.IS.Type.GENERAL 102 103 def setUp(self): 104 self.idx = list(range(10)) 105 random.shuffle(self.idx) 106 self.iset = PETSc.IS().createGeneral(self.idx) 107 108 def testGetIndices(self): 109 idx = self.iset.getIndices() 110 self.assertEqual(self.idx, list(idx)) 111 112 113class TestISStride(BaseTestIS, unittest.TestCase): 114 115 TYPE = PETSc.IS.Type.STRIDE 116 117 def setUp(self): 118 self.info = (10, 7, 3) 119 size, start, step = self.info 120 self.iset = PETSc.IS().createStride(size, start, step) 121 122 def testGetIndices(self): 123 size, start, step = self.info 124 indices = [start+i*step for i in range(size)] 125 self.assertEqual(list(self.iset.getIndices()), indices) 126 127 def testToGeneral(self): 128 self.iset.toGeneral() 129 self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL) 130 131 132class TestISBlock(BaseTestIS, unittest.TestCase): 133 134 TYPE = PETSc.IS.Type.BLOCK 135 136 def setUp(self): 137 self.bsize = 3 138 self.index = list(range(0,10,2)) 139 random.shuffle(self.index) 140 self.iset = PETSc.IS().createBlock(self.bsize, self.index) 141 self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK) 142 143 def testGetSize(self): 144 lsize = self.iset.getLocalSize() 145 self.assertEqual(lsize/self.bsize, len(self.index)) 146 147 def testGetBlockSize(self): 148 bs = self.iset.getBlockSize() 149 self.assertEqual(bs, self.bsize) 150 151 def testGetBlockIndices(self): 152 index = list(self.iset.getBlockIndices()) 153 self.assertEqual(index, self.index) 154 155 def testGetIndices(self): 156 bs = self.bsize 157 idx = [] 158 for i in self.iset.getBlockIndices(): 159 for j in range(bs): 160 idx.append(i*bs+j) 161 index = list(self.iset.getIndices()) 162 #self.assertEqual(index, idx) 163 164 165# -------------------------------------------------------------------- 166 167if __name__ == '__main__': 168 unittest.main() 169