xref: /petsc/src/binding/petsc4py/demo/legacy/ode/heat.py (revision 226f8a8a5081bc6ad7227cd631662400f0d6e2a0)
1# Solves Heat equation on a periodic domain, using raw VecScatter
2import sys
3import petsc4py
4
5petsc4py.init(sys.argv)
6
7from petsc4py import PETSc
8from mpi4py import MPI
9import numpy
10
11
12class Heat:
13    def __init__(self, comm, N):
14        self.comm = comm
15        self.N = N  # global problem size
16        self.h = 1 / N  # grid spacing on unit interval
17        self.n = N // comm.size + int(
18            comm.rank < (N % comm.size)
19        )  # owned part of global problem
20        self.start = comm.exscan(self.n)
21        if comm.rank == 0:
22            self.start = 0
23        gindices = (
24            numpy.arange(self.start - 1, self.start + self.n + 1, dtype=PETSc.IntType)
25            % N
26        )  # periodic
27        self.mat = PETSc.Mat().create(comm=comm)
28        size = (self.n, self.N)  # local and global sizes
29        self.mat.setSizes((size, size))
30        self.mat.setFromOptions()
31        self.mat.setPreallocationNNZ(
32            (3, 1)
33        )  # Conservative preallocation for 3 "local" columns and one non-local
34
35        # Allow matrix insertion using local indices [0:n+2]
36        lgmap = PETSc.LGMap().create(list(gindices), comm=comm)
37        self.mat.setLGMap(lgmap, lgmap)
38
39        # Global and local vectors
40        self.gvec = self.mat.createVecRight()
41        self.lvec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
42        self.lvec.setSizes(self.n + 2)
43        self.lvec.setUp()
44        # Configure scatter from global to local
45        isg = PETSc.IS().createGeneral(list(gindices), comm=comm)
46        self.g2l = PETSc.Scatter().create(self.gvec, isg, self.lvec, None)
47
48        self.tozero, self.zvec = PETSc.Scatter.toZero(self.gvec)
49        self.history = []
50
51        if False:  # Print some diagnostics
52            print(
53                '[%d] local size %d, global size %d, starting offset %d'
54                % (comm.rank, self.n, self.N, self.start)
55            )
56            self.gvec.setArray(numpy.arange(self.start, self.start + self.n))
57            self.gvec.view()
58            self.g2l.scatter(self.gvec, self.lvec, PETSc.InsertMode.INSERT)
59            for rank in range(comm.size):
60                if rank == comm.rank:
61                    print('Contents of local Vec on rank %d' % rank)
62                    self.lvec.view()
63                comm.barrier()
64
65    def evalSolution(self, t, x):
66        if t != 0.0:
67            raise ValueError('Only for t=0')
68        coord = numpy.arange(self.start, self.start + self.n) / self.N
69        x.setArray((numpy.abs(coord - 0.5) < 0.1) * 1.0)
70
71    def evalFunction(self, ts, t, x, xdot, f):
72        self.g2l.scatter(x, self.lvec, PETSc.InsertMode.INSERT)  # lvec is a work vector
73        h = self.h
74        with self.lvec as u, xdot as udot:
75            f.setArray(
76                udot * h + 2 * u[1:-1] / h - u[:-2] / h - u[2:] / h
77            )  # Scale equation by volume element
78
79    def evalJacobian(self, ts, t, x, xdot, a, A, B):
80        h = self.h
81        for i in range(self.n):
82            lidx = i + 1
83            B.setValuesLocal(
84                [lidx], [lidx - 1, lidx, lidx + 1], [-1 / h, a * h + 2 / h, -1 / h]
85            )
86        B.assemble()
87        if A != B:
88            A.assemble()  # If operator is different from matrix used to construct the preconditioner
89
90    def monitor(self, ts, i, t, x):
91        if self.history:
92            lasti, lastt, lastx = self.history[-1]
93            if i < lasti + 4 or t < lastt + 1e-4:
94                return
95        self.tozero.scatter(x, self.zvec, PETSc.InsertMode.INSERT)
96        xx = self.zvec[:].tolist()
97        self.history.append((i, t, xx))
98
99    def plotHistory(self):
100        try:
101            from matplotlib import pylab, rcParams
102        except ImportError:
103            return
104        rcParams.update({'text.usetex': True, 'figure.figsize': (10, 6)})
105        # rc('figure', figsize=(600,400))
106        pylab.title('Heat: TS \\texttt{%s}' % ts.getType())
107        x = numpy.arange(self.N) / self.N
108        for i, t, u in self.history:
109            pylab.plot(x, u, label='step=%d t=%8.2g' % (i, t))
110        pylab.xlabel('$x$')
111        pylab.ylabel('$u$')
112        pylab.legend(loc='upper right')
113        pylab.savefig('heat-history.png')
114        # pylab.show()
115
116
117OptDB = PETSc.Options()
118ode = Heat(MPI.COMM_WORLD, OptDB.getInt('n', 100))
119
120x = ode.gvec.duplicate()
121f = ode.gvec.duplicate()
122
123ts = PETSc.TS().create(comm=ode.comm)
124ts.setType(ts.Type.ROSW)  # Rosenbrock-W. ARKIMEX is a nonlinearly implicit alternative.
125
126ts.setIFunction(ode.evalFunction, ode.gvec)
127ts.setIJacobian(ode.evalJacobian, ode.mat)
128
129ts.setMonitor(ode.monitor)
130
131ts.setTime(0.0)
132ts.setTimeStep(ode.h**2)
133ts.setMaxTime(1)
134ts.setMaxSteps(100)
135ts.setExactFinalTime(PETSc.TS.ExactFinalTime.INTERPOLATE)
136ts.setMaxSNESFailures(
137    -1
138)  # allow an unlimited number of failures (step will be rejected and retried)
139
140snes = ts.getSNES()  # Nonlinear solver
141snes.setTolerances(
142    max_it=10
143)  # Stop nonlinear solve after 10 iterations (TS will retry with shorter step)
144ksp = snes.getKSP()  # Linear solver
145ksp.setType(ksp.Type.CG)  # Conjugate gradients
146pc = ksp.getPC()  # Preconditioner
147if False:  # Configure algebraic multigrid, could use run-time options instead
148    pc.setType(
149        pc.Type.GAMG
150    )  # PETSc's native AMG implementation, mostly based on smoothed aggregation
151    OptDB['mg_coarse_pc_type'] = 'svd'  # more specific multigrid options
152    OptDB['mg_levels_pc_type'] = 'sor'
153
154ts.setFromOptions()  # Apply run-time options, e.g. -ts_adapt_monitor -ts_type arkimex -snes_converged_reason
155ode.evalSolution(0.0, x)
156ts.solve(x)
157if ode.comm.rank == 0:
158    print(
159        'steps %d (%d rejected, %d SNES fails), nonlinear its %d, linear its %d'
160        % (
161            ts.getStepNumber(),
162            ts.getStepRejections(),
163            ts.getSNESFailures(),
164            ts.getSNESIterations(),
165            ts.getKSPIterations(),
166        )
167    )
168
169if OptDB.getBool('plot_history', True) and ode.comm.rank == 0:
170    ode.plotHistory()
171