xref: /petsc/src/binding/petsc4py/test/test_dmshell.py (revision 6dd63270497ad23dcf16ae500a87ff2b2a0b7474)
1from petsc4py import PETSc
2import unittest
3import numpy as np
4
5
6class TestDMShell(unittest.TestCase):
7    COMM = PETSc.COMM_WORLD
8
9    def setUp(self):
10        self.dm = PETSc.DMShell().create(comm=self.COMM)
11
12    def tearDown(self):
13        self.dm.destroy()
14        self.dm = None
15        PETSc.garbage_cleanup()
16
17    def testSetGlobalVector(self):
18        vec = PETSc.Vec().create(comm=self.COMM)
19        vec.setSizes((10, None))
20        vec.setUp()
21        self.dm.setGlobalVector(vec)
22        gvec = self.dm.createGlobalVector()
23        self.assertEqual(vec.getSizes(), gvec.getSizes())
24        self.assertEqual(vec.comm, gvec.comm)
25
26    def testSetCreateGlobalVector(self):
27        def create_vec(dm):
28            v = PETSc.Vec().create(comm=dm.comm)
29            v.setSizes((10, None))
30            v.setUp()
31            return v
32
33        self.dm.setCreateGlobalVector(create_vec)
34        gvec = self.dm.createGlobalVector()
35        self.assertEqual(gvec.comm, self.dm.comm)
36        self.assertEqual(gvec.getLocalSize(), 10)
37
38    def testSetLocalVector(self):
39        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
40        vec.setSizes((1 + 10 * self.COMM.rank, None))
41        vec.setUp()
42        self.dm.setLocalVector(vec)
43        lvec = self.dm.createLocalVector()
44        self.assertEqual(vec.getSizes(), lvec.getSizes())
45        lsize, gsize = lvec.getSizes()
46        self.assertEqual(lsize, gsize)
47        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
48
49    def testSetCreateLocalVector(self):
50        def create_vec(dm):
51            v = PETSc.Vec().create(comm=PETSc.COMM_SELF)
52            v.setSizes((1 + 10 * dm.comm.rank, None))
53            v.setUp()
54            return v
55
56        self.dm.setCreateLocalVector(create_vec)
57        lvec = self.dm.createLocalVector()
58        lsize, gsize = lvec.getSizes()
59        self.assertEqual(lsize, gsize)
60        self.assertEqual(lsize, 1 + 10 * self.dm.comm.rank)
61        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
62
63    def testSetMatrix(self):
64        mat = PETSc.Mat().create(comm=self.COMM)
65        mat.setSizes(((10, None), (2, None)))
66        mat.setUp()
67        mat.assemble()
68        self.dm.setMatrix(mat)
69        nmat = self.dm.createMatrix()
70        self.assertEqual(nmat.getSizes(), mat.getSizes())
71
72    def testSetCreateMatrix(self):
73        def create_mat(dm):
74            mat = PETSc.Mat().create(comm=self.COMM)
75            mat.setSizes(((10, None), (2, None)))
76            mat.setUp()
77            return mat
78
79        self.dm.setCreateMatrix(create_mat)
80        nmat = self.dm.createMatrix()
81        self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes())
82
83    def testGlobalToLocal(self):
84        def begin(dm, ivec, mode, ovec):
85            if mode == PETSc.InsertMode.INSERT_VALUES:
86                ovec[...] = ivec[...]
87            elif mode == PETSc.InsertMode.ADD_VALUES:
88                ovec[...] += ivec[...]
89
90        def end(dm, ivec, mode, ovec):
91            pass
92
93        vec = PETSc.Vec().create(comm=self.COMM)
94        vec.setSizes((10, None))
95        vec.setUp()
96        vec[...] = self.dm.comm.rank + 1
97        ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
98        ovec.setSizes((10, None))
99        ovec.setUp()
100        self.dm.setGlobalToLocal(begin, end)
101        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
102        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
103        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
104        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
105
106    def testLocalToGlobal(self):
107        def begin(dm, ivec, mode, ovec):
108            if mode == PETSc.InsertMode.INSERT_VALUES:
109                ovec[...] = ivec[...]
110            elif mode == PETSc.InsertMode.ADD_VALUES:
111                ovec[...] += ivec[...]
112
113        def end(dm, ivec, mode, ovec):
114            pass
115
116        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
117        vec.setSizes((10, None))
118        vec.setUp()
119        vec[...] = self.dm.comm.rank + 1
120        ovec = PETSc.Vec().create(comm=self.COMM)
121        ovec.setSizes((10, None))
122        ovec.setUp()
123        self.dm.setLocalToGlobal(begin, end)
124        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
125        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
126        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
127        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
128
129    def testLocalToLocal(self):
130        def begin(dm, ivec, mode, ovec):
131            if mode == PETSc.InsertMode.INSERT_VALUES:
132                ovec[...] = ivec[...]
133            elif mode == PETSc.InsertMode.ADD_VALUES:
134                ovec[...] += ivec[...]
135
136        def end(dm, ivec, mode, ovec):
137            pass
138
139        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
140        vec.setSizes((10, None))
141        vec.setUp()
142        vec[...] = self.dm.comm.rank + 1
143        ovec = vec.duplicate()
144        self.dm.setLocalToLocal(begin, end)
145        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
146        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
147        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
148        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
149
150    def testGlobalToLocalVecScatter(self):
151        vec = PETSc.Vec().create()
152        vec.setSizes((10, None))
153        vec.setUp()
154        sct, ovec = PETSc.Scatter.toAll(vec)
155        self.dm.setGlobalToLocalVecScatter(sct)
156        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
157
158    def testLocalToGlobalVecScatter(self):
159        vec = PETSc.Vec().create()
160        vec.setSizes((10, None))
161        vec.setUp()
162        sct, ovec = PETSc.Scatter.toAll(vec)
163        self.dm.setLocalToGlobalVecScatter(sct)
164        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
165
166    def testLocalToLocalVecScatter(self):
167        vec = PETSc.Vec().create()
168        vec.setSizes((10, None))
169        vec.setUp()
170        sct, ovec = PETSc.Scatter.toAll(vec)
171        self.dm.setLocalToLocalVecScatter(sct)
172        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
173
174    def testCoarsenRefine(self):
175        cdm = PETSc.DMShell().create(comm=self.COMM)
176
177        def coarsen(dm, comm):
178            return cdm
179
180        def refine(dm, comm):
181            return self.dm
182
183        cdm.setRefine(refine)
184        self.dm.setCoarsen(coarsen)
185        coarsened = self.dm.coarsen()
186        self.assertEqual(coarsened, cdm)
187        refined = coarsened.refine()
188        self.assertEqual(refined, self.dm)
189
190    def testCreateInterpolation(self):
191        mat = PETSc.Mat().create()
192        mat.setSizes(((10, None), (10, None)))
193        mat.setUp()
194        vec = PETSc.Vec().create()
195        vec.setSizes((10, None))
196        vec.setUp()
197
198        def create_interp(dm, dmf):
199            return mat, vec
200
201        self.dm.setCreateInterpolation(create_interp)
202        m, v = self.dm.createInterpolation(self.dm)
203        self.assertEqual(m, mat)
204        self.assertEqual(v, vec)
205
206    def testCreateInjection(self):
207        mat = PETSc.Mat().create()
208        mat.setSizes(((10, None), (10, None)))
209        mat.setUp()
210
211        def create_inject(dm, dmf):
212            return mat
213
214        self.dm.setCreateInjection(create_inject)
215        m = self.dm.createInjection(self.dm)
216        self.assertEqual(m, mat)
217
218
219if __name__ == '__main__':
220    unittest.main()
221