xref: /petsc/src/tao/unconstrained/tutorials/rosenbrock4.h (revision d8e47b638cf8f604a99e9678e1df24f82d959cd7)
1ec9796c4SHansol Suh #pragma once
2ec9796c4SHansol Suh 
3ec9796c4SHansol Suh #include <petsctao.h>
4ec9796c4SHansol Suh #include <petscsf.h>
5ec9796c4SHansol Suh #include <petscdevice.h>
6ec9796c4SHansol Suh #include <petscdevice_cupm.h>
7ec9796c4SHansol Suh 
8ec9796c4SHansol Suh /*
9ec9796c4SHansol Suh    User-defined application context - contains data needed by the
10ec9796c4SHansol Suh    application-provided call-back routines that evaluate the function,
11ec9796c4SHansol Suh    gradient, and hessian.
12ec9796c4SHansol Suh */
13ec9796c4SHansol Suh 
14ec9796c4SHansol Suh typedef struct _Rosenbrock {
15ec9796c4SHansol Suh   PetscInt  bs; // each block of bs variables is one chained multidimensional rosenbrock problem
16ec9796c4SHansol Suh   PetscInt  i_start, i_end;
17ec9796c4SHansol Suh   PetscInt  c_start, c_end;
18ec9796c4SHansol Suh   PetscReal alpha; // condition parameter
19ec9796c4SHansol Suh } Rosenbrock;
20ec9796c4SHansol Suh 
21ec9796c4SHansol Suh typedef struct _AppCtx *AppCtx;
22ec9796c4SHansol Suh struct _AppCtx {
23ec9796c4SHansol Suh   MPI_Comm      comm;
24ec9796c4SHansol Suh   PetscInt      n; /* dimension */
25ec9796c4SHansol Suh   PetscInt      n_local;
26ec9796c4SHansol Suh   PetscInt      n_local_comp;
27ec9796c4SHansol Suh   Rosenbrock    problem;
28ec9796c4SHansol Suh   Vec           Hvalues; /* vector for writing COO values of this MPI process */
29ec9796c4SHansol Suh   Vec           gvalues; /* vector for writing gradient values of this mpi process */
30ec9796c4SHansol Suh   Vec           fvector;
31ec9796c4SHansol Suh   PetscSF       off_process_scatter;
32ec9796c4SHansol Suh   PetscSF       gscatter;
33ec9796c4SHansol Suh   Vec           off_process_values; /* buffer for off-process values if chained */
34ec9796c4SHansol Suh   PetscBool     test_lmvm;
35ec9796c4SHansol Suh   PetscLogEvent event_f, event_g, event_fg;
36ec9796c4SHansol Suh };
37ec9796c4SHansol Suh 
38ec9796c4SHansol Suh /* -------------- User-defined routines ---------- */
39ec9796c4SHansol Suh 
RosenbrockObjective(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2)40ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
41ec9796c4SHansol Suh {
42ec9796c4SHansol Suh   PetscScalar d = x_2 - x_1 * x_1;
43ec9796c4SHansol Suh   PetscScalar e = 1.0 - x_1;
44ec9796c4SHansol Suh   return alpha * d * d + e * e;
45ec9796c4SHansol Suh }
46ec9796c4SHansol Suh 
47ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;
48ec9796c4SHansol Suh 
RosenbrockGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])49ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
50ec9796c4SHansol Suh {
51ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
52ec9796c4SHansol Suh   PetscScalar e  = 1.0 - x_1;
53ec9796c4SHansol Suh   PetscScalar g2 = alpha * d * 2.0;
54ec9796c4SHansol Suh 
55ec9796c4SHansol Suh   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
56ec9796c4SHansol Suh   g[1] = g2;
57ec9796c4SHansol Suh }
58ec9796c4SHansol Suh 
59ec9796c4SHansol Suh static const PetscInt RosenbrockGradientFlops = 9.0;
60ec9796c4SHansol Suh 
RosenbrockObjectiveGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])61ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
62ec9796c4SHansol Suh {
63ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
64ec9796c4SHansol Suh   PetscScalar e  = 1.0 - x_1;
65ec9796c4SHansol Suh   PetscScalar ad = alpha * d;
66ec9796c4SHansol Suh   PetscScalar g2 = ad * 2.0;
67ec9796c4SHansol Suh 
68ec9796c4SHansol Suh   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
69ec9796c4SHansol Suh   g[1] = g2;
70ec9796c4SHansol Suh   return ad * d + e * e;
71ec9796c4SHansol Suh }
72ec9796c4SHansol Suh 
73ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;
74ec9796c4SHansol Suh 
RosenbrockHessian(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar h[4])75ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
76ec9796c4SHansol Suh {
77ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
78ec9796c4SHansol Suh   PetscScalar g2 = alpha * d * 2.0;
79ec9796c4SHansol Suh   PetscScalar h2 = -4.0 * alpha * x_1;
80ec9796c4SHansol Suh 
81ec9796c4SHansol Suh   h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
82ec9796c4SHansol Suh   h[1] = h[2] = h2;
83ec9796c4SHansol Suh   h[3]        = 2.0 * alpha;
84ec9796c4SHansol Suh }
85ec9796c4SHansol Suh 
86ec9796c4SHansol Suh static const PetscLogDouble RosenbrockHessianFlops = 11.0;
87ec9796c4SHansol Suh 
AppCtxCreate(MPI_Comm comm,AppCtx * ctx)88ec9796c4SHansol Suh static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
89ec9796c4SHansol Suh {
90ec9796c4SHansol Suh   AppCtx             user;
91ec9796c4SHansol Suh   PetscDeviceContext dctx;
92ec9796c4SHansol Suh 
93ec9796c4SHansol Suh   PetscFunctionBegin;
94ec9796c4SHansol Suh   PetscCall(PetscNew(ctx));
95ec9796c4SHansol Suh   user       = *ctx;
96ec9796c4SHansol Suh   user->comm = PETSC_COMM_WORLD;
97ec9796c4SHansol Suh 
98ec9796c4SHansol Suh   /* Initialize problem parameters */
99ec9796c4SHansol Suh   user->n             = 2;
100ec9796c4SHansol Suh   user->problem.alpha = 99.0;
101ec9796c4SHansol Suh   user->problem.bs    = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102ec9796c4SHansol Suh   user->test_lmvm     = PETSC_FALSE;
103ec9796c4SHansol Suh   /* Check for command line arguments to override defaults */
104ec9796c4SHansol Suh   PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105ec9796c4SHansol Suh   PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106ec9796c4SHansol Suh   PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107ec9796c4SHansol Suh   PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108d8b4a066SPierre Jolivet   PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109ec9796c4SHansol Suh   PetscOptionsEnd();
110ec9796c4SHansol Suh   PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111d8b4a066SPierre Jolivet   PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115ec9796c4SHansol Suh   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116ec9796c4SHansol Suh   PetscCall(PetscDeviceContextSetUp(dctx));
117ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
118ec9796c4SHansol Suh }
119ec9796c4SHansol Suh 
AppCtxDestroy(AppCtx * ctx)120ec9796c4SHansol Suh static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121ec9796c4SHansol Suh {
122ec9796c4SHansol Suh   AppCtx user;
123ec9796c4SHansol Suh 
124ec9796c4SHansol Suh   PetscFunctionBegin;
125ec9796c4SHansol Suh   user = *ctx;
126ec9796c4SHansol Suh   *ctx = NULL;
127ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->Hvalues));
128ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->gvalues));
129ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->fvector));
130ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->off_process_values));
131ec9796c4SHansol Suh   PetscCall(PetscSFDestroy(&user->off_process_scatter));
132ec9796c4SHansol Suh   PetscCall(PetscSFDestroy(&user->gscatter));
133ec9796c4SHansol Suh   PetscCall(PetscFree(user));
134ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
135ec9796c4SHansol Suh }
136ec9796c4SHansol Suh 
CreateHessian(AppCtx user,Mat * Hessian)137ec9796c4SHansol Suh static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138ec9796c4SHansol Suh {
139ec9796c4SHansol Suh   Mat         H;
140ec9796c4SHansol Suh   PetscLayout layout;
141ec9796c4SHansol Suh   PetscInt    i_start, i_end, n_local_comp, nnz_local;
142ec9796c4SHansol Suh   PetscInt    c_start, c_end;
143ec9796c4SHansol Suh   PetscInt   *coo_i;
144ec9796c4SHansol Suh   PetscInt   *coo_j;
145ec9796c4SHansol Suh   PetscInt    bs = user->problem.bs;
146ec9796c4SHansol Suh   VecType     vec_type;
147ec9796c4SHansol Suh 
148ec9796c4SHansol Suh   PetscFunctionBegin;
149ec9796c4SHansol Suh   /* Partition the optimization variables and the computations.
150ec9796c4SHansol Suh      There are (bs - 1) contributions to the objective function for every (bs)
151ec9796c4SHansol Suh      degrees of freedom. */
152ec9796c4SHansol Suh   PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153ec9796c4SHansol Suh   PetscCall(PetscLayoutSetUp(layout));
154ec9796c4SHansol Suh   PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155ec9796c4SHansol Suh   user->problem.i_start = i_start;
156ec9796c4SHansol Suh   user->problem.i_end   = i_end;
157ec9796c4SHansol Suh   user->n_local         = i_end - i_start;
158ec9796c4SHansol Suh   user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159ec9796c4SHansol Suh   user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160ec9796c4SHansol Suh   user->n_local_comp = n_local_comp = c_end - c_start;
161ec9796c4SHansol Suh 
162ec9796c4SHansol Suh   PetscCall(MatCreate(user->comm, Hessian));
163ec9796c4SHansol Suh   H = *Hessian;
164ec9796c4SHansol Suh   PetscCall(MatSetLayouts(H, layout, layout));
165ec9796c4SHansol Suh   PetscCall(PetscLayoutDestroy(&layout));
166ec9796c4SHansol Suh   PetscCall(MatSetType(H, MATAIJ));
167ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172ec9796c4SHansol Suh   PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */
173ec9796c4SHansol Suh 
174ec9796c4SHansol Suh   nnz_local = n_local_comp * 4;
175ec9796c4SHansol Suh   PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176ec9796c4SHansol Suh   /* Instead of having one computation thread per row of the matrix,
177ec9796c4SHansol Suh      this example uses one thread per contribution to the objective
178ec9796c4SHansol Suh      function.  Each contribution to the objective function relates
179ec9796c4SHansol Suh      two adjacent degrees of freedom, so each contribution to
180ec9796c4SHansol Suh      the objective function adds a 2x2 block into the matrix.
181ec9796c4SHansol Suh      We describe these 2x2 blocks in COO format. */
182ec9796c4SHansol Suh   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183ec9796c4SHansol Suh     PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);
184ec9796c4SHansol Suh 
185ec9796c4SHansol Suh     coo_i[k + 0] = i;
186ec9796c4SHansol Suh     coo_i[k + 1] = i;
187ec9796c4SHansol Suh     coo_i[k + 2] = i + 1;
188ec9796c4SHansol Suh     coo_i[k + 3] = i + 1;
189ec9796c4SHansol Suh 
190ec9796c4SHansol Suh     coo_j[k + 0] = i;
191ec9796c4SHansol Suh     coo_j[k + 1] = i + 1;
192ec9796c4SHansol Suh     coo_j[k + 2] = i;
193ec9796c4SHansol Suh     coo_j[k + 3] = i + 1;
194ec9796c4SHansol Suh   }
195ec9796c4SHansol Suh   PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196ec9796c4SHansol Suh   PetscCall(PetscFree2(coo_i, coo_j));
197ec9796c4SHansol Suh 
198ec9796c4SHansol Suh   PetscCall(MatGetVecType(H, &vec_type));
199ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->Hvalues));
200ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201ec9796c4SHansol Suh   PetscCall(VecSetType(user->Hvalues, vec_type));
202ec9796c4SHansol Suh 
203ec9796c4SHansol Suh   // vector to collect contributions to the objective
204ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->fvector));
205ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206ec9796c4SHansol Suh   PetscCall(VecSetType(user->fvector, vec_type));
207ec9796c4SHansol Suh 
208ec9796c4SHansol Suh   { /* If we are using a device (such as a GPU), run some computations that will
209ec9796c4SHansol Suh        warm up its linear algebra runtime before the problem we actually want
210ec9796c4SHansol Suh        to profile */
211ec9796c4SHansol Suh 
212ec9796c4SHansol Suh     PetscMemType       memtype;
213ec9796c4SHansol Suh     const PetscScalar *a;
214ec9796c4SHansol Suh 
215ec9796c4SHansol Suh     PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216ec9796c4SHansol Suh     PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));
217ec9796c4SHansol Suh 
218ec9796c4SHansol Suh     if (memtype == PETSC_MEMTYPE_DEVICE) {
219ec9796c4SHansol Suh       PetscLogStage      warmup;
220ec9796c4SHansol Suh       Mat                A, AtA;
221ec9796c4SHansol Suh       Vec                x, b;
222ec9796c4SHansol Suh       PetscInt           warmup_size = 1000;
223ec9796c4SHansol Suh       PetscDeviceContext dctx;
224ec9796c4SHansol Suh 
225ec9796c4SHansol Suh       PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226ec9796c4SHansol Suh       PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));
227ec9796c4SHansol Suh 
228ec9796c4SHansol Suh       PetscCall(PetscLogStagePush(warmup));
229ec9796c4SHansol Suh       PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230ec9796c4SHansol Suh       PetscCall(MatSetRandom(A, NULL));
231ec9796c4SHansol Suh       PetscCall(MatCreateVecs(A, &x, &b));
232ec9796c4SHansol Suh       PetscCall(VecSetRandom(x, NULL));
233ec9796c4SHansol Suh 
234ec9796c4SHansol Suh       PetscCall(MatMult(A, x, b));
235fb842aefSJose E. Roman       PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &AtA));
236ec9796c4SHansol Suh       PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237ec9796c4SHansol Suh       PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238ec9796c4SHansol Suh       PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239ec9796c4SHansol Suh       PetscCall(MatDestroy(&AtA));
240ec9796c4SHansol Suh       PetscCall(VecDestroy(&b));
241ec9796c4SHansol Suh       PetscCall(VecDestroy(&x));
242ec9796c4SHansol Suh       PetscCall(MatDestroy(&A));
243ec9796c4SHansol Suh       PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244ec9796c4SHansol Suh       PetscCall(PetscDeviceContextSynchronize(dctx));
245ec9796c4SHansol Suh       PetscCall(PetscLogStagePop());
246ec9796c4SHansol Suh     }
247ec9796c4SHansol Suh   }
248ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
249ec9796c4SHansol Suh }
250ec9796c4SHansol Suh 
CreateVectors(AppCtx user,Mat H,Vec * solution,Vec * gradient)251ec9796c4SHansol Suh static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252ec9796c4SHansol Suh {
253ec9796c4SHansol Suh   VecType     vec_type;
254ec9796c4SHansol Suh   PetscInt    n_coo, *coo_i, i_start, i_end;
255ec9796c4SHansol Suh   Vec         x;
256ec9796c4SHansol Suh   PetscInt    n_recv;
257ec9796c4SHansol Suh   PetscSFNode recv;
258ec9796c4SHansol Suh   PetscLayout layout;
259ec9796c4SHansol Suh   PetscInt    c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;
260ec9796c4SHansol Suh 
261ec9796c4SHansol Suh   PetscFunctionBegin;
262ec9796c4SHansol Suh   PetscCall(MatCreateVecs(H, solution, gradient));
263ec9796c4SHansol Suh   x = *solution;
264ec9796c4SHansol Suh   PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265ec9796c4SHansol Suh   PetscCall(VecGetType(x, &vec_type));
266ec9796c4SHansol Suh   // create scatter for communicating values
267ec9796c4SHansol Suh   PetscCall(VecGetLayout(x, &layout));
268ec9796c4SHansol Suh   n_recv = 0;
269ec9796c4SHansol Suh   if (user->n_local_comp && i_end < user->n) {
270ec9796c4SHansol Suh     PetscMPIInt rank;
271ec9796c4SHansol Suh     PetscInt    index;
272ec9796c4SHansol Suh 
273ec9796c4SHansol Suh     n_recv = 1;
274ec9796c4SHansol Suh     PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275ec9796c4SHansol Suh     recv.rank  = rank;
276ec9796c4SHansol Suh     recv.index = index;
277ec9796c4SHansol Suh   }
278ec9796c4SHansol Suh   PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279ec9796c4SHansol Suh   PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->off_process_values));
281ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282ec9796c4SHansol Suh   PetscCall(VecSetType(user->off_process_values, vec_type));
283ec9796c4SHansol Suh   PetscCall(VecZeroEntries(user->off_process_values));
284ec9796c4SHansol Suh 
285ec9796c4SHansol Suh   // create COO data for writing the gradient
286ec9796c4SHansol Suh   n_coo = user->n_local_comp * 2;
287ec9796c4SHansol Suh   PetscCall(PetscMalloc1(n_coo, &coo_i));
288ec9796c4SHansol Suh   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289ec9796c4SHansol Suh     PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));
290ec9796c4SHansol Suh 
291ec9796c4SHansol Suh     coo_i[k + 0] = i;
292ec9796c4SHansol Suh     coo_i[k + 1] = i + 1;
293ec9796c4SHansol Suh   }
294ec9796c4SHansol Suh   PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295ec9796c4SHansol Suh   PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296ec9796c4SHansol Suh   PetscCall(PetscSFSetUp(user->gscatter));
297ec9796c4SHansol Suh   PetscCall(PetscFree(coo_i));
298ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->gvalues));
299ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300ec9796c4SHansol Suh   PetscCall(VecSetType(user->gvalues, vec_type));
301ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
302ec9796c4SHansol Suh }
303ec9796c4SHansol Suh 
304ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
305ec9796c4SHansol Suh 
306ec9796c4SHansol Suh   #if PetscDefined(USING_NVCC)
307ec9796c4SHansol Suh typedef cudaStream_t cupmStream_t;
308ec9796c4SHansol Suh     #define PetscCUPMLaunch(...) \
309ec9796c4SHansol Suh       do { \
310ec9796c4SHansol Suh         __VA_ARGS__; \
311ec9796c4SHansol Suh         PetscCallCUDA(cudaGetLastError()); \
312ec9796c4SHansol Suh       } while (0)
313ec9796c4SHansol Suh   #elif PetscDefined(USING_HCC)
314ec9796c4SHansol Suh     #define PetscCUPMLaunch(...) \
315ec9796c4SHansol Suh       do { \
316ec9796c4SHansol Suh         __VA_ARGS__; \
317ec9796c4SHansol Suh         PetscCallHIP(hipGetLastError()); \
318ec9796c4SHansol Suh       } while (0)
319ec9796c4SHansol Suh typedef hipStream_t cupmStream_t;
320ec9796c4SHansol Suh   #endif
321ec9796c4SHansol Suh 
322ec9796c4SHansol Suh // x: on-process optimization variables
323ec9796c4SHansol Suh // o: buffer that contains the next optimization variable after the variables on this process
324ec9796c4SHansol Suh template <typename T>
rosenbrock_for_loop(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],T && func)325ec9796c4SHansol Suh PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326ec9796c4SHansol Suh {
327ec9796c4SHansol Suh   PetscInt idx         = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328ec9796c4SHansol Suh   PetscInt num_threads = gridDim.x * blockDim.x;
329ec9796c4SHansol Suh 
330ec9796c4SHansol Suh   for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
333ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
334ec9796c4SHansol Suh 
335ec9796c4SHansol Suh     func(k, x_a, x_b);
336ec9796c4SHansol Suh   }
337ec9796c4SHansol Suh   return;
338ec9796c4SHansol Suh }
339ec9796c4SHansol Suh 
RosenbrockObjective_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])340ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341ec9796c4SHansol Suh {
342ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343ec9796c4SHansol Suh }
344ec9796c4SHansol Suh 
RosenbrockGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])345ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346ec9796c4SHansol Suh {
347ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348ec9796c4SHansol Suh }
349ec9796c4SHansol Suh 
RosenbrockObjectiveGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])350ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351ec9796c4SHansol Suh {
352ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353ec9796c4SHansol Suh }
354ec9796c4SHansol Suh 
RosenbrockHessian_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])355ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356ec9796c4SHansol Suh {
357ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358ec9796c4SHansol Suh }
359ec9796c4SHansol Suh 
RosenbrockObjective_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])360ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361ec9796c4SHansol Suh {
362ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
363ec9796c4SHansol Suh 
364ec9796c4SHansol Suh   PetscFunctionBegin;
365ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
368ec9796c4SHansol Suh }
369ec9796c4SHansol Suh 
RosenbrockGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])370ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371ec9796c4SHansol Suh {
372ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
373ec9796c4SHansol Suh 
374ec9796c4SHansol Suh   PetscFunctionBegin;
375ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
378ec9796c4SHansol Suh }
379ec9796c4SHansol Suh 
RosenbrockObjectiveGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])380ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381ec9796c4SHansol Suh {
382ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
383ec9796c4SHansol Suh 
384ec9796c4SHansol Suh   PetscFunctionBegin;
385ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
388ec9796c4SHansol Suh }
389ec9796c4SHansol Suh 
RosenbrockHessian_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])390ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391ec9796c4SHansol Suh {
392ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
393ec9796c4SHansol Suh 
394ec9796c4SHansol Suh   PetscFunctionBegin;
395ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
398ec9796c4SHansol Suh }
399ec9796c4SHansol Suh #endif
400ec9796c4SHansol Suh 
RosenbrockObjective_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f)401ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402ec9796c4SHansol Suh {
403ec9796c4SHansol Suh   PetscReal _f = 0.0;
404ec9796c4SHansol Suh 
405ec9796c4SHansol Suh   PetscFunctionBegin;
406ec9796c4SHansol Suh   for (PetscInt c = r.c_start; c < r.c_end; c++) {
407ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
409ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
410ec9796c4SHansol Suh 
411ec9796c4SHansol Suh     _f += RosenbrockObjective(r.alpha, x_a, x_b);
412ec9796c4SHansol Suh   }
413ec9796c4SHansol Suh   *f = _f;
414ec9796c4SHansol Suh   PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
416ec9796c4SHansol Suh }
417ec9796c4SHansol Suh 
RosenbrockGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])418ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419ec9796c4SHansol Suh {
420ec9796c4SHansol Suh   PetscFunctionBegin;
421ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
424ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
425ec9796c4SHansol Suh 
426ec9796c4SHansol Suh     RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427ec9796c4SHansol Suh   }
428ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
430ec9796c4SHansol Suh }
431ec9796c4SHansol Suh 
RosenbrockObjectiveGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f,PetscScalar g[])432ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433ec9796c4SHansol Suh {
434ec9796c4SHansol Suh   PetscReal _f = 0.0;
435ec9796c4SHansol Suh 
436ec9796c4SHansol Suh   PetscFunctionBegin;
437ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
440ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
441ec9796c4SHansol Suh 
442ec9796c4SHansol Suh     _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443ec9796c4SHansol Suh   }
444ec9796c4SHansol Suh   *f = _f;
445ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
447ec9796c4SHansol Suh }
448ec9796c4SHansol Suh 
RosenbrockHessian_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])449ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450ec9796c4SHansol Suh {
451ec9796c4SHansol Suh   PetscFunctionBegin;
452ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
455ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
456ec9796c4SHansol Suh 
457ec9796c4SHansol Suh     RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458ec9796c4SHansol Suh   }
459ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
461ec9796c4SHansol Suh }
462ec9796c4SHansol Suh 
463ec9796c4SHansol Suh /* -------------------------------------------------------------------- */
464ec9796c4SHansol Suh 
FormObjective(Tao tao,Vec X,PetscReal * f,void * ptr)465ec9796c4SHansol Suh static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466ec9796c4SHansol Suh {
467ec9796c4SHansol Suh   AppCtx             user    = (AppCtx)ptr;
468ec9796c4SHansol Suh   PetscReal          f_local = 0.0;
469ec9796c4SHansol Suh   const PetscScalar *x;
470ec9796c4SHansol Suh   const PetscScalar *o = NULL;
471ec9796c4SHansol Suh   PetscMemType       memtype_x;
472ec9796c4SHansol Suh 
473ec9796c4SHansol Suh   PetscFunctionBeginUser;
474ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
480ec9796c4SHansol Suh     PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481*6a210b70SBarry Smith     PetscCallMPI(MPIU_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
483ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484ec9796c4SHansol Suh     PetscScalar       *_fvec;
485ec9796c4SHansol Suh     PetscScalar        f_scalar;
486ec9796c4SHansol Suh     cupmStream_t      *stream;
487ec9796c4SHansol Suh     PetscDeviceContext dctx;
488ec9796c4SHansol Suh 
489ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491ec9796c4SHansol Suh     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492ec9796c4SHansol Suh     PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493ec9796c4SHansol Suh     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494ec9796c4SHansol Suh     PetscCall(VecSum(user->fvector, &f_scalar));
495ec9796c4SHansol Suh     *f = PetscRealPart(f_scalar);
496ec9796c4SHansol Suh #endif
497d8b4a066SPierre Jolivet   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
498ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
502ec9796c4SHansol Suh }
503ec9796c4SHansol Suh 
FormGradient(Tao tao,Vec X,Vec G,void * ptr)504ec9796c4SHansol Suh static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505ec9796c4SHansol Suh {
506ec9796c4SHansol Suh   AppCtx             user = (AppCtx)ptr;
507ec9796c4SHansol Suh   PetscScalar       *g;
508ec9796c4SHansol Suh   const PetscScalar *x;
509ec9796c4SHansol Suh   const PetscScalar *o = NULL;
510ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_g;
511ec9796c4SHansol Suh 
512ec9796c4SHansol Suh   PetscFunctionBeginUser;
513ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
521ec9796c4SHansol Suh     PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
523ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524ec9796c4SHansol Suh     cupmStream_t      *stream;
525ec9796c4SHansol Suh     PetscDeviceContext dctx;
526ec9796c4SHansol Suh 
527ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529ec9796c4SHansol Suh     PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530ec9796c4SHansol Suh #endif
531d8b4a066SPierre Jolivet   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
532ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535ec9796c4SHansol Suh   PetscCall(VecZeroEntries(G));
536ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
540ec9796c4SHansol Suh }
541ec9796c4SHansol Suh 
542ec9796c4SHansol Suh /*
543ec9796c4SHansol Suh     FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).
544ec9796c4SHansol Suh 
545ec9796c4SHansol Suh     Input Parameters:
546ec9796c4SHansol Suh .   tao  - the Tao context
547ec9796c4SHansol Suh .   X    - input vector
548ec9796c4SHansol Suh .   ptr  - optional user-defined context, as set by TaoSetObjectiveGradient()
549ec9796c4SHansol Suh 
550ec9796c4SHansol Suh     Output Parameters:
551ec9796c4SHansol Suh .   G - vector containing the newly evaluated gradient
552ec9796c4SHansol Suh .   f - function value
553ec9796c4SHansol Suh 
554ec9796c4SHansol Suh     Note:
555ec9796c4SHansol Suh     Some optimization methods ask for the function and the gradient evaluation
556ec9796c4SHansol Suh     at the same time.  Evaluating both at once may be more efficient that
557ec9796c4SHansol Suh     evaluating each separately.
558ec9796c4SHansol Suh */
FormObjectiveGradient(Tao tao,Vec X,PetscReal * f,Vec G,void * ptr)559ec9796c4SHansol Suh static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560ec9796c4SHansol Suh {
561ec9796c4SHansol Suh   AppCtx             user    = (AppCtx)ptr;
562ec9796c4SHansol Suh   PetscReal          f_local = 0.0;
563ec9796c4SHansol Suh   PetscScalar       *g;
564ec9796c4SHansol Suh   const PetscScalar *x;
565ec9796c4SHansol Suh   const PetscScalar *o = NULL;
566ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_g;
567ec9796c4SHansol Suh 
568ec9796c4SHansol Suh   PetscFunctionBeginUser;
569ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
577ec9796c4SHansol Suh     PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578*6a210b70SBarry Smith     PetscCallMPI(MPIU_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
580ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581ec9796c4SHansol Suh     PetscScalar       *_fvec;
582ec9796c4SHansol Suh     PetscScalar        f_scalar;
583ec9796c4SHansol Suh     cupmStream_t      *stream;
584ec9796c4SHansol Suh     PetscDeviceContext dctx;
585ec9796c4SHansol Suh 
586ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588ec9796c4SHansol Suh     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589ec9796c4SHansol Suh     PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590ec9796c4SHansol Suh     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591ec9796c4SHansol Suh     PetscCall(VecSum(user->fvector, &f_scalar));
592ec9796c4SHansol Suh     *f = PetscRealPart(f_scalar);
593ec9796c4SHansol Suh #endif
594d8b4a066SPierre Jolivet   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
595ec9796c4SHansol Suh 
596ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599ec9796c4SHansol Suh   PetscCall(VecZeroEntries(G));
600ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
604ec9796c4SHansol Suh }
605ec9796c4SHansol Suh 
606ec9796c4SHansol Suh /* ------------------------------------------------------------------- */
607ec9796c4SHansol Suh /*
608ec9796c4SHansol Suh    FormHessian - Evaluates Hessian matrix.
609ec9796c4SHansol Suh 
610ec9796c4SHansol Suh    Input Parameters:
611ec9796c4SHansol Suh .  tao   - the Tao context
612ec9796c4SHansol Suh .  x     - input vector
613ec9796c4SHansol Suh .  ptr   - optional user-defined context, as set by TaoSetHessian()
614ec9796c4SHansol Suh 
615ec9796c4SHansol Suh    Output Parameters:
616ec9796c4SHansol Suh .  H     - Hessian matrix
617ec9796c4SHansol Suh 
618ec9796c4SHansol Suh    Note:  Providing the Hessian may not be necessary.  Only some solvers
619ec9796c4SHansol Suh    require this matrix.
620ec9796c4SHansol Suh */
FormHessian(Tao tao,Vec X,Mat H,Mat Hpre,void * ptr)621ec9796c4SHansol Suh static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622ec9796c4SHansol Suh {
623ec9796c4SHansol Suh   AppCtx             user = (AppCtx)ptr;
624ec9796c4SHansol Suh   PetscScalar       *h;
625ec9796c4SHansol Suh   const PetscScalar *x;
626ec9796c4SHansol Suh   const PetscScalar *o = NULL;
627ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_h;
628ec9796c4SHansol Suh 
629ec9796c4SHansol Suh   PetscFunctionBeginUser;
630ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
637ec9796c4SHansol Suh     PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
639ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640ec9796c4SHansol Suh     cupmStream_t      *stream;
641ec9796c4SHansol Suh     PetscDeviceContext dctx;
642ec9796c4SHansol Suh 
643ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645ec9796c4SHansol Suh     PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646ec9796c4SHansol Suh #endif
647d8b4a066SPierre Jolivet   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
648ec9796c4SHansol Suh 
649ec9796c4SHansol Suh   PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));
651ec9796c4SHansol Suh 
652ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
654ec9796c4SHansol Suh 
655ec9796c4SHansol Suh   if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
657ec9796c4SHansol Suh }
658ec9796c4SHansol Suh 
TestLMVM(Tao tao)659ec9796c4SHansol Suh static PetscErrorCode TestLMVM(Tao tao)
660ec9796c4SHansol Suh {
661ec9796c4SHansol Suh   KSP       ksp;
662ec9796c4SHansol Suh   PC        pc;
663ec9796c4SHansol Suh   PetscBool is_lmvm;
664ec9796c4SHansol Suh 
665ec9796c4SHansol Suh   PetscFunctionBegin;
666ec9796c4SHansol Suh   PetscCall(TaoGetKSP(tao, &ksp));
667ec9796c4SHansol Suh   if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668ec9796c4SHansol Suh   PetscCall(KSPGetPC(ksp, &pc));
669ec9796c4SHansol Suh   PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670ec9796c4SHansol Suh   if (is_lmvm) {
671ec9796c4SHansol Suh     Mat       M;
672ec9796c4SHansol Suh     Vec       in, out, out2;
673ec9796c4SHansol Suh     PetscReal mult_solve_dist;
674ec9796c4SHansol Suh     Vec       x;
675ec9796c4SHansol Suh 
676ec9796c4SHansol Suh     PetscCall(PCLMVMGetMatLMVM(pc, &M));
677ec9796c4SHansol Suh     PetscCall(TaoGetSolution(tao, &x));
678ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &in));
679ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &out));
680ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &out2));
681ec9796c4SHansol Suh     PetscCall(VecSetRandom(in, NULL));
682ec9796c4SHansol Suh     PetscCall(MatMult(M, in, out));
683ec9796c4SHansol Suh     PetscCall(MatSolve(M, out, out2));
684ec9796c4SHansol Suh 
685ec9796c4SHansol Suh     PetscCall(VecAXPY(out2, -1.0, in));
686ec9796c4SHansol Suh     PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687ec9796c4SHansol Suh     if (mult_solve_dist < 1.e-11) {
688ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689ec9796c4SHansol Suh     } else if (mult_solve_dist < 1.e-6) {
690ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691ec9796c4SHansol Suh     } else {
692ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693ec9796c4SHansol Suh     }
694ec9796c4SHansol Suh     PetscCall(VecDestroy(&in));
695ec9796c4SHansol Suh     PetscCall(VecDestroy(&out));
696ec9796c4SHansol Suh     PetscCall(VecDestroy(&out2));
697ec9796c4SHansol Suh   }
698ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
699ec9796c4SHansol Suh }
700ec9796c4SHansol Suh 
RosenbrockMain(void)701ec9796c4SHansol Suh static PetscErrorCode RosenbrockMain(void)
702ec9796c4SHansol Suh {
703ec9796c4SHansol Suh   Vec           x;    /* solution vector */
704ec9796c4SHansol Suh   Vec           g;    /* gradient vector */
705ec9796c4SHansol Suh   Mat           H;    /* Hessian matrix */
706ec9796c4SHansol Suh   Tao           tao;  /* Tao solver context */
707ec9796c4SHansol Suh   AppCtx        user; /* user-defined application context */
708ec9796c4SHansol Suh   PetscLogStage solve;
709ec9796c4SHansol Suh 
710ec9796c4SHansol Suh   /* Initialize TAO and PETSc */
711ec9796c4SHansol Suh   PetscFunctionBegin;
712ec9796c4SHansol Suh   PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));
713ec9796c4SHansol Suh 
714ec9796c4SHansol Suh   PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715ec9796c4SHansol Suh   PetscCall(CreateHessian(user, &H));
716ec9796c4SHansol Suh   PetscCall(CreateVectors(user, H, &x, &g));
717ec9796c4SHansol Suh 
718ec9796c4SHansol Suh   /* The TAO code begins here */
719ec9796c4SHansol Suh 
720ec9796c4SHansol Suh   PetscCall(TaoCreate(user->comm, &tao));
721ec9796c4SHansol Suh   PetscCall(VecZeroEntries(x));
722ec9796c4SHansol Suh   PetscCall(TaoSetSolution(tao, x));
723ec9796c4SHansol Suh 
724ec9796c4SHansol Suh   /* Set routines for function, gradient, hessian evaluation */
725ec9796c4SHansol Suh   PetscCall(TaoSetObjective(tao, FormObjective, user));
726ec9796c4SHansol Suh   PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727ec9796c4SHansol Suh   PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728ec9796c4SHansol Suh   PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));
729ec9796c4SHansol Suh 
730ec9796c4SHansol Suh   PetscCall(TaoSetFromOptions(tao));
731ec9796c4SHansol Suh 
732ec9796c4SHansol Suh   /* SOLVE THE APPLICATION */
733ec9796c4SHansol Suh   PetscCall(PetscLogStagePush(solve));
734ec9796c4SHansol Suh   PetscCall(TaoSolve(tao));
735ec9796c4SHansol Suh   PetscCall(PetscLogStagePop());
736ec9796c4SHansol Suh 
737ec9796c4SHansol Suh   if (user->test_lmvm) PetscCall(TestLMVM(tao));
738ec9796c4SHansol Suh 
739ec9796c4SHansol Suh   PetscCall(TaoDestroy(&tao));
740ec9796c4SHansol Suh   PetscCall(VecDestroy(&g));
741ec9796c4SHansol Suh   PetscCall(VecDestroy(&x));
742ec9796c4SHansol Suh   PetscCall(MatDestroy(&H));
743ec9796c4SHansol Suh   PetscCall(AppCtxDestroy(&user));
744ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
745ec9796c4SHansol Suh }
746