xref: /petsc/src/binding/petsc4py/demo/legacy/bratu3d/bratu3d.py (revision 5a48edb989d3ea10d6aff6c0e26d581c18691deb)
1*55a74a43SLisandro Dalcin# ------------------------------------------------------------------------
2*55a74a43SLisandro Dalcin#
3*55a74a43SLisandro Dalcin#  Solid Fuel Ignition (SFI) problem.  This problem is modeled by the
4*55a74a43SLisandro Dalcin#  partial differential equation
5*55a74a43SLisandro Dalcin#
6*55a74a43SLisandro Dalcin#          -Laplacian(u) - lambda * exp(u) = 0,  0 < x,y,z < 1,
7*55a74a43SLisandro Dalcin#
8*55a74a43SLisandro Dalcin#  with boundary conditions
9*55a74a43SLisandro Dalcin#
10*55a74a43SLisandro Dalcin#           u = 0  for  x = 0, x = 1, y = 0, y = 1, z = 0, z = 1
11*55a74a43SLisandro Dalcin#
12*55a74a43SLisandro Dalcin#  A finite difference approximation with the usual 7-point stencil
13*55a74a43SLisandro Dalcin#  is used to discretize the boundary value problem to obtain a
14*55a74a43SLisandro Dalcin#  nonlinear system of equations. The problem is solved in a 3D
15*55a74a43SLisandro Dalcin#  rectangular domain, using distributed arrays (DAs) to partition
16*55a74a43SLisandro Dalcin#  the parallel grid.
17*55a74a43SLisandro Dalcin#
18*55a74a43SLisandro Dalcin# ------------------------------------------------------------------------
19*55a74a43SLisandro Dalcin
20*55a74a43SLisandro Dalcintry: range = xrange
21*55a74a43SLisandro Dalcinexcept: pass
22*55a74a43SLisandro Dalcin
23*55a74a43SLisandro Dalcinimport sys, petsc4py
24*55a74a43SLisandro Dalcinpetsc4py.init(sys.argv)
25*55a74a43SLisandro Dalcin
26*55a74a43SLisandro Dalcinfrom numpy import exp, sqrt
27*55a74a43SLisandro Dalcinfrom petsc4py import PETSc
28*55a74a43SLisandro Dalcin
29*55a74a43SLisandro Dalcinclass Bratu3D(object):
30*55a74a43SLisandro Dalcin
31*55a74a43SLisandro Dalcin    def __init__(self, da, lambda_):
32*55a74a43SLisandro Dalcin        assert da.getDim() == 3
33*55a74a43SLisandro Dalcin        self.da = da
34*55a74a43SLisandro Dalcin        self.lambda_ = lambda_
35*55a74a43SLisandro Dalcin        self.localX  = da.createLocalVec()
36*55a74a43SLisandro Dalcin
37*55a74a43SLisandro Dalcin    def formInitGuess(self, snes, X):
38*55a74a43SLisandro Dalcin        #
39*55a74a43SLisandro Dalcin        x = self.da.getVecArray(X)
40*55a74a43SLisandro Dalcin        #
41*55a74a43SLisandro Dalcin        mx, my, mz = self.da.getSizes()
42*55a74a43SLisandro Dalcin        hx, hy, hz = [1.0/(m-1) for m in [mx, my, mz]]
43*55a74a43SLisandro Dalcin        lambda_ = self.lambda_
44*55a74a43SLisandro Dalcin        scale = lambda_/(lambda_ + 1.0)
45*55a74a43SLisandro Dalcin        #
46*55a74a43SLisandro Dalcin        (xs, xe), (ys, ye), (zs, ze) = self.da.getRanges()
47*55a74a43SLisandro Dalcin        for k in range(zs, ze):
48*55a74a43SLisandro Dalcin            min_k = min(k,mz-k-1)*hz
49*55a74a43SLisandro Dalcin            for j in range(ys, ye):
50*55a74a43SLisandro Dalcin                min_j = min(j,my-j-1)*hy
51*55a74a43SLisandro Dalcin                for i in range(xs, xe):
52*55a74a43SLisandro Dalcin                    min_i = min(i,mx-i-1)*hx
53*55a74a43SLisandro Dalcin                    if (i==0    or j==0    or k==0 or
54*55a74a43SLisandro Dalcin                        i==mx-1 or j==my-1 or k==mz-1):
55*55a74a43SLisandro Dalcin                        # boundary points
56*55a74a43SLisandro Dalcin                        x[i, j, k] = 0.0
57*55a74a43SLisandro Dalcin                    else:
58*55a74a43SLisandro Dalcin                        # interior points
59*55a74a43SLisandro Dalcin                        min_kij = min(min_i,min_j,min_k)
60*55a74a43SLisandro Dalcin                        x[i, j, k] = scale*sqrt(min_kij)
61*55a74a43SLisandro Dalcin
62*55a74a43SLisandro Dalcin    def formFunction(self, snes, X, F):
63*55a74a43SLisandro Dalcin        #
64*55a74a43SLisandro Dalcin        self.da.globalToLocal(X, self.localX)
65*55a74a43SLisandro Dalcin        x = self.da.getVecArray(self.localX)
66*55a74a43SLisandro Dalcin        f = self.da.getVecArray(F)
67*55a74a43SLisandro Dalcin        #
68*55a74a43SLisandro Dalcin        mx, my, mz = self.da.getSizes()
69*55a74a43SLisandro Dalcin        hx, hy, hz = [1.0/m for m in [mx, my, mz]]
70*55a74a43SLisandro Dalcin        hxhyhz  = hx*hy*hz
71*55a74a43SLisandro Dalcin        hxhzdhy = hx*hz/hy;
72*55a74a43SLisandro Dalcin        hyhzdhx = hy*hz/hx;
73*55a74a43SLisandro Dalcin        hxhydhz = hx*hy/hz;
74*55a74a43SLisandro Dalcin        lambda_ = self.lambda_
75*55a74a43SLisandro Dalcin        #
76*55a74a43SLisandro Dalcin        (xs, xe), (ys, ye), (zs, ze) = self.da.getRanges()
77*55a74a43SLisandro Dalcin        for k in range(zs, ze):
78*55a74a43SLisandro Dalcin            for j in range(ys, ye):
79*55a74a43SLisandro Dalcin                for i in range(xs, xe):
80*55a74a43SLisandro Dalcin                    if (i==0    or j==0    or k==0 or
81*55a74a43SLisandro Dalcin                        i==mx-1 or j==my-1 or k==mz-1):
82*55a74a43SLisandro Dalcin                        f[i, j, k] = x[i, j, k] - 0
83*55a74a43SLisandro Dalcin                    else:
84*55a74a43SLisandro Dalcin                        u   = x[ i  ,  j  ,  k ] # center
85*55a74a43SLisandro Dalcin                        u_e = x[i+1 ,  j  ,  k ] # east
86*55a74a43SLisandro Dalcin                        u_w = x[i-1 ,  j  ,  k ] # west
87*55a74a43SLisandro Dalcin                        u_n = x[ i  , j+1 ,  k ] # north
88*55a74a43SLisandro Dalcin                        u_s = x[ i  , j-1 ,  k ] # south
89*55a74a43SLisandro Dalcin                        u_u = x[ i  ,  j  , k+1] # up
90*55a74a43SLisandro Dalcin                        u_d = x[ i  ,  j  , k-1] # down
91*55a74a43SLisandro Dalcin                        u_xx = (-u_e + 2*u - u_w)*hyhzdhx
92*55a74a43SLisandro Dalcin                        u_yy = (-u_n + 2*u - u_s)*hxhzdhy
93*55a74a43SLisandro Dalcin                        u_zz = (-u_u + 2*u - u_d)*hxhydhz
94*55a74a43SLisandro Dalcin                        f[i, j, k] = u_xx + u_yy + u_zz \
95*55a74a43SLisandro Dalcin                                     - lambda_*exp(u)*hxhyhz
96*55a74a43SLisandro Dalcin
97*55a74a43SLisandro Dalcin    def formJacobian(self, snes, X, J, P):
98*55a74a43SLisandro Dalcin        #
99*55a74a43SLisandro Dalcin        self.da.globalToLocal(X, self.localX)
100*55a74a43SLisandro Dalcin        x = self.da.getVecArray(self.localX)
101*55a74a43SLisandro Dalcin        #
102*55a74a43SLisandro Dalcin        mx, my, mz = self.da.getSizes()
103*55a74a43SLisandro Dalcin        hx, hy, hz = [1.0/m for m in [mx, my, mz]]
104*55a74a43SLisandro Dalcin        hxhyhz  = hx*hy*hz
105*55a74a43SLisandro Dalcin        hxhzdhy = hx*hz/hy;
106*55a74a43SLisandro Dalcin        hyhzdhx = hy*hz/hx;
107*55a74a43SLisandro Dalcin        hxhydhz = hx*hy/hz;
108*55a74a43SLisandro Dalcin        lambda_ = self.lambda_
109*55a74a43SLisandro Dalcin        #
110*55a74a43SLisandro Dalcin        P.zeroEntries()
111*55a74a43SLisandro Dalcin        row = PETSc.Mat.Stencil()
112*55a74a43SLisandro Dalcin        col = PETSc.Mat.Stencil()
113*55a74a43SLisandro Dalcin        #
114*55a74a43SLisandro Dalcin        (xs, xe), (ys, ye), (zs, ze) = self.da.getRanges()
115*55a74a43SLisandro Dalcin        for k in range(zs, ze):
116*55a74a43SLisandro Dalcin            for j in range(ys, ye):
117*55a74a43SLisandro Dalcin                for i in range(xs, xe):
118*55a74a43SLisandro Dalcin                    row.index = (i,j,k)
119*55a74a43SLisandro Dalcin                    row.field = 0
120*55a74a43SLisandro Dalcin                    if (i==0    or j==0    or k==0 or
121*55a74a43SLisandro Dalcin                        i==mx-1 or j==my-1 or k==mz-1):
122*55a74a43SLisandro Dalcin                        P.setValueStencil(row, row, 1.0)
123*55a74a43SLisandro Dalcin                    else:
124*55a74a43SLisandro Dalcin                        u = x[i,j,k]
125*55a74a43SLisandro Dalcin                        diag = (2*(hyhzdhx+hxhzdhy+hxhydhz)
126*55a74a43SLisandro Dalcin                                - lambda_*exp(u)*hxhyhz)
127*55a74a43SLisandro Dalcin                        for index, value in [
128*55a74a43SLisandro Dalcin                            ((i,j,k-1), -hxhydhz),
129*55a74a43SLisandro Dalcin                            ((i,j-1,k), -hxhzdhy),
130*55a74a43SLisandro Dalcin                            ((i-1,j,k), -hyhzdhx),
131*55a74a43SLisandro Dalcin                            ((i, j, k), diag),
132*55a74a43SLisandro Dalcin                            ((i+1,j,k), -hyhzdhx),
133*55a74a43SLisandro Dalcin                            ((i,j+1,k), -hxhzdhy),
134*55a74a43SLisandro Dalcin                            ((i,j,k+1), -hxhydhz),
135*55a74a43SLisandro Dalcin                            ]:
136*55a74a43SLisandro Dalcin                            col.index = index
137*55a74a43SLisandro Dalcin                            col.field = 0
138*55a74a43SLisandro Dalcin                            P.setValueStencil(row, col, value)
139*55a74a43SLisandro Dalcin        P.assemble()
140*55a74a43SLisandro Dalcin        if J != P: J.assemble() # matrix-free operator
141*55a74a43SLisandro Dalcin        return PETSc.Mat.Structure.SAME_NONZERO_PATTERN
142*55a74a43SLisandro Dalcin
143*55a74a43SLisandro DalcinOptDB = PETSc.Options()
144*55a74a43SLisandro Dalcin
145*55a74a43SLisandro Dalcinn  = OptDB.getInt('n', 16)
146*55a74a43SLisandro Dalcinnx = OptDB.getInt('nx', n)
147*55a74a43SLisandro Dalcinny = OptDB.getInt('ny', n)
148*55a74a43SLisandro Dalcinnz = OptDB.getInt('nz', n)
149*55a74a43SLisandro Dalcinlambda_ = OptDB.getReal('lambda', 6.0)
150*55a74a43SLisandro Dalcin
151*55a74a43SLisandro Dalcinda = PETSc.DMDA().create([nx, ny, nz], stencil_width=1)
152*55a74a43SLisandro Dalcinpde = Bratu3D(da, lambda_)
153*55a74a43SLisandro Dalcin
154*55a74a43SLisandro Dalcinsnes = PETSc.SNES().create()
155*55a74a43SLisandro DalcinF = da.createGlobalVec()
156*55a74a43SLisandro Dalcinsnes.setFunction(pde.formFunction, F)
157*55a74a43SLisandro Dalcin
158*55a74a43SLisandro Dalcinfd = OptDB.getBool('fd', False)
159*55a74a43SLisandro Dalcinmf = OptDB.getBool('mf', False)
160*55a74a43SLisandro Dalcinif mf:
161*55a74a43SLisandro Dalcin    J = None
162*55a74a43SLisandro Dalcin    snes.setUseMF()
163*55a74a43SLisandro Dalcinelse:
164*55a74a43SLisandro Dalcin    J = da.createMat()
165*55a74a43SLisandro Dalcin    snes.setJacobian(pde.formJacobian, J)
166*55a74a43SLisandro Dalcin    if fd:
167*55a74a43SLisandro Dalcin        snes.setUseFD()
168*55a74a43SLisandro Dalcin
169*55a74a43SLisandro Dalcinsnes.getKSP().setType('cg')
170*55a74a43SLisandro Dalcinsnes.setFromOptions()
171*55a74a43SLisandro Dalcin
172*55a74a43SLisandro DalcinX = da.createGlobalVec()
173*55a74a43SLisandro Dalcinpde.formInitGuess(snes, X)
174*55a74a43SLisandro Dalcinsnes.solve(None, X)
175*55a74a43SLisandro Dalcin
176*55a74a43SLisandro DalcinU = da.createNaturalVec()
177*55a74a43SLisandro Dalcinda.globalToNatural(X, U)
178*55a74a43SLisandro Dalcin
179*55a74a43SLisandro Dalcinif OptDB.getBool('plot_mpl', False):
180*55a74a43SLisandro Dalcin
181*55a74a43SLisandro Dalcin    def plot_mpl(da, U):
182*55a74a43SLisandro Dalcin        comm = da.getComm()
183*55a74a43SLisandro Dalcin        rank = comm.getRank()
184*55a74a43SLisandro Dalcin        scatter, U0 = PETSc.Scatter.toZero(U)
185*55a74a43SLisandro Dalcin        scatter.scatter(U, U0, False, PETSc.Scatter.Mode.FORWARD)
186*55a74a43SLisandro Dalcin        if rank == 0:
187*55a74a43SLisandro Dalcin            try:
188*55a74a43SLisandro Dalcin                from matplotlib import pylab
189*55a74a43SLisandro Dalcin            except ImportError:
190*55a74a43SLisandro Dalcin                PETSc.Sys.Print("matplotlib not available")
191*55a74a43SLisandro Dalcin            else:
192*55a74a43SLisandro Dalcin                from numpy import mgrid
193*55a74a43SLisandro Dalcin                nx, ny, nz = da.sizes
194*55a74a43SLisandro Dalcin                solution = U0[...].reshape(da.sizes, order='f')
195*55a74a43SLisandro Dalcin                xx, yy =  mgrid[0:1:1j*nx,0:1:1j*ny]
196*55a74a43SLisandro Dalcin                pylab.contourf(xx, yy, solution[:, :, nz//2])
197*55a74a43SLisandro Dalcin                pylab.axis('equal')
198*55a74a43SLisandro Dalcin                pylab.xlabel('X')
199*55a74a43SLisandro Dalcin                pylab.ylabel('Y')
200*55a74a43SLisandro Dalcin                pylab.title('Z/2')
201*55a74a43SLisandro Dalcin                pylab.show()
202*55a74a43SLisandro Dalcin        comm.barrier()
203*55a74a43SLisandro Dalcin
204*55a74a43SLisandro Dalcin    plot_mpl(da, U)
205