xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision b69d2765e9c8cde17308f61ee677dcd992b8a9cf)
15808f684SSatish Balayfrom petsc4py import PETSc
26f336411SStefano Zampiniimport unittest
36f336411SStefano Zampiniimport numpy
45808f684SSatish Balayfrom sys import getrefcount
55808f684SSatish Balay# --------------------------------------------------------------------
65808f684SSatish Balay
75808f684SSatish Balay
86f336411SStefano Zampiniclass Matrix:
9b2584804SStefano Zampini    setupcalled = 0
10b2584804SStefano Zampini
115808f684SSatish Balay    def __init__(self):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def create(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
175808f684SSatish Balay    def destroy(self, mat):
185808f684SSatish Balay        pass
195808f684SSatish Balay
20b2584804SStefano Zampini    def setUp(self, mat):
21b2584804SStefano Zampini        self.setupcalled += 1
2222fceea1SStefano Zampini
23f575958eSStefano Zampini
246f336411SStefano Zampiniclass ScaledIdentity(Matrix):
2522fceea1SStefano Zampini    s = 2.0
2622fceea1SStefano Zampini
2722fceea1SStefano Zampini    def scale(self, mat, s):
2822fceea1SStefano Zampini        self.s *= s
2922fceea1SStefano Zampini
3022fceea1SStefano Zampini    def shift(self, mat, s):
3122fceea1SStefano Zampini        self.s += s
325808f684SSatish Balay
335808f684SSatish Balay    def mult(self, mat, x, y):
345808f684SSatish Balay        x.copy(y)
3522fceea1SStefano Zampini        y.scale(self.s)
365808f684SSatish Balay
37e124b1b1SStefano Zampini    def duplicate(self, mat, op):
38e124b1b1SStefano Zampini        dmat = PETSc.Mat()
39e124b1b1SStefano Zampini        dctx = ScaledIdentity()
40e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
41e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
42e124b1b1SStefano Zampini            dctx.s = self.s
43e124b1b1SStefano Zampini            dmat.setUp()
44e124b1b1SStefano Zampini        return dmat
45e124b1b1SStefano Zampini
465808f684SSatish Balay    def getDiagonal(self, mat, vd):
4722fceea1SStefano Zampini        vd.set(self.s)
485808f684SSatish Balay
49ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
50ee6c7c31SStefano Zampini        return True
51ee6c7c31SStefano Zampini
52ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
53ee6c7c31SStefano Zampini        if producttype == 'AB':
54ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
55ee6c7c31SStefano Zampini                product.setType(B.getType())
56ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                B.copy(product)
60ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
61ee6c7c31SStefano Zampini                product.setType(A.getType())
62ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
63ee6c7c31SStefano Zampini                product.setUp()
64ee6c7c31SStefano Zampini                product.assemble()
65ee6c7c31SStefano Zampini                A.copy(product)
66ee6c7c31SStefano Zampini            else:
67ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
68ee6c7c31SStefano Zampini        elif producttype == 'AtB':
69ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
70ee6c7c31SStefano Zampini                product.setType(B.getType())
71ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
72ee6c7c31SStefano Zampini                product.setUp()
73ee6c7c31SStefano Zampini                product.assemble()
74ee6c7c31SStefano Zampini                B.copy(product)
75ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
76ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
77ee6c7c31SStefano Zampini                A.transpose(tmp)
78ee6c7c31SStefano Zampini                product.setType(tmp.getType())
79ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
80ee6c7c31SStefano Zampini                product.setUp()
81ee6c7c31SStefano Zampini                product.assemble()
82ee6c7c31SStefano Zampini                tmp.copy(product)
83ee6c7c31SStefano Zampini            else:
84ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
85ee6c7c31SStefano Zampini        elif producttype == 'ABt':
86ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
87ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
88ee6c7c31SStefano Zampini                B.transpose(tmp)
89ee6c7c31SStefano Zampini                product.setType(tmp.getType())
90ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
91ee6c7c31SStefano Zampini                product.setUp()
92ee6c7c31SStefano Zampini                product.assemble()
93ee6c7c31SStefano Zampini                tmp.copy(product)
94ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
95ee6c7c31SStefano Zampini                product.setType(A.getType())
96ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
97ee6c7c31SStefano Zampini                product.setUp()
98ee6c7c31SStefano Zampini                product.assemble()
99ee6c7c31SStefano Zampini                A.copy(product)
100ee6c7c31SStefano Zampini            else:
101ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
102ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
103ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
104*4468a054SStefano Zampini                tmp = PETSc.Mat()
105*4468a054SStefano Zampini                B.transposeMatMult(B, tmp)
106*4468a054SStefano Zampini                product.setType(tmp.getType())
107*4468a054SStefano Zampini                product.setSizes(tmp.getSizes())
108ee6c7c31SStefano Zampini                product.setUp()
109ee6c7c31SStefano Zampini                product.assemble()
110*4468a054SStefano Zampini                tmp.copy(product)
111ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
112ee6c7c31SStefano Zampini                product.setType(A.getType())
113ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
114ee6c7c31SStefano Zampini                product.setUp()
115ee6c7c31SStefano Zampini                product.assemble()
116ee6c7c31SStefano Zampini                A.copy(product)
117ee6c7c31SStefano Zampini            else:
118ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
119ee6c7c31SStefano Zampini        elif producttype == 'RARt':
120ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
121*4468a054SStefano Zampini                tmp = PETSc.Mat()
122*4468a054SStefano Zampini                B.matTransposeMult(B, tmp)
123*4468a054SStefano Zampini                product.setType(tmp.getType())
124*4468a054SStefano Zampini                product.setSizes(tmp.getSizes())
125ee6c7c31SStefano Zampini                product.setUp()
126ee6c7c31SStefano Zampini                product.assemble()
127*4468a054SStefano Zampini                tmp.copy(product)
128ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
129ee6c7c31SStefano Zampini                product.setType(A.getType())
130ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
131ee6c7c31SStefano Zampini                product.setUp()
132ee6c7c31SStefano Zampini                product.assemble()
133ee6c7c31SStefano Zampini                A.copy(product)
134ee6c7c31SStefano Zampini            else:
135ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
136ee6c7c31SStefano Zampini        elif producttype == 'ABC':
137ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
138*4468a054SStefano Zampini                tmp = PETSc.Mat()
139*4468a054SStefano Zampini                B.matMult(C, tmp)
140*4468a054SStefano Zampini                product.setType(tmp.getType())
141*4468a054SStefano Zampini                product.setSizes(tmp.getSizes())
142ee6c7c31SStefano Zampini                product.setUp()
143ee6c7c31SStefano Zampini                product.assemble()
144*4468a054SStefano Zampini                tmp.copy(product)
145ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
146*4468a054SStefano Zampini                tmp = PETSc.Mat()
147*4468a054SStefano Zampini                A.matMult(C, tmp)
148*4468a054SStefano Zampini                product.setType(tmp.getType())
149*4468a054SStefano Zampini                product.setSizes(tmp.getSizes())
150ee6c7c31SStefano Zampini                product.setUp()
151ee6c7c31SStefano Zampini                product.assemble()
152*4468a054SStefano Zampini                tmp.copy(product)
153ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
154*4468a054SStefano Zampini                tmp = PETSc.Mat()
155*4468a054SStefano Zampini                A.matMult(B, tmp)
156*4468a054SStefano Zampini                product.setType(tmp.getType())
157*4468a054SStefano Zampini                product.setSizes(tmp.getSizes())
158ee6c7c31SStefano Zampini                product.setUp()
159ee6c7c31SStefano Zampini                product.assemble()
160*4468a054SStefano Zampini                tmp.copy(product)
161ee6c7c31SStefano Zampini            else:
162ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
163ee6c7c31SStefano Zampini        else:
1646f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
165ee6c7c31SStefano Zampini        product.zeroEntries()
166ee6c7c31SStefano Zampini
167ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
168ee6c7c31SStefano Zampini        if producttype == 'AB':
169ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
170ee6c7c31SStefano Zampini                B.copy(product, structure=True)
171ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
172ee6c7c31SStefano Zampini                A.copy(product, structure=True)
173ee6c7c31SStefano Zampini            else:
174ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17522fceea1SStefano Zampini            product.scale(self.s)
176ee6c7c31SStefano Zampini        elif producttype == 'AtB':
177ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
178ee6c7c31SStefano Zampini                B.copy(product, structure=True)
179ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
1807fb60732SBarry Smith                A.setTransposePrecursor(product)
181ee6c7c31SStefano Zampini                A.transpose(product)
182ee6c7c31SStefano Zampini            else:
183ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18422fceea1SStefano Zampini            product.scale(self.s)
185ee6c7c31SStefano Zampini        elif producttype == 'ABt':
186ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
1877fb60732SBarry Smith                B.setTransposePrecursor(product)
188ee6c7c31SStefano Zampini                B.transpose(product)
189ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
190ee6c7c31SStefano Zampini                A.copy(product, structure=True)
191ee6c7c31SStefano Zampini            else:
192ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
19322fceea1SStefano Zampini            product.scale(self.s)
194ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
195ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
196*4468a054SStefano Zampini                tmp = PETSc.Mat()
197*4468a054SStefano Zampini                B.transposeMatMult(B, tmp)
198*4468a054SStefano Zampini                tmp.copy(product, structure=True)
19922fceea1SStefano Zampini                product.scale(self.s)
200ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
201ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20222fceea1SStefano Zampini                product.scale(self.s**2)
203ee6c7c31SStefano Zampini            else:
204ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
205ee6c7c31SStefano Zampini        elif producttype == 'RARt':
206ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
207*4468a054SStefano Zampini                tmp = PETSc.Mat()
208*4468a054SStefano Zampini                B.matTransposeMult(B, tmp)
209*4468a054SStefano Zampini                tmp.copy(product, structure=True)
21022fceea1SStefano Zampini                product.scale(self.s)
211ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
212ee6c7c31SStefano Zampini                A.copy(product, structure=True)
21322fceea1SStefano Zampini                product.scale(self.s**2)
214ee6c7c31SStefano Zampini            else:
215ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
216ee6c7c31SStefano Zampini        elif producttype == 'ABC':
217ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
218*4468a054SStefano Zampini                tmp = PETSc.Mat()
219*4468a054SStefano Zampini                B.matMult(C, tmp)
220*4468a054SStefano Zampini                tmp.copy(product, structure=True)
221ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
222*4468a054SStefano Zampini                tmp = PETSc.Mat()
223*4468a054SStefano Zampini                A.matMult(C, tmp)
224*4468a054SStefano Zampini                tmp.copy(product, structure=True)
225ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
226*4468a054SStefano Zampini                tmp = PETSc.Mat()
227*4468a054SStefano Zampini                A.matMult(B, tmp)
228*4468a054SStefano Zampini                tmp.copy(product, structure=True)
229ee6c7c31SStefano Zampini            else:
230ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
23122fceea1SStefano Zampini            product.scale(self.s)
232ee6c7c31SStefano Zampini        else:
2336f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
2346f336411SStefano Zampini
235ee6c7c31SStefano Zampini
2365808f684SSatish Balayclass Diagonal(Matrix):
2375808f684SSatish Balay    def create(self, mat):
2386f336411SStefano Zampini        super().create(mat)
2395808f684SSatish Balay        mat.setUp()
2405808f684SSatish Balay        self.D = mat.createVecLeft()
2415808f684SSatish Balay
2425808f684SSatish Balay    def destroy(self, mat):
2435808f684SSatish Balay        self.D.destroy()
2446f336411SStefano Zampini        super().destroy(mat)
2455808f684SSatish Balay
2465808f684SSatish Balay    def scale(self, mat, a):
2475808f684SSatish Balay        self.D.scale(a)
2485808f684SSatish Balay
2495808f684SSatish Balay    def shift(self, mat, a):
2505808f684SSatish Balay        self.D.shift(a)
2515808f684SSatish Balay
2525808f684SSatish Balay    def zeroEntries(self, mat):
2535808f684SSatish Balay        self.D.zeroEntries()
2545808f684SSatish Balay
2555808f684SSatish Balay    def mult(self, mat, x, y):
2565808f684SSatish Balay        y.pointwiseMult(x, self.D)
2575808f684SSatish Balay
258e124b1b1SStefano Zampini    def duplicate(self, mat, op):
259e124b1b1SStefano Zampini        dmat = PETSc.Mat()
260e124b1b1SStefano Zampini        dctx = Diagonal()
261e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
262e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
263e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
264e124b1b1SStefano Zampini            self.D.copy(dctx.D)
265e124b1b1SStefano Zampini            dmat.setUp()
266e124b1b1SStefano Zampini        return dmat
267e124b1b1SStefano Zampini
2685808f684SSatish Balay    def getDiagonal(self, mat, vd):
2695808f684SSatish Balay        self.D.copy(vd)
2705808f684SSatish Balay
2715808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2725808f684SSatish Balay        if isinstance(im, bool):
2735808f684SSatish Balay            addv = im
2745808f684SSatish Balay            if addv:
2755808f684SSatish Balay                self.D.axpy(1, vd)
2765808f684SSatish Balay            else:
2775808f684SSatish Balay                vd.copy(self.D)
2785808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2795808f684SSatish Balay            vd.copy(self.D)
2805808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2815808f684SSatish Balay            self.D.axpy(1, vd)
2825808f684SSatish Balay        else:
2835808f684SSatish Balay            raise ValueError('wrong InsertMode %d' % im)
2845808f684SSatish Balay
2855808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2866f336411SStefano Zampini        if vl:
2876f336411SStefano Zampini            self.D.pointwiseMult(self.D, vl)
2886f336411SStefano Zampini        if vr:
2896f336411SStefano Zampini            self.D.pointwiseMult(self.D, vr)
2906f336411SStefano Zampini
2915808f684SSatish Balay
2925808f684SSatish Balay# --------------------------------------------------------------------
2935808f684SSatish Balay
2945808f684SSatish Balay
2956f336411SStefano Zampiniclass TestMatrix(unittest.TestCase):
2965808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2975808f684SSatish Balay    PYMOD = __name__
2985808f684SSatish Balay    PYCLS = 'Matrix'
299b2584804SStefano Zampini    CREATE_WITH_NONE = False
3005808f684SSatish Balay
3015808f684SSatish Balay    def _getCtx(self):
3025808f684SSatish Balay        return self.A.getPythonContext()
3035808f684SSatish Balay
3045808f684SSatish Balay    def setUp(self):
305300d917bSStefano Zampini        N = self.N = 13
3065808f684SSatish Balay        self.A = PETSc.Mat()
3075808f684SSatish Balay        if 0:  # command line way
3085808f684SSatish Balay            self.A.create(self.COMM)
3095808f684SSatish Balay            self.A.setSizes([N, N])
3105808f684SSatish Balay            self.A.setType('python')
3115808f684SSatish Balay            OptDB = PETSc.Options(self.A)
3126f336411SStefano Zampini            OptDB['mat_python_type'] = f'{self.PYMOD}.{self.PYCLS}'
3135808f684SSatish Balay            self.A.setFromOptions()
3145808f684SSatish Balay            del OptDB['mat_python_type']
3155808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3165808f684SSatish Balay        else:  # python way
3175808f684SSatish Balay            context = globals()[self.PYCLS]()
318b2584804SStefano Zampini            if self.CREATE_WITH_NONE:  # test passing None as context
319b2584804SStefano Zampini                self.A.createPython([N, N], None, comm=self.COMM)
320b2584804SStefano Zampini                self.A.setPythonContext(context)
3215808f684SSatish Balay                self.A.setUp()
322b2584804SStefano Zampini            else:
323b2584804SStefano Zampini                self.A.createPython([N, N], context, comm=self.COMM)
3245808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3255808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3265808f684SSatish Balay            del context
3275808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3285808f684SSatish Balay
3295808f684SSatish Balay    def tearDown(self):
33039933f97SStefano Zampini        self.assertEqual(getrefcount(self._getCtx()), 2)
3315808f684SSatish Balay        self.A.destroy()  # XXX
3325808f684SSatish Balay        self.A = None
33362e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
3345808f684SSatish Balay
3355808f684SSatish Balay    def testBasic(self):
33639933f97SStefano Zampini        self.assertEqual(getrefcount(self._getCtx()), 2)
3375808f684SSatish Balay        ctx = self.A.getPythonContext()
3385808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3395808f684SSatish Balay
340b2584804SStefano Zampini    def testSetUp(self):
341b2584804SStefano Zampini        ctx = self.A.getPythonContext()
342b2584804SStefano Zampini        setupcalled = ctx.setupcalled
343b2584804SStefano Zampini        self.A.setUp()
344b2584804SStefano Zampini        self.assertEqual(setupcalled, ctx.setupcalled)
345b2584804SStefano Zampini        self.A.setPythonContext(ctx)
346b2584804SStefano Zampini        self.A.setUp()
347b2584804SStefano Zampini        self.assertEqual(setupcalled + 1, ctx.setupcalled)
348b2584804SStefano Zampini
3495808f684SSatish Balay    def testZeroEntries(self):
3505808f684SSatish Balay        f = lambda: self.A.zeroEntries()
3515808f684SSatish Balay        self.assertRaises(Exception, f)
3525808f684SSatish Balay
3535808f684SSatish Balay    def testMult(self):
3545808f684SSatish Balay        x, y = self.A.createVecs()
3555808f684SSatish Balay        f = lambda: self.A.mult(x, y)
3565808f684SSatish Balay        self.assertRaises(Exception, f)
3575808f684SSatish Balay
3585808f684SSatish Balay    def testMultTranspose(self):
3595808f684SSatish Balay        x, y = self.A.createVecs()
3605808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
3615808f684SSatish Balay        self.assertRaises(Exception, f)
3625808f684SSatish Balay
3635808f684SSatish Balay    def testGetDiagonal(self):
3645808f684SSatish Balay        d = self.A.createVecLeft()
3655808f684SSatish Balay        f = lambda: self.A.getDiagonal(d)
3665808f684SSatish Balay        self.assertRaises(Exception, f)
3675808f684SSatish Balay
3685808f684SSatish Balay    def testSetDiagonal(self):
3695808f684SSatish Balay        d = self.A.createVecLeft()
3705808f684SSatish Balay        f = lambda: self.A.setDiagonal(d)
3715808f684SSatish Balay        self.assertRaises(Exception, f)
3725808f684SSatish Balay
3735808f684SSatish Balay    def testDiagonalScale(self):
3745808f684SSatish Balay        x, y = self.A.createVecs()
3755808f684SSatish Balay        f = lambda: self.A.diagonalScale(x, y)
3765808f684SSatish Balay        self.assertRaises(Exception, f)
3775808f684SSatish Balay
378e124b1b1SStefano Zampini    def testDuplicate(self):
3796f336411SStefano Zampini        f1 = lambda: self.A.duplicate(True)
3806f336411SStefano Zampini        f2 = lambda: self.A.duplicate(False)
381e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
382e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
383e124b1b1SStefano Zampini
3841cebabd4SStefano Zampini    def testSetVecType(self):
3851cebabd4SStefano Zampini        self.A.setVecType('mpi')
3861cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3871cebabd4SStefano Zampini
388300d917bSStefano Zampini    def testH2Opus(self):
3896f336411SStefano Zampini        if not PETSc.Sys.hasExternalPackage('h2opus'):
390300d917bSStefano Zampini            return
391300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
392300d917bSStefano Zampini            return
393300d917bSStefano Zampini        h = PETSc.Mat()
394300d917bSStefano Zampini
395300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
396300d917bSStefano Zampini        AA = self.A.getPythonContext()
397300d917bSStefano Zampini        if not hasattr(AA, 'mult'):
398300d917bSStefano Zampini            return
399300d917bSStefano Zampini        AA.multTranspose = AA.mult
400300d917bSStefano Zampini
401300d917bSStefano Zampini        # without coordinates
402300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, leafsize=2)
403300d917bSStefano Zampini        h.assemble()
404300d917bSStefano Zampini        h.destroy()
405300d917bSStefano Zampini
406300d917bSStefano Zampini        # with coordinates
4076f336411SStefano Zampini        coords = numpy.linspace(
4086f336411SStefano Zampini            (1, 2, 3), (10, 20, 30), self.A.getSize()[0], dtype=PETSc.RealType
4096f336411SStefano Zampini        )
410300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, coords, leafsize=2)
411300d917bSStefano Zampini        h.assemble()
412300d917bSStefano Zampini
413300d917bSStefano Zampini        # test API
414300d917bSStefano Zampini        h.H2OpusOrthogonalize()
4156f336411SStefano Zampini        h.H2OpusCompress(1.0e-1)
416300d917bSStefano Zampini
417300d917bSStefano Zampini        # Low-rank update
418300d917bSStefano Zampini        U = PETSc.Mat()
419300d917bSStefano Zampini        U.createDense([h.getSizes()[0], 3], comm=h.getComm())
420300d917bSStefano Zampini        U.setUp()
421300d917bSStefano Zampini        U.setRandom()
422300d917bSStefano Zampini
423300d917bSStefano Zampini        he = PETSc.Mat()
424300d917bSStefano Zampini        h.convert('dense', he)
425300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
426300d917bSStefano Zampini
427300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
428300d917bSStefano Zampini        self.assertTrue(he.equal(h))
429300d917bSStefano Zampini
430300d917bSStefano Zampini        h.destroy()
431300d917bSStefano Zampini
432300d917bSStefano Zampini        del AA.multTranspose
433300d917bSStefano Zampini
434ebead697SStefano Zampini    def testGetType(self):
435ebead697SStefano Zampini        ctx = self.A.getPythonContext()
4366f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
437ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
438300d917bSStefano Zampini
4395808f684SSatish Balay
4406f336411SStefano Zampiniclass TestScaledIdentity(TestMatrix):
44122fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4425808f684SSatish Balay
4435808f684SSatish Balay    def testMult(self):
44422fceea1SStefano Zampini        s = self._getCtx().s
4455808f684SSatish Balay        x, y = self.A.createVecs()
4465808f684SSatish Balay        x.setRandom()
4475808f684SSatish Balay        self.A.mult(x, y)
44822fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4495808f684SSatish Balay
4505808f684SSatish Balay    def testMultTransposeSymmKnown(self):
45122fceea1SStefano Zampini        s = self._getCtx().s
4525808f684SSatish Balay        x, y = self.A.createVecs()
4535808f684SSatish Balay        x.setRandom()
4545808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4555808f684SSatish Balay        self.A.multTranspose(x, y)
45622fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4575808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4585808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
4595808f684SSatish Balay        self.assertRaises(Exception, f)
4605808f684SSatish Balay
4615808f684SSatish Balay    def testMultTransposeNewMeth(self):
46222fceea1SStefano Zampini        s = self._getCtx().s
4635808f684SSatish Balay        x, y = self.A.createVecs()
4645808f684SSatish Balay        x.setRandom()
4655808f684SSatish Balay        AA = self.A.getPythonContext()
4665808f684SSatish Balay        AA.multTranspose = AA.mult
4675808f684SSatish Balay        self.A.multTranspose(x, y)
4685808f684SSatish Balay        del AA.multTranspose
46922fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4705808f684SSatish Balay
4715808f684SSatish Balay    def testGetDiagonal(self):
47222fceea1SStefano Zampini        s = self._getCtx().s
4735808f684SSatish Balay        d = self.A.createVecLeft()
4745808f684SSatish Balay        o = d.duplicate()
47522fceea1SStefano Zampini        o.set(s)
4765808f684SSatish Balay        self.A.getDiagonal(d)
4775808f684SSatish Balay        self.assertTrue(o.equal(d))
4785808f684SSatish Balay
479e124b1b1SStefano Zampini    def testDuplicate(self):
480ead56519SStefano Zampini        B = self.A.duplicate()
481ead56519SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
482e124b1b1SStefano Zampini        B = self.A.duplicate(False)
483e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
484e124b1b1SStefano Zampini        B = self.A.duplicate(True)
485e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
486ead56519SStefano Zampini        B = self.A.duplicate(PETSc.Mat.DuplicateOption.DO_NOT_COPY_VALUES)
487ead56519SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
488ead56519SStefano Zampini        B = self.A.duplicate(PETSc.Mat.DuplicateOption.SHARE_NONZERO_PATTERN)
489ead56519SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
490ead56519SStefano Zampini        B = self.A.duplicate(PETSc.Mat.DuplicateOption.COPY_VALUES)
491ead56519SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
492e124b1b1SStefano Zampini
493ee6c7c31SStefano Zampini    def testMatMat(self):
49422fceea1SStefano Zampini        s = self._getCtx().s
495ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
496ee6c7c31SStefano Zampini        R.setFromOptions()
497ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
498ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
499ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
50026cec326SBarry Smith        A.setPreallocationNNZ(None)
501ee6c7c31SStefano Zampini        A.setRandom(R)
502ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
503ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
504ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
50526cec326SBarry Smith        B.setPreallocationNNZ(None)
506ee6c7c31SStefano Zampini        B.setRandom(R)
5076f336411SStefano Zampini        Id = PETSc.Mat().create(self.COMM)
5086f336411SStefano Zampini        Id.setSizes(self.A.getSizes())
5096f336411SStefano Zampini        Id.setType(PETSc.Mat.Type.AIJ)
5106f336411SStefano Zampini        Id.setUp()
5116f336411SStefano Zampini        Id.assemble()
5126f336411SStefano Zampini        Id.shift(s)
513ee6c7c31SStefano Zampini
5146f336411SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(Id.matMult(A)))
5156f336411SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(Id)))
516ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5176f336411SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(Id.matTransposeMult(A)))
5186f336411SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(Id)))
5196f336411SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(Id.transposeMatMult(A)))
5206f336411SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(Id)))
5216f336411SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - Id.ptap(A)).norm(), 0.0, places=5)
5226f336411SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(Id)).norm(), 0.0, places=5)
523ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5246f336411SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - Id.rart(A)).norm(), 0.0, places=5)
5256f336411SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(Id)).norm(), 0.0, places=5)
5266f336411SStefano Zampini        self.assertAlmostEqual(
5276f336411SStefano Zampini            (self.A.matMatMult(A, B) - Id.matMatMult(A, B)).norm(), 0.0, places=5
5286f336411SStefano Zampini        )
5296f336411SStefano Zampini        self.assertAlmostEqual(
5306f336411SStefano Zampini            (A.matMatMult(self.A, B) - A.matMatMult(Id, B)).norm(), 0.0, places=5
5316f336411SStefano Zampini        )
5326f336411SStefano Zampini        self.assertAlmostEqual(
5336f336411SStefano Zampini            (A.matMatMult(B, self.A) - A.matMatMult(B, Id)).norm(), 0.0, places=5
5346f336411SStefano Zampini        )
535ee6c7c31SStefano Zampini
53622fceea1SStefano Zampini    def testShift(self):
53722fceea1SStefano Zampini        sold = self._getCtx().s
53822fceea1SStefano Zampini        self.A.shift(-0.5)
53922fceea1SStefano Zampini        s = self._getCtx().s
54022fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
54122fceea1SStefano Zampini
54222fceea1SStefano Zampini    def testScale(self):
54322fceea1SStefano Zampini        sold = self._getCtx().s
54422fceea1SStefano Zampini        self.A.scale(-0.5)
54522fceea1SStefano Zampini        s = self._getCtx().s
54622fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
54722fceea1SStefano Zampini
5489e7eb791SStefano Zampini    def testDiagonalMat(self):
5499e7eb791SStefano Zampini        s = self._getCtx().s
5506f336411SStefano Zampini        B = PETSc.Mat().createConstantDiagonal(
5516f336411SStefano Zampini            self.A.getSizes(), s, comm=self.A.getComm()
5526f336411SStefano Zampini        )
5539e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
5549e7eb791SStefano Zampini
5555808f684SSatish Balay
5566f336411SStefano Zampiniclass TestDiagonal(TestMatrix):
5575808f684SSatish Balay    PYCLS = 'Diagonal'
558b2584804SStefano Zampini    CREATE_WITH_NONE = True
5595808f684SSatish Balay
5605808f684SSatish Balay    def setUp(self):
5616f336411SStefano Zampini        super().setUp()
5625808f684SSatish Balay        D = self.A.createVecLeft()
5635808f684SSatish Balay        s, e = D.getOwnershipRange()
5645808f684SSatish Balay        for i in range(s, e):
5655808f684SSatish Balay            D[i] = i + 1
5665808f684SSatish Balay        D.assemble()
5675808f684SSatish Balay        self.A.setDiagonal(D)
5685808f684SSatish Balay
5695808f684SSatish Balay    def testZeroEntries(self):
5705808f684SSatish Balay        self.A.zeroEntries()
5715808f684SSatish Balay        D = self._getCtx().D
5725808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5735808f684SSatish Balay
5745808f684SSatish Balay    def testMult(self):
5755808f684SSatish Balay        x, y = self.A.createVecs()
5765808f684SSatish Balay        x.set(1)
5775808f684SSatish Balay        self.A.mult(x, y)
5785808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5795808f684SSatish Balay
5805808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5815808f684SSatish Balay        x, y = self.A.createVecs()
5825808f684SSatish Balay        x.set(1)
5835808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5845808f684SSatish Balay        self.A.multTranspose(x, y)
5855808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5865808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5875808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
5885808f684SSatish Balay        self.assertRaises(Exception, f)
5895808f684SSatish Balay
5905808f684SSatish Balay    def testMultTransposeNewMeth(self):
5915808f684SSatish Balay        x, y = self.A.createVecs()
5925808f684SSatish Balay        x.set(1)
5935808f684SSatish Balay        AA = self.A.getPythonContext()
5945808f684SSatish Balay        AA.multTranspose = AA.mult
5955808f684SSatish Balay        self.A.multTranspose(x, y)
5965808f684SSatish Balay        del AA.multTranspose
5975808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5985808f684SSatish Balay
599e124b1b1SStefano Zampini    def testDuplicate(self):
600e124b1b1SStefano Zampini        B = self.A.duplicate(False)
601e124b1b1SStefano Zampini        B = self.A.duplicate(True)
602e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
603e124b1b1SStefano Zampini
6045808f684SSatish Balay    def testGetDiagonal(self):
6055808f684SSatish Balay        d = self.A.createVecLeft()
6065808f684SSatish Balay        self.A.getDiagonal(d)
6075808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
6085808f684SSatish Balay
6095808f684SSatish Balay    def testSetDiagonal(self):
6105808f684SSatish Balay        d = self.A.createVecLeft()
6115808f684SSatish Balay        d.setRandom()
6125808f684SSatish Balay        self.A.setDiagonal(d)
6135808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
6145808f684SSatish Balay
6155808f684SSatish Balay    def testDiagonalScale(self):
6165808f684SSatish Balay        x, y = self.A.createVecs()
6175808f684SSatish Balay        x.set(2)
6185808f684SSatish Balay        y.set(3)
6195808f684SSatish Balay        old = self._getCtx().D.copy()
6205808f684SSatish Balay        self.A.diagonalScale(x, y)
6215808f684SSatish Balay        D = self._getCtx().D
6225808f684SSatish Balay        self.assertTrue(D.equal(old * 6))
6235808f684SSatish Balay
6245808f684SSatish Balay    def testCreateTranspose(self):
6255808f684SSatish Balay        A = self.A
6265808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
6275808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
6285808f684SSatish Balay        x, y = A.createVecs()
6295808f684SSatish Balay        xt, yt = AT.createVecs()
6305808f684SSatish Balay        #
6315808f684SSatish Balay        y.setRandom()
6325808f684SSatish Balay        A.multTranspose(y, x)
6335808f684SSatish Balay        y.copy(xt)
6345808f684SSatish Balay        AT.mult(xt, yt)
6355808f684SSatish Balay        self.assertTrue(yt.equal(x))
6365808f684SSatish Balay        #
6375808f684SSatish Balay        x.setRandom()
6385808f684SSatish Balay        A.mult(x, y)
6395808f684SSatish Balay        x.copy(yt)
6405808f684SSatish Balay        AT.multTranspose(yt, xt)
6415808f684SSatish Balay        self.assertTrue(xt.equal(y))
6425808f684SSatish Balay        del A
6435808f684SSatish Balay
6448af18dd8SStefano Zampini    def testConvert(self):
6458af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ, PETSc.Mat()).equal(self.A))
6468af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ, PETSc.Mat()).equal(self.A))
6478af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ, PETSc.Mat()).equal(self.A))
6488af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE, PETSc.Mat()).equal(self.A))
6498c2316a8SJeremy Tillay
65022fceea1SStefano Zampini    def testShift(self):
65122fceea1SStefano Zampini        old = self._getCtx().D.copy()
65222fceea1SStefano Zampini        self.A.shift(-0.5)
65322fceea1SStefano Zampini        D = self._getCtx().D
65422fceea1SStefano Zampini        self.assertTrue(D.equal(old - 0.5))
65522fceea1SStefano Zampini
65622fceea1SStefano Zampini    def testScale(self):
65722fceea1SStefano Zampini        old = self._getCtx().D.copy()
65822fceea1SStefano Zampini        self.A.scale(-0.5)
65922fceea1SStefano Zampini        D = self._getCtx().D
66022fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5 * old))
66122fceea1SStefano Zampini
6629e7eb791SStefano Zampini    def testDiagonalMat(self):
6639e7eb791SStefano Zampini        D = self._getCtx().D.copy()
6649e7eb791SStefano Zampini        B = PETSc.Mat().createDiagonal(D)
6659e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
6669e7eb791SStefano Zampini
66722fceea1SStefano Zampini
6685808f684SSatish Balay# --------------------------------------------------------------------
6695808f684SSatish Balay
6705808f684SSatish Balayif __name__ == '__main__':
6715808f684SSatish Balay    unittest.main()
672