1import unittest 2from petsc4py import PETSc 3from sys import getrefcount 4import numpy 5 6 7# -------------------------------------------------------------------- 8class Objective: 9 def __call__(self, tao, x): 10 return (x[0] - 1.0) ** 2 + (x[1] - 2.0) ** 2 11 12 13class Gradient: 14 def __call__(self, tao, x, g): 15 g[0] = 2.0 * (x[0] - 1.0) 16 g[1] = 2.0 * (x[1] - 2.0) 17 g.assemble() 18 19 20class MyTao: 21 def __init__(self): 22 self.log = {} 23 24 def _log(self, method): 25 self.log.setdefault(method, 0) 26 self.log[method] += 1 27 28 def create(self, tao): 29 self._log('create') 30 self.testvec = PETSc.Vec() 31 32 def destroy(self, tao): 33 self._log('destroy') 34 self.testvec.destroy() 35 36 def setFromOptions(self, tao): 37 self._log('setFromOptions') 38 39 def setUp(self, tao): 40 self._log('setUp') 41 self.testvec = tao.getSolution().duplicate() 42 43 def solve(self, tao): 44 self._log('solve') 45 46 def step(self, tao, x, g, s): 47 self._log('step') 48 tao.computeGradient(x, g) 49 g.copy(s) 50 s.scale(-1.0) 51 52 def preStep(self, tao): 53 self._log('preStep') 54 55 def postStep(self, tao): 56 self._log('postStep') 57 58 def monitor(self, tao): 59 self._log('monitor') 60 61 62class TestTaoPython(unittest.TestCase): 63 def setUp(self): 64 self.tao = PETSc.TAO() 65 self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 66 self.assertEqual(getrefcount(self._getCtx()), 2) 67 self.assertEqual(self._getCtx().log['create'], 1) 68 self.nsolve = 0 69 70 def tearDown(self): 71 self.assertEqual(getrefcount(self._getCtx()), 2) 72 self.assertTrue('destroy' not in self._getCtx().log) 73 ctx = self._getCtx() 74 self.tao.destroy() 75 self.tao = None 76 PETSc.garbage_cleanup() 77 self.assertEqual(ctx.log['destroy'], 1) 78 79 def testGetType(self): 80 ctx = self.tao.getPythonContext() 81 pytype = f'{ctx.__module__}.{type(ctx).__name__}' 82 self.assertTrue(self.tao.getPythonType() == pytype) 83 84 def testSolve(self): 85 tao = self.tao 86 ctx = tao.getPythonContext() 87 x = PETSc.Vec().create(tao.getComm()) 88 x.setType('standard') 89 x.setSizes(2) 90 y1 = x.duplicate() 91 y2 = x.duplicate() 92 tao.setObjective(Objective()) 93 tao.setGradient(Gradient(), None) 94 tao.setMonitor(ctx.monitor) 95 tao.setFromOptions() 96 tao.setMaximumIterations(3) 97 98 def _update(tao, it, cnt): 99 cnt += 1 100 101 cnt_up = numpy.array(0) 102 tao.setUpdate(_update, (cnt_up,)) 103 tao.setSolution(x) 104 105 # Call the solve method of MyTAO 106 x.set(0.5) 107 tao.solve() 108 n = tao.getIterationNumber() 109 self.assertTrue(n == 0) 110 111 # Call the default solve method and use step of MyTAO 112 ctx.solve = None 113 x.set(0.5) 114 tao.solve() 115 n = tao.getIterationNumber() 116 self.assertTrue(n in [2, 3]) 117 x.copy(y1) 118 119 # Call the default solve method with the default step method 120 ctx.step = None 121 x.set(0.5) 122 tao.solve() 123 n = tao.getIterationNumber() 124 self.assertTrue(n in [2, 3]) 125 x.copy(y2) 126 127 self.assertTrue(y1.equal(y2)) 128 self.assertTrue(ctx.log['monitor'] == 2 * (n + 1)) 129 self.assertTrue(ctx.log['preStep'] == 2 * n) 130 self.assertTrue(ctx.log['postStep'] == 2 * n) 131 self.assertTrue(ctx.log['solve'] == 1) 132 self.assertTrue(ctx.log['setUp'] == 1) 133 self.assertTrue(ctx.log['setFromOptions'] == 1) 134 self.assertTrue(ctx.log['step'] == n) 135 self.assertEqual(cnt_up, 2 * n) 136 tao.cancelMonitor() 137 138 def _getCtx(self): 139 return self.tao.getPythonContext() 140 141 142# -------------------------------------------------------------------- 143 144if numpy.iscomplexobj(PETSc.ScalarType()): 145 del TestTaoPython 146 147if __name__ == '__main__': 148 unittest.main() 149 150# -------------------------------------------------------------------- 151