xref: /petsc/src/binding/petsc4py/test/test_device.py (revision 199ce8e42107e7d44b661c4e8cb4581aba436073)
1*199ce8e4SStefano Zampinifrom petsc4py import PETSc
2*199ce8e4SStefano Zampiniimport unittest
3*199ce8e4SStefano Zampini
4*199ce8e4SStefano Zampini# --------------------------------------------------------------------
5*199ce8e4SStefano Zampini
6*199ce8e4SStefano Zampini
7*199ce8e4SStefano Zampiniclass TestDevice(unittest.TestCase):
8*199ce8e4SStefano Zampini    def testCurrent(self):
9*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().getCurrent()
10*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 2)
11*199ce8e4SStefano Zampini        device = dctx.getDevice()
12*199ce8e4SStefano Zampini        del device
13*199ce8e4SStefano Zampini        del dctx
14*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().getCurrent()
15*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 2)
16*199ce8e4SStefano Zampini        device = dctx.getDevice()
17*199ce8e4SStefano Zampini        del device
18*199ce8e4SStefano Zampini        del dctx
19*199ce8e4SStefano Zampini
20*199ce8e4SStefano Zampini    def testDevice(self):
21*199ce8e4SStefano Zampini        device = PETSc.Device.create()
22*199ce8e4SStefano Zampini        device.configure()
23*199ce8e4SStefano Zampini        _ = device.getDeviceType()
24*199ce8e4SStefano Zampini        _ = device.getDeviceId()
25*199ce8e4SStefano Zampini        del device
26*199ce8e4SStefano Zampini
27*199ce8e4SStefano Zampini    def testDeviceContext(self):
28*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().create()
29*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 1)
30*199ce8e4SStefano Zampini        dctx.setUp()
31*199ce8e4SStefano Zampini        self.assertTrue(dctx.idle())
32*199ce8e4SStefano Zampini        dctx.destroy()
33*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
34*199ce8e4SStefano Zampini
35*199ce8e4SStefano Zampini    def testStream(self):
36*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().getCurrent()
37*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 2)
38*199ce8e4SStefano Zampini        stype = dctx.getStreamType()
39*199ce8e4SStefano Zampini        dctx.setStreamType(stype)
40*199ce8e4SStefano Zampini        dctx.destroy()
41*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
42*199ce8e4SStefano Zampini
43*199ce8e4SStefano Zampini    def testSetFromOptions(self):
44*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().create()
45*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 1)
46*199ce8e4SStefano Zampini        dctx.setFromOptions()
47*199ce8e4SStefano Zampini        dctx.setUp()
48*199ce8e4SStefano Zampini        dctx.destroy()
49*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
50*199ce8e4SStefano Zampini
51*199ce8e4SStefano Zampini    def testDuplicate(self):
52*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().getCurrent()
53*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 2)
54*199ce8e4SStefano Zampini        dctx2 = dctx.duplicate()
55*199ce8e4SStefano Zampini        self.assertEqual(dctx2.getRefCount(), 1)
56*199ce8e4SStefano Zampini        dctx.destroy()
57*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
58*199ce8e4SStefano Zampini        dctx2.destroy()
59*199ce8e4SStefano Zampini        self.assertEqual(dctx2.getRefCount(), 0)
60*199ce8e4SStefano Zampini
61*199ce8e4SStefano Zampini    def testWaitFor(self):
62*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().create()
63*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 1)
64*199ce8e4SStefano Zampini        dctx.setUp()
65*199ce8e4SStefano Zampini        dctx2 = PETSc.DeviceContext().create()
66*199ce8e4SStefano Zampini        self.assertEqual(dctx2.getRefCount(), 1)
67*199ce8e4SStefano Zampini        dctx2.setUp()
68*199ce8e4SStefano Zampini        dctx.waitFor(dctx2)
69*199ce8e4SStefano Zampini        dctx.destroy()
70*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
71*199ce8e4SStefano Zampini        dctx2.destroy()
72*199ce8e4SStefano Zampini        dctx2.destroy()
73*199ce8e4SStefano Zampini        self.assertEqual(dctx2.getRefCount(), 0)
74*199ce8e4SStefano Zampini
75*199ce8e4SStefano Zampini    def testForkJoin(self):
76*199ce8e4SStefano Zampini        dctx = PETSc.DeviceContext().getCurrent()
77*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 2)
78*199ce8e4SStefano Zampini        jdestroy = PETSc.DeviceContext.JoinMode.DESTROY
79*199ce8e4SStefano Zampini        jtypes = [
80*199ce8e4SStefano Zampini            PETSc.DeviceContext.JoinMode.SYNC,
81*199ce8e4SStefano Zampini            PETSc.DeviceContext.JoinMode.NO_SYNC,
82*199ce8e4SStefano Zampini        ]
83*199ce8e4SStefano Zampini        for j in jtypes:
84*199ce8e4SStefano Zampini            dctxs = dctx.fork(4)
85*199ce8e4SStefano Zampini            for ctx in dctxs:
86*199ce8e4SStefano Zampini                self.assertEqual(ctx.getRefCount(), 1)
87*199ce8e4SStefano Zampini            dctx.join(j, dctxs[0::2])
88*199ce8e4SStefano Zampini            dctx.join(j, dctxs[3::-2])
89*199ce8e4SStefano Zampini            for ctx in dctxs:
90*199ce8e4SStefano Zampini                self.assertEqual(ctx.getRefCount(), 1)
91*199ce8e4SStefano Zampini            dctx.join(jdestroy, dctxs)
92*199ce8e4SStefano Zampini            for ctx in dctxs:
93*199ce8e4SStefano Zampini                self.assertEqual(ctx.getRefCount(), 0)
94*199ce8e4SStefano Zampini        dctx.destroy()
95*199ce8e4SStefano Zampini        self.assertEqual(dctx.getRefCount(), 0)
96*199ce8e4SStefano Zampini
97*199ce8e4SStefano Zampini
98*199ce8e4SStefano Zampini# --------------------------------------------------------------------
99*199ce8e4SStefano Zampini
100*199ce8e4SStefano Zampiniif __name__ == '__main__':
101*199ce8e4SStefano Zampini    unittest.main()
102