xref: /petsc/src/binding/petsc4py/demo/legacy/ode/heat.py (revision bcee047adeeb73090d7e36cc71e39fc287cdbb97)
1# Solves Heat equation on a periodic domain, using raw VecScatter
2from __future__ import division
3import sys, petsc4py
4petsc4py.init(sys.argv)
5
6from petsc4py import PETSc
7from mpi4py import MPI
8import numpy
9
10class Heat(object):
11    def __init__(self,comm,N):
12        self.comm = comm
13        self.N = N              # global problem size
14        self.h = 1/N            # grid spacing on unit interval
15        self.n = N // comm.size + int(comm.rank < (N % comm.size)) # owned part of global problem
16        self.start = comm.exscan(self.n)
17        if comm.rank == 0: self.start = 0
18        gindices = numpy.arange(self.start-1, self.start+self.n+1, dtype=PETSc.IntType) % N # periodic
19        self.mat = PETSc.Mat().create(comm=comm)
20        size = (self.n, self.N) # local and global sizes
21        self.mat.setSizes((size,size))
22        self.mat.setFromOptions()
23        self.mat.setPreallocationNNZ((3,1)) # Conservative preallocation for 3 "local" columns and one non-local
24
25        # Allow matrix insertion using local indices [0:n+2]
26        lgmap = PETSc.LGMap().create(list(gindices), comm=comm)
27        self.mat.setLGMap(lgmap, lgmap)
28
29        # Global and local vectors
30        self.gvec = self.mat.createVecRight()
31        self.lvec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
32        self.lvec.setSizes(self.n+2)
33        self.lvec.setUp()
34        # Configure scatter from global to local
35        isg = PETSc.IS().createGeneral(list(gindices), comm=comm)
36        self.g2l = PETSc.Scatter().create(self.gvec, isg, self.lvec, None)
37
38        self.tozero, self.zvec = PETSc.Scatter.toZero(self.gvec)
39        self.history = []
40
41        if False:                # Print some diagnostics
42            print('[%d] local size %d, global size %d, starting offset %d' % (comm.rank, self.n, self.N, self.start))
43            self.gvec.setArray(numpy.arange(self.start,self.start+self.n))
44            self.gvec.view()
45            self.g2l.scatter(self.gvec, self.lvec, PETSc.InsertMode.INSERT)
46            for rank in range(comm.size):
47                if rank == comm.rank:
48                    print('Contents of local Vec on rank %d' % rank)
49                    self.lvec.view()
50                comm.barrier()
51    def evalSolution(self, t, x):
52        assert t == 0.0, "only for t=0.0"
53        coord = numpy.arange(self.start, self.start+self.n) / self.N
54        x.setArray((numpy.abs(coord-0.5) < 0.1) * 1.0)
55    def evalFunction(self, ts, t, x, xdot, f):
56        self.g2l.scatter(x, self.lvec, PETSc.InsertMode.INSERT) # lvec is a work vector
57        h = self.h
58        with self.lvec as u, xdot as udot:
59            f.setArray(udot*h + 2*u[1:-1]/h - u[:-2]/h - u[2:]/h) # Scale equation by volume element
60    def evalJacobian(self, ts, t, x, xdot, a, A, B):
61        h = self.h
62        for i in range(self.n):
63            lidx = i + 1
64            gidx = self.start + i
65            B.setValuesLocal([lidx], [lidx-1,lidx,lidx+1], [-1/h, a*h+2/h, -1/h])
66        B.assemble()
67        if A != B: A.assemble() # If operator is different from preconditioning matrix
68        return True # same nonzero pattern
69    def monitor(self, ts, i, t, x):
70        if self.history:
71            lasti, lastt, lastx = self.history[-1]
72            if i < lasti + 4 or t < lastt + 1e-4: return
73        self.tozero.scatter(x, self.zvec, PETSc.InsertMode.INSERT)
74        xx = self.zvec[:].tolist()
75        self.history.append((i, t, xx))
76    def plotHistory(self):
77        try:
78            from matplotlib import pylab, rcParams
79        except ImportError:
80            print("matplotlib not available")
81            raise SystemExit
82        rcParams.update({'text.usetex':True, 'figure.figsize':(10,6)})
83        #rc('figure', figsize=(600,400))
84        pylab.title('Heat: TS \\texttt{%s}' % ts.getType())
85        x = numpy.arange(self.N) / self.N
86        for i,t,u in self.history:
87            pylab.plot(x, u, label='step=%d t=%8.2g'%(i,t))
88        pylab.xlabel('$x$')
89        pylab.ylabel('$u$')
90        pylab.legend(loc='upper right')
91        pylab.savefig('heat-history.png')
92        #pylab.show()
93
94OptDB = PETSc.Options()
95ode = Heat(MPI.COMM_WORLD, OptDB.getInt('n',100))
96
97x = ode.gvec.duplicate()
98f = ode.gvec.duplicate()
99
100ts = PETSc.TS().create(comm=ode.comm)
101ts.setType(ts.Type.ROSW)        # Rosenbrock-W. ARKIMEX is a nonlinearly implicit alternative.
102
103ts.setIFunction(ode.evalFunction, ode.gvec)
104ts.setIJacobian(ode.evalJacobian, ode.mat)
105
106ts.setMonitor(ode.monitor)
107
108ts.setTime(0.0)
109ts.setTimeStep(ode.h**2)
110ts.setMaxTime(1)
111ts.setMaxSteps(100)
112ts.setExactFinalTime(PETSc.TS.ExactFinalTime.INTERPOLATE)
113ts.setMaxSNESFailures(-1)       # allow an unlimited number of failures (step will be rejected and retried)
114
115snes = ts.getSNES()             # Nonlinear solver
116snes.setTolerances(max_it=10)   # Stop nonlinear solve after 10 iterations (TS will retry with shorter step)
117ksp = snes.getKSP()             # Linear solver
118ksp.setType(ksp.Type.CG)        # Conjugate gradients
119pc = ksp.getPC()                # Preconditioner
120if False:                       # Configure algebraic multigrid, could use run-time options instead
121    pc.setType(pc.Type.GAMG)    # PETSc's native AMG implementation, mostly based on smoothed aggregation
122    OptDB['mg_coarse_pc_type'] = 'svd' # more specific multigrid options
123    OptDB['mg_levels_pc_type'] = 'sor'
124
125ts.setFromOptions()             # Apply run-time options, e.g. -ts_adapt_monitor -ts_type arkimex -snes_converged_reason
126ode.evalSolution(0.0, x)
127ts.solve(x)
128if ode.comm.rank == 0:
129    print('steps %d (%d rejected, %d SNES fails), nonlinear its %d, linear its %d'
130          % (ts.getStepNumber(), ts.getStepRejections(), ts.getSNESFailures(),
131             ts.getSNESIterations(), ts.getKSPIterations()))
132
133if OptDB.getBool('plot_history', True) and ode.comm.rank == 0:
134    ode.plotHistory()
135