1 #pragma once
2
3 #include <petsctao.h>
4 #include <petscsf.h>
5 #include <petscdevice.h>
6 #include <petscdevice_cupm.h>
7
8 /*
9 User-defined application context - contains data needed by the
10 application-provided call-back routines that evaluate the function,
11 gradient, and hessian.
12 */
13
14 typedef struct _Rosenbrock {
15 PetscInt bs; // each block of bs variables is one chained multidimensional rosenbrock problem
16 PetscInt i_start, i_end;
17 PetscInt c_start, c_end;
18 PetscReal alpha; // condition parameter
19 } Rosenbrock;
20
21 typedef struct _AppCtx *AppCtx;
22 struct _AppCtx {
23 MPI_Comm comm;
24 PetscInt n; /* dimension */
25 PetscInt n_local;
26 PetscInt n_local_comp;
27 Rosenbrock problem;
28 Vec Hvalues; /* vector for writing COO values of this MPI process */
29 Vec gvalues; /* vector for writing gradient values of this mpi process */
30 Vec fvector;
31 PetscSF off_process_scatter;
32 PetscSF gscatter;
33 Vec off_process_values; /* buffer for off-process values if chained */
34 PetscBool test_lmvm;
35 PetscLogEvent event_f, event_g, event_fg;
36 };
37
38 /* -------------- User-defined routines ---------- */
39
RosenbrockObjective(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2)40 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
41 {
42 PetscScalar d = x_2 - x_1 * x_1;
43 PetscScalar e = 1.0 - x_1;
44 return alpha * d * d + e * e;
45 }
46
47 static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;
48
RosenbrockGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])49 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
50 {
51 PetscScalar d = x_2 - x_1 * x_1;
52 PetscScalar e = 1.0 - x_1;
53 PetscScalar g2 = alpha * d * 2.0;
54
55 g[0] = -2.0 * x_1 * g2 - 2.0 * e;
56 g[1] = g2;
57 }
58
59 static const PetscInt RosenbrockGradientFlops = 9.0;
60
RosenbrockObjectiveGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])61 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
62 {
63 PetscScalar d = x_2 - x_1 * x_1;
64 PetscScalar e = 1.0 - x_1;
65 PetscScalar ad = alpha * d;
66 PetscScalar g2 = ad * 2.0;
67
68 g[0] = -2.0 * x_1 * g2 - 2.0 * e;
69 g[1] = g2;
70 return ad * d + e * e;
71 }
72
73 static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;
74
RosenbrockHessian(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar h[4])75 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
76 {
77 PetscScalar d = x_2 - x_1 * x_1;
78 PetscScalar g2 = alpha * d * 2.0;
79 PetscScalar h2 = -4.0 * alpha * x_1;
80
81 h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
82 h[1] = h[2] = h2;
83 h[3] = 2.0 * alpha;
84 }
85
86 static const PetscLogDouble RosenbrockHessianFlops = 11.0;
87
AppCtxCreate(MPI_Comm comm,AppCtx * ctx)88 static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
89 {
90 AppCtx user;
91 PetscDeviceContext dctx;
92
93 PetscFunctionBegin;
94 PetscCall(PetscNew(ctx));
95 user = *ctx;
96 user->comm = PETSC_COMM_WORLD;
97
98 /* Initialize problem parameters */
99 user->n = 2;
100 user->problem.alpha = 99.0;
101 user->problem.bs = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102 user->test_lmvm = PETSC_FALSE;
103 /* Check for command line arguments to override defaults */
104 PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105 PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106 PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107 PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108 PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109 PetscOptionsEnd();
110 PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111 PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112 PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113 PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114 PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116 PetscCall(PetscDeviceContextSetUp(dctx));
117 PetscFunctionReturn(PETSC_SUCCESS);
118 }
119
AppCtxDestroy(AppCtx * ctx)120 static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121 {
122 AppCtx user;
123
124 PetscFunctionBegin;
125 user = *ctx;
126 *ctx = NULL;
127 PetscCall(VecDestroy(&user->Hvalues));
128 PetscCall(VecDestroy(&user->gvalues));
129 PetscCall(VecDestroy(&user->fvector));
130 PetscCall(VecDestroy(&user->off_process_values));
131 PetscCall(PetscSFDestroy(&user->off_process_scatter));
132 PetscCall(PetscSFDestroy(&user->gscatter));
133 PetscCall(PetscFree(user));
134 PetscFunctionReturn(PETSC_SUCCESS);
135 }
136
CreateHessian(AppCtx user,Mat * Hessian)137 static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138 {
139 Mat H;
140 PetscLayout layout;
141 PetscInt i_start, i_end, n_local_comp, nnz_local;
142 PetscInt c_start, c_end;
143 PetscInt *coo_i;
144 PetscInt *coo_j;
145 PetscInt bs = user->problem.bs;
146 VecType vec_type;
147
148 PetscFunctionBegin;
149 /* Partition the optimization variables and the computations.
150 There are (bs - 1) contributions to the objective function for every (bs)
151 degrees of freedom. */
152 PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153 PetscCall(PetscLayoutSetUp(layout));
154 PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155 user->problem.i_start = i_start;
156 user->problem.i_end = i_end;
157 user->n_local = i_end - i_start;
158 user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159 user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160 user->n_local_comp = n_local_comp = c_end - c_start;
161
162 PetscCall(MatCreate(user->comm, Hessian));
163 H = *Hessian;
164 PetscCall(MatSetLayouts(H, layout, layout));
165 PetscCall(PetscLayoutDestroy(&layout));
166 PetscCall(MatSetType(H, MATAIJ));
167 PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168 PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169 PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170 PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171 PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172 PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */
173
174 nnz_local = n_local_comp * 4;
175 PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176 /* Instead of having one computation thread per row of the matrix,
177 this example uses one thread per contribution to the objective
178 function. Each contribution to the objective function relates
179 two adjacent degrees of freedom, so each contribution to
180 the objective function adds a 2x2 block into the matrix.
181 We describe these 2x2 blocks in COO format. */
182 for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183 PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);
184
185 coo_i[k + 0] = i;
186 coo_i[k + 1] = i;
187 coo_i[k + 2] = i + 1;
188 coo_i[k + 3] = i + 1;
189
190 coo_j[k + 0] = i;
191 coo_j[k + 1] = i + 1;
192 coo_j[k + 2] = i;
193 coo_j[k + 3] = i + 1;
194 }
195 PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196 PetscCall(PetscFree2(coo_i, coo_j));
197
198 PetscCall(MatGetVecType(H, &vec_type));
199 PetscCall(VecCreate(user->comm, &user->Hvalues));
200 PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201 PetscCall(VecSetType(user->Hvalues, vec_type));
202
203 // vector to collect contributions to the objective
204 PetscCall(VecCreate(user->comm, &user->fvector));
205 PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206 PetscCall(VecSetType(user->fvector, vec_type));
207
208 { /* If we are using a device (such as a GPU), run some computations that will
209 warm up its linear algebra runtime before the problem we actually want
210 to profile */
211
212 PetscMemType memtype;
213 const PetscScalar *a;
214
215 PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216 PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));
217
218 if (memtype == PETSC_MEMTYPE_DEVICE) {
219 PetscLogStage warmup;
220 Mat A, AtA;
221 Vec x, b;
222 PetscInt warmup_size = 1000;
223 PetscDeviceContext dctx;
224
225 PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226 PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));
227
228 PetscCall(PetscLogStagePush(warmup));
229 PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230 PetscCall(MatSetRandom(A, NULL));
231 PetscCall(MatCreateVecs(A, &x, &b));
232 PetscCall(VecSetRandom(x, NULL));
233
234 PetscCall(MatMult(A, x, b));
235 PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &AtA));
236 PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237 PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238 PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239 PetscCall(MatDestroy(&AtA));
240 PetscCall(VecDestroy(&b));
241 PetscCall(VecDestroy(&x));
242 PetscCall(MatDestroy(&A));
243 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244 PetscCall(PetscDeviceContextSynchronize(dctx));
245 PetscCall(PetscLogStagePop());
246 }
247 }
248 PetscFunctionReturn(PETSC_SUCCESS);
249 }
250
CreateVectors(AppCtx user,Mat H,Vec * solution,Vec * gradient)251 static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252 {
253 VecType vec_type;
254 PetscInt n_coo, *coo_i, i_start, i_end;
255 Vec x;
256 PetscInt n_recv;
257 PetscSFNode recv;
258 PetscLayout layout;
259 PetscInt c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;
260
261 PetscFunctionBegin;
262 PetscCall(MatCreateVecs(H, solution, gradient));
263 x = *solution;
264 PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265 PetscCall(VecGetType(x, &vec_type));
266 // create scatter for communicating values
267 PetscCall(VecGetLayout(x, &layout));
268 n_recv = 0;
269 if (user->n_local_comp && i_end < user->n) {
270 PetscMPIInt rank;
271 PetscInt index;
272
273 n_recv = 1;
274 PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275 recv.rank = rank;
276 recv.index = index;
277 }
278 PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279 PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280 PetscCall(VecCreate(user->comm, &user->off_process_values));
281 PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282 PetscCall(VecSetType(user->off_process_values, vec_type));
283 PetscCall(VecZeroEntries(user->off_process_values));
284
285 // create COO data for writing the gradient
286 n_coo = user->n_local_comp * 2;
287 PetscCall(PetscMalloc1(n_coo, &coo_i));
288 for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289 PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));
290
291 coo_i[k + 0] = i;
292 coo_i[k + 1] = i + 1;
293 }
294 PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295 PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296 PetscCall(PetscSFSetUp(user->gscatter));
297 PetscCall(PetscFree(coo_i));
298 PetscCall(VecCreate(user->comm, &user->gvalues));
299 PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300 PetscCall(VecSetType(user->gvalues, vec_type));
301 PetscFunctionReturn(PETSC_SUCCESS);
302 }
303
304 #if PetscDefined(USING_CUPMCC)
305
306 #if PetscDefined(USING_NVCC)
307 typedef cudaStream_t cupmStream_t;
308 #define PetscCUPMLaunch(...) \
309 do { \
310 __VA_ARGS__; \
311 PetscCallCUDA(cudaGetLastError()); \
312 } while (0)
313 #elif PetscDefined(USING_HCC)
314 #define PetscCUPMLaunch(...) \
315 do { \
316 __VA_ARGS__; \
317 PetscCallHIP(hipGetLastError()); \
318 } while (0)
319 typedef hipStream_t cupmStream_t;
320 #endif
321
322 // x: on-process optimization variables
323 // o: buffer that contains the next optimization variable after the variables on this process
324 template <typename T>
rosenbrock_for_loop(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],T && func)325 PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326 {
327 PetscInt idx = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328 PetscInt num_threads = gridDim.x * blockDim.x;
329
330 for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332 PetscScalar x_a = x[i - r.i_start];
333 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
334
335 func(k, x_a, x_b);
336 }
337 return;
338 }
339
RosenbrockObjective_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])340 PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341 {
342 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343 }
344
RosenbrockGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])345 PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346 {
347 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348 }
349
RosenbrockObjectiveGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])350 PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351 {
352 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353 }
354
RosenbrockHessian_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])355 PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356 {
357 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358 }
359
RosenbrockObjective_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])360 static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361 {
362 PetscInt n_comp = r.c_end - r.c_start;
363
364 PetscFunctionBegin;
365 if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366 PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367 PetscFunctionReturn(PETSC_SUCCESS);
368 }
369
RosenbrockGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])370 static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371 {
372 PetscInt n_comp = r.c_end - r.c_start;
373
374 PetscFunctionBegin;
375 if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376 PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377 PetscFunctionReturn(PETSC_SUCCESS);
378 }
379
RosenbrockObjectiveGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])380 static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381 {
382 PetscInt n_comp = r.c_end - r.c_start;
383
384 PetscFunctionBegin;
385 if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386 PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387 PetscFunctionReturn(PETSC_SUCCESS);
388 }
389
RosenbrockHessian_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])390 static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391 {
392 PetscInt n_comp = r.c_end - r.c_start;
393
394 PetscFunctionBegin;
395 if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396 PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397 PetscFunctionReturn(PETSC_SUCCESS);
398 }
399 #endif
400
RosenbrockObjective_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f)401 static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402 {
403 PetscReal _f = 0.0;
404
405 PetscFunctionBegin;
406 for (PetscInt c = r.c_start; c < r.c_end; c++) {
407 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408 PetscScalar x_a = x[i - r.i_start];
409 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
410
411 _f += RosenbrockObjective(r.alpha, x_a, x_b);
412 }
413 *f = _f;
414 PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415 PetscFunctionReturn(PETSC_SUCCESS);
416 }
417
RosenbrockGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])418 static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419 {
420 PetscFunctionBegin;
421 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423 PetscScalar x_a = x[i - r.i_start];
424 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
425
426 RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427 }
428 PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429 PetscFunctionReturn(PETSC_SUCCESS);
430 }
431
RosenbrockObjectiveGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f,PetscScalar g[])432 static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433 {
434 PetscReal _f = 0.0;
435
436 PetscFunctionBegin;
437 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439 PetscScalar x_a = x[i - r.i_start];
440 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
441
442 _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443 }
444 *f = _f;
445 PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446 PetscFunctionReturn(PETSC_SUCCESS);
447 }
448
RosenbrockHessian_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])449 static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450 {
451 PetscFunctionBegin;
452 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454 PetscScalar x_a = x[i - r.i_start];
455 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
456
457 RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458 }
459 PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460 PetscFunctionReturn(PETSC_SUCCESS);
461 }
462
463 /* -------------------------------------------------------------------- */
464
FormObjective(Tao tao,Vec X,PetscReal * f,void * ptr)465 static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466 {
467 AppCtx user = (AppCtx)ptr;
468 PetscReal f_local = 0.0;
469 const PetscScalar *x;
470 const PetscScalar *o = NULL;
471 PetscMemType memtype_x;
472
473 PetscFunctionBeginUser;
474 PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479 if (memtype_x == PETSC_MEMTYPE_HOST) {
480 PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481 PetscCallMPI(MPIU_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482 #if PetscDefined(USING_CUPMCC)
483 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484 PetscScalar *_fvec;
485 PetscScalar f_scalar;
486 cupmStream_t *stream;
487 PetscDeviceContext dctx;
488
489 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491 PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492 PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493 PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494 PetscCall(VecSum(user->fvector, &f_scalar));
495 *f = PetscRealPart(f_scalar);
496 #endif
497 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
498 PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500 PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501 PetscFunctionReturn(PETSC_SUCCESS);
502 }
503
FormGradient(Tao tao,Vec X,Vec G,void * ptr)504 static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505 {
506 AppCtx user = (AppCtx)ptr;
507 PetscScalar *g;
508 const PetscScalar *x;
509 const PetscScalar *o = NULL;
510 PetscMemType memtype_x, memtype_g;
511
512 PetscFunctionBeginUser;
513 PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518 PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519 PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520 if (memtype_x == PETSC_MEMTYPE_HOST) {
521 PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522 #if PetscDefined(USING_CUPMCC)
523 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524 cupmStream_t *stream;
525 PetscDeviceContext dctx;
526
527 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529 PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530 #endif
531 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
532 PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533 PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535 PetscCall(VecZeroEntries(G));
536 PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537 PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538 PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539 PetscFunctionReturn(PETSC_SUCCESS);
540 }
541
542 /*
543 FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).
544
545 Input Parameters:
546 . tao - the Tao context
547 . X - input vector
548 . ptr - optional user-defined context, as set by TaoSetObjectiveGradient()
549
550 Output Parameters:
551 . G - vector containing the newly evaluated gradient
552 . f - function value
553
554 Note:
555 Some optimization methods ask for the function and the gradient evaluation
556 at the same time. Evaluating both at once may be more efficient that
557 evaluating each separately.
558 */
FormObjectiveGradient(Tao tao,Vec X,PetscReal * f,Vec G,void * ptr)559 static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560 {
561 AppCtx user = (AppCtx)ptr;
562 PetscReal f_local = 0.0;
563 PetscScalar *g;
564 const PetscScalar *x;
565 const PetscScalar *o = NULL;
566 PetscMemType memtype_x, memtype_g;
567
568 PetscFunctionBeginUser;
569 PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574 PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575 PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576 if (memtype_x == PETSC_MEMTYPE_HOST) {
577 PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578 PetscCallMPI(MPIU_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579 #if PetscDefined(USING_CUPMCC)
580 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581 PetscScalar *_fvec;
582 PetscScalar f_scalar;
583 cupmStream_t *stream;
584 PetscDeviceContext dctx;
585
586 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588 PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589 PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590 PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591 PetscCall(VecSum(user->fvector, &f_scalar));
592 *f = PetscRealPart(f_scalar);
593 #endif
594 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
595
596 PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597 PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599 PetscCall(VecZeroEntries(G));
600 PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601 PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602 PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603 PetscFunctionReturn(PETSC_SUCCESS);
604 }
605
606 /* ------------------------------------------------------------------- */
607 /*
608 FormHessian - Evaluates Hessian matrix.
609
610 Input Parameters:
611 . tao - the Tao context
612 . x - input vector
613 . ptr - optional user-defined context, as set by TaoSetHessian()
614
615 Output Parameters:
616 . H - Hessian matrix
617
618 Note: Providing the Hessian may not be necessary. Only some solvers
619 require this matrix.
620 */
FormHessian(Tao tao,Vec X,Mat H,Mat Hpre,void * ptr)621 static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622 {
623 AppCtx user = (AppCtx)ptr;
624 PetscScalar *h;
625 const PetscScalar *x;
626 const PetscScalar *o = NULL;
627 PetscMemType memtype_x, memtype_h;
628
629 PetscFunctionBeginUser;
630 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634 PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635 PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636 if (memtype_x == PETSC_MEMTYPE_HOST) {
637 PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638 #if PetscDefined(USING_CUPMCC)
639 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640 cupmStream_t *stream;
641 PetscDeviceContext dctx;
642
643 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645 PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646 #endif
647 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
648
649 PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650 PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));
651
652 PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
654
655 if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656 PetscFunctionReturn(PETSC_SUCCESS);
657 }
658
TestLMVM(Tao tao)659 static PetscErrorCode TestLMVM(Tao tao)
660 {
661 KSP ksp;
662 PC pc;
663 PetscBool is_lmvm;
664
665 PetscFunctionBegin;
666 PetscCall(TaoGetKSP(tao, &ksp));
667 if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668 PetscCall(KSPGetPC(ksp, &pc));
669 PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670 if (is_lmvm) {
671 Mat M;
672 Vec in, out, out2;
673 PetscReal mult_solve_dist;
674 Vec x;
675
676 PetscCall(PCLMVMGetMatLMVM(pc, &M));
677 PetscCall(TaoGetSolution(tao, &x));
678 PetscCall(VecDuplicate(x, &in));
679 PetscCall(VecDuplicate(x, &out));
680 PetscCall(VecDuplicate(x, &out2));
681 PetscCall(VecSetRandom(in, NULL));
682 PetscCall(MatMult(M, in, out));
683 PetscCall(MatSolve(M, out, out2));
684
685 PetscCall(VecAXPY(out2, -1.0, in));
686 PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687 if (mult_solve_dist < 1.e-11) {
688 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689 } else if (mult_solve_dist < 1.e-6) {
690 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691 } else {
692 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693 }
694 PetscCall(VecDestroy(&in));
695 PetscCall(VecDestroy(&out));
696 PetscCall(VecDestroy(&out2));
697 }
698 PetscFunctionReturn(PETSC_SUCCESS);
699 }
700
RosenbrockMain(void)701 static PetscErrorCode RosenbrockMain(void)
702 {
703 Vec x; /* solution vector */
704 Vec g; /* gradient vector */
705 Mat H; /* Hessian matrix */
706 Tao tao; /* Tao solver context */
707 AppCtx user; /* user-defined application context */
708 PetscLogStage solve;
709
710 /* Initialize TAO and PETSc */
711 PetscFunctionBegin;
712 PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));
713
714 PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715 PetscCall(CreateHessian(user, &H));
716 PetscCall(CreateVectors(user, H, &x, &g));
717
718 /* The TAO code begins here */
719
720 PetscCall(TaoCreate(user->comm, &tao));
721 PetscCall(VecZeroEntries(x));
722 PetscCall(TaoSetSolution(tao, x));
723
724 /* Set routines for function, gradient, hessian evaluation */
725 PetscCall(TaoSetObjective(tao, FormObjective, user));
726 PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727 PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728 PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));
729
730 PetscCall(TaoSetFromOptions(tao));
731
732 /* SOLVE THE APPLICATION */
733 PetscCall(PetscLogStagePush(solve));
734 PetscCall(TaoSolve(tao));
735 PetscCall(PetscLogStagePop());
736
737 if (user->test_lmvm) PetscCall(TestLMVM(tao));
738
739 PetscCall(TaoDestroy(&tao));
740 PetscCall(VecDestroy(&g));
741 PetscCall(VecDestroy(&x));
742 PetscCall(MatDestroy(&H));
743 PetscCall(AppCtxDestroy(&user));
744 PetscFunctionReturn(PETSC_SUCCESS);
745 }
746