1import unittest 2from petsc4py import PETSc 3from sys import getrefcount 4 5# -------------------------------------------------------------------- 6class Objective: 7 def __call__(self, tao, x): 8 return (x[0] - 1.0)**2 + (x[1] - 2.0)**2 9 10class Gradient: 11 def __call__(self, tao, x, g): 12 g[0] = 2.0*(x[0] - 1.0) 13 g[1] = 2.0*(x[1] - 2.0) 14 g.assemble() 15 16class MyTao: 17 def __init__(self): 18 self.log = {} 19 20 def _log(self, method): 21 self.log.setdefault(method, 0) 22 self.log[method] += 1 23 24 def create(self, tao): 25 self._log('create') 26 self.testvec = PETSc.Vec() 27 28 def destroy(self, tao): 29 self._log('destroy') 30 self.testvec.destroy() 31 32 def setFromOptions(self, tao): 33 self._log('setFromOptions') 34 35 def setUp(self, tao): 36 self._log('setUp') 37 self.testvec = tao.getSolution().duplicate() 38 39 def solve(self, tao): 40 self._log('solve') 41 42 def step(self, tao, x, g, s): 43 self._log('step') 44 tao.computeGradient(x,g) 45 g.copy(s) 46 s.scale(-1.0) 47 48 def preStep(self, tao): 49 self._log('preStep') 50 51 def postStep(self, tao): 52 self._log('postStep') 53 54 def monitor(self, tao): 55 self._log('monitor') 56 57class TestTaoPython(unittest.TestCase): 58 59 def setUp(self): 60 self.tao = PETSc.TAO() 61 self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 62 ctx = self.tao.getPythonContext() 63 self.assertEqual(getrefcount(ctx), 3) 64 self.assertEqual(ctx.log['create'], 1) 65 self.nsolve = 0 66 67 def tearDown(self): 68 ctx = self.tao.getPythonContext() 69 self.assertEqual(getrefcount(ctx), 3) 70 self.assertTrue('destroy' not in ctx.log) 71 self.tao.destroy() 72 self.tao = None 73 PETSc.garbage_cleanup() 74 self.assertEqual(ctx.log['destroy'], 1) 75 self.assertEqual(getrefcount(ctx), 2) 76 77 def testGetType(self): 78 ctx = self.tao.getPythonContext() 79 pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__) 80 self.assertTrue(self.tao.getPythonType() == pytype) 81 82 def testSolve(self): 83 tao = self.tao 84 ctx = tao.getPythonContext() 85 x = PETSc.Vec().create(tao.getComm()) 86 x.setType('standard') 87 x.setSizes(2) 88 y1 = x.duplicate() 89 y2 = x.duplicate() 90 tao.setObjective(Objective()) 91 tao.setGradient(Gradient(),None) 92 tao.setMonitor(ctx.monitor) 93 tao.setFromOptions() 94 tao.setMaximumIterations(3) 95 tao.setSolution(x) 96 97 # Call the solve method of MyTAO 98 x.set(0.5) 99 tao.solve() 100 n = tao.getIterationNumber() 101 self.assertTrue(n == 0) 102 103 # Call the default solve method and use step of MyTAO 104 ctx.solve = None 105 x.set(0.5) 106 tao.solve() 107 n = tao.getIterationNumber() 108 self.assertTrue(n == 3) 109 x.copy(y1) 110 111 # Call the default solve method with the default step method 112 ctx.step = None 113 x.set(0.5) 114 tao.solve() 115 n = tao.getIterationNumber() 116 self.assertTrue(n == 3) 117 x.copy(y2) 118 119 self.assertTrue(y1.equal(y2)) 120 self.assertTrue(ctx.log['monitor'] == 2*(n+1)) 121 self.assertTrue(ctx.log['preStep'] == 2*n) 122 self.assertTrue(ctx.log['postStep'] == 2*n) 123 self.assertTrue(ctx.log['solve'] == 1) 124 self.assertTrue(ctx.log['setUp'] == 1) 125 self.assertTrue(ctx.log['setFromOptions'] == 1) 126 self.assertTrue(ctx.log['step'] == n) 127 tao.cancelMonitor() 128 129# -------------------------------------------------------------------- 130 131import numpy 132if numpy.iscomplexobj(PETSc.ScalarType()): 133 del TestTaoPython 134 135if __name__ == '__main__': 136 unittest.main() 137 138# -------------------------------------------------------------------- 139