xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision b458e8f169278db94fa1d489e1d3db410fc8a4c8)
1from petsc4py import PETSc
2import unittest
3from sys import getrefcount
4
5# --------------------------------------------------------------------
6
7class Matrix(object):
8
9    def __init__(self):
10        pass
11
12    def create(self, mat):
13        pass
14
15    def destroy(self, mat):
16        pass
17
18class Identity(Matrix):
19
20    def mult(self, mat, x, y):
21        x.copy(y)
22
23    def getDiagonal(self, mat, vd):
24        vd.set(1)
25
26class Diagonal(Matrix):
27
28    def create(self, mat):
29        super(Diagonal,self).create(mat)
30        mat.setUp()
31        self.D = mat.createVecLeft()
32
33    def destroy(self, mat):
34        self.D.destroy()
35        super(Diagonal,self).destroy(mat)
36
37    def scale(self, mat, a):
38        self.D.scale(a)
39
40    def shift(self, mat, a):
41        self.D.shift(a)
42
43    def zeroEntries(self, mat):
44        self.D.zeroEntries()
45
46    def mult(self, mat, x, y):
47        y.pointwiseMult(x, self.D)
48
49    def getDiagonal(self, mat, vd):
50        self.D.copy(vd)
51
52    def setDiagonal(self, mat, vd, im):
53        if isinstance (im, bool):
54            addv = im
55            if addv:
56                self.D.axpy(1, vd)
57            else:
58                vd.copy(self.D)
59        elif im == PETSc.InsertMode.INSERT_VALUES:
60            vd.copy(self.D)
61        elif im == PETSc.InsertMode.ADD_VALUES:
62            self.D.axpy(1, vd)
63        else:
64            raise ValueError('wrong InsertMode %d'% im)
65
66    def diagonalScale(self, mat, vl, vr):
67        if vl: self.D.pointwiseMult(self.D, vl)
68        if vr: self.D.pointwiseMult(self.D, vr)
69
70# --------------------------------------------------------------------
71
72class TestMatrix(unittest.TestCase):
73
74    COMM = PETSc.COMM_WORLD
75    PYMOD = __name__
76    PYCLS = 'Matrix'
77
78    def _getCtx(self):
79        return self.A.getPythonContext()
80
81    def setUp(self):
82        N = self.N = 10
83        self.A = PETSc.Mat()
84        if 0: # command line way
85            self.A.create(self.COMM)
86            self.A.setSizes([N,N])
87            self.A.setType('python')
88            OptDB = PETSc.Options(self.A)
89            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
90            self.A.setFromOptions()
91            self.A.setUp()
92            del OptDB['mat_python_type']
93            self.assertTrue(self._getCtx() is not None)
94        else: # python way
95            context = globals()[self.PYCLS]()
96            self.A.createPython([N,N], context, comm=self.COMM)
97            self.A.setUp()
98            self.assertTrue(self._getCtx() is context)
99            self.assertEqual(getrefcount(context), 3)
100            del context
101            self.assertEqual(getrefcount(self._getCtx()), 2)
102
103    def tearDown(self):
104        ctx = self.A.getPythonContext()
105        self.assertEqual(getrefcount(ctx), 3)
106        self.A.destroy() # XXX
107        self.A = None
108        self.assertEqual(getrefcount(ctx), 2)
109        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
110
111    def testBasic(self):
112        ctx = self.A.getPythonContext()
113        self.assertTrue(self._getCtx() is ctx)
114        self.assertEqual(getrefcount(ctx), 3)
115
116    def testZeroEntries(self):
117        f = lambda : self.A.zeroEntries()
118        self.assertRaises(Exception, f)
119
120    def testMult(self):
121        x, y = self.A.createVecs()
122        f = lambda : self.A.mult(x, y)
123        self.assertRaises(Exception, f)
124
125    def testMultTranspose(self):
126        x, y = self.A.createVecs()
127        f = lambda : self.A.multTranspose(x, y)
128        self.assertRaises(Exception, f)
129
130    def testGetDiagonal(self):
131        d = self.A.createVecLeft()
132        f = lambda : self.A.getDiagonal(d)
133        self.assertRaises(Exception, f)
134
135    def testSetDiagonal(self):
136        d = self.A.createVecLeft()
137        f = lambda : self.A.setDiagonal(d)
138        self.assertRaises(Exception, f)
139
140    def testDiagonalScale(self):
141        x, y = self.A.createVecs()
142        f = lambda : self.A.diagonalScale(x, y)
143        self.assertRaises(Exception, f)
144
145class TestIdentity(TestMatrix):
146
147    PYCLS = 'Identity'
148
149    def testMult(self):
150        x, y = self.A.createVecs()
151        x.setRandom()
152        self.A.mult(x,y)
153        self.assertTrue(y.equal(x))
154
155    def testMultTransposeSymmKnown(self):
156        x, y = self.A.createVecs()
157        x.setRandom()
158        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
159        self.A.multTranspose(x,y)
160        self.assertTrue(y.equal(x))
161        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
162        f = lambda : self.A.multTranspose(x, y)
163        self.assertRaises(Exception, f)
164
165    def testMultTransposeNewMeth(self):
166        x, y = self.A.createVecs()
167        x.setRandom()
168        AA = self.A.getPythonContext()
169        AA.multTranspose = AA.mult
170        self.A.multTranspose(x,y)
171        del AA.multTranspose
172        self.assertTrue(y.equal(x))
173
174    def testGetDiagonal(self):
175        d = self.A.createVecLeft()
176        o = d.duplicate()
177        o.set(1)
178        self.A.getDiagonal(d)
179        self.assertTrue(o.equal(d))
180
181
182class TestDiagonal(TestMatrix):
183
184    PYCLS = 'Diagonal'
185
186    def setUp(self):
187        super(TestDiagonal, self).setUp()
188        D = self.A.createVecLeft()
189        s, e = D.getOwnershipRange()
190        for i in range(s, e):
191            D[i] = i+1
192        D.assemble()
193        self.A.setDiagonal(D)
194
195
196    def testZeroEntries(self):
197        self.A.zeroEntries()
198        D = self._getCtx().D
199        self.assertEqual(D.norm(), 0)
200
201    def testMult(self):
202        x, y = self.A.createVecs()
203        x.set(1)
204        self.A.mult(x,y)
205        self.assertTrue(y.equal(self._getCtx().D))
206
207    def testMultTransposeSymmKnown(self):
208        x, y = self.A.createVecs()
209        x.set(1)
210        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
211        self.A.multTranspose(x,y)
212        self.assertTrue(y.equal(self._getCtx().D))
213        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
214        f = lambda : self.A.multTranspose(x, y)
215        self.assertRaises(Exception, f)
216
217    def testMultTransposeNewMeth(self):
218        x, y = self.A.createVecs()
219        x.set(1)
220        AA = self.A.getPythonContext()
221        AA.multTranspose = AA.mult
222        self.A.multTranspose(x,y)
223        del AA.multTranspose
224        self.assertTrue(y.equal(self._getCtx().D))
225
226    def testGetDiagonal(self):
227        d = self.A.createVecLeft()
228        self.A.getDiagonal(d)
229        self.assertTrue(d.equal(self._getCtx().D))
230
231    def testSetDiagonal(self):
232        d = self.A.createVecLeft()
233        d.setRandom()
234        self.A.setDiagonal(d)
235        self.assertTrue(d.equal(self._getCtx().D))
236
237    def testDiagonalScale(self):
238        x, y = self.A.createVecs()
239        x.set(2)
240        y.set(3)
241        old = self._getCtx().D.copy()
242        self.A.diagonalScale(x, y)
243        D = self._getCtx().D
244        self.assertTrue(D.equal(old*6))
245
246    def testCreateTranspose(self):
247        A = self.A
248        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
249        AT = PETSc.Mat().createTranspose(A)
250        x, y = A.createVecs()
251        xt, yt = AT.createVecs()
252        #
253        y.setRandom()
254        A.multTranspose(y, x)
255        y.copy(xt)
256        AT.mult(xt, yt)
257        self.assertTrue(yt.equal(x))
258        #
259        x.setRandom()
260        A.mult(x, y)
261        x.copy(yt)
262        AT.multTranspose(yt, xt)
263        self.assertTrue(xt.equal(y))
264        del A
265
266# --------------------------------------------------------------------
267
268if __name__ == '__main__':
269    unittest.main()
270