xref: /petsc/src/binding/petsc4py/test/test_mat_fact.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1from petsc4py import PETSc
2import unittest
3
4
5def mkmat(n, mtype, opts):
6    A = PETSc.Mat().create(PETSc.COMM_SELF)
7    A.setSizes([n, n])
8    A.setType(mtype)
9    A.setUp()
10    for o in opts:
11        A.setOption(o, True)
12    return A
13
14
15def mksys_diag(n, mtype, opts):
16    A = mkmat(n, mtype, opts)
17    x, b = A.createVecs()
18    for i in range(n):
19        A[i, i] = i + 1
20        x[i] = 1.0 / (i + 1)
21        b[i] = 1
22    A.assemble()
23    x.assemble()
24    b.assemble()
25    return A, x, b
26
27
28def mksys_poi2(n, mtype, opts):
29    A = mkmat(n, mtype, opts)
30    x, b = A.createVecs()
31    for i in range(n):
32        if i == 0:
33            cols = [i, i + 1]
34            vals = [2, -1]
35        elif i == n - 1:
36            cols = [i - 1, i]
37            vals = [-1, 2]
38        else:
39            cols = [i - 1, i, i + 1]
40            vals = [-1, 2, -1]
41        A[i, cols] = vals
42        x[i] = i + 1
43        b[i] = 0
44    A.assemble()
45    x.assemble()
46    b.assemble()
47    A.mult(x, b)
48    return A, x, b
49
50
51class BaseTestMatFactor:
52    MKSYS = None
53    MTYPE = None
54    MOPTS = ()
55
56    def setUp(self):
57        A, x, b = self.MKSYS(10, self.MTYPE, self.MOPTS)
58        self.A = A
59        self.x = x
60        self.b = b
61
62    def tearDown(self):
63        self.A.setUnfactored()
64        self.A.destroy()
65        self.A = None
66        self.x.destroy()
67        self.x = None
68        self.b.destroy()
69        self.b = None
70        PETSc.garbage_cleanup()
71
72
73class BaseTestMatFactorLU(BaseTestMatFactor):
74    def testFactorLU(self):
75        r, c = self.A.getOrdering('nd')
76        self.A.reorderForNonzeroDiagonal(r, c)
77        self.A.factorLU(r, c, {'zeropivot': 1e-5})
78        x = self.x.duplicate()
79        self.A.solve(self.b, x)
80        x.axpy(-1, self.x)
81        self.assertTrue(x.norm() < 1e-3)
82
83
84class BaseTestMatFactorILU(BaseTestMatFactor):
85    def testFactorILU(self):
86        r, c = self.A.getOrdering('natural')
87        self.A.factorILU(r, c, {'levels': 0})
88        x = self.x.duplicate()
89        self.A.solve(self.b, x)
90        x.axpy(-1, self.x)
91        self.assertTrue(x.norm() < 1e-3)
92
93
94## class BaseTestMatFactorILUDT(BaseTestMatFactor):
95##
96##     def testFactorILUDT(self):
97##         r, c = self.A.getOrdering("natural")
98##         self.A = self.A.factorILUDT(r,c)
99##         x = self.x.duplicate()
100##         self.A.solve(self.b, x)
101##         x.axpy(-1, self.x)
102##         self.assertTrue(x.norm() < 1e-3)
103##
104class BaseTestMatFactorChol(BaseTestMatFactor):
105    def testFactorChol(self):
106        r, c = self.A.getOrdering('natural')
107        self.A.factorCholesky(r)
108        x = self.x.duplicate()
109        self.A.solve(self.b, x)
110        x.axpy(-1, self.x)
111        self.assertTrue(x.norm() < 1e-3)
112
113
114class BaseTestMatFactorICC(BaseTestMatFactor):
115    def testFactorICC(self):
116        r, c = self.A.getOrdering('natural')
117        self.A.factorICC(r)
118        x = self.x.duplicate()
119        self.A.solve(self.b, x)
120        x.axpy(-1, self.x)
121        self.assertTrue(x.norm() < 1e-3)
122
123
124# --------------------------------------------------------------------
125
126
127class TestMatFactorA1(BaseTestMatFactorLU, BaseTestMatFactorChol, unittest.TestCase):
128    MKSYS = staticmethod(mksys_diag)
129    MTYPE = PETSc.Mat.Type.SEQDENSE
130
131
132class TestMatFactorA2(BaseTestMatFactorLU, BaseTestMatFactorChol, unittest.TestCase):
133    MKSYS = staticmethod(mksys_poi2)
134    MTYPE = PETSc.Mat.Type.SEQDENSE
135
136
137# ---
138
139
140class TestMatFactorB1(
141    BaseTestMatFactorLU,
142    BaseTestMatFactorILU,
143    ## BaseTestMatFactorILUDT,
144    unittest.TestCase,
145):
146    MKSYS = staticmethod(mksys_diag)
147    MTYPE = PETSc.Mat.Type.SEQAIJ
148
149
150class TestMatFactorB2(
151    BaseTestMatFactorLU,
152    BaseTestMatFactorILU,
153    ## BaseTestMatFactorILUDT,
154    unittest.TestCase,
155):
156    MKSYS = staticmethod(mksys_poi2)
157    MTYPE = PETSc.Mat.Type.SEQAIJ
158
159
160# ---
161
162
163class TestMatFactorC1(BaseTestMatFactorLU, BaseTestMatFactorILU, unittest.TestCase):
164    MKSYS = staticmethod(mksys_diag)
165    MTYPE = PETSc.Mat.Type.SEQBAIJ
166
167
168class TestMatFactorC2(BaseTestMatFactorLU, BaseTestMatFactorILU, unittest.TestCase):
169    MKSYS = staticmethod(mksys_poi2)
170    MTYPE = PETSc.Mat.Type.SEQBAIJ
171
172
173# ---
174
175
176class TestMatFactorD1(BaseTestMatFactorChol, BaseTestMatFactorICC, unittest.TestCase):
177    MKSYS = staticmethod(mksys_diag)
178    MTYPE = PETSc.Mat.Type.SEQSBAIJ
179    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
180
181
182class TestMatFactorD2(BaseTestMatFactorChol, BaseTestMatFactorICC, unittest.TestCase):
183    MKSYS = staticmethod(mksys_poi2)
184    MTYPE = PETSc.Mat.Type.SEQSBAIJ
185    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
186
187
188# --------------------------------------------------------------------
189
190if __name__ == '__main__':
191    unittest.main()
192