xref: /petsc/src/binding/petsc4py/test/test_device.py (revision fe972ab94a72e07fc8dae81dd2ded1f8f457b5bc)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6
7class TestDevice(unittest.TestCase):
8    def testCurrent(self):
9        dctx = PETSc.DeviceContext().getCurrent()
10        self.assertEqual(dctx.getRefCount(), 2)
11        device = dctx.getDevice()
12        del device
13        del dctx
14        dctx = PETSc.DeviceContext().getCurrent()
15        self.assertEqual(dctx.getRefCount(), 2)
16        device = dctx.getDevice()
17        del device
18        del dctx
19
20    def testDevice(self):
21        device = PETSc.Device.create()
22        device.configure()
23        _ = device.getDeviceType()
24        _ = device.getDeviceId()
25        del device
26
27    def testDeviceContext(self):
28        dctx = PETSc.DeviceContext().create()
29        self.assertEqual(dctx.getRefCount(), 1)
30        dctx.setUp()
31        self.assertTrue(dctx.idle())
32        dctx.destroy()
33        self.assertEqual(dctx.getRefCount(), 0)
34
35    def testStream(self):
36        dctx = PETSc.DeviceContext().getCurrent()
37        self.assertEqual(dctx.getRefCount(), 2)
38        stype = dctx.getStreamType()
39        dctx.setStreamType(stype)
40        dctx.destroy()
41        self.assertEqual(dctx.getRefCount(), 0)
42
43    def testSetFromOptions(self):
44        dctx = PETSc.DeviceContext().create()
45        self.assertEqual(dctx.getRefCount(), 1)
46        dctx.setFromOptions()
47        dctx.setUp()
48        dctx.destroy()
49        self.assertEqual(dctx.getRefCount(), 0)
50
51    def testDuplicate(self):
52        dctx = PETSc.DeviceContext().getCurrent()
53        self.assertEqual(dctx.getRefCount(), 2)
54        dctx2 = dctx.duplicate()
55        self.assertEqual(dctx2.getRefCount(), 1)
56        dctx.destroy()
57        self.assertEqual(dctx.getRefCount(), 0)
58        dctx2.destroy()
59        self.assertEqual(dctx2.getRefCount(), 0)
60
61    def testWaitFor(self):
62        dctx = PETSc.DeviceContext().create()
63        self.assertEqual(dctx.getRefCount(), 1)
64        dctx.setUp()
65        dctx2 = PETSc.DeviceContext().create()
66        self.assertEqual(dctx2.getRefCount(), 1)
67        dctx2.setUp()
68        dctx.waitFor(dctx2)
69        dctx.destroy()
70        self.assertEqual(dctx.getRefCount(), 0)
71        dctx2.destroy()
72        dctx2.destroy()
73        self.assertEqual(dctx2.getRefCount(), 0)
74
75    def testForkJoin(self):
76        dctx = PETSc.DeviceContext().getCurrent()
77        self.assertEqual(dctx.getRefCount(), 2)
78        jdestroy = PETSc.DeviceContext.JoinMode.DESTROY
79        jtypes = [
80            PETSc.DeviceContext.JoinMode.SYNC,
81            PETSc.DeviceContext.JoinMode.NO_SYNC,
82        ]
83        for j in jtypes:
84            dctxs = dctx.fork(4)
85            for ctx in dctxs:
86                self.assertEqual(ctx.getRefCount(), 1)
87            dctx.join(j, dctxs[0::2])
88            dctx.join(j, dctxs[3::-2])
89            for ctx in dctxs:
90                self.assertEqual(ctx.getRefCount(), 1)
91            dctx.join(jdestroy, dctxs)
92            for ctx in dctxs:
93                self.assertEqual(ctx.getRefCount(), 0)
94        dctx.destroy()
95        self.assertEqual(dctx.getRefCount(), 0)
96
97
98# --------------------------------------------------------------------
99
100if __name__ == '__main__':
101    unittest.main()
102