xref: /petsc/src/binding/petsc4py/test/test_dmshell.py (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1from petsc4py import PETSc
2import unittest
3import numpy as np
4
5
6class TestDMShell(unittest.TestCase):
7
8    COMM = PETSc.COMM_WORLD
9
10    def setUp(self):
11        self.dm = PETSc.DMShell().create(comm=self.COMM)
12
13    def tearDown(self):
14        self.dm.destroy()
15        self.dm = None
16        PETSc.garbage_cleanup()
17
18    def testSetGlobalVector(self):
19        vec = PETSc.Vec().create(comm=self.COMM)
20        vec.setSizes((10, None))
21        vec.setUp()
22        self.dm.setGlobalVector(vec)
23        gvec = self.dm.createGlobalVector()
24        self.assertEqual(vec.getSizes(), gvec.getSizes())
25        self.assertEqual(vec.comm, gvec.comm)
26
27    def testSetCreateGlobalVector(self):
28        def create_vec(dm):
29            v = PETSc.Vec().create(comm=dm.comm)
30            v.setSizes((10, None))
31            v.setUp()
32            return v
33        self.dm.setCreateGlobalVector(create_vec)
34        gvec = self.dm.createGlobalVector()
35        self.assertEqual(gvec.comm, self.dm.comm)
36        self.assertEqual(gvec.getLocalSize(), 10)
37
38    def testSetLocalVector(self):
39        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
40        vec.setSizes((1 + 10*self.COMM.rank, None))
41        vec.setUp()
42        self.dm.setLocalVector(vec)
43        lvec = self.dm.createLocalVector()
44        self.assertEqual(vec.getSizes(), lvec.getSizes())
45        lsize, gsize = lvec.getSizes()
46        self.assertEqual(lsize, gsize)
47        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
48
49    def testSetCreateLocalVector(self):
50        def create_vec(dm):
51            v = PETSc.Vec().create(comm=PETSc.COMM_SELF)
52            v.setSizes((1 + 10*dm.comm.rank, None))
53            v.setUp()
54            return v
55        self.dm.setCreateLocalVector(create_vec)
56        lvec = self.dm.createLocalVector()
57        lsize, gsize = lvec.getSizes()
58        self.assertEqual(lsize, gsize)
59        self.assertEqual(lsize, 1 + 10*self.dm.comm.rank)
60        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
61
62    def testSetMatrix(self):
63        mat = PETSc.Mat().create(comm=self.COMM)
64        mat.setSizes(((10, None), (2, None)))
65        mat.setUp()
66        mat.assemble()
67        self.dm.setMatrix(mat)
68        nmat = self.dm.createMatrix()
69        self.assertEqual(nmat.getSizes(), mat.getSizes())
70
71    def testSetCreateMatrix(self):
72        def create_mat(dm):
73            mat = PETSc.Mat().create(comm=self.COMM)
74            mat.setSizes(((10, None), (2, None)))
75            mat.setUp()
76            return mat
77        self.dm.setCreateMatrix(create_mat)
78        nmat = self.dm.createMatrix()
79        self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes())
80
81    def testGlobalToLocal(self):
82        def begin(dm, ivec, mode, ovec):
83            if mode == PETSc.InsertMode.INSERT_VALUES:
84                ovec[...] = ivec[...]
85            elif mode == PETSc.InsertMode.ADD_VALUES:
86                ovec[...] += ivec[...]
87        def end(dm, ivec, mode, ovec):
88            pass
89        vec = PETSc.Vec().create(comm=self.COMM)
90        vec.setSizes((10, None))
91        vec.setUp()
92        vec[...] = self.dm.comm.rank + 1
93        ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
94        ovec.setSizes((10, None))
95        ovec.setUp()
96        self.dm.setGlobalToLocal(begin, end)
97        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
98        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
99        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
100        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
101
102    def testLocalToGlobal(self):
103        def begin(dm, ivec, mode, ovec):
104            if mode == PETSc.InsertMode.INSERT_VALUES:
105                ovec[...] = ivec[...]
106            elif mode == PETSc.InsertMode.ADD_VALUES:
107                ovec[...] += ivec[...]
108        def end(dm, ivec, mode, ovec):
109            pass
110        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
111        vec.setSizes((10, None))
112        vec.setUp()
113        vec[...] = self.dm.comm.rank + 1
114        ovec = PETSc.Vec().create(comm=self.COMM)
115        ovec.setSizes((10, None))
116        ovec.setUp()
117        self.dm.setLocalToGlobal(begin, end)
118        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
119        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
120        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
121        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
122
123    def testLocalToLocal(self):
124        def begin(dm, ivec, mode, ovec):
125            if mode == PETSc.InsertMode.INSERT_VALUES:
126                ovec[...] = ivec[...]
127            elif mode == PETSc.InsertMode.ADD_VALUES:
128                ovec[...] += ivec[...]
129        def end(dm, ivec, mode, ovec):
130            pass
131        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
132        vec.setSizes((10, None))
133        vec.setUp()
134        vec[...] = self.dm.comm.rank + 1
135        ovec = vec.duplicate()
136        self.dm.setLocalToLocal(begin, end)
137        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
138        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
139        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
140        self.assertTrue(np.allclose(2*vec.getArray(), ovec.getArray()))
141
142    def testGlobalToLocalVecScatter(self):
143        vec = PETSc.Vec().create()
144        vec.setSizes((10, None))
145        vec.setUp()
146        sct, ovec = PETSc.Scatter.toAll(vec)
147        self.dm.setGlobalToLocalVecScatter(sct)
148
149        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
150
151        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
152
153    def testGlobalToLocalVecScatter(self):
154        vec = PETSc.Vec().create()
155        vec.setSizes((10, None))
156        vec.setUp()
157        sct, ovec = PETSc.Scatter.toAll(vec)
158        self.dm.setGlobalToLocalVecScatter(sct)
159        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
160
161    def testLocalToGlobalVecScatter(self):
162        vec = PETSc.Vec().create()
163        vec.setSizes((10, None))
164        vec.setUp()
165        sct, ovec = PETSc.Scatter.toAll(vec)
166        self.dm.setLocalToGlobalVecScatter(sct)
167        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
168
169    def testLocalToLocalVecScatter(self):
170        vec = PETSc.Vec().create()
171        vec.setSizes((10, None))
172        vec.setUp()
173        sct, ovec = PETSc.Scatter.toAll(vec)
174        self.dm.setLocalToLocalVecScatter(sct)
175        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
176
177    def testCoarsenRefine(self):
178        cdm = PETSc.DMShell().create(comm=self.COMM)
179        def coarsen(dm, comm):
180            return cdm
181        def refine(dm, comm):
182            return self.dm
183        cdm.setRefine(refine)
184        self.dm.setCoarsen(coarsen)
185        coarsened = self.dm.coarsen()
186        self.assertEqual(coarsened, cdm)
187        refined = coarsened.refine()
188        self.assertEqual(refined, self.dm)
189
190    def testCreateInterpolation(self):
191        mat = PETSc.Mat().create()
192        mat.setSizes(((10, None), (10, None)))
193        mat.setUp()
194        vec = PETSc.Vec().create()
195        vec.setSizes((10, None))
196        vec.setUp()
197        def create_interp(dm, dmf):
198            return mat, vec
199        self.dm.setCreateInterpolation(create_interp)
200        m, v = self.dm.createInterpolation(self.dm)
201        self.assertEqual(m, mat)
202        self.assertEqual(v, vec)
203
204    def testCreateInjection(self):
205        mat = PETSc.Mat().create()
206        mat.setSizes(((10, None), (10, None)))
207        mat.setUp()
208        def create_inject(dm, dmf):
209            return mat
210        self.dm.setCreateInjection(create_inject)
211        m = self.dm.createInjection(self.dm)
212        self.assertEqual(m, mat)
213
214
215if __name__ == '__main__':
216    unittest.main()
217