xref: /petsc/src/binding/petsc4py/test/test_optdb.py (revision 552edb6364df478b294b3111f33a8f37ca096b20) !
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4import numpy as np
5
6# --------------------------------------------------------------------
7
8
9class TestOptions(unittest.TestCase):
10    PREFIX = 'myopts-'
11    OPTLIST = [
12        ('bool', True),
13        ('int', -7),
14        ('real', 5),
15        ('scalar', 3),
16        ('string', 'petsc4py'),
17    ]
18
19    def _putopts(self, opts=None, OPTLIST=None):
20        if opts is None:
21            opts = self.opts
22        if OPTLIST is None:
23            OPTLIST = self.OPTLIST
24        for k, v in OPTLIST:
25            opts[k] = v
26
27    def _delopts(self, opts=None, OPTLIST=None):
28        if opts is None:
29            opts = self.opts
30        if OPTLIST is None:
31            OPTLIST = self.OPTLIST
32        for k, _ in OPTLIST:
33            del opts[k]
34
35    def setUp(self):
36        self.opts = PETSc.Options(self.PREFIX)
37
38    def tearDown(self):
39        self.opts = None
40        PETSc.garbage_cleanup()
41
42    def testHasOpts(self):
43        self._putopts()
44        for k, _ in self.OPTLIST:
45            self.assertTrue(self.opts.hasName(k))
46            self.assertTrue(k in self.opts)
47            missing = k + '-missing'
48            self.assertFalse(self.opts.hasName(missing))
49            self.assertFalse(missing in self.opts)
50        self._delopts()
51
52    def testGetOpts(self):
53        self._putopts()
54        for k, v in self.OPTLIST:
55            getopt = getattr(self.opts, 'get' + k.title())
56            self.assertEqual(getopt(k), v)
57        self._delopts()
58
59    def testGetAll(self):
60        self._putopts()
61        allopts = self.opts.getAll()
62        self.assertTrue(isinstance(allopts, dict))
63        optlist = [(k, str(v).lower()) for (k, v) in self.OPTLIST]
64        for k, v in allopts.items():
65            self.assertTrue((k, v) in optlist)
66        self._delopts()
67
68    def testGetAllQuoted(self):
69        dct = {
70            'o0': '"0 1 2"',
71            'o1': '"a b c"',
72            'o2': '"x y z"',
73        }
74        for k in dct:
75            self.opts[k] = dct[k]
76        allopts = self.opts.getAll()
77        for k in dct:
78            self.assertEqual(allopts[k], dct[k][1:-1])
79            del self.opts[k]
80
81    def testType(self):
82        types = [
83            (bool, bool, self.opts.getBool, self.opts.getBoolArray),
84            (int, PETSc.IntType, self.opts.getInt, self.opts.getIntArray),
85            (float, PETSc.RealType, self.opts.getReal, self.opts.getRealArray),
86        ]
87        if PETSc.ScalarType is PETSc.ComplexType:
88            types.append(
89                (
90                    complex,
91                    PETSc.ScalarType,
92                    self.opts.getScalar,
93                    self.opts.getScalarArray,
94                )
95            )
96        else:
97            types.append(
98                (
99                    float,
100                    PETSc.ScalarType,
101                    self.opts.getScalar,
102                    self.opts.getScalarArray,
103                )
104            )
105        toval = (lambda x: x, lambda x: np.array(x).tolist(), lambda x: np.array(x))
106        sv = 1
107        av = (1, 0, 1)
108        defv = 0
109        defarrayv = (0, 0, 1, 0)
110        for pyt, pat, pget, pgetarray in types:
111            for tov in toval:
112                self.opts.setValue('sv', tov(sv))
113                self.opts.setValue('av', tov(av))
114
115                v = pget('sv')
116                self.assertTrue(isinstance(v, pyt))
117                self.assertEqual(v, pyt(sv))
118
119                v = pget('sv', defv)
120                self.assertTrue(isinstance(v, pyt))
121                self.assertEqual(v, pyt(sv))
122
123                v = pget('missing', defv)
124                self.assertTrue(isinstance(v, pyt))
125                self.assertEqual(v, pyt(defv))
126
127                if pgetarray is not None:
128                    arrayv = pgetarray('av')
129                    self.assertEqual(arrayv.dtype, pat)
130                    self.assertEqual(len(arrayv), len(av))
131                    for v1, v2 in zip(arrayv, av):
132                        self.assertTrue(isinstance(v1.item(), pyt))
133                        self.assertEqual(v1.item(), pyt(v2))
134
135                    arrayv = pgetarray('av', defarrayv)
136                    self.assertEqual(arrayv.dtype, pat)
137                    self.assertEqual(len(arrayv), len(av))
138                    for v1, v2 in zip(arrayv, av):
139                        self.assertTrue(isinstance(v1.item(), pyt))
140                        self.assertEqual(v1.item(), pyt(v2))
141
142                    arrayv = pgetarray('missing', defarrayv)
143                    self.assertEqual(arrayv.dtype, pat)
144                    self.assertEqual(len(arrayv), len(defarrayv))
145                    for v1, v2 in zip(arrayv, defarrayv):
146                        self.assertTrue(isinstance(v1.item(), pyt))
147                        self.assertEqual(v1.item(), pyt(v2))
148
149                self.opts.delValue('sv')
150                self.opts.delValue('av')
151
152    def testMonitor(self):
153        optlist = []
154        mon = lambda n, v: optlist.append((n, v))
155        self.opts.setMonitor(mon)
156        self.assertEqual(getrefcount(mon) - 1, 2)
157        self._putopts()
158        target = [(self.PREFIX + k, str(v).lower()) for k, v in self.OPTLIST]
159        self.assertEqual(optlist, target)
160        self.opts.cancelMonitor()
161        self.assertEqual(getrefcount(mon) - 1, 1)
162        self._delopts()
163
164
165# --------------------------------------------------------------------
166
167del TestOptions.testMonitor  # XXX
168
169if __name__ == '__main__':
170    unittest.main()
171