1import unittest 2from petsc4py import PETSc 3from sys import getrefcount 4import gc 5 6# -------------------------------------------------------------------- 7class Objective: 8 def __call__(self, tao, x): 9 return (x[0] - 1.0)**2 + (x[1] - 2.0)**2 10 11class Gradient: 12 def __call__(self, tao, x, g): 13 g[0] = 2.0*(x[0] - 1.0) 14 g[1] = 2.0*(x[1] - 2.0) 15 g.assemble() 16 17class MyTao: 18 def __init__(self): 19 self.log = {} 20 21 def _log(self, method): 22 self.log.setdefault(method, 0) 23 self.log[method] += 1 24 25 def create(self, tao): 26 self._log('create') 27 self.testvec = PETSc.Vec() 28 29 def destroy(self, tao): 30 self._log('destroy') 31 self.testvec.destroy() 32 33 def setFromOptions(self, tao): 34 self._log('setFromOptions') 35 36 def setUp(self, tao): 37 self._log('setUp') 38 self.testvec = tao.getSolution().duplicate() 39 40 def solve(self, tao): 41 self._log('solve') 42 43 def step(self, tao, x, g, s): 44 self._log('step') 45 tao.computeGradient(x,g) 46 g.copy(s) 47 s.scale(-1.0) 48 49 def preStep(self, tao): 50 self._log('preStep') 51 52 def postStep(self, tao): 53 self._log('postStep') 54 55 def monitor(self, tao): 56 self._log('monitor') 57 58class TestTaoPython(unittest.TestCase): 59 60 def setUp(self): 61 self.tao = PETSc.TAO() 62 self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 63 ctx = self.tao.getPythonContext() 64 self.assertEqual(getrefcount(ctx), 3) 65 self.assertEqual(ctx.log['create'], 1) 66 self.nsolve = 0 67 68 def tearDown(self): 69 ctx = self.tao.getPythonContext() 70 self.assertEqual(getrefcount(ctx), 4) 71 self.assertTrue('destroy' not in ctx.log) 72 self.tao.destroy() 73 self.tao = None 74 self.assertEqual(ctx.log['destroy'], 1) 75 self.assertEqual(getrefcount(ctx), 2) 76 77 def testSolve(self): 78 tao = self.tao 79 ctx = tao.getPythonContext() 80 x = PETSc.Vec().create(tao.getComm()) 81 x.setType('standard') 82 x.setSizes(2) 83 y1 = x.duplicate() 84 y2 = x.duplicate() 85 tao.setObjective(Objective()) 86 tao.setGradient(Gradient(),None) 87 tao.setMonitor(ctx.monitor) 88 tao.setFromOptions() 89 tao.setMaximumIterations(3) 90 tao.setSolution(x) 91 92 # Call the solve method of MyTAO 93 x.set(0.5) 94 tao.solve() 95 n = tao.getIterationNumber() 96 self.assertTrue(n == 0) 97 98 # Call the default solve method and use step of MyTAO 99 ctx.solve = None 100 x.set(0.5) 101 tao.solve() 102 n = tao.getIterationNumber() 103 self.assertTrue(n == 3) 104 x.copy(y1) 105 106 # Call the default solve method with the default step method 107 ctx.step = None 108 x.set(0.5) 109 tao.solve() 110 n = tao.getIterationNumber() 111 self.assertTrue(n == 3) 112 x.copy(y2) 113 114 self.assertTrue(y1.equal(y2)) 115 self.assertTrue(ctx.log['monitor'] == 2*(n+1)) 116 self.assertTrue(ctx.log['preStep'] == 2*n) 117 self.assertTrue(ctx.log['postStep'] == 2*n) 118 self.assertTrue(ctx.log['solve'] == 1) 119 self.assertTrue(ctx.log['setUp'] == 1) 120 self.assertTrue(ctx.log['setFromOptions'] == 1) 121 self.assertTrue(ctx.log['step'] == n) 122 123# -------------------------------------------------------------------- 124 125import numpy 126if numpy.iscomplexobj(PETSc.ScalarType()): 127 del TestTaoPython 128 129if __name__ == '__main__': 130 unittest.main() 131 132# -------------------------------------------------------------------- 133