xref: /petsc/src/snes/tutorials/ex55k.kokkos.cxx (revision 1511cd715a1f0c8d257549c5ebe5cee9c6feed4d)
1 #include <Kokkos_Core.hpp>
2 #include <petscdmda_kokkos.hpp>
3 
4 #include <petscdm.h>
5 #include <petscdmda.h>
6 #include <petscsnes.h>
7 #include "ex55.h"
8 
9 using DefaultMemorySpace                 = Kokkos::DefaultExecutionSpace::memory_space;
10 using ConstPetscScalarKokkosOffsetView2D = Kokkos::Experimental::OffsetView<const PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
11 using PetscScalarKokkosOffsetView2D      = Kokkos::Experimental::OffsetView<PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
12 
13 using PetscCountKokkosView     = Kokkos::View<PetscCount *, DefaultMemorySpace>;
14 using PetscIntKokkosView       = Kokkos::View<PetscInt *, DefaultMemorySpace>;
15 using PetscCountKokkosViewHost = Kokkos::View<PetscCount *, Kokkos::HostSpace>;
16 using PetscScalarKokkosView    = Kokkos::View<PetscScalar *, DefaultMemorySpace>;
17 using Kokkos::Iterate;
18 using Kokkos::MDRangePolicy;
19 using Kokkos::Rank;
20 
21 KOKKOS_INLINE_FUNCTION PetscErrorCode MMSSolution1(AppCtx *user, const DMDACoor2d *c, PetscScalar *u) {
22   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
23   u[0] = x * (1 - x) * y * (1 - y);
24   return 0;
25 }
26 
27 KOKKOS_INLINE_FUNCTION PetscErrorCode MMSForcing1(PetscReal user_param, const DMDACoor2d *c, PetscScalar *f) {
28   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
29   f[0] = 2 * x * (1 - x) + 2 * y * (1 - y) - user_param * PetscExpReal(x * (1 - x) * y * (1 - y));
30   return 0;
31 }
32 
33 PetscErrorCode FormFunctionLocalVec(DMDALocalInfo *info, Vec x, Vec f, AppCtx *user) {
34   PetscReal lambda, hx, hy, hxdhy, hydhx;
35   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
36   PetscReal user_param = user->param;
37 
38   ConstPetscScalarKokkosOffsetView2D xv;
39   PetscScalarKokkosOffsetView2D      fv;
40 
41   PetscFunctionBeginUser;
42   lambda = user->param;
43   hx     = 1.0 / (PetscReal)(info->mx - 1);
44   hy     = 1.0 / (PetscReal)(info->my - 1);
45   hxdhy  = hx / hy;
46   hydhx  = hy / hx;
47   /*
48      Compute function over the locally owned part of the grid
49   */
50   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
51   PetscCallCXX(DMDAVecGetKokkosOffsetViewWrite(info->da, f, &fv));
52 
53   PetscCallCXX(Kokkos::parallel_for(
54     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscInt j, PetscInt i) {
55       DMDACoor2d  c;
56       PetscScalar u, ue, uw, un, us, uxx, uyy, mms_solution, mms_forcing;
57 
58       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
59         c.x = i * hx;
60         c.y = j * hy;
61         MMSSolution1(user, &c, &mms_solution);
62         fv(j, i) = 2.0 * (hydhx + hxdhy) * (xv(j, i) - mms_solution);
63       } else {
64         u  = xv(j, i);
65         uw = xv(j, i - 1);
66         ue = xv(j, i + 1);
67         un = xv(j - 1, i);
68         us = xv(j + 1, i);
69 
70         /* Enforce boundary conditions at neighboring points -- setting these values causes the Jacobian to be symmetric. */
71         if (i - 1 == 0) {
72           c.x = (i - 1) * hx;
73           c.y = j * hy;
74           MMSSolution1(user, &c, &uw);
75         }
76         if (i + 1 == mx - 1) {
77           c.x = (i + 1) * hx;
78           c.y = j * hy;
79           MMSSolution1(user, &c, &ue);
80         }
81         if (j - 1 == 0) {
82           c.x = i * hx;
83           c.y = (j - 1) * hy;
84           MMSSolution1(user, &c, &un);
85         }
86         if (j + 1 == my - 1) {
87           c.x = i * hx;
88           c.y = (j + 1) * hy;
89           MMSSolution1(user, &c, &us);
90         }
91 
92         uxx         = (2.0 * u - uw - ue) * hydhx;
93         uyy         = (2.0 * u - un - us) * hxdhy;
94         mms_forcing = 0;
95         c.x         = i * hx;
96         c.y         = j * hy;
97         MMSForcing1(user_param, &c, &mms_forcing);
98         fv(j, i) = uxx + uyy - hx * hy * (lambda * PetscExpScalar(u) + mms_forcing);
99       }
100     }));
101 
102   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
103   PetscCallCXX(DMDAVecRestoreKokkosOffsetViewWrite(info->da, f, &fv));
104 
105   PetscCall(PetscLogFlops(11.0 * info->ym * info->xm));
106   PetscFunctionReturn(0);
107 }
108 
109 PetscErrorCode FormObjectiveLocalVec(DMDALocalInfo *info, Vec x, PetscReal *obj, AppCtx *user) {
110   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
111   PetscReal lambda, hx, hy, hxdhy, hydhx, sc, lobj = 0;
112   MPI_Comm  comm;
113 
114   ConstPetscScalarKokkosOffsetView2D xv;
115 
116   PetscFunctionBeginUser;
117   *obj = 0;
118   PetscCall(PetscObjectGetComm((PetscObject)info->da, &comm));
119   lambda = user->param;
120   hx     = 1.0 / (PetscReal)(mx - 1);
121   hy     = 1.0 / (PetscReal)(my - 1);
122   sc     = hx * hy * lambda;
123   hxdhy  = hx / hy;
124   hydhx  = hy / hx;
125   /*
126      Compute function over the locally owned part of the grid
127   */
128   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
129 
130   PetscCallCXX(Kokkos::parallel_reduce(
131     "FormObjectiveLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}),
132     KOKKOS_LAMBDA(PetscInt j, PetscInt i, PetscReal & update) {
133       PetscScalar u, ue, uw, un, us, uxux, uyuy;
134       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
135         update += PetscRealPart((hydhx + hxdhy) * xv(j, i) * xv(j, i));
136       } else {
137         u  = xv(j, i);
138         uw = xv(j, i - 1);
139         ue = xv(j, i + 1);
140         un = xv(j - 1, i);
141         us = xv(j + 1, i);
142 
143         if (i - 1 == 0) uw = 0.;
144         if (i + 1 == mx - 1) ue = 0.;
145         if (j - 1 == 0) un = 0.;
146         if (j + 1 == my - 1) us = 0.;
147 
148         /* F[u] = 1/2\int_{\omega}\nabla^2u(x)*u(x)*dx */
149 
150         uxux = u * (2. * u - ue - uw) * hydhx;
151         uyuy = u * (2. * u - un - us) * hxdhy;
152 
153         update += PetscRealPart(0.5 * (uxux + uyuy) - sc * PetscExpScalar(u));
154       }
155     },
156     lobj));
157 
158   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
159   PetscCall(PetscLogFlops(12.0 * info->ym * info->xm));
160   PetscCallMPI(MPI_Allreduce(&lobj, obj, 1, MPIU_REAL, MPIU_SUM, comm));
161   PetscFunctionReturn(0);
162 }
163 
164 PetscErrorCode FormJacobianLocalVec(DMDALocalInfo *info, Vec x, Mat jac, Mat jacpre, AppCtx *user) {
165   PetscInt     i, j;
166   PetscInt     xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
167   MatStencil   col[5], row;
168   PetscScalar  lambda, hx, hy, hxdhy, hydhx, sc;
169   DM           coordDA;
170   Vec          coordinates;
171   DMDACoor2d **coords;
172 
173   PetscFunctionBeginUser;
174   lambda = user->param;
175   /* Extract coordinates */
176   PetscCall(DMGetCoordinateDM(info->da, &coordDA));
177   PetscCall(DMGetCoordinates(info->da, &coordinates));
178 
179   PetscCall(DMDAVecGetArray(coordDA, coordinates, &coords));
180   hx = xm > 1 ? PetscRealPart(coords[ys][xs + 1].x) - PetscRealPart(coords[ys][xs].x) : 1.0;
181   hy = ym > 1 ? PetscRealPart(coords[ys + 1][xs].y) - PetscRealPart(coords[ys][xs].y) : 1.0;
182   PetscCall(DMDAVecRestoreArray(coordDA, coordinates, &coords));
183 
184   hxdhy = hx / hy;
185   hydhx = hy / hx;
186   sc    = hx * hy * lambda;
187 
188   /* ----------------------------------------- */
189   /*  MatSetPreallocationCOO()                 */
190   /* ----------------------------------------- */
191   PetscCount ncoo = ((PetscCount)xm) * ((PetscCount)ym) * 5;
192   PetscInt  *coo_i, *coo_j, *ip, *jp;
193   PetscCall(PetscMalloc2(ncoo, &coo_i, ncoo, &coo_j)); /* 5-point stencil such that each row has at most 5 nonzeros */
194 
195   ip = coo_i;
196   jp = coo_j;
197   for (j = ys; j < ys + ym; j++) {
198     for (i = xs; i < xs + xm; i++) {
199       row.j    = j;
200       row.i    = i;
201       /* Initialize neighbors with negative indices */
202       col[0].j = col[1].j = col[3].j = col[4].j = -1;
203       /* boundary points */
204       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
205         col[2].j = row.j;
206         col[2].i = row.i;
207       } else {
208         /* interior grid points */
209         if (j - 1 != 0) {
210           col[0].j = j - 1;
211           col[0].i = i;
212         }
213 
214         if (i - 1 != 0) {
215           col[1].j = j;
216           col[1].i = i - 1;
217         }
218 
219         col[2].j = row.j;
220         col[2].i = row.i;
221 
222         if (i + 1 != mx - 1) {
223           col[3].j = j;
224           col[3].i = i + 1;
225         }
226 
227         if (j + 1 != mx - 1) {
228           col[4].j = j + 1;
229           col[4].i = i;
230         }
231       }
232       PetscCall(DMDAMapMatStencilToGlobal(info->da, 5, col, jp));
233       for (PetscInt k = 0; k < 5; k++) ip[k] = jp[2];
234       ip += 5;
235       jp += 5;
236     }
237   }
238 
239   PetscCall(MatSetPreallocationCOO(jacpre, ncoo, coo_i, coo_j));
240   PetscCall(PetscFree2(coo_i, coo_j));
241 
242   /* ----------------------------------------- */
243   /*  MatSetValuesCOO()                        */
244   /* ----------------------------------------- */
245   PetscScalarKokkosView              coo_v("coo_v", ncoo);
246   ConstPetscScalarKokkosOffsetView2D xv;
247 
248   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
249 
250   PetscCallCXX(Kokkos::parallel_for(
251     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscCount j, PetscCount i) {
252       PetscInt p = ((j - ys) * xm + (i - xs)) * 5;
253       /* boundary points */
254       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
255         coo_v(p + 2) = 2.0 * (hydhx + hxdhy);
256       } else {
257         /* interior grid points */
258         if (j - 1 != 0) coo_v(p + 0) = -hxdhy;
259         if (i - 1 != 0) coo_v(p + 1) = -hydhx;
260 
261         coo_v(p + 2) = 2.0 * (hydhx + hxdhy) - sc * PetscExpScalar(xv(j, i));
262 
263         if (i + 1 != mx - 1) coo_v(p + 3) = -hydhx;
264         if (j + 1 != mx - 1) coo_v(p + 4) = -hxdhy;
265       }
266     }));
267   PetscCall(MatSetValuesCOO(jacpre, coo_v.data(), INSERT_VALUES));
268   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
269   PetscFunctionReturn(0);
270 }
271