xref: /petsc/src/binding/petsc4py/test/test_nsp.py (revision aaa8cc7d2a5c3913edcbb923e20f154fe9c4aa65)
1import unittest
2from petsc4py import PETSc
3import numpy as N
4from sys import getrefcount
5
6# --------------------------------------------------------------------
7
8def allclose(seq1, seq2):
9    for v1, v2 in zip(seq1, seq2):
10        if abs(v1-v2) > 1e-5:
11            return False
12    return True
13
14
15class TestNullSpace(unittest.TestCase):
16
17    def setUp(self):
18        u1 = PETSc.Vec().createSeq(3)
19        u2 = PETSc.Vec().createSeq(3)
20        u1[0], u1[1], u1[2] = [1,  2, 0]; u1.normalize()
21        u2[0], u2[1], u2[2] = [2, -1, 0]; u2.normalize()
22        basis = [u1, u2]
23        nullsp = PETSc.NullSpace().create(False, basis, comm=PETSc.COMM_SELF)
24        self.basis = basis
25        self.nullsp = nullsp
26
27    def tearDown(self):
28        self.basis = None
29        self.nullsp = None
30        PETSc.garbage_cleanup()
31
32    def _remove(self):
33        v = PETSc.Vec().createSeq(3);
34        v[0], v[1], v[2] = [7,  8, 9]
35        w = v.copy()
36        self.nullsp.remove(w)
37        return (v, w)
38
39    def testRemove(self):
40        v, w = self._remove()
41        tols = (0, 1e-5)
42        self.assertTrue(allclose(v.array, [7,  8, 9]))
43        self.assertTrue(allclose(w.array, [0,  0, 9]))
44        del v, w
45
46    def testRemoveInplace(self):
47        v, w = self._remove()
48        self.nullsp.remove(v)
49        self.assertTrue(v.equal(w))
50        del v, w
51
52    def testRemoveWithFunction(self):
53        def myremove(nsp, vec):
54            vec.setArray([1,2,3])
55        self.nullsp.setFunction(myremove)
56        v, w = self._remove()
57        self.assertTrue(allclose(v.array, [7,  8, 9]))
58        self.assertTrue(allclose(w.array, [1,  2, 3]))
59        self.nullsp.remove(v)
60        self.assertTrue(allclose(v.array, [1,  2, 3]))
61        self.nullsp.setFunction(None)
62        self.testRemove()
63
64    def testGetSetFunction(self):
65        def rem(nsp, vec):
66            vec.set(0)
67        self.nullsp.setFunction(rem)
68        self.assertEqual(getrefcount(rem)-1, 2)
69        dct = self.nullsp.getDict()
70        self.assertTrue(dct is not None)
71        self.assertEqual(getrefcount(dct)-1, 2)
72        fun, a, kw = dct['__function__']
73        self.assertTrue(fun is rem)
74        self.nullsp.setFunction(None)
75        fun = dct.get('__function__')
76        self.assertEqual(getrefcount(rem)-1, 1)
77        self.assertTrue(fun is None)
78
79# --------------------------------------------------------------------
80
81if __name__ == '__main__':
82    unittest.main()
83
84# --------------------------------------------------------------------
85