1import unittest 2from petsc4py import PETSc 3from sys import getrefcount 4import gc 5 6# -------------------------------------------------------------------- 7class Objective: 8 def __call__(self, tao, x): 9 return (x[0] - 1.0)**2 + (x[1] - 2.0)**2 10 11class Gradient: 12 def __call__(self, tao, x, g): 13 g[0] = 2.0*(x[0] - 1.0) 14 g[1] = 2.0*(x[1] - 2.0) 15 g.assemble() 16 17class MyTao: 18 def __init__(self): 19 self.log = {} 20 21 def _log(self, method): 22 self.log.setdefault(method, 0) 23 self.log[method] += 1 24 25 def create(self, tao): 26 self._log('create') 27 self.testvec = PETSc.Vec() 28 29 def destroy(self, tao): 30 self._log('destroy') 31 self.testvec.destroy() 32 33 def setFromOptions(self, tao): 34 self._log('setFromOptions') 35 36 def setUp(self, tao): 37 self._log('setUp') 38 self.testvec = tao.getSolution().duplicate() 39 40 def solve(self, tao): 41 self._log('solve') 42 43 def step(self, tao, x, g, s): 44 self._log('step') 45 tao.computeGradient(x,g) 46 g.copy(s) 47 s.scale(-1.0) 48 49 def preStep(self, tao): 50 self._log('preStep') 51 52 def postStep(self, tao): 53 self._log('postStep') 54 55 def monitor(self, tao): 56 self._log('monitor') 57 58class TestTaoPython(unittest.TestCase): 59 60 def setUp(self): 61 self.tao = PETSc.TAO() 62 self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 63 ctx = self.tao.getPythonContext() 64 self.assertEqual(getrefcount(ctx), 3) 65 self.assertEqual(ctx.log['create'], 1) 66 self.nsolve = 0 67 68 def tearDown(self): 69 ctx = self.tao.getPythonContext() 70 self.assertEqual(getrefcount(ctx), 3) 71 self.assertTrue('destroy' not in ctx.log) 72 self.tao.destroy() 73 self.tao = None 74 self.assertEqual(ctx.log['destroy'], 1) 75 self.assertEqual(getrefcount(ctx), 2) 76 77 def testGetType(self): 78 ctx = self.tao.getPythonContext() 79 pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__) 80 self.assertTrue(self.tao.getPythonType() == pytype) 81 82 def testSolve(self): 83 tao = self.tao 84 ctx = tao.getPythonContext() 85 x = PETSc.Vec().create(tao.getComm()) 86 x.setType('standard') 87 x.setSizes(2) 88 y1 = x.duplicate() 89 y2 = x.duplicate() 90 tao.setObjective(Objective()) 91 tao.setGradient(Gradient(),None) 92 tao.setMonitor(ctx.monitor) 93 tao.setFromOptions() 94 tao.setMaximumIterations(3) 95 tao.setSolution(x) 96 97 # Call the solve method of MyTAO 98 x.set(0.5) 99 tao.solve() 100 n = tao.getIterationNumber() 101 self.assertTrue(n == 0) 102 103 # Call the default solve method and use step of MyTAO 104 ctx.solve = None 105 x.set(0.5) 106 tao.solve() 107 n = tao.getIterationNumber() 108 self.assertTrue(n == 3) 109 x.copy(y1) 110 111 # Call the default solve method with the default step method 112 ctx.step = None 113 x.set(0.5) 114 tao.solve() 115 n = tao.getIterationNumber() 116 self.assertTrue(n == 3) 117 x.copy(y2) 118 119 self.assertTrue(y1.equal(y2)) 120 self.assertTrue(ctx.log['monitor'] == 2*(n+1)) 121 self.assertTrue(ctx.log['preStep'] == 2*n) 122 self.assertTrue(ctx.log['postStep'] == 2*n) 123 self.assertTrue(ctx.log['solve'] == 1) 124 self.assertTrue(ctx.log['setUp'] == 1) 125 self.assertTrue(ctx.log['setFromOptions'] == 1) 126 self.assertTrue(ctx.log['step'] == n) 127 tao.cancelMonitor() 128 129# -------------------------------------------------------------------- 130 131import numpy 132if numpy.iscomplexobj(PETSc.ScalarType()): 133 del TestTaoPython 134 135if __name__ == '__main__': 136 unittest.main() 137 138# -------------------------------------------------------------------- 139