xref: /petsc/src/binding/petsc4py/test/test_lgmap.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6
7class BaseTestLGMap:
8    def _mk_idx(self, comm):
9        comm_size = comm.getSize()
10        comm_rank = comm.getRank()
11        lsize = 10
12        first = lsize * comm_rank
13        last = first + lsize
14        if comm_rank > 0:
15            first -= 1
16        if comm_rank < (comm_size - 1):
17            last += 1
18        return list(range(first, last))
19
20    def tearDown(self):
21        self.lgmap = None
22        PETSc.garbage_cleanup()
23
24    def testGetSize(self):
25        size = self.lgmap.getSize()
26        self.assertTrue(size >= 0)
27
28    def testGetIndices(self):
29        size = self.lgmap.getSize()
30        idx = self.lgmap.getIndices()
31        self.assertEqual(len(idx), size)
32        for i, val in enumerate(self.idx):
33            self.assertEqual(idx[i], val)
34
35    def testGetInfo(self):
36        info = self.lgmap.getInfo()
37        self.assertEqual(type(info), dict)
38        if self.lgmap.getComm().getSize() == 1:
39            self.assertTrue(len(info) == 1)
40        else:
41            self.assertTrue(len(info) > 1)
42            self.assertTrue(len(info) < 4)
43
44    def testApply(self):
45        idxin = list(range(self.lgmap.getSize()))
46        idxout = self.lgmap.apply(idxin)
47        self.lgmap.apply(idxin, idxout)
48        _ = self.lgmap.applyInverse(idxout)
49
50    def testApplyIS(self):
51        is_in = PETSc.IS().createStride(self.lgmap.getSize())
52        _ = self.lgmap.apply(is_in)
53
54    def testProperties(self):
55        for prop in ('size', 'indices', 'info'):
56            self.assertTrue(hasattr(self.lgmap, prop))
57
58
59# --------------------------------------------------------------------
60
61
62class TestLGMap(BaseTestLGMap, unittest.TestCase):
63    def setUp(self):
64        self.idx = self._mk_idx(PETSc.COMM_WORLD)
65        self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD)
66
67
68class TestLGMapIS(BaseTestLGMap, unittest.TestCase):
69    def setUp(self):
70        self.idx = self._mk_idx(PETSc.COMM_WORLD)
71        self.iset = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD)
72        self.lgmap = PETSc.LGMap().create(self.iset)
73
74    def tearDown(self):
75        self.iset = None
76        self.lgmap = None
77
78    def testSameComm(self):
79        comm1 = self.lgmap.getComm()
80        comm2 = self.iset.getComm()
81        self.assertEqual(comm1, comm2)
82
83
84# --------------------------------------------------------------------
85
86
87class TestLGMapBlock(unittest.TestCase):
88    BS = 3
89
90    def setUp(self):
91        comm = PETSc.COMM_WORLD
92        comm_size = comm.getSize()
93        comm_rank = comm.getRank()
94        lsize = 10
95        first = lsize * comm_rank
96        last = first + lsize
97        if comm_rank > 0:
98            first -= 1
99        if comm_rank < (comm_size - 1):
100            last += 1
101        self.idx = list(range(first, last))
102        bs = self.BS
103        self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD)
104
105    def tearDown(self):
106        self.lgmap = None
107
108    def testGetSize(self):
109        size = self.lgmap.getSize()
110        self.assertTrue(size >= 0)
111
112    def testGetBlockSize(self):
113        bs = self.lgmap.getBlockSize()
114        self.assertEqual(bs, self.BS)
115
116    def testGetBlockIndices(self):
117        size = self.lgmap.getSize()
118        bs = self.lgmap.getBlockSize()
119        idx = self.lgmap.getBlockIndices()
120        self.assertEqual(len(idx), size // bs)
121        for i, val in enumerate(self.idx):
122            self.assertEqual(idx[i], val)
123
124    def testGetIndices(self):
125        size = self.lgmap.getSize()
126        bs = self.lgmap.getBlockSize()
127        idx = self.lgmap.getIndices()
128        self.assertEqual(len(idx), size)
129        for i, val in enumerate(self.idx):
130            for j in range(bs):
131                self.assertEqual(idx[i * bs + j], val * bs + j)
132
133    def testGetBlockInfo(self):
134        info = self.lgmap.getBlockInfo()
135        self.assertEqual(type(info), dict)
136        if self.lgmap.getComm().getSize() == 1:
137            self.assertTrue(len(info) == 1)
138        else:
139            self.assertTrue(len(info) > 1)
140            self.assertTrue(len(info) < 4)
141
142    def testGetInfo(self):
143        info = self.lgmap.getInfo()
144        self.assertEqual(type(info), dict)
145        if self.lgmap.getComm().getSize() == 1:
146            self.assertTrue(len(info) == 1)
147        else:
148            self.assertTrue(len(info) > 1)
149            self.assertTrue(len(info) < 4)
150
151
152# --------------------------------------------------------------------
153
154if __name__ == '__main__':
155    unittest.main()
156