xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1from petsc4py import PETSc
2import unittest
3from sys import getrefcount
4# --------------------------------------------------------------------
5
6class Matrix(object):
7
8    def __init__(self):
9        pass
10
11    def create(self, mat):
12        pass
13
14    def destroy(self, mat):
15        pass
16
17class Identity(Matrix):
18
19    def mult(self, mat, x, y):
20        x.copy(y)
21
22    def getDiagonal(self, mat, vd):
23        vd.set(1)
24
25class Diagonal(Matrix):
26
27    def create(self, mat):
28        super(Diagonal,self).create(mat)
29        mat.setUp()
30        self.D = mat.createVecLeft()
31
32    def destroy(self, mat):
33        self.D.destroy()
34        super(Diagonal,self).destroy(mat)
35
36    def scale(self, mat, a):
37        self.D.scale(a)
38
39    def shift(self, mat, a):
40        self.D.shift(a)
41
42    def zeroEntries(self, mat):
43        self.D.zeroEntries()
44
45    def mult(self, mat, x, y):
46        y.pointwiseMult(x, self.D)
47
48    def getDiagonal(self, mat, vd):
49        self.D.copy(vd)
50
51    def setDiagonal(self, mat, vd, im):
52        if isinstance (im, bool):
53            addv = im
54            if addv:
55                self.D.axpy(1, vd)
56            else:
57                vd.copy(self.D)
58        elif im == PETSc.InsertMode.INSERT_VALUES:
59            vd.copy(self.D)
60        elif im == PETSc.InsertMode.ADD_VALUES:
61            self.D.axpy(1, vd)
62        else:
63            raise ValueError('wrong InsertMode %d'% im)
64
65    def diagonalScale(self, mat, vl, vr):
66        if vl: self.D.pointwiseMult(self.D, vl)
67        if vr: self.D.pointwiseMult(self.D, vr)
68
69# --------------------------------------------------------------------
70
71class TestMatrix(unittest.TestCase):
72
73    COMM = PETSc.COMM_WORLD
74    PYMOD = __name__
75    PYCLS = 'Matrix'
76
77    def _getCtx(self):
78        return self.A.getPythonContext()
79
80    def setUp(self):
81        N = self.N = 10
82        self.A = PETSc.Mat()
83        if 0: # command line way
84            self.A.create(self.COMM)
85            self.A.setSizes([N,N])
86            self.A.setType('python')
87            OptDB = PETSc.Options(self.A)
88            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
89            self.A.setFromOptions()
90            self.A.setUp()
91            del OptDB['mat_python_type']
92            self.assertTrue(self._getCtx() is not None)
93        else: # python way
94            context = globals()[self.PYCLS]()
95            self.A.createPython([N,N], context, comm=self.COMM)
96            self.A.setUp()
97            self.assertTrue(self._getCtx() is context)
98            self.assertEqual(getrefcount(context), 3)
99            del context
100            self.assertEqual(getrefcount(self._getCtx()), 2)
101
102    def tearDown(self):
103        ctx = self.A.getPythonContext()
104        self.assertEqual(getrefcount(ctx), 3)
105        self.A.destroy() # XXX
106        self.A = None
107        self.assertEqual(getrefcount(ctx), 2)
108        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
109
110    def testBasic(self):
111        ctx = self.A.getPythonContext()
112        self.assertTrue(self._getCtx() is ctx)
113        self.assertEqual(getrefcount(ctx), 3)
114
115    def testZeroEntries(self):
116        f = lambda : self.A.zeroEntries()
117        self.assertRaises(Exception, f)
118
119    def testMult(self):
120        x, y = self.A.createVecs()
121        f = lambda : self.A.mult(x, y)
122        self.assertRaises(Exception, f)
123
124    def testMultTranspose(self):
125        x, y = self.A.createVecs()
126        f = lambda : self.A.multTranspose(x, y)
127        self.assertRaises(Exception, f)
128
129    def testGetDiagonal(self):
130        d = self.A.createVecLeft()
131        f = lambda : self.A.getDiagonal(d)
132        self.assertRaises(Exception, f)
133
134    def testSetDiagonal(self):
135        d = self.A.createVecLeft()
136        f = lambda : self.A.setDiagonal(d)
137        self.assertRaises(Exception, f)
138
139    def testDiagonalScale(self):
140        x, y = self.A.createVecs()
141        f = lambda : self.A.diagonalScale(x, y)
142        self.assertRaises(Exception, f)
143
144
145class TestIdentity(TestMatrix):
146
147    PYCLS = 'Identity'
148
149    def testMult(self):
150        x, y = self.A.createVecs()
151        x.setRandom()
152        self.A.mult(x,y)
153        self.assertTrue(y.equal(x))
154
155    def testMultTransposeSymmKnown(self):
156        x, y = self.A.createVecs()
157        x.setRandom()
158        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
159        self.A.multTranspose(x,y)
160        self.assertTrue(y.equal(x))
161        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
162        f = lambda : self.A.multTranspose(x, y)
163        self.assertRaises(Exception, f)
164
165    def testMultTransposeNewMeth(self):
166        x, y = self.A.createVecs()
167        x.setRandom()
168        AA = self.A.getPythonContext()
169        AA.multTranspose = AA.mult
170        self.A.multTranspose(x,y)
171        del AA.multTranspose
172        self.assertTrue(y.equal(x))
173
174    def testGetDiagonal(self):
175        d = self.A.createVecLeft()
176        o = d.duplicate()
177        o.set(1)
178        self.A.getDiagonal(d)
179        self.assertTrue(o.equal(d))
180
181
182class TestDiagonal(TestMatrix):
183
184    PYCLS = 'Diagonal'
185
186    def setUp(self):
187        super(TestDiagonal, self).setUp()
188        D = self.A.createVecLeft()
189        s, e = D.getOwnershipRange()
190        for i in range(s, e):
191            D[i] = i+1
192        D.assemble()
193        self.A.setDiagonal(D)
194
195
196    def testZeroEntries(self):
197        self.A.zeroEntries()
198        D = self._getCtx().D
199        self.assertEqual(D.norm(), 0)
200
201    def testMult(self):
202        x, y = self.A.createVecs()
203        x.set(1)
204        self.A.mult(x,y)
205        self.assertTrue(y.equal(self._getCtx().D))
206
207    def testMultTransposeSymmKnown(self):
208        x, y = self.A.createVecs()
209        x.set(1)
210        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
211        self.A.multTranspose(x,y)
212        self.assertTrue(y.equal(self._getCtx().D))
213        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
214        f = lambda : self.A.multTranspose(x, y)
215        self.assertRaises(Exception, f)
216
217    def testMultTransposeNewMeth(self):
218        x, y = self.A.createVecs()
219        x.set(1)
220        AA = self.A.getPythonContext()
221        AA.multTranspose = AA.mult
222        self.A.multTranspose(x,y)
223        del AA.multTranspose
224        self.assertTrue(y.equal(self._getCtx().D))
225
226    def testGetDiagonal(self):
227        d = self.A.createVecLeft()
228        self.A.getDiagonal(d)
229        self.assertTrue(d.equal(self._getCtx().D))
230
231    def testSetDiagonal(self):
232        d = self.A.createVecLeft()
233        d.setRandom()
234        self.A.setDiagonal(d)
235        self.assertTrue(d.equal(self._getCtx().D))
236
237    def testDiagonalScale(self):
238        x, y = self.A.createVecs()
239        x.set(2)
240        y.set(3)
241        old = self._getCtx().D.copy()
242        self.A.diagonalScale(x, y)
243        D = self._getCtx().D
244        self.assertTrue(D.equal(old*6))
245
246    def testCreateTranspose(self):
247        A = self.A
248        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
249        AT = PETSc.Mat().createTranspose(A)
250        x, y = A.createVecs()
251        xt, yt = AT.createVecs()
252        #
253        y.setRandom()
254        A.multTranspose(y, x)
255        y.copy(xt)
256        AT.mult(xt, yt)
257        self.assertTrue(yt.equal(x))
258        #
259        x.setRandom()
260        A.mult(x, y)
261        x.copy(yt)
262        AT.multTranspose(yt, xt)
263        self.assertTrue(xt.equal(y))
264        del A
265
266
267# --------------------------------------------------------------------
268
269if __name__ == '__main__':
270    unittest.main()
271