xref: /petsc/src/tao/unconstrained/tutorials/rosenbrock4.h (revision d8e47b638cf8f604a99e9678e1df24f82d959cd7)
1 #pragma once
2 
3 #include <petsctao.h>
4 #include <petscsf.h>
5 #include <petscdevice.h>
6 #include <petscdevice_cupm.h>
7 
8 /*
9    User-defined application context - contains data needed by the
10    application-provided call-back routines that evaluate the function,
11    gradient, and hessian.
12 */
13 
14 typedef struct _Rosenbrock {
15   PetscInt  bs; // each block of bs variables is one chained multidimensional rosenbrock problem
16   PetscInt  i_start, i_end;
17   PetscInt  c_start, c_end;
18   PetscReal alpha; // condition parameter
19 } Rosenbrock;
20 
21 typedef struct _AppCtx *AppCtx;
22 struct _AppCtx {
23   MPI_Comm      comm;
24   PetscInt      n; /* dimension */
25   PetscInt      n_local;
26   PetscInt      n_local_comp;
27   Rosenbrock    problem;
28   Vec           Hvalues; /* vector for writing COO values of this MPI process */
29   Vec           gvalues; /* vector for writing gradient values of this mpi process */
30   Vec           fvector;
31   PetscSF       off_process_scatter;
32   PetscSF       gscatter;
33   Vec           off_process_values; /* buffer for off-process values if chained */
34   PetscBool     test_lmvm;
35   PetscLogEvent event_f, event_g, event_fg;
36 };
37 
38 /* -------------- User-defined routines ---------- */
39 
RosenbrockObjective(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2)40 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
41 {
42   PetscScalar d = x_2 - x_1 * x_1;
43   PetscScalar e = 1.0 - x_1;
44   return alpha * d * d + e * e;
45 }
46 
47 static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;
48 
RosenbrockGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])49 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
50 {
51   PetscScalar d  = x_2 - x_1 * x_1;
52   PetscScalar e  = 1.0 - x_1;
53   PetscScalar g2 = alpha * d * 2.0;
54 
55   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
56   g[1] = g2;
57 }
58 
59 static const PetscInt RosenbrockGradientFlops = 9.0;
60 
RosenbrockObjectiveGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])61 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
62 {
63   PetscScalar d  = x_2 - x_1 * x_1;
64   PetscScalar e  = 1.0 - x_1;
65   PetscScalar ad = alpha * d;
66   PetscScalar g2 = ad * 2.0;
67 
68   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
69   g[1] = g2;
70   return ad * d + e * e;
71 }
72 
73 static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;
74 
RosenbrockHessian(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar h[4])75 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
76 {
77   PetscScalar d  = x_2 - x_1 * x_1;
78   PetscScalar g2 = alpha * d * 2.0;
79   PetscScalar h2 = -4.0 * alpha * x_1;
80 
81   h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
82   h[1] = h[2] = h2;
83   h[3]        = 2.0 * alpha;
84 }
85 
86 static const PetscLogDouble RosenbrockHessianFlops = 11.0;
87 
AppCtxCreate(MPI_Comm comm,AppCtx * ctx)88 static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
89 {
90   AppCtx             user;
91   PetscDeviceContext dctx;
92 
93   PetscFunctionBegin;
94   PetscCall(PetscNew(ctx));
95   user       = *ctx;
96   user->comm = PETSC_COMM_WORLD;
97 
98   /* Initialize problem parameters */
99   user->n             = 2;
100   user->problem.alpha = 99.0;
101   user->problem.bs    = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102   user->test_lmvm     = PETSC_FALSE;
103   /* Check for command line arguments to override defaults */
104   PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105   PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106   PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107   PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108   PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109   PetscOptionsEnd();
110   PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111   PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112   PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113   PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114   PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116   PetscCall(PetscDeviceContextSetUp(dctx));
117   PetscFunctionReturn(PETSC_SUCCESS);
118 }
119 
AppCtxDestroy(AppCtx * ctx)120 static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121 {
122   AppCtx user;
123 
124   PetscFunctionBegin;
125   user = *ctx;
126   *ctx = NULL;
127   PetscCall(VecDestroy(&user->Hvalues));
128   PetscCall(VecDestroy(&user->gvalues));
129   PetscCall(VecDestroy(&user->fvector));
130   PetscCall(VecDestroy(&user->off_process_values));
131   PetscCall(PetscSFDestroy(&user->off_process_scatter));
132   PetscCall(PetscSFDestroy(&user->gscatter));
133   PetscCall(PetscFree(user));
134   PetscFunctionReturn(PETSC_SUCCESS);
135 }
136 
CreateHessian(AppCtx user,Mat * Hessian)137 static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138 {
139   Mat         H;
140   PetscLayout layout;
141   PetscInt    i_start, i_end, n_local_comp, nnz_local;
142   PetscInt    c_start, c_end;
143   PetscInt   *coo_i;
144   PetscInt   *coo_j;
145   PetscInt    bs = user->problem.bs;
146   VecType     vec_type;
147 
148   PetscFunctionBegin;
149   /* Partition the optimization variables and the computations.
150      There are (bs - 1) contributions to the objective function for every (bs)
151      degrees of freedom. */
152   PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153   PetscCall(PetscLayoutSetUp(layout));
154   PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155   user->problem.i_start = i_start;
156   user->problem.i_end   = i_end;
157   user->n_local         = i_end - i_start;
158   user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159   user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160   user->n_local_comp = n_local_comp = c_end - c_start;
161 
162   PetscCall(MatCreate(user->comm, Hessian));
163   H = *Hessian;
164   PetscCall(MatSetLayouts(H, layout, layout));
165   PetscCall(PetscLayoutDestroy(&layout));
166   PetscCall(MatSetType(H, MATAIJ));
167   PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168   PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169   PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170   PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171   PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172   PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */
173 
174   nnz_local = n_local_comp * 4;
175   PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176   /* Instead of having one computation thread per row of the matrix,
177      this example uses one thread per contribution to the objective
178      function.  Each contribution to the objective function relates
179      two adjacent degrees of freedom, so each contribution to
180      the objective function adds a 2x2 block into the matrix.
181      We describe these 2x2 blocks in COO format. */
182   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183     PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);
184 
185     coo_i[k + 0] = i;
186     coo_i[k + 1] = i;
187     coo_i[k + 2] = i + 1;
188     coo_i[k + 3] = i + 1;
189 
190     coo_j[k + 0] = i;
191     coo_j[k + 1] = i + 1;
192     coo_j[k + 2] = i;
193     coo_j[k + 3] = i + 1;
194   }
195   PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196   PetscCall(PetscFree2(coo_i, coo_j));
197 
198   PetscCall(MatGetVecType(H, &vec_type));
199   PetscCall(VecCreate(user->comm, &user->Hvalues));
200   PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201   PetscCall(VecSetType(user->Hvalues, vec_type));
202 
203   // vector to collect contributions to the objective
204   PetscCall(VecCreate(user->comm, &user->fvector));
205   PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206   PetscCall(VecSetType(user->fvector, vec_type));
207 
208   { /* If we are using a device (such as a GPU), run some computations that will
209        warm up its linear algebra runtime before the problem we actually want
210        to profile */
211 
212     PetscMemType       memtype;
213     const PetscScalar *a;
214 
215     PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216     PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));
217 
218     if (memtype == PETSC_MEMTYPE_DEVICE) {
219       PetscLogStage      warmup;
220       Mat                A, AtA;
221       Vec                x, b;
222       PetscInt           warmup_size = 1000;
223       PetscDeviceContext dctx;
224 
225       PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226       PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));
227 
228       PetscCall(PetscLogStagePush(warmup));
229       PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230       PetscCall(MatSetRandom(A, NULL));
231       PetscCall(MatCreateVecs(A, &x, &b));
232       PetscCall(VecSetRandom(x, NULL));
233 
234       PetscCall(MatMult(A, x, b));
235       PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &AtA));
236       PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237       PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238       PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239       PetscCall(MatDestroy(&AtA));
240       PetscCall(VecDestroy(&b));
241       PetscCall(VecDestroy(&x));
242       PetscCall(MatDestroy(&A));
243       PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244       PetscCall(PetscDeviceContextSynchronize(dctx));
245       PetscCall(PetscLogStagePop());
246     }
247   }
248   PetscFunctionReturn(PETSC_SUCCESS);
249 }
250 
CreateVectors(AppCtx user,Mat H,Vec * solution,Vec * gradient)251 static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252 {
253   VecType     vec_type;
254   PetscInt    n_coo, *coo_i, i_start, i_end;
255   Vec         x;
256   PetscInt    n_recv;
257   PetscSFNode recv;
258   PetscLayout layout;
259   PetscInt    c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;
260 
261   PetscFunctionBegin;
262   PetscCall(MatCreateVecs(H, solution, gradient));
263   x = *solution;
264   PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265   PetscCall(VecGetType(x, &vec_type));
266   // create scatter for communicating values
267   PetscCall(VecGetLayout(x, &layout));
268   n_recv = 0;
269   if (user->n_local_comp && i_end < user->n) {
270     PetscMPIInt rank;
271     PetscInt    index;
272 
273     n_recv = 1;
274     PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275     recv.rank  = rank;
276     recv.index = index;
277   }
278   PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279   PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280   PetscCall(VecCreate(user->comm, &user->off_process_values));
281   PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282   PetscCall(VecSetType(user->off_process_values, vec_type));
283   PetscCall(VecZeroEntries(user->off_process_values));
284 
285   // create COO data for writing the gradient
286   n_coo = user->n_local_comp * 2;
287   PetscCall(PetscMalloc1(n_coo, &coo_i));
288   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289     PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));
290 
291     coo_i[k + 0] = i;
292     coo_i[k + 1] = i + 1;
293   }
294   PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295   PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296   PetscCall(PetscSFSetUp(user->gscatter));
297   PetscCall(PetscFree(coo_i));
298   PetscCall(VecCreate(user->comm, &user->gvalues));
299   PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300   PetscCall(VecSetType(user->gvalues, vec_type));
301   PetscFunctionReturn(PETSC_SUCCESS);
302 }
303 
304 #if PetscDefined(USING_CUPMCC)
305 
306   #if PetscDefined(USING_NVCC)
307 typedef cudaStream_t cupmStream_t;
308     #define PetscCUPMLaunch(...) \
309       do { \
310         __VA_ARGS__; \
311         PetscCallCUDA(cudaGetLastError()); \
312       } while (0)
313   #elif PetscDefined(USING_HCC)
314     #define PetscCUPMLaunch(...) \
315       do { \
316         __VA_ARGS__; \
317         PetscCallHIP(hipGetLastError()); \
318       } while (0)
319 typedef hipStream_t cupmStream_t;
320   #endif
321 
322 // x: on-process optimization variables
323 // o: buffer that contains the next optimization variable after the variables on this process
324 template <typename T>
rosenbrock_for_loop(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],T && func)325 PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326 {
327   PetscInt idx         = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328   PetscInt num_threads = gridDim.x * blockDim.x;
329 
330   for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332     PetscScalar x_a = x[i - r.i_start];
333     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
334 
335     func(k, x_a, x_b);
336   }
337   return;
338 }
339 
RosenbrockObjective_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])340 PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341 {
342   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343 }
344 
RosenbrockGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])345 PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346 {
347   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348 }
349 
RosenbrockObjectiveGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])350 PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351 {
352   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353 }
354 
RosenbrockHessian_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])355 PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356 {
357   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358 }
359 
RosenbrockObjective_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])360 static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361 {
362   PetscInt n_comp = r.c_end - r.c_start;
363 
364   PetscFunctionBegin;
365   if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367   PetscFunctionReturn(PETSC_SUCCESS);
368 }
369 
RosenbrockGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])370 static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371 {
372   PetscInt n_comp = r.c_end - r.c_start;
373 
374   PetscFunctionBegin;
375   if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376   PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377   PetscFunctionReturn(PETSC_SUCCESS);
378 }
379 
RosenbrockObjectiveGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])380 static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381 {
382   PetscInt n_comp = r.c_end - r.c_start;
383 
384   PetscFunctionBegin;
385   if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387   PetscFunctionReturn(PETSC_SUCCESS);
388 }
389 
RosenbrockHessian_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])390 static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391 {
392   PetscInt n_comp = r.c_end - r.c_start;
393 
394   PetscFunctionBegin;
395   if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396   PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397   PetscFunctionReturn(PETSC_SUCCESS);
398 }
399 #endif
400 
RosenbrockObjective_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f)401 static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402 {
403   PetscReal _f = 0.0;
404 
405   PetscFunctionBegin;
406   for (PetscInt c = r.c_start; c < r.c_end; c++) {
407     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408     PetscScalar x_a = x[i - r.i_start];
409     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
410 
411     _f += RosenbrockObjective(r.alpha, x_a, x_b);
412   }
413   *f = _f;
414   PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415   PetscFunctionReturn(PETSC_SUCCESS);
416 }
417 
RosenbrockGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])418 static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419 {
420   PetscFunctionBegin;
421   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423     PetscScalar x_a = x[i - r.i_start];
424     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
425 
426     RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427   }
428   PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429   PetscFunctionReturn(PETSC_SUCCESS);
430 }
431 
RosenbrockObjectiveGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f,PetscScalar g[])432 static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433 {
434   PetscReal _f = 0.0;
435 
436   PetscFunctionBegin;
437   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439     PetscScalar x_a = x[i - r.i_start];
440     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
441 
442     _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443   }
444   *f = _f;
445   PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446   PetscFunctionReturn(PETSC_SUCCESS);
447 }
448 
RosenbrockHessian_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])449 static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450 {
451   PetscFunctionBegin;
452   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454     PetscScalar x_a = x[i - r.i_start];
455     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
456 
457     RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458   }
459   PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460   PetscFunctionReturn(PETSC_SUCCESS);
461 }
462 
463 /* -------------------------------------------------------------------- */
464 
FormObjective(Tao tao,Vec X,PetscReal * f,void * ptr)465 static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466 {
467   AppCtx             user    = (AppCtx)ptr;
468   PetscReal          f_local = 0.0;
469   const PetscScalar *x;
470   const PetscScalar *o = NULL;
471   PetscMemType       memtype_x;
472 
473   PetscFunctionBeginUser;
474   PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479   if (memtype_x == PETSC_MEMTYPE_HOST) {
480     PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481     PetscCallMPI(MPIU_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482 #if PetscDefined(USING_CUPMCC)
483   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484     PetscScalar       *_fvec;
485     PetscScalar        f_scalar;
486     cupmStream_t      *stream;
487     PetscDeviceContext dctx;
488 
489     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492     PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494     PetscCall(VecSum(user->fvector, &f_scalar));
495     *f = PetscRealPart(f_scalar);
496 #endif
497   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
498   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500   PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501   PetscFunctionReturn(PETSC_SUCCESS);
502 }
503 
FormGradient(Tao tao,Vec X,Vec G,void * ptr)504 static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505 {
506   AppCtx             user = (AppCtx)ptr;
507   PetscScalar       *g;
508   const PetscScalar *x;
509   const PetscScalar *o = NULL;
510   PetscMemType       memtype_x, memtype_g;
511 
512   PetscFunctionBeginUser;
513   PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520   if (memtype_x == PETSC_MEMTYPE_HOST) {
521     PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522 #if PetscDefined(USING_CUPMCC)
523   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524     cupmStream_t      *stream;
525     PetscDeviceContext dctx;
526 
527     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529     PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530 #endif
531   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
532   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535   PetscCall(VecZeroEntries(G));
536   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538   PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539   PetscFunctionReturn(PETSC_SUCCESS);
540 }
541 
542 /*
543     FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).
544 
545     Input Parameters:
546 .   tao  - the Tao context
547 .   X    - input vector
548 .   ptr  - optional user-defined context, as set by TaoSetObjectiveGradient()
549 
550     Output Parameters:
551 .   G - vector containing the newly evaluated gradient
552 .   f - function value
553 
554     Note:
555     Some optimization methods ask for the function and the gradient evaluation
556     at the same time.  Evaluating both at once may be more efficient that
557     evaluating each separately.
558 */
FormObjectiveGradient(Tao tao,Vec X,PetscReal * f,Vec G,void * ptr)559 static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560 {
561   AppCtx             user    = (AppCtx)ptr;
562   PetscReal          f_local = 0.0;
563   PetscScalar       *g;
564   const PetscScalar *x;
565   const PetscScalar *o = NULL;
566   PetscMemType       memtype_x, memtype_g;
567 
568   PetscFunctionBeginUser;
569   PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576   if (memtype_x == PETSC_MEMTYPE_HOST) {
577     PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578     PetscCallMPI(MPIU_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579 #if PetscDefined(USING_CUPMCC)
580   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581     PetscScalar       *_fvec;
582     PetscScalar        f_scalar;
583     cupmStream_t      *stream;
584     PetscDeviceContext dctx;
585 
586     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589     PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591     PetscCall(VecSum(user->fvector, &f_scalar));
592     *f = PetscRealPart(f_scalar);
593 #endif
594   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
595 
596   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599   PetscCall(VecZeroEntries(G));
600   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602   PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603   PetscFunctionReturn(PETSC_SUCCESS);
604 }
605 
606 /* ------------------------------------------------------------------- */
607 /*
608    FormHessian - Evaluates Hessian matrix.
609 
610    Input Parameters:
611 .  tao   - the Tao context
612 .  x     - input vector
613 .  ptr   - optional user-defined context, as set by TaoSetHessian()
614 
615    Output Parameters:
616 .  H     - Hessian matrix
617 
618    Note:  Providing the Hessian may not be necessary.  Only some solvers
619    require this matrix.
620 */
FormHessian(Tao tao,Vec X,Mat H,Mat Hpre,void * ptr)621 static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622 {
623   AppCtx             user = (AppCtx)ptr;
624   PetscScalar       *h;
625   const PetscScalar *x;
626   const PetscScalar *o = NULL;
627   PetscMemType       memtype_x, memtype_h;
628 
629   PetscFunctionBeginUser;
630   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634   PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635   PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636   if (memtype_x == PETSC_MEMTYPE_HOST) {
637     PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638 #if PetscDefined(USING_CUPMCC)
639   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640     cupmStream_t      *stream;
641     PetscDeviceContext dctx;
642 
643     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645     PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646 #endif
647   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
648 
649   PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650   PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));
651 
652   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
654 
655   if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656   PetscFunctionReturn(PETSC_SUCCESS);
657 }
658 
TestLMVM(Tao tao)659 static PetscErrorCode TestLMVM(Tao tao)
660 {
661   KSP       ksp;
662   PC        pc;
663   PetscBool is_lmvm;
664 
665   PetscFunctionBegin;
666   PetscCall(TaoGetKSP(tao, &ksp));
667   if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668   PetscCall(KSPGetPC(ksp, &pc));
669   PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670   if (is_lmvm) {
671     Mat       M;
672     Vec       in, out, out2;
673     PetscReal mult_solve_dist;
674     Vec       x;
675 
676     PetscCall(PCLMVMGetMatLMVM(pc, &M));
677     PetscCall(TaoGetSolution(tao, &x));
678     PetscCall(VecDuplicate(x, &in));
679     PetscCall(VecDuplicate(x, &out));
680     PetscCall(VecDuplicate(x, &out2));
681     PetscCall(VecSetRandom(in, NULL));
682     PetscCall(MatMult(M, in, out));
683     PetscCall(MatSolve(M, out, out2));
684 
685     PetscCall(VecAXPY(out2, -1.0, in));
686     PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687     if (mult_solve_dist < 1.e-11) {
688       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689     } else if (mult_solve_dist < 1.e-6) {
690       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691     } else {
692       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693     }
694     PetscCall(VecDestroy(&in));
695     PetscCall(VecDestroy(&out));
696     PetscCall(VecDestroy(&out2));
697   }
698   PetscFunctionReturn(PETSC_SUCCESS);
699 }
700 
RosenbrockMain(void)701 static PetscErrorCode RosenbrockMain(void)
702 {
703   Vec           x;    /* solution vector */
704   Vec           g;    /* gradient vector */
705   Mat           H;    /* Hessian matrix */
706   Tao           tao;  /* Tao solver context */
707   AppCtx        user; /* user-defined application context */
708   PetscLogStage solve;
709 
710   /* Initialize TAO and PETSc */
711   PetscFunctionBegin;
712   PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));
713 
714   PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715   PetscCall(CreateHessian(user, &H));
716   PetscCall(CreateVectors(user, H, &x, &g));
717 
718   /* The TAO code begins here */
719 
720   PetscCall(TaoCreate(user->comm, &tao));
721   PetscCall(VecZeroEntries(x));
722   PetscCall(TaoSetSolution(tao, x));
723 
724   /* Set routines for function, gradient, hessian evaluation */
725   PetscCall(TaoSetObjective(tao, FormObjective, user));
726   PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727   PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728   PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));
729 
730   PetscCall(TaoSetFromOptions(tao));
731 
732   /* SOLVE THE APPLICATION */
733   PetscCall(PetscLogStagePush(solve));
734   PetscCall(TaoSolve(tao));
735   PetscCall(PetscLogStagePop());
736 
737   if (user->test_lmvm) PetscCall(TestLMVM(tao));
738 
739   PetscCall(TaoDestroy(&tao));
740   PetscCall(VecDestroy(&g));
741   PetscCall(VecDestroy(&x));
742   PetscCall(MatDestroy(&H));
743   PetscCall(AppCtxDestroy(&user));
744   PetscFunctionReturn(PETSC_SUCCESS);
745 }
746