xref: /petsc/src/binding/petsc4py/test/test_dmshell.py (revision bef158480efac06de457f7a665168877ab3c2fd7) !
1from petsc4py import PETSc
2import unittest
3import numpy as np
4
5
6class TestDMShell(unittest.TestCase):
7
8    COMM = PETSc.COMM_WORLD
9
10    def setUp(self):
11        self.dm = PETSc.DMShell().create(comm=self.COMM)
12
13    def tearDown(self):
14        self.dm = None
15
16    def testSetGlobalVector(self):
17        vec = PETSc.Vec().create(comm=self.COMM)
18        vec.setSizes((10, None))
19        vec.setUp()
20        self.dm.setGlobalVector(vec)
21        gvec = self.dm.createGlobalVector()
22        self.assertEqual(vec.getSizes(), gvec.getSizes())
23        self.assertEqual(vec.comm, gvec.comm)
24
25    def testSetCreateGlobalVector(self):
26        def create_vec(dm):
27            v = PETSc.Vec().create(comm=dm.comm)
28            v.setSizes((10, None))
29            v.setUp()
30            return v
31        self.dm.setCreateGlobalVector(create_vec)
32        gvec = self.dm.createGlobalVector()
33        self.assertEqual(gvec.comm, self.dm.comm)
34        self.assertEqual(gvec.getLocalSize(), 10)
35
36    def testSetLocalVector(self):
37        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
38        vec.setSizes((1 + 10*self.COMM.rank, None))
39        vec.setUp()
40        self.dm.setLocalVector(vec)
41        lvec = self.dm.createLocalVector()
42        self.assertEqual(vec.getSizes(), lvec.getSizes())
43        lsize, gsize = lvec.getSizes()
44        self.assertEqual(lsize, gsize)
45        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
46
47    def testSetCreateLocalVector(self):
48        def create_vec(dm):
49            v = PETSc.Vec().create(comm=PETSc.COMM_SELF)
50            v.setSizes((1 + 10*dm.comm.rank, None))
51            v.setUp()
52            return v
53        self.dm.setCreateLocalVector(create_vec)
54        lvec = self.dm.createLocalVector()
55        lsize, gsize = lvec.getSizes()
56        self.assertEqual(lsize, gsize)
57        self.assertEqual(lsize, 1 + 10*self.dm.comm.rank)
58        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
59
60    def testSetMatrix(self):
61        mat = PETSc.Mat().create(comm=self.COMM)
62        mat.setSizes(((10, None), (2, None)))
63        mat.setUp()
64        mat.assemble()
65        self.dm.setMatrix(mat)
66        nmat = self.dm.createMatrix()
67        self.assertEqual(nmat.getSizes(), mat.getSizes())
68
69    def testSetCreateMatrix(self):
70        def create_mat(dm):
71            mat = PETSc.Mat().create(comm=self.COMM)
72            mat.setSizes(((10, None), (2, None)))
73            mat.setUp()
74            return mat
75        self.dm.setCreateMatrix(create_mat)
76        nmat = self.dm.createMatrix()
77        self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes())
78
79    def testGlobalToLocal(self):
80        def begin(dm, ivec, mode, ovec):
81            if mode == PETSc.InsertMode.INSERT_VALUES:
82                ovec[...] = ivec[...]
83            elif mode == PETSc.InsertMode.ADD_VALUES:
84                ovec[...] += ivec[...]
85        def end(dm, ivec, mode, ovec):
86            pass
87        vec = PETSc.Vec().create(comm=self.COMM)
88        vec.setSizes((10, None))
89        vec.setUp()
90        vec[...] = self.dm.comm.rank + 1
91        ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
92        ovec.setSizes((10, None))
93        ovec.setUp()
94        self.dm.setGlobalToLocal(begin, end)
95        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
96        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
97        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
98        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
99
100    def testLocalToGlobal(self):
101        def begin(dm, ivec, mode, ovec):
102            if mode == PETSc.InsertMode.INSERT_VALUES:
103                ovec[...] = ivec[...]
104            elif mode == PETSc.InsertMode.ADD_VALUES:
105                ovec[...] += ivec[...]
106        def end(dm, ivec, mode, ovec):
107            pass
108        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
109        vec.setSizes((10, None))
110        vec.setUp()
111        vec[...] = self.dm.comm.rank + 1
112        ovec = PETSc.Vec().create(comm=self.COMM)
113        ovec.setSizes((10, None))
114        ovec.setUp()
115        self.dm.setLocalToGlobal(begin, end)
116        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
117        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
118        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
119        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
120
121    def testLocalToLocal(self):
122        def begin(dm, ivec, mode, ovec):
123            if mode == PETSc.InsertMode.INSERT_VALUES:
124                ovec[...] = ivec[...]
125            elif mode == PETSc.InsertMode.ADD_VALUES:
126                ovec[...] += ivec[...]
127        def end(dm, ivec, mode, ovec):
128            pass
129        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
130        vec.setSizes((10, None))
131        vec.setUp()
132        vec[...] = self.dm.comm.rank + 1
133        ovec = vec.duplicate()
134        self.dm.setLocalToLocal(begin, end)
135        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
136        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
137        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
138        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
139
140    def testGlobalToLocalVecScatter(self):
141        vec = PETSc.Vec().create()
142        vec.setSizes((10, None))
143        vec.setUp()
144        sct, ovec = PETSc.Scatter.toAll(vec)
145        self.dm.setGlobalToLocalVecScatter(sct)
146
147        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
148
149        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
150
151    def testGlobalToLocalVecScatter(self):
152        vec = PETSc.Vec().create()
153        vec.setSizes((10, None))
154        vec.setUp()
155        sct, ovec = PETSc.Scatter.toAll(vec)
156        self.dm.setGlobalToLocalVecScatter(sct)
157        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
158
159    def testLocalToGlobalVecScatter(self):
160        vec = PETSc.Vec().create()
161        vec.setSizes((10, None))
162        vec.setUp()
163        sct, ovec = PETSc.Scatter.toAll(vec)
164        self.dm.setLocalToGlobalVecScatter(sct)
165        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
166
167    def testLocalToLocalVecScatter(self):
168        vec = PETSc.Vec().create()
169        vec.setSizes((10, None))
170        vec.setUp()
171        sct, ovec = PETSc.Scatter.toAll(vec)
172        self.dm.setLocalToLocalVecScatter(sct)
173        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
174
175    def testCoarsenRefine(self):
176        cdm = PETSc.DMShell().create(comm=self.COMM)
177        def coarsen(dm, comm):
178            return cdm
179        def refine(dm, comm):
180            return self.dm
181        cdm.setRefine(refine)
182        self.dm.setCoarsen(coarsen)
183        coarsened = self.dm.coarsen()
184        self.assertEqual(coarsened, cdm)
185        refined = coarsened.refine()
186        self.assertEqual(refined, self.dm)
187
188    def testCreateInterpolation(self):
189        mat = PETSc.Mat().create()
190        mat.setSizes(((10, None), (10, None)))
191        mat.setUp()
192        vec = PETSc.Vec().create()
193        vec.setSizes((10, None))
194        vec.setUp()
195        def create_interp(dm, dmf):
196            return mat, vec
197        self.dm.setCreateInterpolation(create_interp)
198        m, v = self.dm.createInterpolation(self.dm)
199        self.assertEqual(m, mat)
200        self.assertEqual(v, vec)
201
202    def testCreateInjection(self):
203        mat = PETSc.Mat().create()
204        mat.setSizes(((10, None), (10, None)))
205        mat.setUp()
206        def create_inject(dm, dmf):
207            return mat
208        self.dm.setCreateInjection(create_inject)
209        m = self.dm.createInjection(self.dm)
210        self.assertEqual(m, mat)
211
212
213if __name__ == '__main__':
214    unittest.main()
215