xref: /petsc/src/binding/petsc4py/test/test_tao.py (revision 6327175cd4e6f53ea7dcf4f5513d67577e46142b)
15808f684SSatish Balay# --------------------------------------------------------------------
25808f684SSatish Balay
349b2cd16SPaul T. Kühnerfrom math import sqrt
45808f684SSatish Balayfrom petsc4py import PETSc
55808f684SSatish Balayimport unittest
66f336411SStefano Zampiniimport numpy
76f336411SStefano Zampini
85808f684SSatish Balay
95808f684SSatish Balay# --------------------------------------------------------------------
105971bccaSStefano Zampiniclass Objective:
115971bccaSStefano Zampini    def __call__(self, tao, x):
125971bccaSStefano Zampini        return (x[0] - 2.0) ** 2 + (x[1] - 2.0) ** 2 - 2.0 * (x[0] + x[1])
135971bccaSStefano Zampini
146f336411SStefano Zampini
155971bccaSStefano Zampiniclass Gradient:
165971bccaSStefano Zampini    def __call__(self, tao, x, g):
175971bccaSStefano Zampini        g[0] = 2.0 * (x[0] - 2.0) - 2.0
185971bccaSStefano Zampini        g[1] = 2.0 * (x[1] - 2.0) - 2.0
195971bccaSStefano Zampini        g.assemble()
205971bccaSStefano Zampini
216f336411SStefano Zampini
225971bccaSStefano Zampiniclass EqConstraints:
235971bccaSStefano Zampini    def __call__(self, tao, x, c):
245971bccaSStefano Zampini        c[0] = x[0] ** 2 + x[1] - 2.0
255971bccaSStefano Zampini        c.assemble()
265971bccaSStefano Zampini
276f336411SStefano Zampini
285971bccaSStefano Zampiniclass EqJacobian:
295971bccaSStefano Zampini    def __call__(self, tao, x, J, P):
305971bccaSStefano Zampini        P[0, 0] = 2.0 * x[0]
315971bccaSStefano Zampini        P[0, 1] = 1.0
325971bccaSStefano Zampini        P.assemble()
336f336411SStefano Zampini        if J != P:
346f336411SStefano Zampini            J.assemble()
355808f684SSatish Balay
365808f684SSatish Balay
372944e117SPaul T. Kühnerclass InEqConstraints:
382944e117SPaul T. Kühner    def __call__(self, tao, x, c):
392944e117SPaul T. Kühner        c[0] = x[1] - x[0] ** 2
402944e117SPaul T. Kühner        c.assemble()
412944e117SPaul T. Kühner
422944e117SPaul T. Kühner
432944e117SPaul T. Kühnerclass InEqJacobian:
442944e117SPaul T. Kühner    def __call__(self, tao, x, J, P):
452944e117SPaul T. Kühner        P[0, 0] = -2.0 * x[0]
462944e117SPaul T. Kühner        P[0, 1] = 1.0
472944e117SPaul T. Kühner        P.assemble()
482944e117SPaul T. Kühner        if J != P:
492944e117SPaul T. Kühner            J.assemble()
502944e117SPaul T. Kühner
512944e117SPaul T. Kühner
526f336411SStefano Zampiniclass BaseTestTAO:
535808f684SSatish Balay    COMM = None
545808f684SSatish Balay
555808f684SSatish Balay    def setUp(self):
565808f684SSatish Balay        self.tao = PETSc.TAO().create(comm=self.COMM)
575808f684SSatish Balay
585808f684SSatish Balay    def tearDown(self):
595808f684SSatish Balay        self.tao = None
6062e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
615808f684SSatish Balay
625808f684SSatish Balay    def testSetRoutinesToNone(self):
635808f684SSatish Balay        tao = self.tao
645808f684SSatish Balay        objective, gradient, objgrad = None, None, None
655808f684SSatish Balay        constraint, varbounds = None, None
665808f684SSatish Balay        hessian, jacobian = None, None
675808f684SSatish Balay        tao.setObjective(objective)
68a82e8c82SStefano Zampini        tao.setGradient(gradient, None)
695808f684SSatish Balay        tao.setVariableBounds(varbounds)
70a82e8c82SStefano Zampini        tao.setObjectiveGradient(objgrad, None)
715808f684SSatish Balay        tao.setConstraints(constraint)
725808f684SSatish Balay        tao.setHessian(hessian)
735808f684SSatish Balay        tao.setJacobian(jacobian)
745808f684SSatish Balay
755808f684SSatish Balay    def testGetVecsAndMats(self):
765808f684SSatish Balay        tao = self.tao
775808f684SSatish Balay        x = tao.getSolution()
78a82e8c82SStefano Zampini        (g, _) = tao.getGradient()
796f336411SStefano Zampini        low, up = tao.getVariableBounds()
805808f684SSatish Balay        r = None  # tao.getConstraintVec()
815808f684SSatish Balay        H, HP = None, None  # tao.getHessianMat()
825808f684SSatish Balay        J, JP = None, None  # tao.getJacobianMat()
836f336411SStefano Zampini        for o in [
846f336411SStefano Zampini            x,
856f336411SStefano Zampini            g,
866f336411SStefano Zampini            r,
876f336411SStefano Zampini            low,
886f336411SStefano Zampini            up,
896f336411SStefano Zampini            H,
906f336411SStefano Zampini            HP,
916f336411SStefano Zampini            J,
926f336411SStefano Zampini            JP,
936f336411SStefano Zampini        ]:
945808f684SSatish Balay            self.assertFalse(o)
955808f684SSatish Balay
965808f684SSatish Balay    def testGetKSP(self):
975808f684SSatish Balay        ksp = self.tao.getKSP()
985808f684SSatish Balay        self.assertFalse(ksp)
995808f684SSatish Balay
1005971bccaSStefano Zampini    def testEqualityConstraints(self):
1015971bccaSStefano Zampini        if self.tao.getComm().Get_size() > 1:
1025971bccaSStefano Zampini            return
1035971bccaSStefano Zampini        tao = self.tao
1045971bccaSStefano Zampini
1055971bccaSStefano Zampini        x = PETSc.Vec().create(tao.getComm())
1065971bccaSStefano Zampini        x.setType('standard')
1075971bccaSStefano Zampini        x.setSizes(2)
1085971bccaSStefano Zampini        c = PETSc.Vec().create(tao.getComm())
1095971bccaSStefano Zampini        c.setSizes(1)
1105971bccaSStefano Zampini        c.setType(x.getType())
1115971bccaSStefano Zampini        J = PETSc.Mat().create(tao.getComm())
1125971bccaSStefano Zampini        J.setSizes([1, 2])
1135971bccaSStefano Zampini        J.setType(PETSc.Mat.Type.DENSE)
1145971bccaSStefano Zampini        J.setUp()
1155971bccaSStefano Zampini
1165971bccaSStefano Zampini        tao.setObjective(Objective())
117a82e8c82SStefano Zampini        tao.setGradient(Gradient(), None)
1185971bccaSStefano Zampini        tao.setEqualityConstraints(EqConstraints(), c)
1195971bccaSStefano Zampini        tao.setJacobianEquality(EqJacobian(), J, J)
120a82e8c82SStefano Zampini        tao.setSolution(x)
1215971bccaSStefano Zampini        tao.setType(PETSc.TAO.Type.ALMM)
12249b2cd16SPaul T. Kühner        tao.setALMMType(PETSc.TAO.ALMMType.PHR)
1236f336411SStefano Zampini        tao.setTolerances(gatol=1.0e-4)
1245971bccaSStefano Zampini        tao.setFromOptions()
1255971bccaSStefano Zampini        tao.solve()
12649b2cd16SPaul T. Kühner        self.assertTrue(tao.getALMMType() == PETSc.TAO.ALMMType.PHR)
1275971bccaSStefano Zampini        self.assertAlmostEqual(abs(x[0] ** 2 + x[1] - 2.0), 0.0, places=4)
12849b2cd16SPaul T. Kühner        self.assertAlmostEqual(x[0], 0.7351392590499015014254200465, places=4)
12949b2cd16SPaul T. Kühner        self.assertAlmostEqual(x[1], 1.4595702698035618134357683666, places=4)
1306c187712Spaul.kuehner        self.assertTrue(tao.getObjective() is not None)
1315971bccaSStefano Zampini
132*1b741c34Spaul.kuehner        c, g = tao.getEqualityConstraints()
133*1b741c34Spaul.kuehner        c_eval = c.copy()
134*1b741c34Spaul.kuehner        g[0](tao, x, c_eval, *g[1], **g[2])
135*1b741c34Spaul.kuehner        self.assertTrue(c.equal(c_eval))
136*1b741c34Spaul.kuehner
137*1b741c34Spaul.kuehner        J, Jpre, Jg = tao.getJacobianEquality()
138*1b741c34Spaul.kuehner        Jg[0](tao, x, J, Jpre, *Jg[1], **Jg[2])
139*1b741c34Spaul.kuehner        self.assertTrue(J.equal(Jpre))
140*1b741c34Spaul.kuehner
1412944e117SPaul T. Kühner    def testInequlityConstraints(self):
1422944e117SPaul T. Kühner        if self.tao.getComm().Get_size() > 1:
1432944e117SPaul T. Kühner            return
1442944e117SPaul T. Kühner        tao = self.tao
1452944e117SPaul T. Kühner
1462944e117SPaul T. Kühner        x = PETSc.Vec().create(tao.getComm())
1472944e117SPaul T. Kühner        x.setType('standard')
1482944e117SPaul T. Kühner        x.setSizes(2)
1492944e117SPaul T. Kühner        c = PETSc.Vec().create(tao.getComm())
1502944e117SPaul T. Kühner        c.setSizes(1)
1512944e117SPaul T. Kühner        c.setType(x.getType())
1522944e117SPaul T. Kühner        J = PETSc.Mat().create(tao.getComm())
1532944e117SPaul T. Kühner        J.setSizes([1, 2])
1542944e117SPaul T. Kühner        J.setType(PETSc.Mat.Type.DENSE)
1552944e117SPaul T. Kühner        J.setUp()
1562944e117SPaul T. Kühner
1572944e117SPaul T. Kühner        tao.setObjective(Objective())
1582944e117SPaul T. Kühner        tao.setGradient(Gradient(), None)
1592944e117SPaul T. Kühner        tao.setInequalityConstraints(InEqConstraints(), c)
1602944e117SPaul T. Kühner        tao.setJacobianInequality(InEqJacobian(), J, J)
1612944e117SPaul T. Kühner        tao.setSolution(x)
1622944e117SPaul T. Kühner        tao.setType(PETSc.TAO.Type.ALMM)
16349b2cd16SPaul T. Kühner        tao.setALMMType(PETSc.TAO.ALMMType.CLASSIC)
1642944e117SPaul T. Kühner        tao.setTolerances(gatol=1.0e-4)
1652944e117SPaul T. Kühner        tao.setFromOptions()
1662944e117SPaul T. Kühner        tao.solve()
16749b2cd16SPaul T. Kühner
16849b2cd16SPaul T. Kühner        self.assertTrue(tao.getALMMType() == PETSc.TAO.ALMMType.CLASSIC)
1692944e117SPaul T. Kühner        self.assertTrue(x[1] - x[0] ** 2 >= -1.0e-4)
17049b2cd16SPaul T. Kühner        self.assertAlmostEqual(x[0], 0.5 + sqrt(7) / 2, places=4)
17149b2cd16SPaul T. Kühner        self.assertAlmostEqual(x[1], 2 + sqrt(7) / 2, places=4)
1722944e117SPaul T. Kühner
173*1b741c34Spaul.kuehner        c, h = tao.getInequalityConstraints()
174*1b741c34Spaul.kuehner        c_eval = c.copy()
175*1b741c34Spaul.kuehner        h[0](tao, x, c_eval, *h[1], **h[2])
176*1b741c34Spaul.kuehner        self.assertTrue(c.equal(c_eval))
177*1b741c34Spaul.kuehner
178*1b741c34Spaul.kuehner        J, Jpre, Jh = tao.getJacobianInequality()
179*1b741c34Spaul.kuehner        Jh[0](tao, x, J, Jpre, *Jh[1], **Jh[2])
180*1b741c34Spaul.kuehner        self.assertTrue(J.equal(Jpre))
181*1b741c34Spaul.kuehner
182d6e07cdcSHong Zhang    def testBNCG(self):
183d6e07cdcSHong Zhang        if self.tao.getComm().Get_size() > 1:
184d6e07cdcSHong Zhang            return
185d6e07cdcSHong Zhang        tao = self.tao
186d6e07cdcSHong Zhang
187d6e07cdcSHong Zhang        x = PETSc.Vec().create(tao.getComm())
188d6e07cdcSHong Zhang        x.setType('standard')
189d6e07cdcSHong Zhang        x.setSizes(2)
190d6e07cdcSHong Zhang        xl = PETSc.Vec().create(tao.getComm())
191d6e07cdcSHong Zhang        xl.setType('standard')
192d6e07cdcSHong Zhang        xl.setSizes(2)
193d6e07cdcSHong Zhang        xl.set(0.0)
194d6e07cdcSHong Zhang        xu = PETSc.Vec().create(tao.getComm())
195d6e07cdcSHong Zhang        xu.setType('standard')
196d6e07cdcSHong Zhang        xu.setSizes(2)
197d6e07cdcSHong Zhang        xu.set(2.0)
198d6e07cdcSHong Zhang        tao.setVariableBounds((xl, xu))
199d6e07cdcSHong Zhang        tao.setObjective(Objective())
200d6e07cdcSHong Zhang        tao.setGradient(Gradient(), None)
201d6e07cdcSHong Zhang        tao.setSolution(x)
202d6e07cdcSHong Zhang        tao.setType(PETSc.TAO.Type.BNCG)
2036f336411SStefano Zampini        tao.setTolerances(gatol=1.0e-4)
204d6e07cdcSHong Zhang        ls = tao.getLineSearch()
205d6e07cdcSHong Zhang        ls.setType(PETSc.TAOLineSearch.Type.UNIT)
206d6e07cdcSHong Zhang        tao.setFromOptions()
207d6e07cdcSHong Zhang        tao.solve()
208d6e07cdcSHong Zhang        self.assertAlmostEqual(x[0], 2.0, places=4)
209d6e07cdcSHong Zhang        self.assertAlmostEqual(x[1], 2.0, places=4)
210d6e07cdcSHong Zhang
21157ac95aeSpaul.kuehner    def templateBQNLS(self, lmvm_setup):
2128a612f1bSpaul.kuehner        if self.tao.getComm().Get_size() > 1:
2138a612f1bSpaul.kuehner            return
2148a612f1bSpaul.kuehner        tao = self.tao
2158a612f1bSpaul.kuehner
2168a612f1bSpaul.kuehner        x = PETSc.Vec().create(tao.getComm())
2178a612f1bSpaul.kuehner        x.setType('standard')
2188a612f1bSpaul.kuehner        x.setSizes(2)
2198a612f1bSpaul.kuehner        xl = PETSc.Vec().create(tao.getComm())
2208a612f1bSpaul.kuehner        xl.setType('standard')
2218a612f1bSpaul.kuehner        xl.setSizes(2)
2228a612f1bSpaul.kuehner        xl.set(0.0)
2238a612f1bSpaul.kuehner        xu = PETSc.Vec().create(tao.getComm())
2248a612f1bSpaul.kuehner        xu.setType('standard')
2258a612f1bSpaul.kuehner        xu.setSizes(2)
2268a612f1bSpaul.kuehner        xu.set(2.0)
2278a612f1bSpaul.kuehner        tao.setVariableBounds((xl, xu))
2288a612f1bSpaul.kuehner        tao.setObjective(Objective())
2298a612f1bSpaul.kuehner        tao.setGradient(Gradient(), None)
2308a612f1bSpaul.kuehner        tao.setSolution(x)
2318a612f1bSpaul.kuehner        tao.setType(PETSc.TAO.Type.BQNLS)
2328a612f1bSpaul.kuehner        tao.setTolerances(gatol=1.0e-4)
23357ac95aeSpaul.kuehner
23457ac95aeSpaul.kuehner        H = PETSc.Mat()
23557ac95aeSpaul.kuehner        if lmvm_setup == 'dense' or lmvm_setup == 'ksp':
23657ac95aeSpaul.kuehner            H.createDense((2, 2), comm=tao.getComm())
2378a612f1bSpaul.kuehner            H[0, 0] = 2
2388a612f1bSpaul.kuehner            H[0, 1] = 0
2398a612f1bSpaul.kuehner            H[1, 0] = 0
2408a612f1bSpaul.kuehner            H[1, 1] = 2
2418a612f1bSpaul.kuehner            H.assemble()
24257ac95aeSpaul.kuehner        elif lmvm_setup == 'diagonal':
24357ac95aeSpaul.kuehner            H_vec = PETSc.Vec().createSeq(2)
24457ac95aeSpaul.kuehner            H_vec[0] = 2
24557ac95aeSpaul.kuehner            H_vec[1] = 2
24657ac95aeSpaul.kuehner            H_vec.assemble()
24757ac95aeSpaul.kuehner            H.createDiagonal(H_vec)
24857ac95aeSpaul.kuehner            H.assemble()
24957ac95aeSpaul.kuehner
25057ac95aeSpaul.kuehner        if lmvm_setup == 'dense' or lmvm_setup == 'diagonal':
2518a612f1bSpaul.kuehner            tao.getLMVMMat().setLMVMJ0(H)
25257ac95aeSpaul.kuehner        elif lmvm_setup == 'ksp':
25357ac95aeSpaul.kuehner            lmvm_ksp = PETSc.KSP().create(tao.getComm())
25457ac95aeSpaul.kuehner            lmvm_ksp.setType(PETSc.KSP.Type.CG)
25557ac95aeSpaul.kuehner            lmvm_ksp.setOperators(H)
25657ac95aeSpaul.kuehner            tao.getLMVMMat().setLMVMJ0KSP(lmvm_ksp)
25757ac95aeSpaul.kuehner
2588a612f1bSpaul.kuehner        tao.setFromOptions()
2598a612f1bSpaul.kuehner        tao.solve()
260a9d31c33SSatish Balay        if lmvm_setup == 'dense':
2618a612f1bSpaul.kuehner            self.assertEqual(tao.getIterationNumber(), 1)
2628a612f1bSpaul.kuehner        self.assertAlmostEqual(x[0], 2.0, places=4)
2638a612f1bSpaul.kuehner        self.assertAlmostEqual(x[1], 2.0, places=4)
26457ac95aeSpaul.kuehner
26557ac95aeSpaul.kuehner        if lmvm_setup == 'dense' or lmvm_setup == 'diagonal':
2668a612f1bSpaul.kuehner            self.assertTrue(tao.getLMVMMat().getLMVMJ0().equal(H))
26757ac95aeSpaul.kuehner        elif lmvm_setup == 'ksp':
26857ac95aeSpaul.kuehner            self.assertTrue(
26957ac95aeSpaul.kuehner                tao.getLMVMMat().getLMVMJ0KSP().getType() == PETSc.KSP.Type.CG
27057ac95aeSpaul.kuehner            )
27157ac95aeSpaul.kuehner
27257ac95aeSpaul.kuehner    def testBQNLS_dense(self):
27357ac95aeSpaul.kuehner        self.templateBQNLS('dense')
27457ac95aeSpaul.kuehner
27557ac95aeSpaul.kuehner    def testBQNLS_ksp(self):
27657ac95aeSpaul.kuehner        self.templateBQNLS('ksp')
27757ac95aeSpaul.kuehner
27857ac95aeSpaul.kuehner    def testBQNLS_diagonal(self):
27957ac95aeSpaul.kuehner        self.templateBQNLS('diagonal')
2808a612f1bSpaul.kuehner
2816f336411SStefano Zampini
2825808f684SSatish Balay# --------------------------------------------------------------------
2835808f684SSatish Balay
2846f336411SStefano Zampini
2855808f684SSatish Balayclass TestTAOSelf(BaseTestTAO, unittest.TestCase):
2865808f684SSatish Balay    COMM = PETSc.COMM_SELF
2875808f684SSatish Balay
2886f336411SStefano Zampini
2895808f684SSatish Balayclass TestTAOWorld(BaseTestTAO, unittest.TestCase):
2905808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2915808f684SSatish Balay
2926f336411SStefano Zampini
2935808f684SSatish Balay# --------------------------------------------------------------------
2945808f684SSatish Balay
2956f336411SStefano Zampini
2965808f684SSatish Balayif numpy.iscomplexobj(PETSc.ScalarType()):
2975808f684SSatish Balay    del BaseTestTAO
2985808f684SSatish Balay    del TestTAOSelf
2995808f684SSatish Balay    del TestTAOWorld
3005808f684SSatish Balay
3015808f684SSatish Balayif __name__ == '__main__':
3025808f684SSatish Balay    unittest.main()
303