1from petsc4py import PETSc 2import unittest 3 4# -------------------------------------------------------------------- 5 6 7class TestDevice(unittest.TestCase): 8 def testCurrent(self): 9 dctx = PETSc.DeviceContext().getCurrent() 10 if not dctx: 11 return 12 self.assertEqual(dctx.getRefCount(), 2) 13 device = dctx.getDevice() 14 del device 15 del dctx 16 dctx = PETSc.DeviceContext().getCurrent() 17 self.assertEqual(dctx.getRefCount(), 2) 18 device = dctx.getDevice() 19 del device 20 del dctx 21 22 def testDevice(self): 23 device = PETSc.Device.create() 24 device.configure() 25 _ = device.getDeviceType() 26 _ = device.getDeviceId() 27 del device 28 29 def testDeviceContext(self): 30 dctx = PETSc.DeviceContext().create() 31 if not dctx: 32 return 33 self.assertEqual(dctx.getRefCount(), 1) 34 dctx.setUp() 35 self.assertTrue(dctx.idle()) 36 dctx.destroy() 37 self.assertEqual(dctx.getRefCount(), 0) 38 39 def testStream(self): 40 dctx = PETSc.DeviceContext().getCurrent() 41 if not dctx: 42 return 43 self.assertEqual(dctx.getRefCount(), 2) 44 stype = dctx.getStreamType() 45 dctx.setStreamType(stype) 46 dctx.destroy() 47 self.assertEqual(dctx.getRefCount(), 0) 48 49 def testSetFromOptions(self): 50 dctx = PETSc.DeviceContext().create() 51 if not dctx: 52 return 53 self.assertEqual(dctx.getRefCount(), 1) 54 dctx.setFromOptions() 55 dctx.setUp() 56 dctx.destroy() 57 self.assertEqual(dctx.getRefCount(), 0) 58 59 def testDuplicate(self): 60 dctx = PETSc.DeviceContext().getCurrent() 61 if not dctx: 62 return 63 self.assertEqual(dctx.getRefCount(), 2) 64 dctx2 = dctx.duplicate() 65 self.assertEqual(dctx2.getRefCount(), 1) 66 dctx.destroy() 67 self.assertEqual(dctx.getRefCount(), 0) 68 dctx2.destroy() 69 self.assertEqual(dctx2.getRefCount(), 0) 70 71 def testWaitFor(self): 72 dctx = PETSc.DeviceContext().create() 73 if not dctx: 74 return 75 self.assertEqual(dctx.getRefCount(), 1) 76 dctx.setUp() 77 dctx2 = PETSc.DeviceContext().create() 78 self.assertEqual(dctx2.getRefCount(), 1) 79 dctx2.setUp() 80 dctx.waitFor(dctx2) 81 dctx.destroy() 82 self.assertEqual(dctx.getRefCount(), 0) 83 dctx2.destroy() 84 dctx2.destroy() 85 self.assertEqual(dctx2.getRefCount(), 0) 86 87 def testForkJoin(self): 88 dctx = PETSc.DeviceContext().getCurrent() 89 if not dctx: 90 return 91 self.assertEqual(dctx.getRefCount(), 2) 92 jdestroy = PETSc.DeviceContext.JoinMode.DESTROY 93 jtypes = [ 94 PETSc.DeviceContext.JoinMode.SYNC, 95 PETSc.DeviceContext.JoinMode.NO_SYNC, 96 ] 97 for j in jtypes: 98 dctxs = dctx.fork(4) 99 for ctx in dctxs: 100 self.assertEqual(ctx.getRefCount(), 1) 101 dctx.join(j, dctxs[0::2]) 102 dctx.join(j, dctxs[3::-2]) 103 for ctx in dctxs: 104 self.assertEqual(ctx.getRefCount(), 1) 105 dctx.join(jdestroy, dctxs) 106 for ctx in dctxs: 107 self.assertEqual(ctx.getRefCount(), 0) 108 dctx.destroy() 109 self.assertEqual(dctx.getRefCount(), 0) 110 111 112# -------------------------------------------------------------------- 113 114if __name__ == '__main__': 115 unittest.main() 116