xref: /petsc/src/binding/petsc4py/test/test_mat_fact.py (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1from petsc4py import PETSc
2import unittest
3
4import numpy as N
5
6def mkmat(n, mtype, opts):
7    A = PETSc.Mat().create(PETSc.COMM_SELF)
8    A.setSizes([n,n])
9    A.setType(mtype)
10    A.setUp()
11    for o in opts:
12        A.setOption(o, True)
13    return A
14
15def mksys_diag(n, mtype, opts):
16    A = mkmat(n, mtype, opts)
17    x, b = A.createVecs()
18    for i in range(n):
19        A[i,i] = i+1
20        x[i]   = 1.0/(i+1)
21        b[i]   = 1
22    A.assemble()
23    x.assemble()
24    b.assemble()
25    return A, x, b
26
27def mksys_poi2(n, mtype, opts):
28    A = mkmat(n, mtype, opts)
29    x, b = A.createVecs()
30    for i in range(n):
31        if i == 0:
32            cols = [i, i+1]
33            vals = [2, -1]
34        elif i == n-1:
35            cols = [i-1, i]
36            vals = [-1,  2]
37        else:
38            cols = [i-1, i, i+1]
39            vals = [-1,  2, -1]
40        A[i,cols] = vals
41        x[i]   = i+1
42        b[i]   = 0
43    A.assemble()
44    x.assemble()
45    b.assemble()
46    A.mult(x,b)
47    return A, x, b
48
49class BaseTestMatFactor(object):
50
51    MKSYS = None
52    MTYPE = None
53    MOPTS = ()
54
55    def setUp(self):
56        A, x, b = self.MKSYS(10, self.MTYPE, self.MOPTS)
57        self.A = A
58        self.x = x
59        self.b = b
60
61    def tearDown(self):
62        self.A.setUnfactored()
63        self.A.destroy(); self.A = None
64        self.x.destroy(); self.x = None
65        self.b.destroy(); self.b = None
66        PETSc.garbage_cleanup()
67
68class BaseTestMatFactorLU(BaseTestMatFactor):
69
70    def testFactorLU(self):
71        r, c = self.A.getOrdering("nd")
72        self.A.reorderForNonzeroDiagonal(r, c)
73        self.A.factorLU(r,c,{'zeropivot':1e-5})
74        x = self.x.duplicate()
75        self.A.solve(self.b, x)
76        x.axpy(-1, self.x)
77        self.assertTrue(x.norm() < 1e-3)
78
79class BaseTestMatFactorILU(BaseTestMatFactor):
80
81    def testFactorILU(self):
82        r, c = self.A.getOrdering("natural")
83        self.A.factorILU(r,c,{'levels':0})
84        x = self.x.duplicate()
85        self.A.solve(self.b, x)
86        x.axpy(-1, self.x)
87        self.assertTrue(x.norm() < 1e-3)
88
89## class BaseTestMatFactorILUDT(BaseTestMatFactor):
90##
91##     def testFactorILUDT(self):
92##         r, c = self.A.getOrdering("natural")
93##         self.A = self.A.factorILUDT(r,c)
94##         x = self.x.duplicate()
95##         self.A.solve(self.b, x)
96##         x.axpy(-1, self.x)
97##         self.assertTrue(x.norm() < 1e-3)
98##
99class BaseTestMatFactorChol(BaseTestMatFactor):
100
101    def testFactorChol(self):
102        r, c = self.A.getOrdering("natural")
103        self.A.factorCholesky(r)
104        x = self.x.duplicate()
105        self.A.solve(self.b, x)
106        x.axpy(-1, self.x)
107        self.assertTrue(x.norm() < 1e-3)
108
109class BaseTestMatFactorICC(BaseTestMatFactor):
110
111    def testFactorICC(self):
112        r, c = self.A.getOrdering("natural")
113        self.A.factorICC(r)
114        x = self.x.duplicate()
115        self.A.solve(self.b, x)
116        x.axpy(-1, self.x)
117        self.assertTrue(x.norm() < 1e-3)
118
119
120# --------------------------------------------------------------------
121
122class TestMatFactorA1(BaseTestMatFactorLU,
123                      BaseTestMatFactorChol,
124                      unittest.TestCase):
125    MKSYS = staticmethod(mksys_diag)
126    MTYPE = PETSc.Mat.Type.SEQDENSE
127
128class TestMatFactorA2(BaseTestMatFactorLU,
129                      BaseTestMatFactorChol,
130                      unittest.TestCase):
131    MKSYS = staticmethod(mksys_poi2)
132    MTYPE = PETSc.Mat.Type.SEQDENSE
133
134# ---
135
136class TestMatFactorB1(BaseTestMatFactorLU,
137                      BaseTestMatFactorILU,
138                      ## BaseTestMatFactorILUDT,
139                      unittest.TestCase):
140    MKSYS = staticmethod(mksys_diag)
141    MTYPE = PETSc.Mat.Type.SEQAIJ
142
143class TestMatFactorB2(BaseTestMatFactorLU,
144                      BaseTestMatFactorILU,
145                      ## BaseTestMatFactorILUDT,
146                      unittest.TestCase):
147    MKSYS = staticmethod(mksys_poi2)
148    MTYPE = PETSc.Mat.Type.SEQAIJ
149
150# ---
151
152class TestMatFactorC1(BaseTestMatFactorLU,
153                      BaseTestMatFactorILU,
154                      unittest.TestCase):
155    MKSYS = staticmethod(mksys_diag)
156    MTYPE = PETSc.Mat.Type.SEQBAIJ
157
158class TestMatFactorC2(BaseTestMatFactorLU,
159                      BaseTestMatFactorILU,
160                      unittest.TestCase):
161    MKSYS = staticmethod(mksys_poi2)
162    MTYPE = PETSc.Mat.Type.SEQBAIJ
163
164# ---
165
166class TestMatFactorD1(BaseTestMatFactorChol,
167                      BaseTestMatFactorICC,
168                      unittest.TestCase):
169    MKSYS = staticmethod(mksys_diag)
170    MTYPE = PETSc.Mat.Type.SEQSBAIJ
171    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
172
173class TestMatFactorD2(BaseTestMatFactorChol,
174                      BaseTestMatFactorICC,
175                      unittest.TestCase):
176    MKSYS = staticmethod(mksys_poi2)
177    MTYPE = PETSc.Mat.Type.SEQSBAIJ
178    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
179
180# --------------------------------------------------------------------
181
182if __name__ == '__main__':
183    unittest.main()
184