1from petsc4py import PETSc 2import unittest 3 4# -------------------------------------------------------------------- 5 6 7class BaseTestLGMap: 8 def _mk_idx(self, comm): 9 comm_size = comm.getSize() 10 comm_rank = comm.getRank() 11 lsize = 10 12 first = lsize * comm_rank 13 last = first + lsize 14 if comm_rank > 0: 15 first -= 1 16 if comm_rank < (comm_size - 1): 17 last += 1 18 return list(range(first, last)) 19 20 def tearDown(self): 21 self.lgmap = None 22 PETSc.garbage_cleanup() 23 24 def testGetSize(self): 25 size = self.lgmap.getSize() 26 self.assertTrue(size >= 0) 27 28 def testGetIndices(self): 29 size = self.lgmap.getSize() 30 idx = self.lgmap.getIndices() 31 self.assertEqual(len(idx), size) 32 for i, val in enumerate(self.idx): 33 self.assertEqual(idx[i], val) 34 35 def testGetInfo(self): 36 info = self.lgmap.getInfo() 37 self.assertEqual(type(info), dict) 38 if self.lgmap.getComm().getSize() == 1: 39 self.assertTrue(len(info) == 1) 40 else: 41 self.assertTrue(len(info) > 1) 42 self.assertTrue(len(info) < 4) 43 44 def testApply(self): 45 idxin = list(range(self.lgmap.getSize())) 46 idxout = self.lgmap.apply(idxin) 47 self.lgmap.apply(idxin, idxout) 48 _ = self.lgmap.applyInverse(idxout) 49 50 def testApplyIS(self): 51 is_in = PETSc.IS().createStride(self.lgmap.getSize()) 52 _ = self.lgmap.apply(is_in) 53 54 def testProperties(self): 55 for prop in ('size', 'indices', 'info'): 56 self.assertTrue(hasattr(self.lgmap, prop)) 57 58 59# -------------------------------------------------------------------- 60 61 62class TestLGMap(BaseTestLGMap, unittest.TestCase): 63 def setUp(self): 64 self.idx = self._mk_idx(PETSc.COMM_WORLD) 65 self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD) 66 67 68class TestLGMapIS(BaseTestLGMap, unittest.TestCase): 69 def setUp(self): 70 self.idx = self._mk_idx(PETSc.COMM_WORLD) 71 self.iset = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD) 72 self.lgmap = PETSc.LGMap().create(self.iset) 73 74 def tearDown(self): 75 self.iset = None 76 self.lgmap = None 77 78 def testSameComm(self): 79 comm1 = self.lgmap.getComm() 80 comm2 = self.iset.getComm() 81 self.assertEqual(comm1, comm2) 82 83 84# -------------------------------------------------------------------- 85 86 87class TestLGMapBlock(unittest.TestCase): 88 BS = 3 89 90 def setUp(self): 91 comm = PETSc.COMM_WORLD 92 comm_size = comm.getSize() 93 comm_rank = comm.getRank() 94 lsize = 10 95 first = lsize * comm_rank 96 last = first + lsize 97 if comm_rank > 0: 98 first -= 1 99 if comm_rank < (comm_size - 1): 100 last += 1 101 self.idx = list(range(first, last)) 102 bs = self.BS 103 self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD) 104 105 def tearDown(self): 106 self.lgmap = None 107 108 def testGetSize(self): 109 size = self.lgmap.getSize() 110 self.assertTrue(size >= 0) 111 112 def testGetBlockSize(self): 113 bs = self.lgmap.getBlockSize() 114 self.assertEqual(bs, self.BS) 115 116 def testGetBlockIndices(self): 117 size = self.lgmap.getSize() 118 bs = self.lgmap.getBlockSize() 119 idx = self.lgmap.getBlockIndices() 120 self.assertEqual(len(idx), size // bs) 121 for i, val in enumerate(self.idx): 122 self.assertEqual(idx[i], val) 123 124 def testGetIndices(self): 125 size = self.lgmap.getSize() 126 bs = self.lgmap.getBlockSize() 127 idx = self.lgmap.getIndices() 128 self.assertEqual(len(idx), size) 129 for i, val in enumerate(self.idx): 130 for j in range(bs): 131 self.assertEqual(idx[i * bs + j], val * bs + j) 132 133 def testGetBlockInfo(self): 134 info = self.lgmap.getBlockInfo() 135 self.assertEqual(type(info), dict) 136 if self.lgmap.getComm().getSize() == 1: 137 self.assertTrue(len(info) == 1) 138 else: 139 self.assertTrue(len(info) > 1) 140 self.assertTrue(len(info) < 4) 141 142 def testGetInfo(self): 143 info = self.lgmap.getInfo() 144 self.assertEqual(type(info), dict) 145 if self.lgmap.getComm().getSize() == 1: 146 self.assertTrue(len(info) == 1) 147 else: 148 self.assertTrue(len(info) > 1) 149 self.assertTrue(len(info) < 4) 150 151 152# -------------------------------------------------------------------- 153 154if __name__ == '__main__': 155 unittest.main() 156