1a82e8c82SStefano Zampiniimport unittest 2a82e8c82SStefano Zampinifrom petsc4py import PETSc 3a82e8c82SStefano Zampinifrom sys import getrefcount 46f336411SStefano Zampiniimport numpy 56f336411SStefano Zampini 6a82e8c82SStefano Zampini 7a82e8c82SStefano Zampini# -------------------------------------------------------------------- 8a82e8c82SStefano Zampiniclass Objective: 9a82e8c82SStefano Zampini def __call__(self, tao, x): 10a82e8c82SStefano Zampini return (x[0] - 1.0) ** 2 + (x[1] - 2.0) ** 2 11a82e8c82SStefano Zampini 126f336411SStefano Zampini 13a82e8c82SStefano Zampiniclass Gradient: 14a82e8c82SStefano Zampini def __call__(self, tao, x, g): 15a82e8c82SStefano Zampini g[0] = 2.0 * (x[0] - 1.0) 16a82e8c82SStefano Zampini g[1] = 2.0 * (x[1] - 2.0) 17a82e8c82SStefano Zampini g.assemble() 18a82e8c82SStefano Zampini 196f336411SStefano Zampini 20a82e8c82SStefano Zampiniclass MyTao: 21a82e8c82SStefano Zampini def __init__(self): 22a82e8c82SStefano Zampini self.log = {} 23a82e8c82SStefano Zampini 24a82e8c82SStefano Zampini def _log(self, method): 25a82e8c82SStefano Zampini self.log.setdefault(method, 0) 26a82e8c82SStefano Zampini self.log[method] += 1 27a82e8c82SStefano Zampini 28a82e8c82SStefano Zampini def create(self, tao): 29a82e8c82SStefano Zampini self._log('create') 30a82e8c82SStefano Zampini self.testvec = PETSc.Vec() 31a82e8c82SStefano Zampini 32a82e8c82SStefano Zampini def destroy(self, tao): 33a82e8c82SStefano Zampini self._log('destroy') 34a82e8c82SStefano Zampini self.testvec.destroy() 35a82e8c82SStefano Zampini 36a82e8c82SStefano Zampini def setFromOptions(self, tao): 37a82e8c82SStefano Zampini self._log('setFromOptions') 38a82e8c82SStefano Zampini 39a82e8c82SStefano Zampini def setUp(self, tao): 40a82e8c82SStefano Zampini self._log('setUp') 41a82e8c82SStefano Zampini self.testvec = tao.getSolution().duplicate() 42a82e8c82SStefano Zampini 43a82e8c82SStefano Zampini def solve(self, tao): 44a82e8c82SStefano Zampini self._log('solve') 45a82e8c82SStefano Zampini 46a82e8c82SStefano Zampini def step(self, tao, x, g, s): 47a82e8c82SStefano Zampini self._log('step') 48a82e8c82SStefano Zampini tao.computeGradient(x, g) 49a82e8c82SStefano Zampini g.copy(s) 50a82e8c82SStefano Zampini s.scale(-1.0) 51a82e8c82SStefano Zampini 52a82e8c82SStefano Zampini def preStep(self, tao): 53a82e8c82SStefano Zampini self._log('preStep') 54a82e8c82SStefano Zampini 55a82e8c82SStefano Zampini def postStep(self, tao): 56a82e8c82SStefano Zampini self._log('postStep') 57a82e8c82SStefano Zampini 58a82e8c82SStefano Zampini def monitor(self, tao): 59a82e8c82SStefano Zampini self._log('monitor') 60a82e8c82SStefano Zampini 61a82e8c82SStefano Zampini 626f336411SStefano Zampiniclass TestTaoPython(unittest.TestCase): 63a82e8c82SStefano Zampini def setUp(self): 64a82e8c82SStefano Zampini self.tao = PETSc.TAO() 65a82e8c82SStefano Zampini self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 6639933f97SStefano Zampini self.assertEqual(getrefcount(self._getCtx()), 2) 6739933f97SStefano Zampini self.assertEqual(self._getCtx().log['create'], 1) 68a82e8c82SStefano Zampini self.nsolve = 0 69a82e8c82SStefano Zampini 70a82e8c82SStefano Zampini def tearDown(self): 7139933f97SStefano Zampini self.assertEqual(getrefcount(self._getCtx()), 2) 7239933f97SStefano Zampini self.assertTrue('destroy' not in self._getCtx().log) 7339933f97SStefano Zampini ctx = self._getCtx() 74a82e8c82SStefano Zampini self.tao.destroy() 75a82e8c82SStefano Zampini self.tao = None 7662e5d2d2SJDBetteridge PETSc.garbage_cleanup() 77a82e8c82SStefano Zampini self.assertEqual(ctx.log['destroy'], 1) 78a82e8c82SStefano Zampini 79ebead697SStefano Zampini def testGetType(self): 80ebead697SStefano Zampini ctx = self.tao.getPythonContext() 816f336411SStefano Zampini pytype = f'{ctx.__module__}.{type(ctx).__name__}' 82ebead697SStefano Zampini self.assertTrue(self.tao.getPythonType() == pytype) 83ebead697SStefano Zampini 84a82e8c82SStefano Zampini def testSolve(self): 85a82e8c82SStefano Zampini tao = self.tao 86a82e8c82SStefano Zampini ctx = tao.getPythonContext() 87a82e8c82SStefano Zampini x = PETSc.Vec().create(tao.getComm()) 88a82e8c82SStefano Zampini x.setType('standard') 89a82e8c82SStefano Zampini x.setSizes(2) 90a82e8c82SStefano Zampini y1 = x.duplicate() 91a82e8c82SStefano Zampini y2 = x.duplicate() 92a82e8c82SStefano Zampini tao.setObjective(Objective()) 93a82e8c82SStefano Zampini tao.setGradient(Gradient(), None) 94a82e8c82SStefano Zampini tao.setMonitor(ctx.monitor) 95a82e8c82SStefano Zampini tao.setFromOptions() 96a82e8c82SStefano Zampini tao.setMaximumIterations(3) 979ae6e6bbSStefano Zampini 989ae6e6bbSStefano Zampini def _update(tao, it, cnt): 999ae6e6bbSStefano Zampini cnt += 1 10039933f97SStefano Zampini 1019ae6e6bbSStefano Zampini cnt_up = numpy.array(0) 1029ae6e6bbSStefano Zampini tao.setUpdate(_update, (cnt_up,)) 103a82e8c82SStefano Zampini tao.setSolution(x) 104a82e8c82SStefano Zampini 105a82e8c82SStefano Zampini # Call the solve method of MyTAO 106a82e8c82SStefano Zampini x.set(0.5) 107a82e8c82SStefano Zampini tao.solve() 108a82e8c82SStefano Zampini n = tao.getIterationNumber() 109a82e8c82SStefano Zampini self.assertTrue(n == 0) 110a82e8c82SStefano Zampini 111a82e8c82SStefano Zampini # Call the default solve method and use step of MyTAO 112a82e8c82SStefano Zampini ctx.solve = None 113a82e8c82SStefano Zampini x.set(0.5) 114a82e8c82SStefano Zampini tao.solve() 115a82e8c82SStefano Zampini n = tao.getIterationNumber() 116e36f15a9Spaul.kuehner self.assertGreater(tao.getConvergedReason(), 0) 117a19a13c7SStefano Zampini self.assertTrue(n in [2, 3]) 118e36f15a9Spaul.kuehner self.assertAlmostEqual(x[0], 1.0) 119e36f15a9Spaul.kuehner self.assertAlmostEqual(x[1], 2.0) 120a82e8c82SStefano Zampini 121a82e8c82SStefano Zampini # Call the default solve method with the default step method 122a82e8c82SStefano Zampini ctx.step = None 123a82e8c82SStefano Zampini x.set(0.5) 124a82e8c82SStefano Zampini tao.solve() 125a82e8c82SStefano Zampini n = tao.getIterationNumber() 126e36f15a9Spaul.kuehner self.assertGreater(tao.getConvergedReason(), 0) 127a19a13c7SStefano Zampini self.assertTrue(n in [2, 3]) 128e36f15a9Spaul.kuehner self.assertAlmostEqual(x[0], 1.0) 129e36f15a9Spaul.kuehner self.assertAlmostEqual(x[1], 2.0) 130a82e8c82SStefano Zampini 131a82e8c82SStefano Zampini self.assertTrue(y1.equal(y2)) 132a82e8c82SStefano Zampini self.assertTrue(ctx.log['monitor'] == 2 * (n + 1)) 133a82e8c82SStefano Zampini self.assertTrue(ctx.log['preStep'] == 2 * n) 134a82e8c82SStefano Zampini self.assertTrue(ctx.log['postStep'] == 2 * n) 135a82e8c82SStefano Zampini self.assertTrue(ctx.log['solve'] == 1) 136a82e8c82SStefano Zampini self.assertTrue(ctx.log['setUp'] == 1) 137a82e8c82SStefano Zampini self.assertTrue(ctx.log['setFromOptions'] == 1) 138a82e8c82SStefano Zampini self.assertTrue(ctx.log['step'] == n) 1399ae6e6bbSStefano Zampini self.assertEqual(cnt_up, 2 * n) 140ebead697SStefano Zampini tao.cancelMonitor() 141a82e8c82SStefano Zampini 14239933f97SStefano Zampini def _getCtx(self): 14339933f97SStefano Zampini return self.tao.getPythonContext() 14439933f97SStefano Zampini 1456f336411SStefano Zampini 146*23e8ad30Spaul.kuehnerclass MyGradientDescent: 147*23e8ad30Spaul.kuehner def __init__(self): 148*23e8ad30Spaul.kuehner self._ls = None 149*23e8ad30Spaul.kuehner 150*23e8ad30Spaul.kuehner def create(self, tao): 151*23e8ad30Spaul.kuehner self._ls = PETSc.TAOLineSearch().create(comm=PETSc.COMM_SELF) 152*23e8ad30Spaul.kuehner self._ls.useTAORoutine(tao) 153*23e8ad30Spaul.kuehner self._ls.setType(PETSc.TAOLineSearch.Type.UNIT) 154*23e8ad30Spaul.kuehner self._ls.setInitialStepLength(0.2) 155*23e8ad30Spaul.kuehner 156*23e8ad30Spaul.kuehner def destroy(self, tao): 157*23e8ad30Spaul.kuehner self._ls.destroy() 158*23e8ad30Spaul.kuehner 159*23e8ad30Spaul.kuehner def setUp(self, tao): 160*23e8ad30Spaul.kuehner pass 161*23e8ad30Spaul.kuehner 162*23e8ad30Spaul.kuehner def solve(self, tao): 163*23e8ad30Spaul.kuehner x = tao.getSolution() 164*23e8ad30Spaul.kuehner gradient = tao.getGradient()[0] 165*23e8ad30Spaul.kuehner search_direction = gradient.copy() 166*23e8ad30Spaul.kuehner for it in range(tao.getMaximumIterations()): 167*23e8ad30Spaul.kuehner tao.setIterationNumber(it) 168*23e8ad30Spaul.kuehner 169*23e8ad30Spaul.kuehner # search_direction = -gradient 170*23e8ad30Spaul.kuehner tao.computeGradient(x, gradient) 171*23e8ad30Spaul.kuehner gradient.copy(search_direction) 172*23e8ad30Spaul.kuehner search_direction.scale(-1) 173*23e8ad30Spaul.kuehner 174*23e8ad30Spaul.kuehner # x = x + .2 search_direction 175*23e8ad30Spaul.kuehner f, s, reason = self._ls.apply(x, gradient, search_direction) 176*23e8ad30Spaul.kuehner 177*23e8ad30Spaul.kuehner tao.monitor(f=f, res=gradient.norm()) 178*23e8ad30Spaul.kuehner 179*23e8ad30Spaul.kuehner if reason < 0: 180*23e8ad30Spaul.kuehner raise RuntimeError('LS failed.') 181*23e8ad30Spaul.kuehner 182*23e8ad30Spaul.kuehner if tao.checkConverged() > 0: 183*23e8ad30Spaul.kuehner break 184*23e8ad30Spaul.kuehner 185*23e8ad30Spaul.kuehner def step(self, tao, x, g, s): 186*23e8ad30Spaul.kuehner raise RuntimeError('Should only be called by builtin solve.') 187*23e8ad30Spaul.kuehner 188*23e8ad30Spaul.kuehner def preStep(self, tao): 189*23e8ad30Spaul.kuehner raise RuntimeError('Should only be called by builtin solve.') 190*23e8ad30Spaul.kuehner 191*23e8ad30Spaul.kuehner def postStep(self, tao): 192*23e8ad30Spaul.kuehner raise RuntimeError('Should only be called by builtin solve.') 193*23e8ad30Spaul.kuehner 194*23e8ad30Spaul.kuehner 195*23e8ad30Spaul.kuehnerclass TestTaoPythonOptimiser(unittest.TestCase): 196*23e8ad30Spaul.kuehner def setUp(self): 197*23e8ad30Spaul.kuehner self.tao = PETSc.TAO() 198*23e8ad30Spaul.kuehner self.tao.createPython(MyGradientDescent(), comm=PETSc.COMM_SELF) 199*23e8ad30Spaul.kuehner 200*23e8ad30Spaul.kuehner def tearDown(self): 201*23e8ad30Spaul.kuehner self.tao.destroy() 202*23e8ad30Spaul.kuehner self.tao = None 203*23e8ad30Spaul.kuehner 204*23e8ad30Spaul.kuehner def testSolve(self): 205*23e8ad30Spaul.kuehner tao = self.tao 206*23e8ad30Spaul.kuehner 207*23e8ad30Spaul.kuehner opts = PETSc.Options('test_tao_python_optimiser_') 208*23e8ad30Spaul.kuehner opts['tao_max_it'] = 100 209*23e8ad30Spaul.kuehner opts['tao_gatol'] = 1e-6 210*23e8ad30Spaul.kuehner 211*23e8ad30Spaul.kuehner tao.setOptionsPrefix('test_tao_python_optimiser_') 212*23e8ad30Spaul.kuehner tao.setFromOptions() 213*23e8ad30Spaul.kuehner 214*23e8ad30Spaul.kuehner x = PETSc.Vec().createSeq(2, comm=tao.getComm()) 215*23e8ad30Spaul.kuehner x.set(0.5) 216*23e8ad30Spaul.kuehner 217*23e8ad30Spaul.kuehner tao.setSolution(x) 218*23e8ad30Spaul.kuehner tao.setObjective(Objective()) 219*23e8ad30Spaul.kuehner tao.setGradient(Gradient(), x.copy()) 220*23e8ad30Spaul.kuehner 221*23e8ad30Spaul.kuehner tao.solve() 222*23e8ad30Spaul.kuehner 223*23e8ad30Spaul.kuehner self.assertEqual(tao.getMaximumIterations(), 100) 224*23e8ad30Spaul.kuehner self.assertAlmostEqual(tao.getTolerances()[0], 1e-6) 225*23e8ad30Spaul.kuehner self.assertGreater(tao.getIterationNumber(), 0) 226*23e8ad30Spaul.kuehner self.assertGreater(tao.getConvergedReason(), 0) 227*23e8ad30Spaul.kuehner self.assertAlmostEqual(x[0], 1.0, places=5) 228*23e8ad30Spaul.kuehner self.assertAlmostEqual(x[1], 2.0, places=5) 229*23e8ad30Spaul.kuehner self.assertGreater(tao.getObjectiveValue(), 0) 230*23e8ad30Spaul.kuehner self.assertAlmostEqual(tao.getObjectiveValue(), 0, places=5) 231*23e8ad30Spaul.kuehner 232*23e8ad30Spaul.kuehner 233a82e8c82SStefano Zampini# -------------------------------------------------------------------- 234a82e8c82SStefano Zampini 235a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()): 236a82e8c82SStefano Zampini del TestTaoPython 237*23e8ad30Spaul.kuehner del TestTaoPythonOptimiser 238a82e8c82SStefano Zampini 239a82e8c82SStefano Zampiniif __name__ == '__main__': 240a82e8c82SStefano Zampini unittest.main() 241a82e8c82SStefano Zampini 242a82e8c82SStefano Zampini# -------------------------------------------------------------------- 243