xref: /petsc/src/binding/petsc4py/test/test_is.py (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1from petsc4py import PETSc
2import unittest
3import random
4
5# --------------------------------------------------------------------
6
7class BaseTestIS(object):
8
9    TYPE = None
10
11    def tearDown(self):
12        self.iset = None
13        PETSc.garbage_cleanup()
14
15    def testGetType(self):
16        istype = self.iset.getType()
17        self.assertEqual(istype, self.TYPE)
18
19    def testGetSize(self):
20        lsize = self.iset.getLocalSize()
21        gsize = self.iset.getSize()
22        self.assertTrue(lsize <= gsize)
23
24    def testDuplicate(self):
25        iset = self.iset.duplicate()
26        self.assertTrue(self.iset.equal(iset))
27        del iset
28
29    def testCopy(self):
30        iset = self.iset.copy()
31        self.assertTrue(self.iset.equal(iset))
32        del iset
33
34    def testEqual(self):
35        self.assertTrue(self.iset.equal(self.iset))
36        iset = self.iset.duplicate()
37        self.assertTrue(self.iset.equal(iset))
38        del iset
39
40    def testSort(self):
41        self.iset.sort()
42        self.assertTrue(self.iset.isSorted())
43
44    def testDifference(self):
45        iset = self.iset.difference(self.iset)
46        self.assertEqual(iset.getLocalSize(), 0)
47        del iset
48
49    def testComplement(self):
50        self.iset.sort()
51        nmin = self.iset.getIndices().min()
52        nmax = self.iset.getIndices().max()
53        iset = self.iset.complement(nmin, nmax+1)
54        iset.complement(nmin, nmax+1)
55        del iset
56
57    def testSum(self):
58        if self.iset.getComm().getSize() > 1:
59            return
60        self.iset.sort()
61        iset = self.iset.duplicate()
62        iset.sum(self.iset)
63        self.assertTrue(self.iset.equal(iset))
64        del iset
65
66    def testExpand(self):
67        iset = self.iset.expand(self.iset)
68        if self.iset.type == iset.type:
69            self.assertTrue(self.iset.equal(iset))
70        del iset
71
72    def testRenumber(self):
73        (n1,is1) = self.iset.renumber()
74        (n2,is2) = self.iset.renumber(self.iset)
75        del is1
76        del is2
77
78    def testProperties(self):
79        proplist = ['sizes', 'size', 'local_size', 'indices',
80                    'permutation', 'identity', 'sorted']
81        for prop in proplist:
82            self.assertTrue(hasattr(self.iset, prop))
83
84    def testArray(self):
85        import numpy
86        refs = self.iset.getRefCount()
87        arr1 = numpy.asarray(self.iset)
88        self.assertEqual(self.iset.getRefCount(), refs+1)
89        arr2 = self.iset.array
90        self.assertEqual(self.iset.getRefCount(), refs+2)
91        self.assertTrue((arr1 == arr2).all())
92        del arr2
93        self.assertEqual(self.iset.getRefCount(), refs+1)
94        del arr1
95        self.assertEqual(self.iset.getRefCount(), refs)
96
97
98# --------------------------------------------------------------------
99
100class TestISGeneral(BaseTestIS, unittest.TestCase):
101
102    TYPE = PETSc.IS.Type.GENERAL
103
104    def setUp(self):
105        self.idx = list(range(10))
106        random.shuffle(self.idx)
107        self.iset = PETSc.IS().createGeneral(self.idx)
108
109    def testGetIndices(self):
110        idx = self.iset.getIndices()
111        self.assertEqual(self.idx, list(idx))
112
113
114class TestISStride(BaseTestIS, unittest.TestCase):
115
116    TYPE = PETSc.IS.Type.STRIDE
117
118    def setUp(self):
119        self.info = (10, 7, 3)
120        size, start, step = self.info
121        self.iset = PETSc.IS().createStride(size, start, step)
122
123    def testGetIndices(self):
124        size, start, step = self.info
125        indices = [start+i*step for i in range(size)]
126        self.assertEqual(list(self.iset.getIndices()), indices)
127
128    def testToGeneral(self):
129        self.iset.toGeneral()
130        self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL)
131
132
133class TestISBlock(BaseTestIS, unittest.TestCase):
134
135    TYPE = PETSc.IS.Type.BLOCK
136
137    def setUp(self):
138        self.bsize = 3
139        self.index = list(range(0,10,2))
140        random.shuffle(self.index)
141        self.iset = PETSc.IS().createBlock(self.bsize, self.index)
142        self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK)
143
144    def testGetSize(self):
145        lsize = self.iset.getLocalSize()
146        self.assertEqual(lsize/self.bsize, len(self.index))
147
148    def testGetBlockSize(self):
149        bs = self.iset.getBlockSize()
150        self.assertEqual(bs, self.bsize)
151
152    def testGetBlockIndices(self):
153        index = list(self.iset.getBlockIndices())
154        self.assertEqual(index, self.index)
155
156    def testGetIndices(self):
157        bs = self.bsize
158        idx = []
159        for i in self.iset.getBlockIndices():
160            for j in range(bs):
161                idx.append(i*bs+j)
162        index = list(self.iset.getIndices())
163        #self.assertEqual(index, idx)
164
165
166# --------------------------------------------------------------------
167
168if __name__ == '__main__':
169    unittest.main()
170