xref: /petsc/src/binding/petsc4py/test/test_dmstag.py (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6
7class BaseTestDMStag:
8    COMM = PETSc.COMM_WORLD
9    STENCIL = PETSc.DMStag.StencilType.BOX
10    SWIDTH = 1
11    PROC_SIZES = None
12    OWNERSHIP_RANGES = None
13
14    def setUp(self):
15        dim = len(self.SIZES)
16        self.da = PETSc.DMStag().create(
17            dim,
18            dofs=self.DOFS,
19            sizes=self.SIZES,
20            boundary_types=self.BOUNDARY,
21            stencil_type=self.STENCIL,
22            stencil_width=self.SWIDTH,
23            comm=self.COMM,
24            proc_sizes=self.PROC_SIZES,
25            ownership_ranges=self.OWNERSHIP_RANGES,
26            setUp=True,
27        )
28
29        self.directda = PETSc.DMStag().create(dim)
30        self.directda.setStencilType(self.STENCIL)
31        self.directda.setStencilWidth(self.SWIDTH)
32        self.directda.setBoundaryTypes(self.BOUNDARY)
33        self.directda.setDof(self.DOFS)
34        self.directda.setGlobalSizes(self.SIZES)
35        if self.PROC_SIZES is not None:
36            self.directda.setProcSizes(self.PROC_SIZES)
37        if self.OWNERSHIP_RANGES is not None:
38            self.directda.setOwnershipRanges(self.OWNERSHIP_RANGES)
39        self.directda.setUp()
40
41    def tearDown(self):
42        self.da = None
43        self.directda = None
44        PETSc.garbage_cleanup()
45
46    def testCoordinates(self):
47        self.da.setCoordinateDMType('stag')
48        self.da.setUniformCoordinates(0, 1, 0, 1, 0, 1)
49        self.da.setUniformCoordinatesExplicit(0, 1, 0, 1, 0, 1)
50        cda = self.da.getCoordinateDM()
51        datype = cda.getType()
52        self.assertEqual(datype, 'stag')
53        cda.destroy()
54
55        c = self.da.getCoordinatesLocal()
56        self.da.setCoordinatesLocal(c)
57        gc = self.da.getCoordinatesLocal()
58        self.assertEqual(c.max()[1], gc.max()[1])
59        self.assertEqual(c.min()[1], gc.min()[1])
60
61        c = self.da.getCoordinates()
62        self.da.setCoordinates(c)
63        gc = self.da.getCoordinates()
64        self.assertEqual(c.max()[1], gc.max()[1])
65        self.assertEqual(c.min()[1], gc.min()[1])
66
67        self.directda.setCoordinateDMType('product')
68        self.directda.setUniformCoordinates(0, 1, 0, 1, 0, 1)
69        self.directda.setUniformCoordinatesProduct(0, 1, 0, 1, 0, 1)
70        cda = self.directda.getCoordinateDM()
71        datype = cda.getType()
72        self.assertEqual(datype, 'product')
73        cda.destroy()
74
75    def testGetVec(self):
76        vg = self.da.getGlobalVec()
77        vl = self.da.getLocalVec()
78
79        vg.set(1.0)
80        self.assertEqual(vg.max()[1], 1.0)
81        self.assertEqual(vg.min()[1], 1.0)
82        self.da.globalToLocal(vg, vl)
83        self.assertEqual(vl.max()[1], 1.0)
84        self.assertTrue(vl.min()[1] in (1.0, 0.0))
85
86        vl.set(2.0)
87        self.da.localToGlobal(vl, vg)
88        self.assertEqual(vg.max()[1], 2.0)
89        self.assertTrue(vg.min()[1] in (2.0, 0.0))
90
91        self.da.restoreGlobalVec(vg)
92        self.da.restoreLocalVec(vl)
93
94    def testGetOther(self):
95        _ = self.da.getLGMap()
96        _ = self.directda.getLGMap()
97
98    def testDof(self):
99        dim = self.da.getDim()
100        dofs = self.da.getDof()
101        if dim == 1:
102            dof0 = self.da.getLocationDof('left')
103            dof1 = self.da.getLocationDof('element')
104            self.assertEqual(dofs[0], dof0)
105            self.assertEqual(dofs[1], dof1)
106        if dim == 2:
107            dof0 = self.da.getLocationDof('down_left')
108            dof1 = self.da.getLocationDof('left')
109            dof2 = self.da.getLocationDof('element')
110            self.assertEqual(dofs[0], dof0)
111            self.assertEqual(dofs[1], dof1)
112            self.assertEqual(dofs[2], dof2)
113        if dim == 3:
114            dof0 = self.da.getLocationDof('back_down_right')
115            dof1 = self.da.getLocationDof('down_left')
116            dof2 = self.da.getLocationDof('left')
117            dof3 = self.da.getLocationDof('element')
118            self.assertEqual(dofs[0], dof0)
119            self.assertEqual(dofs[1], dof1)
120            self.assertEqual(dofs[2], dof2)
121            self.assertEqual(dofs[3], dof3)
122
123    def testMigrateVec(self):
124        vec = self.da.createGlobalVec()
125        dmTo = self.da.createCompatibleDMStag(self.NEWDOF)
126        vecTo = dmTo.createGlobalVec()
127        self.da.migrateVec(vec, dmTo, vecTo)
128
129    def testDMDAInterface(self):
130        return
131        self.da.setCoordinateDMType('stag')
132        self.da.setUniformCoordinates(0, 1, 0, 1, 0, 1)
133        dim = self.da.getDim()
134        dofs = self.da.getDof()
135        vec = self.da.createGlobalVec()
136        if dim == 1:
137            da, davec = self.da.VecSplitToDMDA(vec, 'left', -dofs[0])
138            da, davec = self.da.VecSplitToDMDA(vec, 'element', -dofs[1])
139        if dim == 2:
140            da, davec = self.da.VecSplitToDMDA(vec, 'down_left', -dofs[0])
141            da, davec = self.da.VecSplitToDMDA(vec, 'down_left', -dofs[1])
142            da, davec = self.da.VecSplitToDMDA(vec, 'down_left', -dofs[2])
143        if dim == 3:
144            da, davec = self.da.VecSplitToDMDA(vec, 'back_down_right', -dofs[0])
145            da, davec = self.da.VecSplitToDMDA(vec, 'down_left', -dofs[1])
146            da, davec = self.da.VecSplitToDMDA(vec, 'left', -dofs[2])
147            da, davec = self.da.VecSplitToDMDA(vec, 'element', -dofs[3])
148
149
150GHOSTED = PETSc.DM.BoundaryType.GHOSTED
151PERIODIC = PETSc.DM.BoundaryType.PERIODIC
152NONE = PETSc.DM.BoundaryType.NONE
153
154SCALE = 4
155
156
157class BaseTestDMStag_1D(BaseTestDMStag):
158    SIZES = [
159        100 * SCALE,
160    ]
161    BOUNDARY = [
162        NONE,
163    ]
164
165
166class BaseTestDMStag_2D(BaseTestDMStag):
167    SIZES = [9 * SCALE, 11 * SCALE]
168    BOUNDARY = [NONE, NONE]
169
170
171class BaseTestDMStag_3D(BaseTestDMStag):
172    SIZES = [6 * SCALE, 7 * SCALE, 8 * SCALE]
173    BOUNDARY = [NONE, NONE, NONE]
174
175
176# --------------------------------------------------------------------
177
178
179class TestDMStag_1D_W0_N11(BaseTestDMStag_1D, unittest.TestCase):
180    SWIDTH = 0
181    DOFS = (1, 1)
182    NEWDOF = (2, 1)
183
184
185class TestDMStag_1D_W0_N21(BaseTestDMStag_1D, unittest.TestCase):
186    SWIDTH = 0
187    DOFS = (2, 1)
188    NEWDOF = (2, 2)
189
190
191class TestDMStag_1D_W0_N12(BaseTestDMStag_1D, unittest.TestCase):
192    SWIDTH = 0
193    DOFS = (1, 2)
194    NEWDOF = (2, 2)
195
196
197class TestDMStag_1D_W2_N11(BaseTestDMStag_1D, unittest.TestCase):
198    SWIDTH = 2
199    DOFS = (1, 1)
200    NEWDOF = (2, 1)
201
202
203class TestDMStag_1D_W2_N21(BaseTestDMStag_1D, unittest.TestCase):
204    SWIDTH = 2
205    DOFS = (2, 1)
206    NEWDOF = (2, 2)
207
208
209class TestDMStag_1D_W2_N12(BaseTestDMStag_1D, unittest.TestCase):
210    SWIDTH = 2
211    DOFS = (1, 2)
212    NEWDOF = (2, 2)
213
214
215class TestDMStag_2D_W0_N112(BaseTestDMStag_2D, unittest.TestCase):
216    DOFS = (1, 1, 2)
217    SWIDTH = 0
218    NEWDOF = (2, 2, 2)
219
220
221class TestDMStag_2D_W2_N112(BaseTestDMStag_2D, unittest.TestCase):
222    DOFS = (1, 1, 2)
223    SWIDTH = 2
224    NEWDOF = (2, 2, 2)
225
226
227class TestDMStag_2D_PXY(BaseTestDMStag_2D, unittest.TestCase):
228    SIZES = [13 * SCALE, 17 * SCALE]
229    DOFS = (1, 1, 2)
230    SWIDTH = 5
231    BOUNDARY = (PERIODIC,) * 2
232    NEWDOF = (2, 2, 2)
233
234
235class TestDMStag_2D_GXY(BaseTestDMStag_2D, unittest.TestCase):
236    SIZES = [13 * SCALE, 17 * SCALE]
237    DOFS = (1, 1, 2)
238    SWIDTH = 5
239    BOUNDARY = (GHOSTED,) * 2
240    NEWDOF = (2, 2, 2)
241
242
243class TestDMStag_3D_W0_N1123(BaseTestDMStag_3D, unittest.TestCase):
244    DOFS = (1, 1, 2, 3)
245    SWIDTH = 0
246    NEWDOF = (2, 2, 3, 3)
247
248
249class TestDMStag_3D_W2_N1123(BaseTestDMStag_3D, unittest.TestCase):
250    DOFS = (1, 1, 2, 3)
251    SWIDTH = 2
252    NEWDOF = (2, 2, 3, 3)
253
254
255class TestDMStag_3D_PXYZ(BaseTestDMStag_3D, unittest.TestCase):
256    SIZES = [11 * SCALE, 13 * SCALE, 17 * SCALE]
257    DOFS = (1, 1, 2, 3)
258    NEWDOF = (2, 2, 3, 3)
259    SWIDTH = 3
260    BOUNDARY = (PERIODIC,) * 3
261
262
263class TestDMStag_3D_GXYZ(BaseTestDMStag_3D, unittest.TestCase):
264    SIZES = [11 * SCALE, 13 * SCALE, 17 * SCALE]
265    DOFS = (1, 1, 2, 3)
266    NEWDOF = (2, 2, 3, 3)
267    SWIDTH = 3
268    BOUNDARY = (GHOSTED,) * 3
269
270
271# --------------------------------------------------------------------
272
273DIM = (1, 2, 3)
274DOF0 = (0, 1, 2)
275DOF1 = (0, 1, 2)
276DOF2 = (0, 1, 2)
277DOF3 = (0, 1, 2)
278BOUNDARY_TYPE = ('none', 'ghosted', 'periodic')
279STENCIL_TYPE = ('none', 'star', 'box')
280STENCIL_WIDTH = (0, 1, 2, 3)
281
282
283class TestDMStagCreate(unittest.TestCase):
284    pass
285
286
287counter = 0
288for dim in DIM:
289    for dof0 in DOF0:
290        for dof1 in DOF1:
291            for dof2 in DOF2:
292                if dim == 1 and dof2 > 0:
293                    continue
294                for dof3 in DOF3:
295                    if dim == 2 and dof3 > 0:
296                        continue
297                    if dof0 == 0 and dof1 == 0 and dof2 == 0 and dof3 == 0:
298                        continue
299                    dofs = [dof0, dof1, dof2, dof3][: dim + 1]
300                    for boundary in BOUNDARY_TYPE:
301                        if boundary == 'periodic':
302                            continue  # XXX broken
303                        for stencil in STENCIL_TYPE:
304                            if stencil == 'none' and boundary != 'none':
305                                continue
306                            for width in STENCIL_WIDTH:
307                                if stencil == 'none' and width > 0:
308                                    continue
309                                if stencil in ['star', 'box'] and width == 0:
310                                    continue
311                                kargs = {
312                                    'dim': dim,
313                                    'dofs': dofs,
314                                    'boundary_type': boundary,
315                                    'stencil_type': stencil,
316                                    'stencil_width': width,
317                                }
318
319                                def testCreate(self, kargs=kargs):
320                                    kargs = dict(kargs)
321                                    cda = PETSc.DMStag().create(
322                                        kargs['dim'],
323                                        dofs=kargs['dofs'],
324                                        sizes=[
325                                            8 * SCALE,
326                                        ]
327                                        * kargs['dim'],
328                                        boundary_types=[
329                                            kargs['boundary_type'],
330                                        ]
331                                        * kargs['dim'],
332                                        stencil_type=kargs['stencil_type'],
333                                        stencil_width=kargs['stencil_width'],
334                                        setUp=True,
335                                    )
336
337                                    dda = PETSc.DMStag().create(kargs['dim'])
338                                    dda.setStencilType(kargs['stencil_type'])
339                                    dda.setStencilWidth(kargs['stencil_width'])
340                                    dda.setBoundaryTypes(
341                                        [
342                                            kargs['boundary_type'],
343                                        ]
344                                        * kargs['dim']
345                                    )
346                                    dda.setDof(kargs['dofs'])
347                                    dda.setGlobalSizes(
348                                        [
349                                            8 * SCALE,
350                                        ]
351                                        * kargs['dim']
352                                    )
353                                    dda.setUp()
354
355                                    cdim = cda.getDim()
356                                    cdof = cda.getDof()
357                                    cgsizes = cda.getGlobalSizes()
358                                    clsizes = cda.getLocalSizes()
359                                    cboundary = cda.getBoundaryTypes()
360                                    cstencil_type = cda.getStencilType()
361                                    cstencil_width = cda.getStencilWidth()
362                                    centries_per_element = cda.getEntriesPerElement()
363                                    cstarts, csizes, cnextra = cda.getCorners()
364                                    cisLastRank = cda.getIsLastRank()
365                                    cisFirstRank = cda.getIsFirstRank()
366                                    cownershipranges = cda.getOwnershipRanges()
367                                    cprocsizes = cda.getProcSizes()
368
369                                    ddim = dda.getDim()
370                                    ddof = dda.getDof()
371                                    dgsizes = dda.getGlobalSizes()
372                                    dlsizes = dda.getLocalSizes()
373                                    dboundary = dda.getBoundaryTypes()
374                                    dstencil_type = dda.getStencilType()
375                                    dstencil_width = dda.getStencilWidth()
376                                    dentries_per_element = dda.getEntriesPerElement()
377                                    dstarts, dsizes, dnextra = dda.getCorners()
378                                    disLastRank = dda.getIsLastRank()
379                                    disFirstRank = dda.getIsFirstRank()
380                                    downershipranges = dda.getOwnershipRanges()
381                                    dprocsizes = dda.getProcSizes()
382
383                                    self.assertEqual(cdim, kargs['dim'])
384                                    self.assertEqual(cdof, tuple(kargs['dofs']))
385                                    self.assertEqual(
386                                        cboundary,
387                                        tuple(
388                                            [
389                                                kargs['boundary_type'],
390                                            ]
391                                            * kargs['dim']
392                                        ),
393                                    )
394                                    self.assertEqual(
395                                        cstencil_type, kargs['stencil_type']
396                                    )
397                                    self.assertEqual(
398                                        cstencil_width, kargs['stencil_width']
399                                    )
400                                    self.assertEqual(
401                                        cgsizes,
402                                        tuple(
403                                            [
404                                                8 * SCALE,
405                                            ]
406                                            * kargs['dim']
407                                        ),
408                                    )
409
410                                    self.assertEqual(cdim, ddim)
411                                    self.assertEqual(cdof, ddof)
412                                    self.assertEqual(cgsizes, dgsizes)
413                                    self.assertEqual(clsizes, dlsizes)
414                                    self.assertEqual(cboundary, dboundary)
415                                    self.assertEqual(cstencil_type, dstencil_type)
416                                    self.assertEqual(cstencil_width, dstencil_width)
417                                    self.assertEqual(
418                                        centries_per_element, dentries_per_element
419                                    )
420                                    self.assertEqual(cstarts, dstarts)
421                                    self.assertEqual(csizes, dsizes)
422                                    self.assertEqual(cnextra, dnextra)
423                                    self.assertEqual(cisLastRank, disLastRank)
424                                    self.assertEqual(cisFirstRank, disFirstRank)
425                                    self.assertEqual(cprocsizes, dprocsizes)
426                                    for co, do in zip(
427                                        cownershipranges, downershipranges
428                                    ):
429                                        for i, j in zip(co, do):
430                                            self.assertEqual(i, j)
431
432                                    self.assertEqual(cdim + 1, len(cdof))
433                                    self.assertEqual(cdim, len(cgsizes))
434                                    self.assertEqual(cdim, len(clsizes))
435                                    self.assertEqual(cdim, len(cboundary))
436                                    self.assertEqual(cdim, len(cstarts))
437                                    self.assertEqual(cdim, len(csizes))
438                                    self.assertEqual(cdim, len(cnextra))
439                                    self.assertEqual(cdim, len(cisLastRank))
440                                    self.assertEqual(cdim, len(cisLastRank))
441                                    if cdim == 1:
442                                        self.assertEqual(
443                                            centries_per_element, cdof[0] + cdof[1]
444                                        )
445                                    if cdim == 2:
446                                        self.assertEqual(
447                                            centries_per_element,
448                                            cdof[0] + 2 * cdof[1] + cdof[2],
449                                        )
450                                    if cdim == 3:
451                                        self.assertEqual(
452                                            centries_per_element,
453                                            cdof[0]
454                                            + 3 * cdof[1]
455                                            + 3 * cdof[2]
456                                            + cdof[3],
457                                        )
458                                    for i in range(cdim):
459                                        self.assertEqual(csizes[i], clsizes[i])
460                                        if cisLastRank[i]:
461                                            self.assertEqual(cnextra[i], 1)
462                                        if cnextra[i] == 1:
463                                            self.assertTrue(cisLastRank[i])
464                                        if cisFirstRank[i]:
465                                            self.assertEqual(cstarts[i], 0)
466                                    self.assertEqual(
467                                        len(cprocsizes), len(cownershipranges)
468                                    )
469                                    self.assertEqual(len(cprocsizes), cdim)
470                                    for i, m in enumerate(cprocsizes):
471                                        self.assertEqual(m, len(cownershipranges[i]))
472                                    dda.destroy()
473                                    cda.destroy()
474
475                                setattr(
476                                    TestDMStagCreate,
477                                    'testCreate%05d' % counter,
478                                    testCreate,
479                                )
480
481                                del testCreate
482                                counter += 1
483
484del counter, dim, dofs, dof0, dof1, dof2, dof3, boundary, stencil, width
485
486# --------------------------------------------------------------------
487
488if __name__ == '__main__':
489    unittest.main()
490
491# --------------------------------------------------------------------
492