xref: /petsc/src/binding/petsc4py/test/test_is.py (revision f97672e55eacc8688507b9471cd7ec2664d7f203)
1from petsc4py import PETSc
2import unittest
3import random
4
5# --------------------------------------------------------------------
6
7class BaseTestIS(object):
8
9    TYPE = None
10
11    def tearDown(self):
12        self.iset = None
13
14    def testGetType(self):
15        istype = self.iset.getType()
16        self.assertEqual(istype, self.TYPE)
17
18    def testGetSize(self):
19        lsize = self.iset.getLocalSize()
20        gsize = self.iset.getSize()
21        self.assertTrue(lsize <= gsize)
22
23    def testDuplicate(self):
24        iset = self.iset.duplicate()
25        self.assertTrue(self.iset.equal(iset))
26        del iset
27
28    def testCopy(self):
29        iset = self.iset.copy()
30        self.assertTrue(self.iset.equal(iset))
31        del iset
32
33    def testEqual(self):
34        self.assertTrue(self.iset.equal(self.iset))
35        iset = self.iset.duplicate()
36        self.assertTrue(self.iset.equal(iset))
37        del iset
38
39    def testSort(self):
40        self.iset.sort()
41        self.assertTrue(self.iset.isSorted())
42
43    def testDifference(self):
44        iset = self.iset.difference(self.iset)
45        self.assertEqual(iset.getLocalSize(), 0)
46        del iset
47
48    def testComplement(self):
49        self.iset.sort()
50        nmin = self.iset.getIndices().min()
51        nmax = self.iset.getIndices().max()
52        iset = self.iset.complement(nmin, nmax+1)
53        iset.complement(nmin, nmax+1)
54        del iset
55
56    def testSum(self):
57        if self.iset.getComm().getSize() > 1:
58            return
59        self.iset.sort()
60        iset = self.iset.duplicate()
61        iset.sum(self.iset)
62        self.assertTrue(self.iset.equal(iset))
63        del iset
64
65    def testExpand(self):
66        iset = self.iset.expand(self.iset)
67        if self.iset.type == iset.type:
68            self.assertTrue(self.iset.equal(iset))
69        del iset
70
71    def testRenumber(self):
72        (n1,is1) = self.iset.renumber()
73        (n2,is2) = self.iset.renumber(self.iset)
74        del is1
75        del is2
76
77    def testProperties(self):
78        proplist = ['sizes', 'size', 'local_size', 'indices',
79                    'permutation', 'identity', 'sorted']
80        for prop in proplist:
81            self.assertTrue(hasattr(self.iset, prop))
82
83    def testArray(self):
84        import numpy
85        refs = self.iset.getRefCount()
86        arr1 = numpy.asarray(self.iset)
87        self.assertEqual(self.iset.getRefCount(), refs+1)
88        arr2 = self.iset.array
89        self.assertEqual(self.iset.getRefCount(), refs+2)
90        self.assertTrue((arr1 == arr2).all())
91        del arr2
92        self.assertEqual(self.iset.getRefCount(), refs+1)
93        del arr1
94        self.assertEqual(self.iset.getRefCount(), refs)
95
96
97# --------------------------------------------------------------------
98
99class TestISGeneral(BaseTestIS, unittest.TestCase):
100
101    TYPE = PETSc.IS.Type.GENERAL
102
103    def setUp(self):
104        self.idx = list(range(10))
105        random.shuffle(self.idx)
106        self.iset = PETSc.IS().createGeneral(self.idx)
107
108    def testGetIndices(self):
109        idx = self.iset.getIndices()
110        self.assertEqual(self.idx, list(idx))
111
112
113class TestISStride(BaseTestIS, unittest.TestCase):
114
115    TYPE = PETSc.IS.Type.STRIDE
116
117    def setUp(self):
118        self.info = (10, 7, 3)
119        size, start, step = self.info
120        self.iset = PETSc.IS().createStride(size, start, step)
121
122    def testGetIndices(self):
123        size, start, step = self.info
124        indices = [start+i*step for i in range(size)]
125        self.assertEqual(list(self.iset.getIndices()), indices)
126
127    def testToGeneral(self):
128        self.iset.toGeneral()
129        self.assertEqual(self.iset.getType(), PETSc.IS.Type.GENERAL)
130
131
132class TestISBlock(BaseTestIS, unittest.TestCase):
133
134    TYPE = PETSc.IS.Type.BLOCK
135
136    def setUp(self):
137        self.bsize = 3
138        self.index = list(range(0,10,2))
139        random.shuffle(self.index)
140        self.iset = PETSc.IS().createBlock(self.bsize, self.index)
141        self.assertEqual(self.iset.getType(), PETSc.IS.Type.BLOCK)
142
143    def testGetSize(self):
144        lsize = self.iset.getLocalSize()
145        self.assertEqual(lsize/self.bsize, len(self.index))
146
147    def testGetBlockSize(self):
148        bs = self.iset.getBlockSize()
149        self.assertEqual(bs, self.bsize)
150
151    def testGetBlockIndices(self):
152        index = list(self.iset.getBlockIndices())
153        self.assertEqual(index, self.index)
154
155    def testGetIndices(self):
156        bs = self.bsize
157        idx = []
158        for i in self.iset.getBlockIndices():
159            for j in range(bs):
160                idx.append(i*bs+j)
161        index = list(self.iset.getIndices())
162        #self.assertEqual(index, idx)
163
164
165# --------------------------------------------------------------------
166
167if __name__ == '__main__':
168    unittest.main()
169