xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 22fceea1769ca91bdb9988b063eaa3e47b647107)
15808f684SSatish Balayfrom petsc4py import PETSc
253022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
17*22fceea1SStefano Zampiniclass ScaledIdentity(Matrix):
18*22fceea1SStefano Zampini
19*22fceea1SStefano Zampini    s = 2.0
20*22fceea1SStefano Zampini
21*22fceea1SStefano Zampini    def scale(self, mat, s):
22*22fceea1SStefano Zampini        self.s *= s
23*22fceea1SStefano Zampini
24*22fceea1SStefano Zampini    def shift(self, mat, s):
25*22fceea1SStefano Zampini        self.s += s
265808f684SSatish Balay
275808f684SSatish Balay    def mult(self, mat, x, y):
285808f684SSatish Balay        x.copy(y)
29*22fceea1SStefano Zampini        y.scale(self.s)
305808f684SSatish Balay
315808f684SSatish Balay    def getDiagonal(self, mat, vd):
32*22fceea1SStefano Zampini        vd.set(self.s)
335808f684SSatish Balay
34ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
35ee6c7c31SStefano Zampini        return True
36ee6c7c31SStefano Zampini
37ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
38ee6c7c31SStefano Zampini        if producttype == 'AB':
39ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
40ee6c7c31SStefano Zampini                product.setType(B.getType())
41ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
42ee6c7c31SStefano Zampini                product.setUp()
43ee6c7c31SStefano Zampini                product.assemble()
44ee6c7c31SStefano Zampini                B.copy(product)
45ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
46ee6c7c31SStefano Zampini                product.setType(A.getType())
47ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
48ee6c7c31SStefano Zampini                product.setUp()
49ee6c7c31SStefano Zampini                product.assemble()
50ee6c7c31SStefano Zampini                A.copy(product)
51ee6c7c31SStefano Zampini            else:
52ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
53ee6c7c31SStefano Zampini        elif producttype == 'AtB':
54ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
55ee6c7c31SStefano Zampini                product.setType(B.getType())
56ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                B.copy(product)
60ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
61ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
62ee6c7c31SStefano Zampini                A.transpose(tmp)
63ee6c7c31SStefano Zampini                product.setType(tmp.getType())
64ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
65ee6c7c31SStefano Zampini                product.setUp()
66ee6c7c31SStefano Zampini                product.assemble()
67ee6c7c31SStefano Zampini                tmp.copy(product)
68ee6c7c31SStefano Zampini            else:
69ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
70ee6c7c31SStefano Zampini        elif producttype == 'ABt':
71ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
72ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
73ee6c7c31SStefano Zampini                B.transpose(tmp)
74ee6c7c31SStefano Zampini                product.setType(tmp.getType())
75ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
76ee6c7c31SStefano Zampini                product.setUp()
77ee6c7c31SStefano Zampini                product.assemble()
78ee6c7c31SStefano Zampini                tmp.copy(product)
79ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
80ee6c7c31SStefano Zampini                product.setType(A.getType())
81ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
82ee6c7c31SStefano Zampini                product.setUp()
83ee6c7c31SStefano Zampini                product.assemble()
84ee6c7c31SStefano Zampini                A.copy(product)
85ee6c7c31SStefano Zampini            else:
86ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
87ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
88ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
89ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
90ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
91ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
92ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
93ee6c7c31SStefano Zampini                product.setUp()
94ee6c7c31SStefano Zampini                product.assemble()
95ee6c7c31SStefano Zampini                self.tmp.copy(product)
96ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
97ee6c7c31SStefano Zampini                product.setType(A.getType())
98ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
99ee6c7c31SStefano Zampini                product.setUp()
100ee6c7c31SStefano Zampini                product.assemble()
101ee6c7c31SStefano Zampini                A.copy(product)
102ee6c7c31SStefano Zampini            else:
103ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
104ee6c7c31SStefano Zampini        elif producttype == 'RARt':
105ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
106ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
107ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
108ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
109ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
110ee6c7c31SStefano Zampini                product.setUp()
111ee6c7c31SStefano Zampini                product.assemble()
112ee6c7c31SStefano Zampini                self.tmp.copy(product)
113ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
114ee6c7c31SStefano Zampini                product.setType(A.getType())
115ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
116ee6c7c31SStefano Zampini                product.setUp()
117ee6c7c31SStefano Zampini                product.assemble()
118ee6c7c31SStefano Zampini                A.copy(product)
119ee6c7c31SStefano Zampini            else:
120ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
121ee6c7c31SStefano Zampini        elif producttype == 'ABC':
122ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
123ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
124ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
125ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
126ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
127ee6c7c31SStefano Zampini                product.setUp()
128ee6c7c31SStefano Zampini                product.assemble()
129ee6c7c31SStefano Zampini                self.tmp.copy(product)
130ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
131ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
132ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
133ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
134ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
135ee6c7c31SStefano Zampini                product.setUp()
136ee6c7c31SStefano Zampini                product.assemble()
137ee6c7c31SStefano Zampini                self.tmp.copy(product)
138ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
139ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
140ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
141ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
142ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
143ee6c7c31SStefano Zampini                product.setUp()
144ee6c7c31SStefano Zampini                product.assemble()
145ee6c7c31SStefano Zampini                self.tmp.copy(product)
146ee6c7c31SStefano Zampini            else:
147ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
148ee6c7c31SStefano Zampini        else:
149ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
150ee6c7c31SStefano Zampini        product.zeroEntries()
151ee6c7c31SStefano Zampini
152ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
153ee6c7c31SStefano Zampini        if producttype == 'AB':
154ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
155ee6c7c31SStefano Zampini                B.copy(product, structure=True)
156ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
157ee6c7c31SStefano Zampini                A.copy(product, structure=True)
158ee6c7c31SStefano Zampini            else:
159ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
160*22fceea1SStefano Zampini            product.scale(self.s)
161ee6c7c31SStefano Zampini        elif producttype == 'AtB':
162ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
163ee6c7c31SStefano Zampini                B.copy(product, structure=True)
164ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
165ee6c7c31SStefano Zampini                A.transpose(product)
166ee6c7c31SStefano Zampini            else:
167ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
168*22fceea1SStefano Zampini            product.scale(self.s)
169ee6c7c31SStefano Zampini        elif producttype == 'ABt':
170ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
171ee6c7c31SStefano Zampini                B.transpose(product)
172ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
173ee6c7c31SStefano Zampini                A.copy(product, structure=True)
174ee6c7c31SStefano Zampini            else:
175ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
176*22fceea1SStefano Zampini            product.scale(self.s)
177ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
178ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
179ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
180ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
181*22fceea1SStefano Zampini                product.scale(self.s)
182ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
183ee6c7c31SStefano Zampini                A.copy(product, structure=True)
184*22fceea1SStefano Zampini                product.scale(self.s**2)
185ee6c7c31SStefano Zampini            else:
186ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
187ee6c7c31SStefano Zampini        elif producttype == 'RARt':
188ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
189ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
190ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
191*22fceea1SStefano Zampini                product.scale(self.s)
192ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
193ee6c7c31SStefano Zampini                A.copy(product, structure=True)
194*22fceea1SStefano Zampini                product.scale(self.s**2)
195ee6c7c31SStefano Zampini            else:
196ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
197ee6c7c31SStefano Zampini        elif producttype == 'ABC':
198ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
199ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
200ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
201ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
202ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
203ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
204ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
205ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
206ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
207ee6c7c31SStefano Zampini            else:
208ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
209*22fceea1SStefano Zampini            product.scale(self.s)
210ee6c7c31SStefano Zampini        else:
211ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
212ee6c7c31SStefano Zampini
2135808f684SSatish Balayclass Diagonal(Matrix):
2145808f684SSatish Balay
2155808f684SSatish Balay    def create(self, mat):
2165808f684SSatish Balay        super(Diagonal,self).create(mat)
2175808f684SSatish Balay        mat.setUp()
2185808f684SSatish Balay        self.D = mat.createVecLeft()
2195808f684SSatish Balay
2205808f684SSatish Balay    def destroy(self, mat):
2215808f684SSatish Balay        self.D.destroy()
2225808f684SSatish Balay        super(Diagonal,self).destroy(mat)
2235808f684SSatish Balay
2245808f684SSatish Balay    def scale(self, mat, a):
2255808f684SSatish Balay        self.D.scale(a)
2265808f684SSatish Balay
2275808f684SSatish Balay    def shift(self, mat, a):
2285808f684SSatish Balay        self.D.shift(a)
2295808f684SSatish Balay
2305808f684SSatish Balay    def zeroEntries(self, mat):
2315808f684SSatish Balay        self.D.zeroEntries()
2325808f684SSatish Balay
2335808f684SSatish Balay    def mult(self, mat, x, y):
2345808f684SSatish Balay        y.pointwiseMult(x, self.D)
2355808f684SSatish Balay
2365808f684SSatish Balay    def getDiagonal(self, mat, vd):
2375808f684SSatish Balay        self.D.copy(vd)
2385808f684SSatish Balay
2395808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2405808f684SSatish Balay        if isinstance (im, bool):
2415808f684SSatish Balay            addv = im
2425808f684SSatish Balay            if addv:
2435808f684SSatish Balay                self.D.axpy(1, vd)
2445808f684SSatish Balay            else:
2455808f684SSatish Balay                vd.copy(self.D)
2465808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2475808f684SSatish Balay            vd.copy(self.D)
2485808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2495808f684SSatish Balay            self.D.axpy(1, vd)
2505808f684SSatish Balay        else:
2515808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
2525808f684SSatish Balay
2535808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2545808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
2555808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
2565808f684SSatish Balay
2575808f684SSatish Balay# --------------------------------------------------------------------
2585808f684SSatish Balay
2595808f684SSatish Balayclass TestMatrix(unittest.TestCase):
2605808f684SSatish Balay
2615808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2625808f684SSatish Balay    PYMOD = __name__
2635808f684SSatish Balay    PYCLS = 'Matrix'
2645808f684SSatish Balay
2655808f684SSatish Balay    def _getCtx(self):
2665808f684SSatish Balay        return self.A.getPythonContext()
2675808f684SSatish Balay
2685808f684SSatish Balay    def setUp(self):
2695808f684SSatish Balay        N = self.N = 10
2705808f684SSatish Balay        self.A = PETSc.Mat()
2715808f684SSatish Balay        if 0: # command line way
2725808f684SSatish Balay            self.A.create(self.COMM)
2735808f684SSatish Balay            self.A.setSizes([N,N])
2745808f684SSatish Balay            self.A.setType('python')
2755808f684SSatish Balay            OptDB = PETSc.Options(self.A)
2765808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
2775808f684SSatish Balay            self.A.setFromOptions()
2785808f684SSatish Balay            self.A.setUp()
2795808f684SSatish Balay            del OptDB['mat_python_type']
2805808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
2815808f684SSatish Balay        else: # python way
2825808f684SSatish Balay            context = globals()[self.PYCLS]()
2835808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
2845808f684SSatish Balay            self.A.setUp()
2855808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
2865808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
2875808f684SSatish Balay            del context
2885808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
2895808f684SSatish Balay
2905808f684SSatish Balay    def tearDown(self):
2915808f684SSatish Balay        ctx = self.A.getPythonContext()
2925808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
2935808f684SSatish Balay        self.A.destroy() # XXX
2945808f684SSatish Balay        self.A = None
2955808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
2965808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
2975808f684SSatish Balay
2985808f684SSatish Balay    def testBasic(self):
2995808f684SSatish Balay        ctx = self.A.getPythonContext()
3005808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3015808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3025808f684SSatish Balay
3035808f684SSatish Balay    def testZeroEntries(self):
3045808f684SSatish Balay        f = lambda : self.A.zeroEntries()
3055808f684SSatish Balay        self.assertRaises(Exception, f)
3065808f684SSatish Balay
3075808f684SSatish Balay    def testMult(self):
3085808f684SSatish Balay        x, y = self.A.createVecs()
3095808f684SSatish Balay        f = lambda : self.A.mult(x, y)
3105808f684SSatish Balay        self.assertRaises(Exception, f)
3115808f684SSatish Balay
3125808f684SSatish Balay    def testMultTranspose(self):
3135808f684SSatish Balay        x, y = self.A.createVecs()
3145808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3155808f684SSatish Balay        self.assertRaises(Exception, f)
3165808f684SSatish Balay
3175808f684SSatish Balay    def testGetDiagonal(self):
3185808f684SSatish Balay        d = self.A.createVecLeft()
3195808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
3205808f684SSatish Balay        self.assertRaises(Exception, f)
3215808f684SSatish Balay
3225808f684SSatish Balay    def testSetDiagonal(self):
3235808f684SSatish Balay        d = self.A.createVecLeft()
3245808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
3255808f684SSatish Balay        self.assertRaises(Exception, f)
3265808f684SSatish Balay
3275808f684SSatish Balay    def testDiagonalScale(self):
3285808f684SSatish Balay        x, y = self.A.createVecs()
3295808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
3305808f684SSatish Balay        self.assertRaises(Exception, f)
3315808f684SSatish Balay
332*22fceea1SStefano Zampiniclass TestScaledIdentity(TestMatrix):
3335808f684SSatish Balay
334*22fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
3355808f684SSatish Balay
3365808f684SSatish Balay    def testMult(self):
337*22fceea1SStefano Zampini        s = self._getCtx().s
3385808f684SSatish Balay        x, y = self.A.createVecs()
3395808f684SSatish Balay        x.setRandom()
3405808f684SSatish Balay        self.A.mult(x,y)
341*22fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
3425808f684SSatish Balay
3435808f684SSatish Balay    def testMultTransposeSymmKnown(self):
344*22fceea1SStefano Zampini        s = self._getCtx().s
3455808f684SSatish Balay        x, y = self.A.createVecs()
3465808f684SSatish Balay        x.setRandom()
3475808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
3485808f684SSatish Balay        self.A.multTranspose(x,y)
349*22fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
3505808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
3515808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3525808f684SSatish Balay        self.assertRaises(Exception, f)
3535808f684SSatish Balay
3545808f684SSatish Balay    def testMultTransposeNewMeth(self):
355*22fceea1SStefano Zampini        s = self._getCtx().s
3565808f684SSatish Balay        x, y = self.A.createVecs()
3575808f684SSatish Balay        x.setRandom()
3585808f684SSatish Balay        AA = self.A.getPythonContext()
3595808f684SSatish Balay        AA.multTranspose = AA.mult
3605808f684SSatish Balay        self.A.multTranspose(x,y)
3615808f684SSatish Balay        del AA.multTranspose
362*22fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
3635808f684SSatish Balay
3645808f684SSatish Balay    def testGetDiagonal(self):
365*22fceea1SStefano Zampini        s = self._getCtx().s
3665808f684SSatish Balay        d = self.A.createVecLeft()
3675808f684SSatish Balay        o = d.duplicate()
368*22fceea1SStefano Zampini        o.set(s)
3695808f684SSatish Balay        self.A.getDiagonal(d)
3705808f684SSatish Balay        self.assertTrue(o.equal(d))
3715808f684SSatish Balay
372ee6c7c31SStefano Zampini    def testMatMat(self):
373*22fceea1SStefano Zampini        s = self._getCtx().s
374ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
375ee6c7c31SStefano Zampini        R.setFromOptions()
376ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
377ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
378ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
379ee6c7c31SStefano Zampini        A.setUp()
380ee6c7c31SStefano Zampini        A.setRandom(R)
381ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
382ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
383ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
384ee6c7c31SStefano Zampini        B.setUp()
385ee6c7c31SStefano Zampini        B.setRandom(R)
386ee6c7c31SStefano Zampini        I = PETSc.Mat().create(self.COMM)
387ee6c7c31SStefano Zampini        I.setSizes(self.A.getSizes())
388ee6c7c31SStefano Zampini        I.setType(PETSc.Mat.Type.AIJ)
389ee6c7c31SStefano Zampini        I.setUp()
390ee6c7c31SStefano Zampini        I.assemble()
391*22fceea1SStefano Zampini        I.shift(s)
392ee6c7c31SStefano Zampini
393ee6c7c31SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(I.matMult(A)))
394ee6c7c31SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(I)))
395ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
396ee6c7c31SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A)))
397ee6c7c31SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I)))
398ee6c7c31SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A)))
399ee6c7c31SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I)))
400ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5)
401ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5)
402ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
403ee6c7c31SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5)
404ee6c7c31SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5)
405ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5)
406ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5)
407ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5)
408ee6c7c31SStefano Zampini
40953022affSStefano Zampini    def testH2Opus(self):
41053022affSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
41153022affSStefano Zampini            return
412ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() > 1:
413ee6c7c31SStefano Zampini            return
41453022affSStefano Zampini        h = PETSc.Mat()
41553022affSStefano Zampini
41653022affSStefano Zampini        # need transpose operation for norm estimation
41753022affSStefano Zampini        AA = self.A.getPythonContext()
41853022affSStefano Zampini        AA.multTranspose = AA.mult
41953022affSStefano Zampini
42053022affSStefano Zampini        # without coordinates
42153022affSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
42253022affSStefano Zampini        h.assemble()
42353022affSStefano Zampini        h.destroy()
42453022affSStefano Zampini
42553022affSStefano Zampini        # with coordinates
426e0aaf7daSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType)
42753022affSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
42853022affSStefano Zampini        h.assemble()
42953022affSStefano Zampini        h.destroy()
43053022affSStefano Zampini
43153022affSStefano Zampini        del AA.multTranspose
4325808f684SSatish Balay
433*22fceea1SStefano Zampini    def testShift(self):
434*22fceea1SStefano Zampini        sold = self._getCtx().s
435*22fceea1SStefano Zampini        self.A.shift(-0.5)
436*22fceea1SStefano Zampini        s = self._getCtx().s
437*22fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
438*22fceea1SStefano Zampini
439*22fceea1SStefano Zampini    def testScale(self):
440*22fceea1SStefano Zampini        sold = self._getCtx().s
441*22fceea1SStefano Zampini        self.A.scale(-0.5)
442*22fceea1SStefano Zampini        s = self._getCtx().s
443*22fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
444*22fceea1SStefano Zampini
4455808f684SSatish Balayclass TestDiagonal(TestMatrix):
4465808f684SSatish Balay
4475808f684SSatish Balay    PYCLS = 'Diagonal'
4485808f684SSatish Balay
4495808f684SSatish Balay    def setUp(self):
4505808f684SSatish Balay        super(TestDiagonal, self).setUp()
4515808f684SSatish Balay        D = self.A.createVecLeft()
4525808f684SSatish Balay        s, e = D.getOwnershipRange()
4535808f684SSatish Balay        for i in range(s, e):
4545808f684SSatish Balay            D[i] = i+1
4555808f684SSatish Balay        D.assemble()
4565808f684SSatish Balay        self.A.setDiagonal(D)
4575808f684SSatish Balay
4585808f684SSatish Balay    def testZeroEntries(self):
4595808f684SSatish Balay        self.A.zeroEntries()
4605808f684SSatish Balay        D = self._getCtx().D
4615808f684SSatish Balay        self.assertEqual(D.norm(), 0)
4625808f684SSatish Balay
4635808f684SSatish Balay    def testMult(self):
4645808f684SSatish Balay        x, y = self.A.createVecs()
4655808f684SSatish Balay        x.set(1)
4665808f684SSatish Balay        self.A.mult(x,y)
4675808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4685808f684SSatish Balay
4695808f684SSatish Balay    def testMultTransposeSymmKnown(self):
4705808f684SSatish Balay        x, y = self.A.createVecs()
4715808f684SSatish Balay        x.set(1)
4725808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4735808f684SSatish Balay        self.A.multTranspose(x,y)
4745808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4755808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4765808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
4775808f684SSatish Balay        self.assertRaises(Exception, f)
4785808f684SSatish Balay
4795808f684SSatish Balay    def testMultTransposeNewMeth(self):
4805808f684SSatish Balay        x, y = self.A.createVecs()
4815808f684SSatish Balay        x.set(1)
4825808f684SSatish Balay        AA = self.A.getPythonContext()
4835808f684SSatish Balay        AA.multTranspose = AA.mult
4845808f684SSatish Balay        self.A.multTranspose(x,y)
4855808f684SSatish Balay        del AA.multTranspose
4865808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4875808f684SSatish Balay
4885808f684SSatish Balay    def testGetDiagonal(self):
4895808f684SSatish Balay        d = self.A.createVecLeft()
4905808f684SSatish Balay        self.A.getDiagonal(d)
4915808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
4925808f684SSatish Balay
4935808f684SSatish Balay    def testSetDiagonal(self):
4945808f684SSatish Balay        d = self.A.createVecLeft()
4955808f684SSatish Balay        d.setRandom()
4965808f684SSatish Balay        self.A.setDiagonal(d)
4975808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
4985808f684SSatish Balay
4995808f684SSatish Balay    def testDiagonalScale(self):
5005808f684SSatish Balay        x, y = self.A.createVecs()
5015808f684SSatish Balay        x.set(2)
5025808f684SSatish Balay        y.set(3)
5035808f684SSatish Balay        old = self._getCtx().D.copy()
5045808f684SSatish Balay        self.A.diagonalScale(x, y)
5055808f684SSatish Balay        D = self._getCtx().D
5065808f684SSatish Balay        self.assertTrue(D.equal(old*6))
5075808f684SSatish Balay
5085808f684SSatish Balay    def testCreateTranspose(self):
5095808f684SSatish Balay        A = self.A
5105808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5115808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
5125808f684SSatish Balay        x, y = A.createVecs()
5135808f684SSatish Balay        xt, yt = AT.createVecs()
5145808f684SSatish Balay        #
5155808f684SSatish Balay        y.setRandom()
5165808f684SSatish Balay        A.multTranspose(y, x)
5175808f684SSatish Balay        y.copy(xt)
5185808f684SSatish Balay        AT.mult(xt, yt)
5195808f684SSatish Balay        self.assertTrue(yt.equal(x))
5205808f684SSatish Balay        #
5215808f684SSatish Balay        x.setRandom()
5225808f684SSatish Balay        A.mult(x, y)
5235808f684SSatish Balay        x.copy(yt)
5245808f684SSatish Balay        AT.multTranspose(yt, xt)
5255808f684SSatish Balay        self.assertTrue(xt.equal(y))
5265808f684SSatish Balay        del A
5275808f684SSatish Balay
5288af18dd8SStefano Zampini    def testConvert(self):
5298af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ,PETSc.Mat()).equal(self.A))
5308af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ,PETSc.Mat()).equal(self.A))
5318af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ,PETSc.Mat()).equal(self.A))
5328af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE,PETSc.Mat()).equal(self.A))
5338c2316a8SJeremy Tillay
534*22fceea1SStefano Zampini    def testShift(self):
535*22fceea1SStefano Zampini        old = self._getCtx().D.copy()
536*22fceea1SStefano Zampini        self.A.shift(-0.5)
537*22fceea1SStefano Zampini        D = self._getCtx().D
538*22fceea1SStefano Zampini        self.assertTrue(D.equal(old-0.5))
539*22fceea1SStefano Zampini
540*22fceea1SStefano Zampini    def testScale(self):
541*22fceea1SStefano Zampini        old = self._getCtx().D.copy()
542*22fceea1SStefano Zampini        self.A.scale(-0.5)
543*22fceea1SStefano Zampini        D = self._getCtx().D
544*22fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5*old))
545*22fceea1SStefano Zampini
546*22fceea1SStefano Zampini
5475808f684SSatish Balay# --------------------------------------------------------------------
5485808f684SSatish Balay
5495808f684SSatish Balayif __name__ == '__main__':
5505808f684SSatish Balay    unittest.main()
551