1from petsc4py import PETSc 2import unittest 3 4# -------------------------------------------------------------------- 5 6class BaseTestLGMap(object): 7 8 def _mk_idx(self, comm): 9 comm_size = comm.getSize() 10 comm_rank = comm.getRank() 11 lsize = 10 12 first = lsize * comm_rank 13 last = first + lsize 14 if comm_rank > 0: 15 first -= 1 16 if comm_rank < (comm_size-1): 17 last += 1 18 return list(range(first, last)) 19 20 def tearDown(self): 21 self.lgmap = None 22 23 def testGetSize(self): 24 size = self.lgmap.getSize() 25 self.assertTrue(size >= 0) 26 27 def testGetIndices(self): 28 size = self.lgmap.getSize() 29 idx = self.lgmap.getIndices() 30 self.assertEqual(len(idx), size) 31 for i, val in enumerate(self.idx): 32 self.assertEqual(idx[i], val) 33 34 def testGetInfo(self): 35 info = self.lgmap.getInfo() 36 self.assertEqual(type(info), dict) 37 if self.lgmap.getComm().getSize() == 1: 38 self.assertEqual(info, {}) 39 else: 40 self.assertTrue(len(info) > 1) 41 self.assertTrue(len(info) < 4) 42 43 def testApply(self): 44 idxin = list(range(self.lgmap.getSize())) 45 idxout = self.lgmap.apply(idxin) 46 self.lgmap.apply(idxin, idxout) 47 invmap = self.lgmap.applyInverse(idxout) 48 49 50 def testApplyIS(self): 51 is_in = PETSc.IS().createStride(self.lgmap.getSize()) 52 is_out = self.lgmap.apply(is_in) 53 54 def testProperties(self): 55 for prop in ('size', 'indices', 'info'): 56 self.assertTrue(hasattr(self.lgmap, prop)) 57 58# -------------------------------------------------------------------- 59 60class TestLGMap(BaseTestLGMap, unittest.TestCase): 61 62 def setUp(self): 63 self.idx = self._mk_idx(PETSc.COMM_WORLD) 64 self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD) 65 66class TestLGMapIS(BaseTestLGMap, unittest.TestCase): 67 68 def setUp(self): 69 self.idx = self._mk_idx(PETSc.COMM_WORLD) 70 self.iset = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD) 71 self.lgmap = PETSc.LGMap().create(self.iset) 72 73 def tearDown(self): 74 self.iset = None 75 self.lgmap = None 76 77 def testSameComm(self): 78 comm1 = self.lgmap.getComm() 79 comm2 = self.iset.getComm() 80 self.assertEqual(comm1, comm2) 81 82# -------------------------------------------------------------------- 83 84class TestLGMapBlock(unittest.TestCase): 85 86 BS = 3 87 88 def setUp(self): 89 comm = PETSc.COMM_WORLD 90 comm_size = comm.getSize() 91 comm_rank = comm.getRank() 92 lsize = 10 93 first = lsize * comm_rank 94 last = first + lsize 95 if comm_rank > 0: 96 first -= 1 97 if comm_rank < (comm_size-1): 98 last += 1 99 self.idx = list(range(first, last)) 100 bs = self.BS 101 self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD) 102 103 def tearDown(self): 104 self.lgmap = None 105 106 def testGetSize(self): 107 size = self.lgmap.getSize() 108 self.assertTrue(size >= 0) 109 110 def testGetBlockSize(self): 111 bs = self.lgmap.getBlockSize() 112 self.assertEqual(bs, self.BS) 113 114 def testGetBlockIndices(self): 115 size = self.lgmap.getSize() 116 bs = self.lgmap.getBlockSize() 117 idx = self.lgmap.getBlockIndices() 118 self.assertEqual(len(idx), size//bs) 119 for i, val in enumerate(self.idx): 120 self.assertEqual(idx[i], val) 121 122 def testGetIndices(self): 123 size = self.lgmap.getSize() 124 bs = self.lgmap.getBlockSize() 125 idx = self.lgmap.getIndices() 126 self.assertEqual(len(idx), size) 127 for i, val in enumerate(self.idx): 128 for j in range(bs): 129 self.assertEqual(idx[i*bs+j], val*bs+j) 130 131 def testGetBlockInfo(self): 132 info = self.lgmap.getBlockInfo() 133 self.assertEqual(type(info), dict) 134 if self.lgmap.getComm().getSize() == 1: 135 self.assertEqual(info, {}) 136 else: 137 self.assertTrue(len(info) > 1) 138 self.assertTrue(len(info) < 4) 139 140 def testGetInfo(self): 141 info = self.lgmap.getInfo() 142 self.assertEqual(type(info), dict) 143 if self.lgmap.getComm().getSize() == 1: 144 self.assertEqual(info, {}) 145 else: 146 self.assertTrue(len(info) > 1) 147 self.assertTrue(len(info) < 4) 148 149# -------------------------------------------------------------------- 150 151if __name__ == '__main__': 152 unittest.main() 153