xref: /petsc/src/binding/petsc4py/test/test_mat_fact.py (revision bef158480efac06de457f7a665168877ab3c2fd7)
1from petsc4py import PETSc
2import unittest
3
4import numpy as N
5
6def mkmat(n, mtype, opts):
7    A = PETSc.Mat().create(PETSc.COMM_SELF)
8    A.setSizes([n,n])
9    A.setType(mtype)
10    A.setUp()
11    for o in opts:
12        A.setOption(o, True)
13    return A
14
15def mksys_diag(n, mtype, opts):
16    A = mkmat(n, mtype, opts)
17    x, b = A.createVecs()
18    for i in range(n):
19        A[i,i] = i+1
20        x[i]   = 1.0/(i+1)
21        b[i]   = 1
22    A.assemble()
23    x.assemble()
24    b.assemble()
25    return A, x, b
26
27def mksys_poi2(n, mtype, opts):
28    A = mkmat(n, mtype, opts)
29    x, b = A.createVecs()
30    for i in range(n):
31        if i == 0:
32            cols = [i, i+1]
33            vals = [2, -1]
34        elif i == n-1:
35            cols = [i-1, i]
36            vals = [-1,  2]
37        else:
38            cols = [i-1, i, i+1]
39            vals = [-1,  2, -1]
40        A[i,cols] = vals
41        x[i]   = i+1
42        b[i]   = 0
43    A.assemble()
44    x.assemble()
45    b.assemble()
46    A.mult(x,b)
47    return A, x, b
48
49class BaseTestMatFactor(object):
50
51    MKSYS = None
52    MTYPE = None
53    MOPTS = ()
54
55    def setUp(self):
56        A, x, b = self.MKSYS(10, self.MTYPE, self.MOPTS)
57        self.A = A
58        self.x = x
59        self.b = b
60
61    def tearDown(self):
62        self.A.setUnfactored()
63        self.A.destroy(); self.A = None
64        self.x.destroy(); self.x = None
65        self.b.destroy(); self.b = None
66
67class BaseTestMatFactorLU(BaseTestMatFactor):
68
69    def testFactorLU(self):
70        r, c = self.A.getOrdering("nd")
71        self.A.reorderForNonzeroDiagonal(r, c)
72        self.A.factorLU(r,c,{'zeropivot':1e-5})
73        x = self.x.duplicate()
74        self.A.solve(self.b, x)
75        x.axpy(-1, self.x)
76        self.assertTrue(x.norm() < 1e-3)
77
78class BaseTestMatFactorILU(BaseTestMatFactor):
79
80    def testFactorILU(self):
81        r, c = self.A.getOrdering("natural")
82        self.A.factorILU(r,c,{'levels':0})
83        x = self.x.duplicate()
84        self.A.solve(self.b, x)
85        x.axpy(-1, self.x)
86        self.assertTrue(x.norm() < 1e-3)
87
88## class BaseTestMatFactorILUDT(BaseTestMatFactor):
89##
90##     def testFactorILUDT(self):
91##         r, c = self.A.getOrdering("natural")
92##         self.A = self.A.factorILUDT(r,c)
93##         x = self.x.duplicate()
94##         self.A.solve(self.b, x)
95##         x.axpy(-1, self.x)
96##         self.assertTrue(x.norm() < 1e-3)
97##
98class BaseTestMatFactorChol(BaseTestMatFactor):
99
100    def testFactorChol(self):
101        r, c = self.A.getOrdering("natural")
102        self.A.factorCholesky(r)
103        x = self.x.duplicate()
104        self.A.solve(self.b, x)
105        x.axpy(-1, self.x)
106        self.assertTrue(x.norm() < 1e-3)
107
108class BaseTestMatFactorICC(BaseTestMatFactor):
109
110    def testFactorICC(self):
111        r, c = self.A.getOrdering("natural")
112        self.A.factorICC(r)
113        x = self.x.duplicate()
114        self.A.solve(self.b, x)
115        x.axpy(-1, self.x)
116        self.assertTrue(x.norm() < 1e-3)
117
118
119# --------------------------------------------------------------------
120
121class TestMatFactorA1(BaseTestMatFactorLU,
122                      BaseTestMatFactorChol,
123                      unittest.TestCase):
124    MKSYS = staticmethod(mksys_diag)
125    MTYPE = PETSc.Mat.Type.SEQDENSE
126
127class TestMatFactorA2(BaseTestMatFactorLU,
128                      BaseTestMatFactorChol,
129                      unittest.TestCase):
130    MKSYS = staticmethod(mksys_poi2)
131    MTYPE = PETSc.Mat.Type.SEQDENSE
132
133# ---
134
135class TestMatFactorB1(BaseTestMatFactorLU,
136                      BaseTestMatFactorILU,
137                      ## BaseTestMatFactorILUDT,
138                      unittest.TestCase):
139    MKSYS = staticmethod(mksys_diag)
140    MTYPE = PETSc.Mat.Type.SEQAIJ
141
142class TestMatFactorB2(BaseTestMatFactorLU,
143                      BaseTestMatFactorILU,
144                      ## BaseTestMatFactorILUDT,
145                      unittest.TestCase):
146    MKSYS = staticmethod(mksys_poi2)
147    MTYPE = PETSc.Mat.Type.SEQAIJ
148
149# ---
150
151class TestMatFactorC1(BaseTestMatFactorLU,
152                      BaseTestMatFactorILU,
153                      unittest.TestCase):
154    MKSYS = staticmethod(mksys_diag)
155    MTYPE = PETSc.Mat.Type.SEQBAIJ
156
157class TestMatFactorC2(BaseTestMatFactorLU,
158                      BaseTestMatFactorILU,
159                      unittest.TestCase):
160    MKSYS = staticmethod(mksys_poi2)
161    MTYPE = PETSc.Mat.Type.SEQBAIJ
162
163# ---
164
165class TestMatFactorD1(BaseTestMatFactorChol,
166                      BaseTestMatFactorICC,
167                      unittest.TestCase):
168    MKSYS = staticmethod(mksys_diag)
169    MTYPE = PETSc.Mat.Type.SEQSBAIJ
170    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
171
172class TestMatFactorD2(BaseTestMatFactorChol,
173                      BaseTestMatFactorICC,
174                      unittest.TestCase):
175    MKSYS = staticmethod(mksys_poi2)
176    MTYPE = PETSc.Mat.Type.SEQSBAIJ
177    MOPTS = [PETSc.Mat.Option.IGNORE_LOWER_TRIANGULAR]
178
179# --------------------------------------------------------------------
180
181if __name__ == '__main__':
182    unittest.main()
183