xref: /petsc/src/binding/petsc4py/test/test_is.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1from petsc4py import PETSc
2import unittest
3import random
4
5# --------------------------------------------------------------------
6
7
8class BaseTestIS:
9    TYPE = None
10
11    def tearDown(self):
12        self.iset = None
13        PETSc.garbage_cleanup()
14
15    def testGetType(self):
16        istype = self.iset.getType()
17        self.assertEqual(istype, self.TYPE)
18
19    def testGetSize(self):
20        lsize = self.iset.getLocalSize()
21        gsize = self.iset.getSize()
22        self.assertTrue(lsize <= gsize)
23
24    def testDuplicate(self):
25        iset = self.iset.duplicate()
26        self.assertTrue(self.iset.equal(iset))
27        del iset
28
29    def testCopy(self):
30        iset = self.iset.copy()
31        self.assertTrue(self.iset.equal(iset))
32        del iset
33
34    def testEqual(self):
35        self.assertTrue(self.iset.equal(self.iset))
36        iset = self.iset.duplicate()
37        self.assertTrue(self.iset.equal(iset))
38        del iset
39
40    def testSort(self):
41        self.iset.sort()
42        self.assertTrue(self.iset.isSorted())
43
44    def testDifference(self):
45        iset = self.iset.difference(self.iset)
46        self.assertEqual(iset.getLocalSize(), 0)
47        del iset
48
49    def testComplement(self):
50        self.iset.sort()
51        nmin = self.iset.getIndices().min()
52        nmax = self.iset.getIndices().max()
53        iset = self.iset.complement(nmin, nmax + 1)
54        iset.complement(nmin, nmax + 1)
55        del iset
56
57    def testSum(self):
58        if self.iset.getComm().getSize() > 1:
59            return
60        self.iset.sort()
61        iset = self.iset.duplicate()
62        iset.sum(self.iset)
63        self.assertTrue(self.iset.equal(iset))
64        del iset
65
66    def testExpand(self):
67        iset = self.iset.expand(self.iset)
68        if self.iset.type == iset.type:
69            self.assertTrue(self.iset.equal(iset))
70        del iset
71
72    def testRenumber(self):
73        (n1, is1) = self.iset.renumber()
74        (n2, is2) = self.iset.renumber(self.iset)
75        del is1
76        del is2
77
78    def testProperties(self):
79        proplist = [
80            'sizes',
81            'size',
82            'local_size',
83            'indices',
84            'permutation',
85            'identity',
86            'sorted',
87        ]
88        for prop in proplist:
89            self.assertTrue(hasattr(self.iset, prop))
90
91    def testArray(self):
92        import numpy
93
94        refs = self.iset.getRefCount()
95        arr1 = numpy.asarray(self.iset)
96        self.assertEqual(self.iset.getRefCount(), refs + 1)
97        arr2 = self.iset.array
98        self.assertEqual(self.iset.getRefCount(), refs + 2)
99        self.assertTrue((arr1 == arr2).all())
100        del arr2
101        self.assertEqual(self.iset.getRefCount(), refs + 1)
102        del arr1
103        self.assertEqual(self.iset.getRefCount(), refs)
104
105
106# --------------------------------------------------------------------
107
108
109class TestISGeneral(BaseTestIS, unittest.TestCase):
110    TYPE = PETSc.IS.Type.GENERAL
111
112    def setUp(self):
113        self.idx = list(range(10))
114        random.shuffle(self.idx)
115        self.iset = PETSc.IS().createGeneral(self.idx)
116
117    def testGetIndices(self):
118        idx = self.iset.getIndices()
119        self.assertEqual(self.idx, list(idx))
120
121
122class TestISStride(BaseTestIS, unittest.TestCase):
123    TYPE = PETSc.IS.Type.STRIDE
124
125    def setUp(self):
126        self.info = (10, 7, 3)
127        size, start, step = self.info
128        self.iset = PETSc.IS().createStride(size, start, step)
129
130    def testGetIndices(self):
131        size, start, step = self.info
132        indices = [start + i * step for i in range(size)]
133        self.assertEqual(list(self.iset.getIndices()), indices)
134
135    def testToGeneral(self):
136        self.iset.toGeneral()
137        self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL)
138
139
140class TestISBlock(BaseTestIS, unittest.TestCase):
141    TYPE = PETSc.IS.Type.BLOCK
142
143    def setUp(self):
144        self.bsize = 3
145        self.index = list(range(0, 10, 2))
146        random.shuffle(self.index)
147        self.iset = PETSc.IS().createBlock(self.bsize, self.index)
148        self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK)
149
150    def testGetSize(self):
151        lsize = self.iset.getLocalSize()
152        self.assertEqual(lsize / self.bsize, len(self.index))
153
154    def testGetBlockSize(self):
155        bs = self.iset.getBlockSize()
156        self.assertEqual(bs, self.bsize)
157
158    def testGetBlockIndices(self):
159        index = list(self.iset.getBlockIndices())
160        self.assertEqual(index, self.index)
161
162    def testGetIndices(self):
163        bs = self.bsize
164        idx = []
165        for i in self.iset.getBlockIndices():
166            for j in range(bs):
167                idx.append(i * bs + j)
168        index = list(self.iset.getIndices())
169        self.assertEqual(index, idx)
170
171
172# --------------------------------------------------------------------
173
174if __name__ == '__main__':
175    unittest.main()
176