xref: /petsc/src/binding/petsc4py/test/test_lgmap.py (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6class BaseTestLGMap(object):
7
8    def _mk_idx(self, comm):
9        comm_size = comm.getSize()
10        comm_rank = comm.getRank()
11        lsize = 10
12        first = lsize * comm_rank
13        last  = first + lsize
14        if comm_rank > 0:
15            first -= 1
16        if comm_rank < (comm_size-1):
17            last += 1
18        return list(range(first, last))
19
20    def tearDown(self):
21        self.lgmap = None
22        PETSc.garbage_cleanup()
23
24    def testGetSize(self):
25        size = self.lgmap.getSize()
26        self.assertTrue(size >= 0)
27
28    def testGetIndices(self):
29        size = self.lgmap.getSize()
30        idx  = self.lgmap.getIndices()
31        self.assertEqual(len(idx), size)
32        for i, val in enumerate(self.idx):
33            self.assertEqual(idx[i], val)
34
35    def testGetInfo(self):
36        info = self.lgmap.getInfo()
37        self.assertEqual(type(info), dict)
38        if self.lgmap.getComm().getSize() == 1:
39            self.assertEqual(info, {})
40        else:
41            self.assertTrue(len(info) > 1)
42            self.assertTrue(len(info) < 4)
43
44    def testApply(self):
45        idxin  = list(range(self.lgmap.getSize()))
46        idxout = self.lgmap.apply(idxin)
47        self.lgmap.apply(idxin, idxout)
48        invmap = self.lgmap.applyInverse(idxout)
49
50
51    def testApplyIS(self):
52        is_in  = PETSc.IS().createStride(self.lgmap.getSize())
53        is_out = self.lgmap.apply(is_in)
54
55    def testProperties(self):
56        for prop in ('size', 'indices', 'info'):
57            self.assertTrue(hasattr(self.lgmap, prop))
58
59# --------------------------------------------------------------------
60
61class TestLGMap(BaseTestLGMap, unittest.TestCase):
62
63    def setUp(self):
64        self.idx   = self._mk_idx(PETSc.COMM_WORLD)
65        self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD)
66
67class TestLGMapIS(BaseTestLGMap, unittest.TestCase):
68
69    def setUp(self):
70        self.idx   = self._mk_idx(PETSc.COMM_WORLD)
71        self.iset  = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD)
72        self.lgmap = PETSc.LGMap().create(self.iset)
73
74    def tearDown(self):
75        self.iset  = None
76        self.lgmap = None
77
78    def testSameComm(self):
79        comm1 = self.lgmap.getComm()
80        comm2 = self.iset.getComm()
81        self.assertEqual(comm1, comm2)
82
83# --------------------------------------------------------------------
84
85class TestLGMapBlock(unittest.TestCase):
86
87    BS = 3
88
89    def setUp(self):
90        comm = PETSc.COMM_WORLD
91        comm_size = comm.getSize()
92        comm_rank = comm.getRank()
93        lsize = 10
94        first = lsize * comm_rank
95        last  = first + lsize
96        if comm_rank > 0:
97            first -= 1
98        if comm_rank < (comm_size-1):
99            last += 1
100        self.idx = list(range(first, last))
101        bs = self.BS
102        self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD)
103
104    def tearDown(self):
105        self.lgmap = None
106
107    def testGetSize(self):
108        size = self.lgmap.getSize()
109        self.assertTrue(size >= 0)
110
111    def testGetBlockSize(self):
112        bs = self.lgmap.getBlockSize()
113        self.assertEqual(bs, self.BS)
114
115    def testGetBlockIndices(self):
116        size = self.lgmap.getSize()
117        bs = self.lgmap.getBlockSize()
118        idx = self.lgmap.getBlockIndices()
119        self.assertEqual(len(idx), size//bs)
120        for i, val in enumerate(self.idx):
121            self.assertEqual(idx[i], val)
122
123    def testGetIndices(self):
124        size = self.lgmap.getSize()
125        bs = self.lgmap.getBlockSize()
126        idx = self.lgmap.getIndices()
127        self.assertEqual(len(idx), size)
128        for i, val in enumerate(self.idx):
129            for j in range(bs):
130                self.assertEqual(idx[i*bs+j], val*bs+j)
131
132    def testGetBlockInfo(self):
133        info = self.lgmap.getBlockInfo()
134        self.assertEqual(type(info), dict)
135        if self.lgmap.getComm().getSize() == 1:
136            self.assertEqual(info, {})
137        else:
138            self.assertTrue(len(info) > 1)
139            self.assertTrue(len(info) < 4)
140
141    def testGetInfo(self):
142        info = self.lgmap.getInfo()
143        self.assertEqual(type(info), dict)
144        if self.lgmap.getComm().getSize() == 1:
145            self.assertEqual(info, {})
146        else:
147            self.assertTrue(len(info) > 1)
148            self.assertTrue(len(info) < 4)
149
150# --------------------------------------------------------------------
151
152if __name__ == '__main__':
153    unittest.main()
154