xref: /petsc/src/binding/petsc4py/test/test_dmshell.py (revision 2ff79c18c26c94ed8cb599682f680f231dca6444)
1from petsc4py import PETSc
2import unittest
3import numpy as np
4
5
6class TestDMShell(unittest.TestCase):
7    COMM = PETSc.COMM_WORLD
8
9    def setUp(self):
10        self.dm = PETSc.DMShell().create(comm=self.COMM)
11
12    def tearDown(self):
13        self.dm.destroy()
14        self.dm = None
15        PETSc.garbage_cleanup()
16
17    def testSetGlobalVector(self):
18        vec = PETSc.Vec().create(comm=self.COMM)
19        vec.setSizes((10, None))
20        vec.setUp()
21        self.dm.setGlobalVector(vec)
22        gvec = self.dm.createGlobalVector()
23        self.assertEqual(vec.getSizes(), gvec.getSizes())
24        self.assertEqual(vec.comm, gvec.comm)
25
26    def testSetCreateGlobalVector(self):
27        def create_vec(dm):
28            v = PETSc.Vec().create(comm=dm.comm)
29            v.setSizes((10, None))
30            v.setUp()
31            return v
32
33        self.dm.setCreateGlobalVector(create_vec)
34        gvec = self.dm.createGlobalVector()
35        self.assertEqual(gvec.comm, self.dm.comm)
36        self.assertEqual(gvec.getLocalSize(), 10)
37
38    def testSetLocalVector(self):
39        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
40        vec.setSizes((1 + 10 * self.COMM.rank, None))
41        vec.setUp()
42        self.dm.setLocalVector(vec)
43        lvec = self.dm.createLocalVector()
44        self.assertEqual(vec.getSizes(), lvec.getSizes())
45        lsize, gsize = lvec.getSizes()
46        self.assertEqual(lsize, gsize)
47        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
48
49    def testSetCreateLocalVector(self):
50        def create_vec(dm):
51            v = PETSc.Vec().create(comm=PETSc.COMM_SELF)
52            v.setSizes((1 + 10 * dm.comm.rank, None))
53            v.setUp()
54            return v
55
56        self.dm.setCreateLocalVector(create_vec)
57        lvec = self.dm.createLocalVector()
58        lsize, gsize = lvec.getSizes()
59        self.assertEqual(lsize, gsize)
60        self.assertEqual(lsize, 1 + 10 * self.dm.comm.rank)
61        self.assertEqual(lvec.comm, PETSc.COMM_SELF)
62
63    def testSetMatrix(self):
64        mat = PETSc.Mat().create(comm=self.COMM)
65        mat.setSizes(((10, None), (2, None)))
66        mat.setUp()
67        mat.assemble()
68        self.dm.setMatrix(mat)
69        nmat = self.dm.createMatrix()
70        self.assertEqual(nmat.getSizes(), mat.getSizes())
71
72    def testSetCreateMatrix(self):
73        def create_mat(dm):
74            mat = PETSc.Mat().create(comm=self.COMM)
75            mat.setSizes(((10, None), (2, None)))
76            mat.setUp()
77            return mat
78
79        self.dm.setCreateMatrix(create_mat)
80        nmat = self.dm.createMatrix()
81        self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes())
82
83    def testSetCreateFieldDecomposition(self):
84        def create_field_decomposition_only_is(dm):
85            return None, [PETSc.IS().createStride(1) for _ in range(2)], None
86
87        def create_field_decomposition_only_is_names(dm):
88            _, ises, _ = create_field_decomposition_only_is(dm)
89            names = [f'f_{i}' for i in range(len(ises))]
90            return names, ises, None
91
92        def create_field_decomposition_only_is_dms(dm):
93            _, ises, _ = create_field_decomposition_only_is(dm)
94            dms = [dm.clone() for _ in range(len(ises))]
95            return _, ises, dms
96
97        def create_field_decomposition_only_full(dm):
98            names, ises, _ = create_field_decomposition_only_is_names(dm)
99            dms = [dm.clone() for _ in range(len(ises))]
100            return names, ises, dms
101
102        tests = [
103            (create_field_decomposition_only_is, False, False),
104            (create_field_decomposition_only_is_names, True, False),
105            (create_field_decomposition_only_is_dms, False, True),
106            (create_field_decomposition_only_full, True, True),
107        ]
108        for test_f, has_names, has_dms in tests:
109            self.dm.setCreateFieldDecomposition(test_f)
110            names, ises, dms = self.dm.createFieldDecomposition()
111            self.assertEqual(len(names), len(ises))
112            self.assertEqual(len(dms), len(ises))
113            if has_names:
114                checknames = [f'f_{i}' for i in range(len(ises))]
115                self.assertEqual(names, checknames)
116            else:
117                for n in names:
118                    self.assertTrue(n is None)
119            for dm in dms:
120                if has_dms:
121                    self.assertEqual(type(dm), PETSc.DM)
122                else:
123                    self.assertTrue(dm is None)
124
125    def testGlobalToLocal(self):
126        def begin(dm, ivec, mode, ovec):
127            if mode == PETSc.InsertMode.INSERT_VALUES:
128                ovec[...] = ivec[...]
129            elif mode == PETSc.InsertMode.ADD_VALUES:
130                ovec[...] += ivec[...]
131
132        def end(dm, ivec, mode, ovec):
133            pass
134
135        vec = PETSc.Vec().create(comm=self.COMM)
136        vec.setSizes((10, None))
137        vec.setUp()
138        vec[...] = self.dm.comm.rank + 1
139        ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
140        ovec.setSizes((10, None))
141        ovec.setUp()
142        self.dm.setGlobalToLocal(begin, end)
143        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
144        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
145        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
146        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
147
148    def testLocalToGlobal(self):
149        def begin(dm, ivec, mode, ovec):
150            if mode == PETSc.InsertMode.INSERT_VALUES:
151                ovec[...] = ivec[...]
152            elif mode == PETSc.InsertMode.ADD_VALUES:
153                ovec[...] += ivec[...]
154
155        def end(dm, ivec, mode, ovec):
156            pass
157
158        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
159        vec.setSizes((10, None))
160        vec.setUp()
161        vec[...] = self.dm.comm.rank + 1
162        ovec = PETSc.Vec().create(comm=self.COMM)
163        ovec.setSizes((10, None))
164        ovec.setUp()
165        self.dm.setLocalToGlobal(begin, end)
166        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
167        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
168        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
169        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
170
171    def testLocalToLocal(self):
172        def begin(dm, ivec, mode, ovec):
173            if mode == PETSc.InsertMode.INSERT_VALUES:
174                ovec[...] = ivec[...]
175            elif mode == PETSc.InsertMode.ADD_VALUES:
176                ovec[...] += ivec[...]
177
178        def end(dm, ivec, mode, ovec):
179            pass
180
181        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
182        vec.setSizes((10, None))
183        vec.setUp()
184        vec[...] = self.dm.comm.rank + 1
185        ovec = vec.duplicate()
186        self.dm.setLocalToLocal(begin, end)
187        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
188        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
189        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
190        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))
191
192    def testGlobalToLocalVecScatter(self):
193        vec = PETSc.Vec().create()
194        vec.setSizes((10, None))
195        vec.setUp()
196        sct, ovec = PETSc.Scatter.toAll(vec)
197        self.dm.setGlobalToLocalVecScatter(sct)
198        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
199
200    def testLocalToGlobalVecScatter(self):
201        vec = PETSc.Vec().create()
202        vec.setSizes((10, None))
203        vec.setUp()
204        sct, ovec = PETSc.Scatter.toAll(vec)
205        self.dm.setLocalToGlobalVecScatter(sct)
206        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
207
208    def testLocalToLocalVecScatter(self):
209        vec = PETSc.Vec().create()
210        vec.setSizes((10, None))
211        vec.setUp()
212        sct, ovec = PETSc.Scatter.toAll(vec)
213        self.dm.setLocalToLocalVecScatter(sct)
214        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
215
216    def testCoarsenRefine(self):
217        cdm = PETSc.DMShell().create(comm=self.COMM)
218
219        def coarsen(dm, comm):
220            return cdm
221
222        def refine(dm, comm):
223            return self.dm
224
225        cdm.setRefine(refine)
226        self.dm.setCoarsen(coarsen)
227        coarsened = self.dm.coarsen()
228        self.assertEqual(coarsened, cdm)
229        refined = coarsened.refine()
230        self.assertEqual(refined, self.dm)
231
232    def testCreateInterpolation(self):
233        mat = PETSc.Mat().create()
234        mat.setSizes(((10, None), (10, None)))
235        mat.setUp()
236        vec = PETSc.Vec().create()
237        vec.setSizes((10, None))
238        vec.setUp()
239
240        def create_interp(dm, dmf):
241            return mat, vec
242
243        self.dm.setCreateInterpolation(create_interp)
244        m, v = self.dm.createInterpolation(self.dm)
245        self.assertEqual(m, mat)
246        self.assertEqual(v, vec)
247
248    def testCreateInjection(self):
249        mat = PETSc.Mat().create()
250        mat.setSizes(((10, None), (10, None)))
251        mat.setUp()
252
253        def create_inject(dm, dmf):
254            return mat
255
256        self.dm.setCreateInjection(create_inject)
257        m = self.dm.createInjection(self.dm)
258        self.assertEqual(m, mat)
259
260
261if __name__ == '__main__':
262    unittest.main()
263