xref: /petsc/src/binding/petsc4py/demo/legacy/ode/heat.py (revision 226f8a8a5081bc6ad7227cd631662400f0d6e2a0)
155a74a43SLisandro Dalcin# Solves Heat equation on a periodic domain, using raw VecScatter
269777137SStefano Zampiniimport sys
369777137SStefano Zampiniimport petsc4py
469777137SStefano Zampini
555a74a43SLisandro Dalcinpetsc4py.init(sys.argv)
655a74a43SLisandro Dalcin
755a74a43SLisandro Dalcinfrom petsc4py import PETSc
855a74a43SLisandro Dalcinfrom mpi4py import MPI
955a74a43SLisandro Dalcinimport numpy
1055a74a43SLisandro Dalcin
1169777137SStefano Zampini
1269777137SStefano Zampiniclass Heat:
1355a74a43SLisandro Dalcin    def __init__(self, comm, N):
1455a74a43SLisandro Dalcin        self.comm = comm
1555a74a43SLisandro Dalcin        self.N = N  # global problem size
1655a74a43SLisandro Dalcin        self.h = 1 / N  # grid spacing on unit interval
1769777137SStefano Zampini        self.n = N // comm.size + int(
1869777137SStefano Zampini            comm.rank < (N % comm.size)
1969777137SStefano Zampini        )  # owned part of global problem
2055a74a43SLisandro Dalcin        self.start = comm.exscan(self.n)
2169777137SStefano Zampini        if comm.rank == 0:
2269777137SStefano Zampini            self.start = 0
2369777137SStefano Zampini        gindices = (
2469777137SStefano Zampini            numpy.arange(self.start - 1, self.start + self.n + 1, dtype=PETSc.IntType)
2569777137SStefano Zampini            % N
2669777137SStefano Zampini        )  # periodic
2755a74a43SLisandro Dalcin        self.mat = PETSc.Mat().create(comm=comm)
2855a74a43SLisandro Dalcin        size = (self.n, self.N)  # local and global sizes
2955a74a43SLisandro Dalcin        self.mat.setSizes((size, size))
3055a74a43SLisandro Dalcin        self.mat.setFromOptions()
3169777137SStefano Zampini        self.mat.setPreallocationNNZ(
3269777137SStefano Zampini            (3, 1)
3369777137SStefano Zampini        )  # Conservative preallocation for 3 "local" columns and one non-local
3455a74a43SLisandro Dalcin
3555a74a43SLisandro Dalcin        # Allow matrix insertion using local indices [0:n+2]
3655a74a43SLisandro Dalcin        lgmap = PETSc.LGMap().create(list(gindices), comm=comm)
3755a74a43SLisandro Dalcin        self.mat.setLGMap(lgmap, lgmap)
3855a74a43SLisandro Dalcin
3955a74a43SLisandro Dalcin        # Global and local vectors
4055a74a43SLisandro Dalcin        self.gvec = self.mat.createVecRight()
4155a74a43SLisandro Dalcin        self.lvec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
4255a74a43SLisandro Dalcin        self.lvec.setSizes(self.n + 2)
4355a74a43SLisandro Dalcin        self.lvec.setUp()
4455a74a43SLisandro Dalcin        # Configure scatter from global to local
4555a74a43SLisandro Dalcin        isg = PETSc.IS().createGeneral(list(gindices), comm=comm)
4655a74a43SLisandro Dalcin        self.g2l = PETSc.Scatter().create(self.gvec, isg, self.lvec, None)
4755a74a43SLisandro Dalcin
4855a74a43SLisandro Dalcin        self.tozero, self.zvec = PETSc.Scatter.toZero(self.gvec)
4955a74a43SLisandro Dalcin        self.history = []
5055a74a43SLisandro Dalcin
5155a74a43SLisandro Dalcin        if False:  # Print some diagnostics
5269777137SStefano Zampini            print(
5369777137SStefano Zampini                '[%d] local size %d, global size %d, starting offset %d'
5469777137SStefano Zampini                % (comm.rank, self.n, self.N, self.start)
5569777137SStefano Zampini            )
5655a74a43SLisandro Dalcin            self.gvec.setArray(numpy.arange(self.start, self.start + self.n))
5755a74a43SLisandro Dalcin            self.gvec.view()
5855a74a43SLisandro Dalcin            self.g2l.scatter(self.gvec, self.lvec, PETSc.InsertMode.INSERT)
5955a74a43SLisandro Dalcin            for rank in range(comm.size):
6055a74a43SLisandro Dalcin                if rank == comm.rank:
6155a74a43SLisandro Dalcin                    print('Contents of local Vec on rank %d' % rank)
6255a74a43SLisandro Dalcin                    self.lvec.view()
6355a74a43SLisandro Dalcin                comm.barrier()
6469777137SStefano Zampini
6555a74a43SLisandro Dalcin    def evalSolution(self, t, x):
6669777137SStefano Zampini        if t != 0.0:
6769777137SStefano Zampini            raise ValueError('Only for t=0')
6855a74a43SLisandro Dalcin        coord = numpy.arange(self.start, self.start + self.n) / self.N
6955a74a43SLisandro Dalcin        x.setArray((numpy.abs(coord - 0.5) < 0.1) * 1.0)
7069777137SStefano Zampini
7155a74a43SLisandro Dalcin    def evalFunction(self, ts, t, x, xdot, f):
7255a74a43SLisandro Dalcin        self.g2l.scatter(x, self.lvec, PETSc.InsertMode.INSERT)  # lvec is a work vector
7355a74a43SLisandro Dalcin        h = self.h
7455a74a43SLisandro Dalcin        with self.lvec as u, xdot as udot:
7569777137SStefano Zampini            f.setArray(
7669777137SStefano Zampini                udot * h + 2 * u[1:-1] / h - u[:-2] / h - u[2:] / h
7769777137SStefano Zampini            )  # Scale equation by volume element
7869777137SStefano Zampini
7955a74a43SLisandro Dalcin    def evalJacobian(self, ts, t, x, xdot, a, A, B):
8055a74a43SLisandro Dalcin        h = self.h
8155a74a43SLisandro Dalcin        for i in range(self.n):
8255a74a43SLisandro Dalcin            lidx = i + 1
8369777137SStefano Zampini            B.setValuesLocal(
8469777137SStefano Zampini                [lidx], [lidx - 1, lidx, lidx + 1], [-1 / h, a * h + 2 / h, -1 / h]
8569777137SStefano Zampini            )
8655a74a43SLisandro Dalcin        B.assemble()
8769777137SStefano Zampini        if A != B:
88*7addb90fSBarry Smith            A.assemble()  # If operator is different from matrix used to construct the preconditioner
8969777137SStefano Zampini
9055a74a43SLisandro Dalcin    def monitor(self, ts, i, t, x):
9155a74a43SLisandro Dalcin        if self.history:
9255a74a43SLisandro Dalcin            lasti, lastt, lastx = self.history[-1]
9369777137SStefano Zampini            if i < lasti + 4 or t < lastt + 1e-4:
9469777137SStefano Zampini                return
9555a74a43SLisandro Dalcin        self.tozero.scatter(x, self.zvec, PETSc.InsertMode.INSERT)
9655a74a43SLisandro Dalcin        xx = self.zvec[:].tolist()
9755a74a43SLisandro Dalcin        self.history.append((i, t, xx))
9869777137SStefano Zampini
9955a74a43SLisandro Dalcin    def plotHistory(self):
10055a74a43SLisandro Dalcin        try:
10155a74a43SLisandro Dalcin            from matplotlib import pylab, rcParams
10255a74a43SLisandro Dalcin        except ImportError:
10369777137SStefano Zampini            return
10455a74a43SLisandro Dalcin        rcParams.update({'text.usetex': True, 'figure.figsize': (10, 6)})
10555a74a43SLisandro Dalcin        # rc('figure', figsize=(600,400))
10655a74a43SLisandro Dalcin        pylab.title('Heat: TS \\texttt{%s}' % ts.getType())
10755a74a43SLisandro Dalcin        x = numpy.arange(self.N) / self.N
10855a74a43SLisandro Dalcin        for i, t, u in self.history:
10955a74a43SLisandro Dalcin            pylab.plot(x, u, label='step=%d t=%8.2g' % (i, t))
11055a74a43SLisandro Dalcin        pylab.xlabel('$x$')
11155a74a43SLisandro Dalcin        pylab.ylabel('$u$')
11255a74a43SLisandro Dalcin        pylab.legend(loc='upper right')
11355a74a43SLisandro Dalcin        pylab.savefig('heat-history.png')
11455a74a43SLisandro Dalcin        # pylab.show()
11555a74a43SLisandro Dalcin
11669777137SStefano Zampini
11755a74a43SLisandro DalcinOptDB = PETSc.Options()
11855a74a43SLisandro Dalcinode = Heat(MPI.COMM_WORLD, OptDB.getInt('n', 100))
11955a74a43SLisandro Dalcin
12055a74a43SLisandro Dalcinx = ode.gvec.duplicate()
12155a74a43SLisandro Dalcinf = ode.gvec.duplicate()
12255a74a43SLisandro Dalcin
12355a74a43SLisandro Dalcints = PETSc.TS().create(comm=ode.comm)
12455a74a43SLisandro Dalcints.setType(ts.Type.ROSW)  # Rosenbrock-W. ARKIMEX is a nonlinearly implicit alternative.
12555a74a43SLisandro Dalcin
12655a74a43SLisandro Dalcints.setIFunction(ode.evalFunction, ode.gvec)
12755a74a43SLisandro Dalcints.setIJacobian(ode.evalJacobian, ode.mat)
12855a74a43SLisandro Dalcin
12955a74a43SLisandro Dalcints.setMonitor(ode.monitor)
13055a74a43SLisandro Dalcin
13155a74a43SLisandro Dalcints.setTime(0.0)
13255a74a43SLisandro Dalcints.setTimeStep(ode.h**2)
13355a74a43SLisandro Dalcints.setMaxTime(1)
13455a74a43SLisandro Dalcints.setMaxSteps(100)
13555a74a43SLisandro Dalcints.setExactFinalTime(PETSc.TS.ExactFinalTime.INTERPOLATE)
13669777137SStefano Zampinits.setMaxSNESFailures(
13769777137SStefano Zampini    -1
13869777137SStefano Zampini)  # allow an unlimited number of failures (step will be rejected and retried)
13955a74a43SLisandro Dalcin
14055a74a43SLisandro Dalcinsnes = ts.getSNES()  # Nonlinear solver
14169777137SStefano Zampinisnes.setTolerances(
14269777137SStefano Zampini    max_it=10
14369777137SStefano Zampini)  # Stop nonlinear solve after 10 iterations (TS will retry with shorter step)
14455a74a43SLisandro Dalcinksp = snes.getKSP()  # Linear solver
14555a74a43SLisandro Dalcinksp.setType(ksp.Type.CG)  # Conjugate gradients
14655a74a43SLisandro Dalcinpc = ksp.getPC()  # Preconditioner
14755a74a43SLisandro Dalcinif False:  # Configure algebraic multigrid, could use run-time options instead
14869777137SStefano Zampini    pc.setType(
14969777137SStefano Zampini        pc.Type.GAMG
15069777137SStefano Zampini    )  # PETSc's native AMG implementation, mostly based on smoothed aggregation
15155a74a43SLisandro Dalcin    OptDB['mg_coarse_pc_type'] = 'svd'  # more specific multigrid options
15255a74a43SLisandro Dalcin    OptDB['mg_levels_pc_type'] = 'sor'
15355a74a43SLisandro Dalcin
15455a74a43SLisandro Dalcints.setFromOptions()  # Apply run-time options, e.g. -ts_adapt_monitor -ts_type arkimex -snes_converged_reason
15555a74a43SLisandro Dalcinode.evalSolution(0.0, x)
15655a74a43SLisandro Dalcints.solve(x)
15755a74a43SLisandro Dalcinif ode.comm.rank == 0:
15869777137SStefano Zampini    print(
15969777137SStefano Zampini        'steps %d (%d rejected, %d SNES fails), nonlinear its %d, linear its %d'
16069777137SStefano Zampini        % (
16169777137SStefano Zampini            ts.getStepNumber(),
16269777137SStefano Zampini            ts.getStepRejections(),
16369777137SStefano Zampini            ts.getSNESFailures(),
16469777137SStefano Zampini            ts.getSNESIterations(),
16569777137SStefano Zampini            ts.getKSPIterations(),
16669777137SStefano Zampini        )
16769777137SStefano Zampini    )
16855a74a43SLisandro Dalcin
16955a74a43SLisandro Dalcinif OptDB.getBool('plot_history', True) and ode.comm.rank == 0:
17055a74a43SLisandro Dalcin    ode.plotHistory()
171