xref: /petsc/src/binding/petsc4py/demo/legacy/taosolve/chwirut.py (revision 5a48edb989d3ea10d6aff6c0e26d581c18691deb)
1import sys, petsc4py
2petsc4py.init(sys.argv)
3
4import numpy as np
5from petsc4py import PETSc
6
7class Chwirut(object):
8
9    """
10    Finds the nonlinear least-squares solution to the model
11    y = exp(-b1*x)/(b2+b3*x)  +  e
12    """
13
14    def __init__(self):
15        BETA = [0.2, 0.12, 0.08]
16        NOBSERVATIONS = 100
17        NPARAMETERS = 3
18
19        np.random.seed(456)
20        x = np.random.rand(NOBSERVATIONS)
21        e = np.random.rand(NOBSERVATIONS)
22
23        y = np.exp(-BETA[0]*x)/(BETA[1] + BETA[2]*x) + e
24
25        self.NOBSERVATIONS = NOBSERVATIONS
26        self.NPARAMETERS = NPARAMETERS
27        self.x = x
28        self.y = y
29
30    def createVecs(self):
31        X = PETSc.Vec().create(PETSc.COMM_SELF)
32        X.setSizes(self.NPARAMETERS)
33        F = PETSc.Vec().create(PETSc.COMM_SELF)
34        F.setSizes(self.NOBSERVATIONS)
35        return X, F
36
37    def formInitialGuess(self, X):
38        X[0] = 0.15
39        X[1] = 0.08
40        X[2] = 0.05
41
42    def formResidual(self, tao, X, F):
43        x, y = self.x, self.y
44        b1, b2, b3 = X.array
45        F.array = y - np.exp(-b1*x)/(b2 + b3*x)
46
47    def plotSolution(self, X):
48        try:
49            from matplotlib import pylab
50        except ImportError:
51            return
52        b1, b2, b3 = X.array
53        x, y = self.x, self.y
54        u = np.linspace(x.min(), x.max(), 100)
55        v = np.exp(-b1*u)/(b2+b3*u)
56        pylab.plot(x, y, 'ro')
57        pylab.plot(u, v, 'b-')
58        pylab.show()
59
60OptDB = PETSc.Options()
61
62user = Chwirut()
63
64x, f = user.createVecs()
65x.setFromOptions()
66f.setFromOptions()
67
68tao = PETSc.TAO().create(PETSc.COMM_SELF)
69tao.setType(PETSc.TAO.Type.POUNDERS)
70tao.setResidual(user.formResidual, f)
71tao.setFromOptions()
72
73user.formInitialGuess(x)
74tao.solve(x)
75
76plot = OptDB.getBool('plot', False)
77if plot: user.plotSolution(x)
78
79x.destroy()
80f.destroy()
81tao.destroy()
82