1from petsc4py import PETSc 2import unittest 3 4# -------------------------------------------------------------------- 5 6class BaseTestLGMap(object): 7 8 def _mk_idx(self, comm): 9 comm_size = comm.getSize() 10 comm_rank = comm.getRank() 11 lsize = 10 12 first = lsize * comm_rank 13 last = first + lsize 14 if comm_rank > 0: 15 first -= 1 16 if comm_rank < (comm_size-1): 17 last += 1 18 return list(range(first, last)) 19 20 def tearDown(self): 21 self.lgmap = None 22 PETSc.garbage_cleanup() 23 24 def testGetSize(self): 25 size = self.lgmap.getSize() 26 self.assertTrue(size >= 0) 27 28 def testGetIndices(self): 29 size = self.lgmap.getSize() 30 idx = self.lgmap.getIndices() 31 self.assertEqual(len(idx), size) 32 for i, val in enumerate(self.idx): 33 self.assertEqual(idx[i], val) 34 35 def testGetInfo(self): 36 info = self.lgmap.getInfo() 37 self.assertEqual(type(info), dict) 38 if self.lgmap.getComm().getSize() == 1: 39 self.assertEqual(info, {}) 40 else: 41 self.assertTrue(len(info) > 1) 42 self.assertTrue(len(info) < 4) 43 44 def testApply(self): 45 idxin = list(range(self.lgmap.getSize())) 46 idxout = self.lgmap.apply(idxin) 47 self.lgmap.apply(idxin, idxout) 48 invmap = self.lgmap.applyInverse(idxout) 49 50 51 def testApplyIS(self): 52 is_in = PETSc.IS().createStride(self.lgmap.getSize()) 53 is_out = self.lgmap.apply(is_in) 54 55 def testProperties(self): 56 for prop in ('size', 'indices', 'info'): 57 self.assertTrue(hasattr(self.lgmap, prop)) 58 59# -------------------------------------------------------------------- 60 61class TestLGMap(BaseTestLGMap, unittest.TestCase): 62 63 def setUp(self): 64 self.idx = self._mk_idx(PETSc.COMM_WORLD) 65 self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD) 66 67class TestLGMapIS(BaseTestLGMap, unittest.TestCase): 68 69 def setUp(self): 70 self.idx = self._mk_idx(PETSc.COMM_WORLD) 71 self.iset = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD) 72 self.lgmap = PETSc.LGMap().create(self.iset) 73 74 def tearDown(self): 75 self.iset = None 76 self.lgmap = None 77 78 def testSameComm(self): 79 comm1 = self.lgmap.getComm() 80 comm2 = self.iset.getComm() 81 self.assertEqual(comm1, comm2) 82 83# -------------------------------------------------------------------- 84 85class TestLGMapBlock(unittest.TestCase): 86 87 BS = 3 88 89 def setUp(self): 90 comm = PETSc.COMM_WORLD 91 comm_size = comm.getSize() 92 comm_rank = comm.getRank() 93 lsize = 10 94 first = lsize * comm_rank 95 last = first + lsize 96 if comm_rank > 0: 97 first -= 1 98 if comm_rank < (comm_size-1): 99 last += 1 100 self.idx = list(range(first, last)) 101 bs = self.BS 102 self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD) 103 104 def tearDown(self): 105 self.lgmap = None 106 107 def testGetSize(self): 108 size = self.lgmap.getSize() 109 self.assertTrue(size >= 0) 110 111 def testGetBlockSize(self): 112 bs = self.lgmap.getBlockSize() 113 self.assertEqual(bs, self.BS) 114 115 def testGetBlockIndices(self): 116 size = self.lgmap.getSize() 117 bs = self.lgmap.getBlockSize() 118 idx = self.lgmap.getBlockIndices() 119 self.assertEqual(len(idx), size//bs) 120 for i, val in enumerate(self.idx): 121 self.assertEqual(idx[i], val) 122 123 def testGetIndices(self): 124 size = self.lgmap.getSize() 125 bs = self.lgmap.getBlockSize() 126 idx = self.lgmap.getIndices() 127 self.assertEqual(len(idx), size) 128 for i, val in enumerate(self.idx): 129 for j in range(bs): 130 self.assertEqual(idx[i*bs+j], val*bs+j) 131 132 def testGetBlockInfo(self): 133 info = self.lgmap.getBlockInfo() 134 self.assertEqual(type(info), dict) 135 if self.lgmap.getComm().getSize() == 1: 136 self.assertEqual(info, {}) 137 else: 138 self.assertTrue(len(info) > 1) 139 self.assertTrue(len(info) < 4) 140 141 def testGetInfo(self): 142 info = self.lgmap.getInfo() 143 self.assertEqual(type(info), dict) 144 if self.lgmap.getComm().getSize() == 1: 145 self.assertEqual(info, {}) 146 else: 147 self.assertTrue(len(info) > 1) 148 self.assertTrue(len(info) < 4) 149 150# -------------------------------------------------------------------- 151 152if __name__ == '__main__': 153 unittest.main() 154