xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision ebead697dbf761eb322f829370bbe90b3bd93fa3)
15808f684SSatish Balayfrom petsc4py import PETSc
253022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
1722fceea1SStefano Zampiniclass ScaledIdentity(Matrix):
1822fceea1SStefano Zampini
1922fceea1SStefano Zampini    s = 2.0
2022fceea1SStefano Zampini
2122fceea1SStefano Zampini    def scale(self, mat, s):
2222fceea1SStefano Zampini        self.s *= s
2322fceea1SStefano Zampini
2422fceea1SStefano Zampini    def shift(self, mat, s):
2522fceea1SStefano Zampini        self.s += s
265808f684SSatish Balay
275808f684SSatish Balay    def mult(self, mat, x, y):
285808f684SSatish Balay        x.copy(y)
2922fceea1SStefano Zampini        y.scale(self.s)
305808f684SSatish Balay
31e124b1b1SStefano Zampini    def duplicate(self, mat, op):
32e124b1b1SStefano Zampini        dmat = PETSc.Mat()
33e124b1b1SStefano Zampini        dctx = ScaledIdentity()
34e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
35e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
36e124b1b1SStefano Zampini          dctx.s = self.s
37e124b1b1SStefano Zampini          dmat.setUp()
38e124b1b1SStefano Zampini        return dmat
39e124b1b1SStefano Zampini
405808f684SSatish Balay    def getDiagonal(self, mat, vd):
4122fceea1SStefano Zampini        vd.set(self.s)
425808f684SSatish Balay
43ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
44ee6c7c31SStefano Zampini        return True
45ee6c7c31SStefano Zampini
46ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
47ee6c7c31SStefano Zampini        if producttype == 'AB':
48ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
49ee6c7c31SStefano Zampini                product.setType(B.getType())
50ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
51ee6c7c31SStefano Zampini                product.setUp()
52ee6c7c31SStefano Zampini                product.assemble()
53ee6c7c31SStefano Zampini                B.copy(product)
54ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
55ee6c7c31SStefano Zampini                product.setType(A.getType())
56ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                A.copy(product)
60ee6c7c31SStefano Zampini            else:
61ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
62ee6c7c31SStefano Zampini        elif producttype == 'AtB':
63ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
64ee6c7c31SStefano Zampini                product.setType(B.getType())
65ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
66ee6c7c31SStefano Zampini                product.setUp()
67ee6c7c31SStefano Zampini                product.assemble()
68ee6c7c31SStefano Zampini                B.copy(product)
69ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
70ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
71ee6c7c31SStefano Zampini                A.transpose(tmp)
72ee6c7c31SStefano Zampini                product.setType(tmp.getType())
73ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
74ee6c7c31SStefano Zampini                product.setUp()
75ee6c7c31SStefano Zampini                product.assemble()
76ee6c7c31SStefano Zampini                tmp.copy(product)
77ee6c7c31SStefano Zampini            else:
78ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
79ee6c7c31SStefano Zampini        elif producttype == 'ABt':
80ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
81ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
82ee6c7c31SStefano Zampini                B.transpose(tmp)
83ee6c7c31SStefano Zampini                product.setType(tmp.getType())
84ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
85ee6c7c31SStefano Zampini                product.setUp()
86ee6c7c31SStefano Zampini                product.assemble()
87ee6c7c31SStefano Zampini                tmp.copy(product)
88ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
89ee6c7c31SStefano Zampini                product.setType(A.getType())
90ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
91ee6c7c31SStefano Zampini                product.setUp()
92ee6c7c31SStefano Zampini                product.assemble()
93ee6c7c31SStefano Zampini                A.copy(product)
94ee6c7c31SStefano Zampini            else:
95ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
96ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
97ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
98ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
99ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
100ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
101ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
102ee6c7c31SStefano Zampini                product.setUp()
103ee6c7c31SStefano Zampini                product.assemble()
104ee6c7c31SStefano Zampini                self.tmp.copy(product)
105ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
106ee6c7c31SStefano Zampini                product.setType(A.getType())
107ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
108ee6c7c31SStefano Zampini                product.setUp()
109ee6c7c31SStefano Zampini                product.assemble()
110ee6c7c31SStefano Zampini                A.copy(product)
111ee6c7c31SStefano Zampini            else:
112ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
113ee6c7c31SStefano Zampini        elif producttype == 'RARt':
114ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
115ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
116ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
117ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
118ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
119ee6c7c31SStefano Zampini                product.setUp()
120ee6c7c31SStefano Zampini                product.assemble()
121ee6c7c31SStefano Zampini                self.tmp.copy(product)
122ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
123ee6c7c31SStefano Zampini                product.setType(A.getType())
124ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
125ee6c7c31SStefano Zampini                product.setUp()
126ee6c7c31SStefano Zampini                product.assemble()
127ee6c7c31SStefano Zampini                A.copy(product)
128ee6c7c31SStefano Zampini            else:
129ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
130ee6c7c31SStefano Zampini        elif producttype == 'ABC':
131ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
132ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
133ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
134ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
135ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
136ee6c7c31SStefano Zampini                product.setUp()
137ee6c7c31SStefano Zampini                product.assemble()
138ee6c7c31SStefano Zampini                self.tmp.copy(product)
139ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
140ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
141ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
142ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
143ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
144ee6c7c31SStefano Zampini                product.setUp()
145ee6c7c31SStefano Zampini                product.assemble()
146ee6c7c31SStefano Zampini                self.tmp.copy(product)
147ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
148ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
149ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
150ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
151ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
152ee6c7c31SStefano Zampini                product.setUp()
153ee6c7c31SStefano Zampini                product.assemble()
154ee6c7c31SStefano Zampini                self.tmp.copy(product)
155ee6c7c31SStefano Zampini            else:
156ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
157ee6c7c31SStefano Zampini        else:
158ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
159ee6c7c31SStefano Zampini        product.zeroEntries()
160ee6c7c31SStefano Zampini
161ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
162ee6c7c31SStefano Zampini        if producttype == 'AB':
163ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
164ee6c7c31SStefano Zampini                B.copy(product, structure=True)
165ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
166ee6c7c31SStefano Zampini                A.copy(product, structure=True)
167ee6c7c31SStefano Zampini            else:
168ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
16922fceea1SStefano Zampini            product.scale(self.s)
170ee6c7c31SStefano Zampini        elif producttype == 'AtB':
171ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
172ee6c7c31SStefano Zampini                B.copy(product, structure=True)
173ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
1747fb60732SBarry Smith                A.setTransposePrecursor(product)
175ee6c7c31SStefano Zampini                A.transpose(product)
176ee6c7c31SStefano Zampini            else:
177ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17822fceea1SStefano Zampini            product.scale(self.s)
179ee6c7c31SStefano Zampini        elif producttype == 'ABt':
180ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
1817fb60732SBarry Smith                B.setTransposePrecursor(product)
182ee6c7c31SStefano Zampini                B.transpose(product)
183ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
184ee6c7c31SStefano Zampini                A.copy(product, structure=True)
185ee6c7c31SStefano Zampini            else:
186ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18722fceea1SStefano Zampini            product.scale(self.s)
188ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
189ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
190ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
191ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19222fceea1SStefano Zampini                product.scale(self.s)
193ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
194ee6c7c31SStefano Zampini                A.copy(product, structure=True)
19522fceea1SStefano Zampini                product.scale(self.s**2)
196ee6c7c31SStefano Zampini            else:
197ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
198ee6c7c31SStefano Zampini        elif producttype == 'RARt':
199ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
200ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
201ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20222fceea1SStefano Zampini                product.scale(self.s)
203ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
204ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20522fceea1SStefano Zampini                product.scale(self.s**2)
206ee6c7c31SStefano Zampini            else:
207ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
208ee6c7c31SStefano Zampini        elif producttype == 'ABC':
209ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
210ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
211ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
212ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
213ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
214ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
215ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
216ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
217ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
218ee6c7c31SStefano Zampini            else:
219ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
22022fceea1SStefano Zampini            product.scale(self.s)
221ee6c7c31SStefano Zampini        else:
222ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
223ee6c7c31SStefano Zampini
2245808f684SSatish Balayclass Diagonal(Matrix):
2255808f684SSatish Balay
2265808f684SSatish Balay    def create(self, mat):
2275808f684SSatish Balay        super(Diagonal,self).create(mat)
2285808f684SSatish Balay        mat.setUp()
2295808f684SSatish Balay        self.D = mat.createVecLeft()
2305808f684SSatish Balay
2315808f684SSatish Balay    def destroy(self, mat):
2325808f684SSatish Balay        self.D.destroy()
2335808f684SSatish Balay        super(Diagonal,self).destroy(mat)
2345808f684SSatish Balay
2355808f684SSatish Balay    def scale(self, mat, a):
2365808f684SSatish Balay        self.D.scale(a)
2375808f684SSatish Balay
2385808f684SSatish Balay    def shift(self, mat, a):
2395808f684SSatish Balay        self.D.shift(a)
2405808f684SSatish Balay
2415808f684SSatish Balay    def zeroEntries(self, mat):
2425808f684SSatish Balay        self.D.zeroEntries()
2435808f684SSatish Balay
2445808f684SSatish Balay    def mult(self, mat, x, y):
2455808f684SSatish Balay        y.pointwiseMult(x, self.D)
2465808f684SSatish Balay
247e124b1b1SStefano Zampini    def duplicate(self, mat, op):
248e124b1b1SStefano Zampini        dmat = PETSc.Mat()
249e124b1b1SStefano Zampini        dctx = Diagonal()
250e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
251e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
252e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
253e124b1b1SStefano Zampini          self.D.copy(dctx.D)
254e124b1b1SStefano Zampini          dmat.setUp()
255e124b1b1SStefano Zampini        return dmat
256e124b1b1SStefano Zampini
2575808f684SSatish Balay    def getDiagonal(self, mat, vd):
2585808f684SSatish Balay        self.D.copy(vd)
2595808f684SSatish Balay
2605808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2615808f684SSatish Balay        if isinstance (im, bool):
2625808f684SSatish Balay            addv = im
2635808f684SSatish Balay            if addv:
2645808f684SSatish Balay                self.D.axpy(1, vd)
2655808f684SSatish Balay            else:
2665808f684SSatish Balay                vd.copy(self.D)
2675808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2685808f684SSatish Balay            vd.copy(self.D)
2695808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2705808f684SSatish Balay            self.D.axpy(1, vd)
2715808f684SSatish Balay        else:
2725808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
2735808f684SSatish Balay
2745808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2755808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
2765808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
2775808f684SSatish Balay
2785808f684SSatish Balay# --------------------------------------------------------------------
2795808f684SSatish Balay
2805808f684SSatish Balayclass TestMatrix(unittest.TestCase):
2815808f684SSatish Balay
2825808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2835808f684SSatish Balay    PYMOD = __name__
2845808f684SSatish Balay    PYCLS = 'Matrix'
2855808f684SSatish Balay
2865808f684SSatish Balay    def _getCtx(self):
2875808f684SSatish Balay        return self.A.getPythonContext()
2885808f684SSatish Balay
2895808f684SSatish Balay    def setUp(self):
290300d917bSStefano Zampini        N = self.N = 13
2915808f684SSatish Balay        self.A = PETSc.Mat()
2925808f684SSatish Balay        if 0: # command line way
2935808f684SSatish Balay            self.A.create(self.COMM)
2945808f684SSatish Balay            self.A.setSizes([N,N])
2955808f684SSatish Balay            self.A.setType('python')
2965808f684SSatish Balay            OptDB = PETSc.Options(self.A)
2975808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
2985808f684SSatish Balay            self.A.setFromOptions()
2995808f684SSatish Balay            self.A.setUp()
3005808f684SSatish Balay            del OptDB['mat_python_type']
3015808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3025808f684SSatish Balay        else: # python way
3035808f684SSatish Balay            context = globals()[self.PYCLS]()
3045808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
3055808f684SSatish Balay            self.A.setUp()
3065808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3075808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3085808f684SSatish Balay            del context
3095808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3105808f684SSatish Balay
3115808f684SSatish Balay    def tearDown(self):
3125808f684SSatish Balay        ctx = self.A.getPythonContext()
3135808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3145808f684SSatish Balay        self.A.destroy() # XXX
3155808f684SSatish Balay        self.A = None
3165808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3175808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
3185808f684SSatish Balay
3195808f684SSatish Balay    def testBasic(self):
3205808f684SSatish Balay        ctx = self.A.getPythonContext()
3215808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3225808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3235808f684SSatish Balay
3245808f684SSatish Balay    def testZeroEntries(self):
3255808f684SSatish Balay        f = lambda : self.A.zeroEntries()
3265808f684SSatish Balay        self.assertRaises(Exception, f)
3275808f684SSatish Balay
3285808f684SSatish Balay    def testMult(self):
3295808f684SSatish Balay        x, y = self.A.createVecs()
3305808f684SSatish Balay        f = lambda : self.A.mult(x, y)
3315808f684SSatish Balay        self.assertRaises(Exception, f)
3325808f684SSatish Balay
3335808f684SSatish Balay    def testMultTranspose(self):
3345808f684SSatish Balay        x, y = self.A.createVecs()
3355808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3365808f684SSatish Balay        self.assertRaises(Exception, f)
3375808f684SSatish Balay
3385808f684SSatish Balay    def testGetDiagonal(self):
3395808f684SSatish Balay        d = self.A.createVecLeft()
3405808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
3415808f684SSatish Balay        self.assertRaises(Exception, f)
3425808f684SSatish Balay
3435808f684SSatish Balay    def testSetDiagonal(self):
3445808f684SSatish Balay        d = self.A.createVecLeft()
3455808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
3465808f684SSatish Balay        self.assertRaises(Exception, f)
3475808f684SSatish Balay
3485808f684SSatish Balay    def testDiagonalScale(self):
3495808f684SSatish Balay        x, y = self.A.createVecs()
3505808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
3515808f684SSatish Balay        self.assertRaises(Exception, f)
3525808f684SSatish Balay
353e124b1b1SStefano Zampini    def testDuplicate(self):
354e124b1b1SStefano Zampini        f1 = lambda : self.A.duplicate(x, True)
355e124b1b1SStefano Zampini        f2 = lambda : self.A.duplicate(x, False)
356e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
357e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
358e124b1b1SStefano Zampini
3591cebabd4SStefano Zampini    def testSetVecType(self):
3601cebabd4SStefano Zampini        self.A.setVecType('mpi')
3611cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3621cebabd4SStefano Zampini
363300d917bSStefano Zampini    def testH2Opus(self):
364300d917bSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
365300d917bSStefano Zampini            return
366300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
367300d917bSStefano Zampini            return
368300d917bSStefano Zampini        h = PETSc.Mat()
369300d917bSStefano Zampini
370300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
371300d917bSStefano Zampini        AA = self.A.getPythonContext()
372300d917bSStefano Zampini        if not hasattr(AA,'mult'):
373300d917bSStefano Zampini            return
374300d917bSStefano Zampini        AA.multTranspose = AA.mult
375300d917bSStefano Zampini
376300d917bSStefano Zampini        # without coordinates
377300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
378300d917bSStefano Zampini        h.assemble()
379300d917bSStefano Zampini        h.destroy()
380300d917bSStefano Zampini
381300d917bSStefano Zampini        # with coordinates
382300d917bSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType)
383300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
384300d917bSStefano Zampini        h.assemble()
385300d917bSStefano Zampini
386300d917bSStefano Zampini        # test API
387300d917bSStefano Zampini        h.H2OpusOrthogonalize()
388300d917bSStefano Zampini        h.H2OpusCompress(1.e-1)
389300d917bSStefano Zampini
390300d917bSStefano Zampini        # Low-rank update
391300d917bSStefano Zampini        U = PETSc.Mat()
392300d917bSStefano Zampini        U.createDense([h.getSizes()[0],3],comm=h.getComm())
393300d917bSStefano Zampini        U.setUp()
394300d917bSStefano Zampini        U.setRandom()
395300d917bSStefano Zampini
396300d917bSStefano Zampini        he = PETSc.Mat()
397300d917bSStefano Zampini        h.convert('dense',he)
398300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
399300d917bSStefano Zampini
400300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
401300d917bSStefano Zampini        self.assertTrue(he.equal(h))
402300d917bSStefano Zampini
403300d917bSStefano Zampini
404300d917bSStefano Zampini        h.destroy()
405300d917bSStefano Zampini
406300d917bSStefano Zampini        del AA.multTranspose
407300d917bSStefano Zampini
408*ebead697SStefano Zampini    def testGetType(self):
409*ebead697SStefano Zampini        ctx = self.A.getPythonContext()
410*ebead697SStefano Zampini        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
411*ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
412300d917bSStefano Zampini
41322fceea1SStefano Zampiniclass TestScaledIdentity(TestMatrix):
4145808f684SSatish Balay
41522fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4165808f684SSatish Balay
4175808f684SSatish Balay    def testMult(self):
41822fceea1SStefano Zampini        s = self._getCtx().s
4195808f684SSatish Balay        x, y = self.A.createVecs()
4205808f684SSatish Balay        x.setRandom()
4215808f684SSatish Balay        self.A.mult(x,y)
42222fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4235808f684SSatish Balay
4245808f684SSatish Balay    def testMultTransposeSymmKnown(self):
42522fceea1SStefano Zampini        s = self._getCtx().s
4265808f684SSatish Balay        x, y = self.A.createVecs()
4275808f684SSatish Balay        x.setRandom()
4285808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4295808f684SSatish Balay        self.A.multTranspose(x,y)
43022fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4315808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4325808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
4335808f684SSatish Balay        self.assertRaises(Exception, f)
4345808f684SSatish Balay
4355808f684SSatish Balay    def testMultTransposeNewMeth(self):
43622fceea1SStefano Zampini        s = self._getCtx().s
4375808f684SSatish Balay        x, y = self.A.createVecs()
4385808f684SSatish Balay        x.setRandom()
4395808f684SSatish Balay        AA = self.A.getPythonContext()
4405808f684SSatish Balay        AA.multTranspose = AA.mult
4415808f684SSatish Balay        self.A.multTranspose(x,y)
4425808f684SSatish Balay        del AA.multTranspose
44322fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4445808f684SSatish Balay
4455808f684SSatish Balay    def testGetDiagonal(self):
44622fceea1SStefano Zampini        s = self._getCtx().s
4475808f684SSatish Balay        d = self.A.createVecLeft()
4485808f684SSatish Balay        o = d.duplicate()
44922fceea1SStefano Zampini        o.set(s)
4505808f684SSatish Balay        self.A.getDiagonal(d)
4515808f684SSatish Balay        self.assertTrue(o.equal(d))
4525808f684SSatish Balay
453e124b1b1SStefano Zampini    def testDuplicate(self):
454e124b1b1SStefano Zampini        B = self.A.duplicate(False)
455e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
456e124b1b1SStefano Zampini        B = self.A.duplicate(True)
457e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
458e124b1b1SStefano Zampini
459ee6c7c31SStefano Zampini    def testMatMat(self):
46022fceea1SStefano Zampini        s = self._getCtx().s
461ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
462ee6c7c31SStefano Zampini        R.setFromOptions()
463ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
464ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
465ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
466ee6c7c31SStefano Zampini        A.setUp()
467ee6c7c31SStefano Zampini        A.setRandom(R)
468ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
469ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
470ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
471ee6c7c31SStefano Zampini        B.setUp()
472ee6c7c31SStefano Zampini        B.setRandom(R)
473ee6c7c31SStefano Zampini        I = PETSc.Mat().create(self.COMM)
474ee6c7c31SStefano Zampini        I.setSizes(self.A.getSizes())
475ee6c7c31SStefano Zampini        I.setType(PETSc.Mat.Type.AIJ)
476ee6c7c31SStefano Zampini        I.setUp()
477ee6c7c31SStefano Zampini        I.assemble()
47822fceea1SStefano Zampini        I.shift(s)
479ee6c7c31SStefano Zampini
480ee6c7c31SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(I.matMult(A)))
481ee6c7c31SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(I)))
482ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
483ee6c7c31SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A)))
484ee6c7c31SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I)))
485ee6c7c31SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A)))
486ee6c7c31SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I)))
487ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5)
488ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5)
489ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
490ee6c7c31SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5)
491ee6c7c31SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5)
492ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5)
493ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5)
494ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5)
495ee6c7c31SStefano Zampini
49622fceea1SStefano Zampini    def testShift(self):
49722fceea1SStefano Zampini        sold = self._getCtx().s
49822fceea1SStefano Zampini        self.A.shift(-0.5)
49922fceea1SStefano Zampini        s = self._getCtx().s
50022fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
50122fceea1SStefano Zampini
50222fceea1SStefano Zampini    def testScale(self):
50322fceea1SStefano Zampini        sold = self._getCtx().s
50422fceea1SStefano Zampini        self.A.scale(-0.5)
50522fceea1SStefano Zampini        s = self._getCtx().s
50622fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
50722fceea1SStefano Zampini
5085808f684SSatish Balayclass TestDiagonal(TestMatrix):
5095808f684SSatish Balay
5105808f684SSatish Balay    PYCLS = 'Diagonal'
5115808f684SSatish Balay
5125808f684SSatish Balay    def setUp(self):
5135808f684SSatish Balay        super(TestDiagonal, self).setUp()
5145808f684SSatish Balay        D = self.A.createVecLeft()
5155808f684SSatish Balay        s, e = D.getOwnershipRange()
5165808f684SSatish Balay        for i in range(s, e):
5175808f684SSatish Balay            D[i] = i+1
5185808f684SSatish Balay        D.assemble()
5195808f684SSatish Balay        self.A.setDiagonal(D)
5205808f684SSatish Balay
5215808f684SSatish Balay    def testZeroEntries(self):
5225808f684SSatish Balay        self.A.zeroEntries()
5235808f684SSatish Balay        D = self._getCtx().D
5245808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5255808f684SSatish Balay
5265808f684SSatish Balay    def testMult(self):
5275808f684SSatish Balay        x, y = self.A.createVecs()
5285808f684SSatish Balay        x.set(1)
5295808f684SSatish Balay        self.A.mult(x,y)
5305808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5315808f684SSatish Balay
5325808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5335808f684SSatish Balay        x, y = self.A.createVecs()
5345808f684SSatish Balay        x.set(1)
5355808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5365808f684SSatish Balay        self.A.multTranspose(x,y)
5375808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5385808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5395808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
5405808f684SSatish Balay        self.assertRaises(Exception, f)
5415808f684SSatish Balay
5425808f684SSatish Balay    def testMultTransposeNewMeth(self):
5435808f684SSatish Balay        x, y = self.A.createVecs()
5445808f684SSatish Balay        x.set(1)
5455808f684SSatish Balay        AA = self.A.getPythonContext()
5465808f684SSatish Balay        AA.multTranspose = AA.mult
5475808f684SSatish Balay        self.A.multTranspose(x,y)
5485808f684SSatish Balay        del AA.multTranspose
5495808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5505808f684SSatish Balay
551e124b1b1SStefano Zampini    def testDuplicate(self):
552e124b1b1SStefano Zampini        B = self.A.duplicate(False)
553e124b1b1SStefano Zampini        B = self.A.duplicate(True)
554e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
555e124b1b1SStefano Zampini
5565808f684SSatish Balay    def testGetDiagonal(self):
5575808f684SSatish Balay        d = self.A.createVecLeft()
5585808f684SSatish Balay        self.A.getDiagonal(d)
5595808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5605808f684SSatish Balay
5615808f684SSatish Balay    def testSetDiagonal(self):
5625808f684SSatish Balay        d = self.A.createVecLeft()
5635808f684SSatish Balay        d.setRandom()
5645808f684SSatish Balay        self.A.setDiagonal(d)
5655808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5665808f684SSatish Balay
5675808f684SSatish Balay    def testDiagonalScale(self):
5685808f684SSatish Balay        x, y = self.A.createVecs()
5695808f684SSatish Balay        x.set(2)
5705808f684SSatish Balay        y.set(3)
5715808f684SSatish Balay        old = self._getCtx().D.copy()
5725808f684SSatish Balay        self.A.diagonalScale(x, y)
5735808f684SSatish Balay        D = self._getCtx().D
5745808f684SSatish Balay        self.assertTrue(D.equal(old*6))
5755808f684SSatish Balay
5765808f684SSatish Balay    def testCreateTranspose(self):
5775808f684SSatish Balay        A = self.A
5785808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5795808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
5805808f684SSatish Balay        x, y = A.createVecs()
5815808f684SSatish Balay        xt, yt = AT.createVecs()
5825808f684SSatish Balay        #
5835808f684SSatish Balay        y.setRandom()
5845808f684SSatish Balay        A.multTranspose(y, x)
5855808f684SSatish Balay        y.copy(xt)
5865808f684SSatish Balay        AT.mult(xt, yt)
5875808f684SSatish Balay        self.assertTrue(yt.equal(x))
5885808f684SSatish Balay        #
5895808f684SSatish Balay        x.setRandom()
5905808f684SSatish Balay        A.mult(x, y)
5915808f684SSatish Balay        x.copy(yt)
5925808f684SSatish Balay        AT.multTranspose(yt, xt)
5935808f684SSatish Balay        self.assertTrue(xt.equal(y))
5945808f684SSatish Balay        del A
5955808f684SSatish Balay
5968af18dd8SStefano Zampini    def testConvert(self):
5978af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ,PETSc.Mat()).equal(self.A))
5988af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ,PETSc.Mat()).equal(self.A))
5998af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ,PETSc.Mat()).equal(self.A))
6008af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE,PETSc.Mat()).equal(self.A))
6018c2316a8SJeremy Tillay
60222fceea1SStefano Zampini    def testShift(self):
60322fceea1SStefano Zampini        old = self._getCtx().D.copy()
60422fceea1SStefano Zampini        self.A.shift(-0.5)
60522fceea1SStefano Zampini        D = self._getCtx().D
60622fceea1SStefano Zampini        self.assertTrue(D.equal(old-0.5))
60722fceea1SStefano Zampini
60822fceea1SStefano Zampini    def testScale(self):
60922fceea1SStefano Zampini        old = self._getCtx().D.copy()
61022fceea1SStefano Zampini        self.A.scale(-0.5)
61122fceea1SStefano Zampini        D = self._getCtx().D
61222fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5*old))
61322fceea1SStefano Zampini
61422fceea1SStefano Zampini
6155808f684SSatish Balay# --------------------------------------------------------------------
6165808f684SSatish Balay
6175808f684SSatish Balayif __name__ == '__main__':
6185808f684SSatish Balay    unittest.main()
619