xref: /petsc/src/binding/petsc4py/test/test_lgmap.py (revision 4e278199b78715991f5c71ebbd945c1489263e6c)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6class BaseTestLGMap(object):
7
8    def _mk_idx(self, comm):
9        comm_size = comm.getSize()
10        comm_rank = comm.getRank()
11        lsize = 10
12        first = lsize * comm_rank
13        last  = first + lsize
14        if comm_rank > 0:
15            first -= 1
16        if comm_rank < (comm_size-1):
17            last += 1
18        return list(range(first, last))
19
20    def tearDown(self):
21        self.lgmap = None
22
23    def testGetSize(self):
24        size = self.lgmap.getSize()
25        self.assertTrue(size >= 0)
26
27    def testGetIndices(self):
28        size = self.lgmap.getSize()
29        idx  = self.lgmap.getIndices()
30        self.assertEqual(len(idx), size)
31        for i, val in enumerate(self.idx):
32            self.assertEqual(idx[i], val)
33
34    def testGetInfo(self):
35        info = self.lgmap.getInfo()
36        self.assertEqual(type(info), dict)
37        if self.lgmap.getComm().getSize() == 1:
38            self.assertEqual(info, {})
39        else:
40            self.assertTrue(len(info) > 1)
41            self.assertTrue(len(info) < 4)
42
43    def testApply(self):
44        idxin  = list(range(self.lgmap.getSize()))
45        idxout = self.lgmap.apply(idxin)
46        self.lgmap.apply(idxin, idxout)
47        invmap = self.lgmap.applyInverse(idxout)
48
49
50    def testApplyIS(self):
51        is_in  = PETSc.IS().createStride(self.lgmap.getSize())
52        is_out = self.lgmap.apply(is_in)
53
54    def testProperties(self):
55        for prop in ('size', 'indices', 'info'):
56            self.assertTrue(hasattr(self.lgmap, prop))
57
58# --------------------------------------------------------------------
59
60class TestLGMap(BaseTestLGMap, unittest.TestCase):
61
62    def setUp(self):
63        self.idx   = self._mk_idx(PETSc.COMM_WORLD)
64        self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD)
65
66class TestLGMapIS(BaseTestLGMap, unittest.TestCase):
67
68    def setUp(self):
69        self.idx   = self._mk_idx(PETSc.COMM_WORLD)
70        self.iset  = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD)
71        self.lgmap = PETSc.LGMap().create(self.iset)
72
73    def tearDown(self):
74        self.iset  = None
75        self.lgmap = None
76
77    def testSameComm(self):
78        comm1 = self.lgmap.getComm()
79        comm2 = self.iset.getComm()
80        self.assertEqual(comm1, comm2)
81
82# --------------------------------------------------------------------
83
84class TestLGMapBlock(unittest.TestCase):
85
86    BS = 3
87
88    def setUp(self):
89        comm = PETSc.COMM_WORLD
90        comm_size = comm.getSize()
91        comm_rank = comm.getRank()
92        lsize = 10
93        first = lsize * comm_rank
94        last  = first + lsize
95        if comm_rank > 0:
96            first -= 1
97        if comm_rank < (comm_size-1):
98            last += 1
99        self.idx = list(range(first, last))
100        bs = self.BS
101        self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD)
102
103    def tearDown(self):
104        self.lgmap = None
105
106    def testGetSize(self):
107        size = self.lgmap.getSize()
108        self.assertTrue(size >= 0)
109
110    def testGetBlockSize(self):
111        bs = self.lgmap.getBlockSize()
112        self.assertEqual(bs, self.BS)
113
114    def testGetBlockIndices(self):
115        size = self.lgmap.getSize()
116        bs = self.lgmap.getBlockSize()
117        idx = self.lgmap.getBlockIndices()
118        self.assertEqual(len(idx), size//bs)
119        for i, val in enumerate(self.idx):
120            self.assertEqual(idx[i], val)
121
122    def testGetIndices(self):
123        size = self.lgmap.getSize()
124        bs = self.lgmap.getBlockSize()
125        idx = self.lgmap.getIndices()
126        self.assertEqual(len(idx), size)
127        for i, val in enumerate(self.idx):
128            for j in range(bs):
129                self.assertEqual(idx[i*bs+j], val*bs+j)
130
131    def testGetBlockInfo(self):
132        info = self.lgmap.getBlockInfo()
133        self.assertEqual(type(info), dict)
134        if self.lgmap.getComm().getSize() == 1:
135            self.assertEqual(info, {})
136        else:
137            self.assertTrue(len(info) > 1)
138            self.assertTrue(len(info) < 4)
139
140    def testGetInfo(self):
141        info = self.lgmap.getInfo()
142        self.assertEqual(type(info), dict)
143        if self.lgmap.getComm().getSize() == 1:
144            self.assertEqual(info, {})
145        else:
146            self.assertTrue(len(info) > 1)
147            self.assertTrue(len(info) < 4)
148
149# --------------------------------------------------------------------
150
151if __name__ == '__main__':
152    unittest.main()
153