xref: /petsc/src/snes/tutorials/ex55k.kokkos.cxx (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1 #include <Kokkos_Core.hpp>
2 #include <petscdmda_kokkos.hpp>
3 
4 #include <petscdm.h>
5 #include <petscdmda.h>
6 #include <petscsnes.h>
7 #include "ex55.h"
8 
9 using DefaultMemorySpace                 = Kokkos::DefaultExecutionSpace::memory_space;
10 using ConstPetscScalarKokkosOffsetView2D = Kokkos::Experimental::OffsetView<const PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
11 using PetscScalarKokkosOffsetView2D      = Kokkos::Experimental::OffsetView<PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
12 
13 using PetscCountKokkosView     = Kokkos::View<PetscCount *, DefaultMemorySpace>;
14 using PetscIntKokkosView       = Kokkos::View<PetscInt *, DefaultMemorySpace>;
15 using PetscCountKokkosViewHost = Kokkos::View<PetscCount *, Kokkos::HostSpace>;
16 using PetscScalarKokkosView    = Kokkos::View<PetscScalar *, DefaultMemorySpace>;
17 using Kokkos::Iterate;
18 using Kokkos::MDRangePolicy;
19 using Kokkos::Rank;
20 
21 KOKKOS_INLINE_FUNCTION PetscErrorCode MMSSolution1(AppCtx *user, const DMDACoor2d *c, PetscScalar *u)
22 {
23   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
24   u[0] = x * (1 - x) * y * (1 - y);
25   return 0;
26 }
27 
28 KOKKOS_INLINE_FUNCTION PetscErrorCode MMSForcing1(PetscReal user_param, const DMDACoor2d *c, PetscScalar *f)
29 {
30   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
31   f[0] = 2 * x * (1 - x) + 2 * y * (1 - y) - user_param * PetscExpReal(x * (1 - x) * y * (1 - y));
32   return 0;
33 }
34 
35 PetscErrorCode FormFunctionLocalVec(DMDALocalInfo *info, Vec x, Vec f, AppCtx *user)
36 {
37   PetscReal lambda, hx, hy, hxdhy, hydhx;
38   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
39   PetscReal user_param = user->param;
40 
41   ConstPetscScalarKokkosOffsetView2D xv;
42   PetscScalarKokkosOffsetView2D      fv;
43 
44   PetscFunctionBeginUser;
45   lambda = user->param;
46   hx     = 1.0 / (PetscReal)(info->mx - 1);
47   hy     = 1.0 / (PetscReal)(info->my - 1);
48   hxdhy  = hx / hy;
49   hydhx  = hy / hx;
50   /*
51      Compute function over the locally owned part of the grid
52   */
53   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
54   PetscCallCXX(DMDAVecGetKokkosOffsetViewWrite(info->da, f, &fv));
55 
56   PetscCallCXX(Kokkos::parallel_for(
57     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscInt j, PetscInt i) {
58       DMDACoor2d  c;
59       PetscScalar u, ue, uw, un, us, uxx, uyy, mms_solution, mms_forcing;
60 
61       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
62         c.x = i * hx;
63         c.y = j * hy;
64         MMSSolution1(user, &c, &mms_solution);
65         fv(j, i) = 2.0 * (hydhx + hxdhy) * (xv(j, i) - mms_solution);
66       } else {
67         u  = xv(j, i);
68         uw = xv(j, i - 1);
69         ue = xv(j, i + 1);
70         un = xv(j - 1, i);
71         us = xv(j + 1, i);
72 
73         /* Enforce boundary conditions at neighboring points -- setting these values causes the Jacobian to be symmetric. */
74         if (i - 1 == 0) {
75           c.x = (i - 1) * hx;
76           c.y = j * hy;
77           MMSSolution1(user, &c, &uw);
78         }
79         if (i + 1 == mx - 1) {
80           c.x = (i + 1) * hx;
81           c.y = j * hy;
82           MMSSolution1(user, &c, &ue);
83         }
84         if (j - 1 == 0) {
85           c.x = i * hx;
86           c.y = (j - 1) * hy;
87           MMSSolution1(user, &c, &un);
88         }
89         if (j + 1 == my - 1) {
90           c.x = i * hx;
91           c.y = (j + 1) * hy;
92           MMSSolution1(user, &c, &us);
93         }
94 
95         uxx         = (2.0 * u - uw - ue) * hydhx;
96         uyy         = (2.0 * u - un - us) * hxdhy;
97         mms_forcing = 0;
98         c.x         = i * hx;
99         c.y         = j * hy;
100         MMSForcing1(user_param, &c, &mms_forcing);
101         fv(j, i) = uxx + uyy - hx * hy * (lambda * PetscExpScalar(u) + mms_forcing);
102       }
103     }));
104 
105   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
106   PetscCallCXX(DMDAVecRestoreKokkosOffsetViewWrite(info->da, f, &fv));
107 
108   PetscCall(PetscLogFlops(11.0 * info->ym * info->xm));
109   PetscFunctionReturn(0);
110 }
111 
112 PetscErrorCode FormObjectiveLocalVec(DMDALocalInfo *info, Vec x, PetscReal *obj, AppCtx *user)
113 {
114   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
115   PetscReal lambda, hx, hy, hxdhy, hydhx, sc, lobj = 0;
116   MPI_Comm  comm;
117 
118   ConstPetscScalarKokkosOffsetView2D xv;
119 
120   PetscFunctionBeginUser;
121   *obj = 0;
122   PetscCall(PetscObjectGetComm((PetscObject)info->da, &comm));
123   lambda = user->param;
124   hx     = 1.0 / (PetscReal)(mx - 1);
125   hy     = 1.0 / (PetscReal)(my - 1);
126   sc     = hx * hy * lambda;
127   hxdhy  = hx / hy;
128   hydhx  = hy / hx;
129   /*
130      Compute function over the locally owned part of the grid
131   */
132   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
133 
134   PetscCallCXX(Kokkos::parallel_reduce(
135     "FormObjectiveLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}),
136     KOKKOS_LAMBDA(PetscInt j, PetscInt i, PetscReal & update) {
137       PetscScalar u, ue, uw, un, us, uxux, uyuy;
138       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
139         update += PetscRealPart((hydhx + hxdhy) * xv(j, i) * xv(j, i));
140       } else {
141         u  = xv(j, i);
142         uw = xv(j, i - 1);
143         ue = xv(j, i + 1);
144         un = xv(j - 1, i);
145         us = xv(j + 1, i);
146 
147         if (i - 1 == 0) uw = 0.;
148         if (i + 1 == mx - 1) ue = 0.;
149         if (j - 1 == 0) un = 0.;
150         if (j + 1 == my - 1) us = 0.;
151 
152         /* F[u] = 1/2\int_{\omega}\nabla^2u(x)*u(x)*dx */
153 
154         uxux = u * (2. * u - ue - uw) * hydhx;
155         uyuy = u * (2. * u - un - us) * hxdhy;
156 
157         update += PetscRealPart(0.5 * (uxux + uyuy) - sc * PetscExpScalar(u));
158       }
159     },
160     lobj));
161 
162   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
163   PetscCall(PetscLogFlops(12.0 * info->ym * info->xm));
164   PetscCallMPI(MPI_Allreduce(&lobj, obj, 1, MPIU_REAL, MPIU_SUM, comm));
165   PetscFunctionReturn(0);
166 }
167 
168 PetscErrorCode FormJacobianLocalVec(DMDALocalInfo *info, Vec x, Mat jac, Mat jacpre, AppCtx *user)
169 {
170   PetscInt     i, j;
171   PetscInt     xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
172   MatStencil   col[5], row;
173   PetscScalar  lambda, hx, hy, hxdhy, hydhx, sc;
174   DM           coordDA;
175   Vec          coordinates;
176   DMDACoor2d **coords;
177 
178   PetscFunctionBeginUser;
179   lambda = user->param;
180   /* Extract coordinates */
181   PetscCall(DMGetCoordinateDM(info->da, &coordDA));
182   PetscCall(DMGetCoordinates(info->da, &coordinates));
183 
184   PetscCall(DMDAVecGetArray(coordDA, coordinates, &coords));
185   hx = xm > 1 ? PetscRealPart(coords[ys][xs + 1].x) - PetscRealPart(coords[ys][xs].x) : 1.0;
186   hy = ym > 1 ? PetscRealPart(coords[ys + 1][xs].y) - PetscRealPart(coords[ys][xs].y) : 1.0;
187   PetscCall(DMDAVecRestoreArray(coordDA, coordinates, &coords));
188 
189   hxdhy = hx / hy;
190   hydhx = hy / hx;
191   sc    = hx * hy * lambda;
192 
193   /* ----------------------------------------- */
194   /*  MatSetPreallocationCOO()                 */
195   /* ----------------------------------------- */
196   PetscCount ncoo = ((PetscCount)xm) * ((PetscCount)ym) * 5;
197   PetscInt  *coo_i, *coo_j, *ip, *jp;
198   PetscCall(PetscMalloc2(ncoo, &coo_i, ncoo, &coo_j)); /* 5-point stencil such that each row has at most 5 nonzeros */
199 
200   ip = coo_i;
201   jp = coo_j;
202   for (j = ys; j < ys + ym; j++) {
203     for (i = xs; i < xs + xm; i++) {
204       row.j = j;
205       row.i = i;
206       /* Initialize neighbors with negative indices */
207       col[0].j = col[1].j = col[3].j = col[4].j = -1;
208       /* boundary points */
209       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
210         col[2].j = row.j;
211         col[2].i = row.i;
212       } else {
213         /* interior grid points */
214         if (j - 1 != 0) {
215           col[0].j = j - 1;
216           col[0].i = i;
217         }
218 
219         if (i - 1 != 0) {
220           col[1].j = j;
221           col[1].i = i - 1;
222         }
223 
224         col[2].j = row.j;
225         col[2].i = row.i;
226 
227         if (i + 1 != mx - 1) {
228           col[3].j = j;
229           col[3].i = i + 1;
230         }
231 
232         if (j + 1 != mx - 1) {
233           col[4].j = j + 1;
234           col[4].i = i;
235         }
236       }
237       PetscCall(DMDAMapMatStencilToGlobal(info->da, 5, col, jp));
238       for (PetscInt k = 0; k < 5; k++) ip[k] = jp[2];
239       ip += 5;
240       jp += 5;
241     }
242   }
243 
244   PetscCall(MatSetPreallocationCOO(jacpre, ncoo, coo_i, coo_j));
245   PetscCall(PetscFree2(coo_i, coo_j));
246 
247   /* ----------------------------------------- */
248   /*  MatSetValuesCOO()                        */
249   /* ----------------------------------------- */
250   PetscScalarKokkosView              coo_v("coo_v", ncoo);
251   ConstPetscScalarKokkosOffsetView2D xv;
252 
253   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
254 
255   PetscCallCXX(Kokkos::parallel_for(
256     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscCount j, PetscCount i) {
257       PetscInt p = ((j - ys) * xm + (i - xs)) * 5;
258       /* boundary points */
259       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
260         coo_v(p + 2) = 2.0 * (hydhx + hxdhy);
261       } else {
262         /* interior grid points */
263         if (j - 1 != 0) coo_v(p + 0) = -hxdhy;
264         if (i - 1 != 0) coo_v(p + 1) = -hydhx;
265 
266         coo_v(p + 2) = 2.0 * (hydhx + hxdhy) - sc * PetscExpScalar(xv(j, i));
267 
268         if (i + 1 != mx - 1) coo_v(p + 3) = -hydhx;
269         if (j + 1 != mx - 1) coo_v(p + 4) = -hxdhy;
270       }
271     }));
272   PetscCall(MatSetValuesCOO(jacpre, coo_v.data(), INSERT_VALUES));
273   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
274   PetscFunctionReturn(0);
275 }
276