xref: /petsc/src/binding/petsc4py/test/test_device.py (revision 5574ef43b0ac3357bb9db30b381e3fae1c24bab1)
1from petsc4py import PETSc
2import unittest
3
4# --------------------------------------------------------------------
5
6
7class TestDevice(unittest.TestCase):
8    def testCurrent(self):
9        dctx = PETSc.DeviceContext().getCurrent()
10        if not dctx:
11            return
12        self.assertEqual(dctx.getRefCount(), 2)
13        device = dctx.getDevice()
14        del device
15        del dctx
16        dctx = PETSc.DeviceContext().getCurrent()
17        self.assertEqual(dctx.getRefCount(), 2)
18        device = dctx.getDevice()
19        del device
20        del dctx
21
22    def testDevice(self):
23        device = PETSc.Device.create()
24        device.configure()
25        _ = device.getDeviceType()
26        _ = device.getDeviceId()
27        del device
28
29    def testDeviceContext(self):
30        dctx = PETSc.DeviceContext().create()
31        if not dctx:
32            return
33        self.assertEqual(dctx.getRefCount(), 1)
34        dctx.setUp()
35        self.assertTrue(dctx.idle())
36        dctx.destroy()
37        self.assertEqual(dctx.getRefCount(), 0)
38
39    def testStream(self):
40        dctx = PETSc.DeviceContext().getCurrent()
41        if not dctx:
42            return
43        self.assertEqual(dctx.getRefCount(), 2)
44        stype = dctx.getStreamType()
45        dctx.setStreamType(stype)
46        dctx.destroy()
47        self.assertEqual(dctx.getRefCount(), 0)
48
49    def testSetFromOptions(self):
50        dctx = PETSc.DeviceContext().create()
51        if not dctx:
52            return
53        self.assertEqual(dctx.getRefCount(), 1)
54        dctx.setFromOptions()
55        dctx.setUp()
56        dctx.destroy()
57        self.assertEqual(dctx.getRefCount(), 0)
58
59    def testDuplicate(self):
60        dctx = PETSc.DeviceContext().getCurrent()
61        if not dctx:
62            return
63        self.assertEqual(dctx.getRefCount(), 2)
64        dctx2 = dctx.duplicate()
65        self.assertEqual(dctx2.getRefCount(), 1)
66        dctx.destroy()
67        self.assertEqual(dctx.getRefCount(), 0)
68        dctx2.destroy()
69        self.assertEqual(dctx2.getRefCount(), 0)
70
71    def testWaitFor(self):
72        dctx = PETSc.DeviceContext().create()
73        if not dctx:
74            return
75        self.assertEqual(dctx.getRefCount(), 1)
76        dctx.setUp()
77        dctx2 = PETSc.DeviceContext().create()
78        self.assertEqual(dctx2.getRefCount(), 1)
79        dctx2.setUp()
80        dctx.waitFor(dctx2)
81        dctx.destroy()
82        self.assertEqual(dctx.getRefCount(), 0)
83        dctx2.destroy()
84        dctx2.destroy()
85        self.assertEqual(dctx2.getRefCount(), 0)
86
87    def testForkJoin(self):
88        dctx = PETSc.DeviceContext().getCurrent()
89        if not dctx:
90            return
91        self.assertEqual(dctx.getRefCount(), 2)
92        jdestroy = PETSc.DeviceContext.JoinMode.DESTROY
93        jtypes = [
94            PETSc.DeviceContext.JoinMode.SYNC,
95            PETSc.DeviceContext.JoinMode.NO_SYNC,
96        ]
97        for j in jtypes:
98            dctxs = dctx.fork(4)
99            for ctx in dctxs:
100                self.assertEqual(ctx.getRefCount(), 1)
101            dctx.join(j, dctxs[0::2])
102            dctx.join(j, dctxs[3::-2])
103            for ctx in dctxs:
104                self.assertEqual(ctx.getRefCount(), 1)
105            dctx.join(jdestroy, dctxs)
106            for ctx in dctxs:
107                self.assertEqual(ctx.getRefCount(), 0)
108        dctx.destroy()
109        self.assertEqual(dctx.getRefCount(), 0)
110
111
112# --------------------------------------------------------------------
113
114if __name__ == '__main__':
115    unittest.main()
116