1ec9796c4SHansol Suh #pragma once
2ec9796c4SHansol Suh
3ec9796c4SHansol Suh #include <petsctao.h>
4ec9796c4SHansol Suh #include <petscsf.h>
5ec9796c4SHansol Suh #include <petscdevice.h>
6ec9796c4SHansol Suh #include <petscdevice_cupm.h>
7ec9796c4SHansol Suh
8ec9796c4SHansol Suh /*
9ec9796c4SHansol Suh User-defined application context - contains data needed by the
10ec9796c4SHansol Suh application-provided call-back routines that evaluate the function,
11ec9796c4SHansol Suh gradient, and hessian.
12ec9796c4SHansol Suh */
13ec9796c4SHansol Suh
14ec9796c4SHansol Suh typedef struct _Rosenbrock {
15ec9796c4SHansol Suh PetscInt bs; // each block of bs variables is one chained multidimensional rosenbrock problem
16ec9796c4SHansol Suh PetscInt i_start, i_end;
17ec9796c4SHansol Suh PetscInt c_start, c_end;
18ec9796c4SHansol Suh PetscReal alpha; // condition parameter
19ec9796c4SHansol Suh } Rosenbrock;
20ec9796c4SHansol Suh
21ec9796c4SHansol Suh typedef struct _AppCtx *AppCtx;
22ec9796c4SHansol Suh struct _AppCtx {
23ec9796c4SHansol Suh MPI_Comm comm;
24ec9796c4SHansol Suh PetscInt n; /* dimension */
25ec9796c4SHansol Suh PetscInt n_local;
26ec9796c4SHansol Suh PetscInt n_local_comp;
27ec9796c4SHansol Suh Rosenbrock problem;
28ec9796c4SHansol Suh Vec Hvalues; /* vector for writing COO values of this MPI process */
29ec9796c4SHansol Suh Vec gvalues; /* vector for writing gradient values of this mpi process */
30ec9796c4SHansol Suh Vec fvector;
31ec9796c4SHansol Suh PetscSF off_process_scatter;
32ec9796c4SHansol Suh PetscSF gscatter;
33ec9796c4SHansol Suh Vec off_process_values; /* buffer for off-process values if chained */
34ec9796c4SHansol Suh PetscBool test_lmvm;
35ec9796c4SHansol Suh PetscLogEvent event_f, event_g, event_fg;
36ec9796c4SHansol Suh };
37ec9796c4SHansol Suh
38ec9796c4SHansol Suh /* -------------- User-defined routines ---------- */
39ec9796c4SHansol Suh
RosenbrockObjective(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2)40ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
41ec9796c4SHansol Suh {
42ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1;
43ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1;
44ec9796c4SHansol Suh return alpha * d * d + e * e;
45ec9796c4SHansol Suh }
46ec9796c4SHansol Suh
47ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;
48ec9796c4SHansol Suh
RosenbrockGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])49ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
50ec9796c4SHansol Suh {
51ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1;
52ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1;
53ec9796c4SHansol Suh PetscScalar g2 = alpha * d * 2.0;
54ec9796c4SHansol Suh
55ec9796c4SHansol Suh g[0] = -2.0 * x_1 * g2 - 2.0 * e;
56ec9796c4SHansol Suh g[1] = g2;
57ec9796c4SHansol Suh }
58ec9796c4SHansol Suh
59ec9796c4SHansol Suh static const PetscInt RosenbrockGradientFlops = 9.0;
60ec9796c4SHansol Suh
RosenbrockObjectiveGradient(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar g[2])61ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
62ec9796c4SHansol Suh {
63ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1;
64ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1;
65ec9796c4SHansol Suh PetscScalar ad = alpha * d;
66ec9796c4SHansol Suh PetscScalar g2 = ad * 2.0;
67ec9796c4SHansol Suh
68ec9796c4SHansol Suh g[0] = -2.0 * x_1 * g2 - 2.0 * e;
69ec9796c4SHansol Suh g[1] = g2;
70ec9796c4SHansol Suh return ad * d + e * e;
71ec9796c4SHansol Suh }
72ec9796c4SHansol Suh
73ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;
74ec9796c4SHansol Suh
RosenbrockHessian(PetscScalar alpha,PetscScalar x_1,PetscScalar x_2,PetscScalar h[4])75ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
76ec9796c4SHansol Suh {
77ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1;
78ec9796c4SHansol Suh PetscScalar g2 = alpha * d * 2.0;
79ec9796c4SHansol Suh PetscScalar h2 = -4.0 * alpha * x_1;
80ec9796c4SHansol Suh
81ec9796c4SHansol Suh h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
82ec9796c4SHansol Suh h[1] = h[2] = h2;
83ec9796c4SHansol Suh h[3] = 2.0 * alpha;
84ec9796c4SHansol Suh }
85ec9796c4SHansol Suh
86ec9796c4SHansol Suh static const PetscLogDouble RosenbrockHessianFlops = 11.0;
87ec9796c4SHansol Suh
AppCtxCreate(MPI_Comm comm,AppCtx * ctx)88ec9796c4SHansol Suh static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
89ec9796c4SHansol Suh {
90ec9796c4SHansol Suh AppCtx user;
91ec9796c4SHansol Suh PetscDeviceContext dctx;
92ec9796c4SHansol Suh
93ec9796c4SHansol Suh PetscFunctionBegin;
94ec9796c4SHansol Suh PetscCall(PetscNew(ctx));
95ec9796c4SHansol Suh user = *ctx;
96ec9796c4SHansol Suh user->comm = PETSC_COMM_WORLD;
97ec9796c4SHansol Suh
98ec9796c4SHansol Suh /* Initialize problem parameters */
99ec9796c4SHansol Suh user->n = 2;
100ec9796c4SHansol Suh user->problem.alpha = 99.0;
101ec9796c4SHansol Suh user->problem.bs = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102ec9796c4SHansol Suh user->test_lmvm = PETSC_FALSE;
103ec9796c4SHansol Suh /* Check for command line arguments to override defaults */
104ec9796c4SHansol Suh PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105ec9796c4SHansol Suh PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106ec9796c4SHansol Suh PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107ec9796c4SHansol Suh PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108d8b4a066SPierre Jolivet PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109ec9796c4SHansol Suh PetscOptionsEnd();
110ec9796c4SHansol Suh PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111d8b4a066SPierre Jolivet PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116ec9796c4SHansol Suh PetscCall(PetscDeviceContextSetUp(dctx));
117ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
118ec9796c4SHansol Suh }
119ec9796c4SHansol Suh
AppCtxDestroy(AppCtx * ctx)120ec9796c4SHansol Suh static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121ec9796c4SHansol Suh {
122ec9796c4SHansol Suh AppCtx user;
123ec9796c4SHansol Suh
124ec9796c4SHansol Suh PetscFunctionBegin;
125ec9796c4SHansol Suh user = *ctx;
126ec9796c4SHansol Suh *ctx = NULL;
127ec9796c4SHansol Suh PetscCall(VecDestroy(&user->Hvalues));
128ec9796c4SHansol Suh PetscCall(VecDestroy(&user->gvalues));
129ec9796c4SHansol Suh PetscCall(VecDestroy(&user->fvector));
130ec9796c4SHansol Suh PetscCall(VecDestroy(&user->off_process_values));
131ec9796c4SHansol Suh PetscCall(PetscSFDestroy(&user->off_process_scatter));
132ec9796c4SHansol Suh PetscCall(PetscSFDestroy(&user->gscatter));
133ec9796c4SHansol Suh PetscCall(PetscFree(user));
134ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
135ec9796c4SHansol Suh }
136ec9796c4SHansol Suh
CreateHessian(AppCtx user,Mat * Hessian)137ec9796c4SHansol Suh static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138ec9796c4SHansol Suh {
139ec9796c4SHansol Suh Mat H;
140ec9796c4SHansol Suh PetscLayout layout;
141ec9796c4SHansol Suh PetscInt i_start, i_end, n_local_comp, nnz_local;
142ec9796c4SHansol Suh PetscInt c_start, c_end;
143ec9796c4SHansol Suh PetscInt *coo_i;
144ec9796c4SHansol Suh PetscInt *coo_j;
145ec9796c4SHansol Suh PetscInt bs = user->problem.bs;
146ec9796c4SHansol Suh VecType vec_type;
147ec9796c4SHansol Suh
148ec9796c4SHansol Suh PetscFunctionBegin;
149ec9796c4SHansol Suh /* Partition the optimization variables and the computations.
150ec9796c4SHansol Suh There are (bs - 1) contributions to the objective function for every (bs)
151ec9796c4SHansol Suh degrees of freedom. */
152ec9796c4SHansol Suh PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153ec9796c4SHansol Suh PetscCall(PetscLayoutSetUp(layout));
154ec9796c4SHansol Suh PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155ec9796c4SHansol Suh user->problem.i_start = i_start;
156ec9796c4SHansol Suh user->problem.i_end = i_end;
157ec9796c4SHansol Suh user->n_local = i_end - i_start;
158ec9796c4SHansol Suh user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159ec9796c4SHansol Suh user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160ec9796c4SHansol Suh user->n_local_comp = n_local_comp = c_end - c_start;
161ec9796c4SHansol Suh
162ec9796c4SHansol Suh PetscCall(MatCreate(user->comm, Hessian));
163ec9796c4SHansol Suh H = *Hessian;
164ec9796c4SHansol Suh PetscCall(MatSetLayouts(H, layout, layout));
165ec9796c4SHansol Suh PetscCall(PetscLayoutDestroy(&layout));
166ec9796c4SHansol Suh PetscCall(MatSetType(H, MATAIJ));
167ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172ec9796c4SHansol Suh PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */
173ec9796c4SHansol Suh
174ec9796c4SHansol Suh nnz_local = n_local_comp * 4;
175ec9796c4SHansol Suh PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176ec9796c4SHansol Suh /* Instead of having one computation thread per row of the matrix,
177ec9796c4SHansol Suh this example uses one thread per contribution to the objective
178ec9796c4SHansol Suh function. Each contribution to the objective function relates
179ec9796c4SHansol Suh two adjacent degrees of freedom, so each contribution to
180ec9796c4SHansol Suh the objective function adds a 2x2 block into the matrix.
181ec9796c4SHansol Suh We describe these 2x2 blocks in COO format. */
182ec9796c4SHansol Suh for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183ec9796c4SHansol Suh PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);
184ec9796c4SHansol Suh
185ec9796c4SHansol Suh coo_i[k + 0] = i;
186ec9796c4SHansol Suh coo_i[k + 1] = i;
187ec9796c4SHansol Suh coo_i[k + 2] = i + 1;
188ec9796c4SHansol Suh coo_i[k + 3] = i + 1;
189ec9796c4SHansol Suh
190ec9796c4SHansol Suh coo_j[k + 0] = i;
191ec9796c4SHansol Suh coo_j[k + 1] = i + 1;
192ec9796c4SHansol Suh coo_j[k + 2] = i;
193ec9796c4SHansol Suh coo_j[k + 3] = i + 1;
194ec9796c4SHansol Suh }
195ec9796c4SHansol Suh PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196ec9796c4SHansol Suh PetscCall(PetscFree2(coo_i, coo_j));
197ec9796c4SHansol Suh
198ec9796c4SHansol Suh PetscCall(MatGetVecType(H, &vec_type));
199ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->Hvalues));
200ec9796c4SHansol Suh PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201ec9796c4SHansol Suh PetscCall(VecSetType(user->Hvalues, vec_type));
202ec9796c4SHansol Suh
203ec9796c4SHansol Suh // vector to collect contributions to the objective
204ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->fvector));
205ec9796c4SHansol Suh PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206ec9796c4SHansol Suh PetscCall(VecSetType(user->fvector, vec_type));
207ec9796c4SHansol Suh
208ec9796c4SHansol Suh { /* If we are using a device (such as a GPU), run some computations that will
209ec9796c4SHansol Suh warm up its linear algebra runtime before the problem we actually want
210ec9796c4SHansol Suh to profile */
211ec9796c4SHansol Suh
212ec9796c4SHansol Suh PetscMemType memtype;
213ec9796c4SHansol Suh const PetscScalar *a;
214ec9796c4SHansol Suh
215ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));
217ec9796c4SHansol Suh
218ec9796c4SHansol Suh if (memtype == PETSC_MEMTYPE_DEVICE) {
219ec9796c4SHansol Suh PetscLogStage warmup;
220ec9796c4SHansol Suh Mat A, AtA;
221ec9796c4SHansol Suh Vec x, b;
222ec9796c4SHansol Suh PetscInt warmup_size = 1000;
223ec9796c4SHansol Suh PetscDeviceContext dctx;
224ec9796c4SHansol Suh
225ec9796c4SHansol Suh PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226ec9796c4SHansol Suh PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));
227ec9796c4SHansol Suh
228ec9796c4SHansol Suh PetscCall(PetscLogStagePush(warmup));
229ec9796c4SHansol Suh PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230ec9796c4SHansol Suh PetscCall(MatSetRandom(A, NULL));
231ec9796c4SHansol Suh PetscCall(MatCreateVecs(A, &x, &b));
232ec9796c4SHansol Suh PetscCall(VecSetRandom(x, NULL));
233ec9796c4SHansol Suh
234ec9796c4SHansol Suh PetscCall(MatMult(A, x, b));
235fb842aefSJose E. Roman PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &AtA));
236ec9796c4SHansol Suh PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237ec9796c4SHansol Suh PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238ec9796c4SHansol Suh PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239ec9796c4SHansol Suh PetscCall(MatDestroy(&AtA));
240ec9796c4SHansol Suh PetscCall(VecDestroy(&b));
241ec9796c4SHansol Suh PetscCall(VecDestroy(&x));
242ec9796c4SHansol Suh PetscCall(MatDestroy(&A));
243ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244ec9796c4SHansol Suh PetscCall(PetscDeviceContextSynchronize(dctx));
245ec9796c4SHansol Suh PetscCall(PetscLogStagePop());
246ec9796c4SHansol Suh }
247ec9796c4SHansol Suh }
248ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
249ec9796c4SHansol Suh }
250ec9796c4SHansol Suh
CreateVectors(AppCtx user,Mat H,Vec * solution,Vec * gradient)251ec9796c4SHansol Suh static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252ec9796c4SHansol Suh {
253ec9796c4SHansol Suh VecType vec_type;
254ec9796c4SHansol Suh PetscInt n_coo, *coo_i, i_start, i_end;
255ec9796c4SHansol Suh Vec x;
256ec9796c4SHansol Suh PetscInt n_recv;
257ec9796c4SHansol Suh PetscSFNode recv;
258ec9796c4SHansol Suh PetscLayout layout;
259ec9796c4SHansol Suh PetscInt c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;
260ec9796c4SHansol Suh
261ec9796c4SHansol Suh PetscFunctionBegin;
262ec9796c4SHansol Suh PetscCall(MatCreateVecs(H, solution, gradient));
263ec9796c4SHansol Suh x = *solution;
264ec9796c4SHansol Suh PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265ec9796c4SHansol Suh PetscCall(VecGetType(x, &vec_type));
266ec9796c4SHansol Suh // create scatter for communicating values
267ec9796c4SHansol Suh PetscCall(VecGetLayout(x, &layout));
268ec9796c4SHansol Suh n_recv = 0;
269ec9796c4SHansol Suh if (user->n_local_comp && i_end < user->n) {
270ec9796c4SHansol Suh PetscMPIInt rank;
271ec9796c4SHansol Suh PetscInt index;
272ec9796c4SHansol Suh
273ec9796c4SHansol Suh n_recv = 1;
274ec9796c4SHansol Suh PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275ec9796c4SHansol Suh recv.rank = rank;
276ec9796c4SHansol Suh recv.index = index;
277ec9796c4SHansol Suh }
278ec9796c4SHansol Suh PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279ec9796c4SHansol Suh PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->off_process_values));
281ec9796c4SHansol Suh PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282ec9796c4SHansol Suh PetscCall(VecSetType(user->off_process_values, vec_type));
283ec9796c4SHansol Suh PetscCall(VecZeroEntries(user->off_process_values));
284ec9796c4SHansol Suh
285ec9796c4SHansol Suh // create COO data for writing the gradient
286ec9796c4SHansol Suh n_coo = user->n_local_comp * 2;
287ec9796c4SHansol Suh PetscCall(PetscMalloc1(n_coo, &coo_i));
288ec9796c4SHansol Suh for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289ec9796c4SHansol Suh PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));
290ec9796c4SHansol Suh
291ec9796c4SHansol Suh coo_i[k + 0] = i;
292ec9796c4SHansol Suh coo_i[k + 1] = i + 1;
293ec9796c4SHansol Suh }
294ec9796c4SHansol Suh PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295ec9796c4SHansol Suh PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296ec9796c4SHansol Suh PetscCall(PetscSFSetUp(user->gscatter));
297ec9796c4SHansol Suh PetscCall(PetscFree(coo_i));
298ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->gvalues));
299ec9796c4SHansol Suh PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300ec9796c4SHansol Suh PetscCall(VecSetType(user->gvalues, vec_type));
301ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
302ec9796c4SHansol Suh }
303ec9796c4SHansol Suh
304ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
305ec9796c4SHansol Suh
306ec9796c4SHansol Suh #if PetscDefined(USING_NVCC)
307ec9796c4SHansol Suh typedef cudaStream_t cupmStream_t;
308ec9796c4SHansol Suh #define PetscCUPMLaunch(...) \
309ec9796c4SHansol Suh do { \
310ec9796c4SHansol Suh __VA_ARGS__; \
311ec9796c4SHansol Suh PetscCallCUDA(cudaGetLastError()); \
312ec9796c4SHansol Suh } while (0)
313ec9796c4SHansol Suh #elif PetscDefined(USING_HCC)
314ec9796c4SHansol Suh #define PetscCUPMLaunch(...) \
315ec9796c4SHansol Suh do { \
316ec9796c4SHansol Suh __VA_ARGS__; \
317ec9796c4SHansol Suh PetscCallHIP(hipGetLastError()); \
318ec9796c4SHansol Suh } while (0)
319ec9796c4SHansol Suh typedef hipStream_t cupmStream_t;
320ec9796c4SHansol Suh #endif
321ec9796c4SHansol Suh
322ec9796c4SHansol Suh // x: on-process optimization variables
323ec9796c4SHansol Suh // o: buffer that contains the next optimization variable after the variables on this process
324ec9796c4SHansol Suh template <typename T>
rosenbrock_for_loop(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],T && func)325ec9796c4SHansol Suh PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326ec9796c4SHansol Suh {
327ec9796c4SHansol Suh PetscInt idx = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328ec9796c4SHansol Suh PetscInt num_threads = gridDim.x * blockDim.x;
329ec9796c4SHansol Suh
330ec9796c4SHansol Suh for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start];
333ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
334ec9796c4SHansol Suh
335ec9796c4SHansol Suh func(k, x_a, x_b);
336ec9796c4SHansol Suh }
337ec9796c4SHansol Suh return;
338ec9796c4SHansol Suh }
339ec9796c4SHansol Suh
RosenbrockObjective_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])340ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341ec9796c4SHansol Suh {
342ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343ec9796c4SHansol Suh }
344ec9796c4SHansol Suh
RosenbrockGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])345ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346ec9796c4SHansol Suh {
347ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348ec9796c4SHansol Suh }
349ec9796c4SHansol Suh
RosenbrockObjectiveGradient_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])350ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351ec9796c4SHansol Suh {
352ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353ec9796c4SHansol Suh }
354ec9796c4SHansol Suh
RosenbrockHessian_Kernel(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])355ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356ec9796c4SHansol Suh {
357ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358ec9796c4SHansol Suh }
359ec9796c4SHansol Suh
RosenbrockObjective_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[])360ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361ec9796c4SHansol Suh {
362ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start;
363ec9796c4SHansol Suh
364ec9796c4SHansol Suh PetscFunctionBegin;
365ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
368ec9796c4SHansol Suh }
369ec9796c4SHansol Suh
RosenbrockGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])370ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371ec9796c4SHansol Suh {
372ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start;
373ec9796c4SHansol Suh
374ec9796c4SHansol Suh PetscFunctionBegin;
375ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
378ec9796c4SHansol Suh }
379ec9796c4SHansol Suh
RosenbrockObjectiveGradient_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar f_vec[],PetscScalar g[])380ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381ec9796c4SHansol Suh {
382ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start;
383ec9796c4SHansol Suh
384ec9796c4SHansol Suh PetscFunctionBegin;
385ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
388ec9796c4SHansol Suh }
389ec9796c4SHansol Suh
RosenbrockHessian_Device(cupmStream_t stream,Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])390ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391ec9796c4SHansol Suh {
392ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start;
393ec9796c4SHansol Suh
394ec9796c4SHansol Suh PetscFunctionBegin;
395ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
398ec9796c4SHansol Suh }
399ec9796c4SHansol Suh #endif
400ec9796c4SHansol Suh
RosenbrockObjective_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f)401ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402ec9796c4SHansol Suh {
403ec9796c4SHansol Suh PetscReal _f = 0.0;
404ec9796c4SHansol Suh
405ec9796c4SHansol Suh PetscFunctionBegin;
406ec9796c4SHansol Suh for (PetscInt c = r.c_start; c < r.c_end; c++) {
407ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start];
409ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
410ec9796c4SHansol Suh
411ec9796c4SHansol Suh _f += RosenbrockObjective(r.alpha, x_a, x_b);
412ec9796c4SHansol Suh }
413ec9796c4SHansol Suh *f = _f;
414ec9796c4SHansol Suh PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
416ec9796c4SHansol Suh }
417ec9796c4SHansol Suh
RosenbrockGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar g[])418ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419ec9796c4SHansol Suh {
420ec9796c4SHansol Suh PetscFunctionBegin;
421ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start];
424ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
425ec9796c4SHansol Suh
426ec9796c4SHansol Suh RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427ec9796c4SHansol Suh }
428ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
430ec9796c4SHansol Suh }
431ec9796c4SHansol Suh
RosenbrockObjectiveGradient_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscReal * f,PetscScalar g[])432ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433ec9796c4SHansol Suh {
434ec9796c4SHansol Suh PetscReal _f = 0.0;
435ec9796c4SHansol Suh
436ec9796c4SHansol Suh PetscFunctionBegin;
437ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start];
440ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
441ec9796c4SHansol Suh
442ec9796c4SHansol Suh _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443ec9796c4SHansol Suh }
444ec9796c4SHansol Suh *f = _f;
445ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
447ec9796c4SHansol Suh }
448ec9796c4SHansol Suh
RosenbrockHessian_Host(Rosenbrock r,const PetscScalar x[],const PetscScalar o[],PetscScalar h[])449ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450ec9796c4SHansol Suh {
451ec9796c4SHansol Suh PetscFunctionBegin;
452ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start];
455ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
456ec9796c4SHansol Suh
457ec9796c4SHansol Suh RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458ec9796c4SHansol Suh }
459ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
461ec9796c4SHansol Suh }
462ec9796c4SHansol Suh
463ec9796c4SHansol Suh /* -------------------------------------------------------------------- */
464ec9796c4SHansol Suh
FormObjective(Tao tao,Vec X,PetscReal * f,void * ptr)465ec9796c4SHansol Suh static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466ec9796c4SHansol Suh {
467ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr;
468ec9796c4SHansol Suh PetscReal f_local = 0.0;
469ec9796c4SHansol Suh const PetscScalar *x;
470ec9796c4SHansol Suh const PetscScalar *o = NULL;
471ec9796c4SHansol Suh PetscMemType memtype_x;
472ec9796c4SHansol Suh
473ec9796c4SHansol Suh PetscFunctionBeginUser;
474ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) {
480ec9796c4SHansol Suh PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481*6a210b70SBarry Smith PetscCallMPI(MPIU_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
483ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484ec9796c4SHansol Suh PetscScalar *_fvec;
485ec9796c4SHansol Suh PetscScalar f_scalar;
486ec9796c4SHansol Suh cupmStream_t *stream;
487ec9796c4SHansol Suh PetscDeviceContext dctx;
488ec9796c4SHansol Suh
489ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492ec9796c4SHansol Suh PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494ec9796c4SHansol Suh PetscCall(VecSum(user->fvector, &f_scalar));
495ec9796c4SHansol Suh *f = PetscRealPart(f_scalar);
496ec9796c4SHansol Suh #endif
497d8b4a066SPierre Jolivet } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
498ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
502ec9796c4SHansol Suh }
503ec9796c4SHansol Suh
FormGradient(Tao tao,Vec X,Vec G,void * ptr)504ec9796c4SHansol Suh static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505ec9796c4SHansol Suh {
506ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr;
507ec9796c4SHansol Suh PetscScalar *g;
508ec9796c4SHansol Suh const PetscScalar *x;
509ec9796c4SHansol Suh const PetscScalar *o = NULL;
510ec9796c4SHansol Suh PetscMemType memtype_x, memtype_g;
511ec9796c4SHansol Suh
512ec9796c4SHansol Suh PetscFunctionBeginUser;
513ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) {
521ec9796c4SHansol Suh PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
523ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524ec9796c4SHansol Suh cupmStream_t *stream;
525ec9796c4SHansol Suh PetscDeviceContext dctx;
526ec9796c4SHansol Suh
527ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529ec9796c4SHansol Suh PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530ec9796c4SHansol Suh #endif
531d8b4a066SPierre Jolivet } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
532ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535ec9796c4SHansol Suh PetscCall(VecZeroEntries(G));
536ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
540ec9796c4SHansol Suh }
541ec9796c4SHansol Suh
542ec9796c4SHansol Suh /*
543ec9796c4SHansol Suh FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).
544ec9796c4SHansol Suh
545ec9796c4SHansol Suh Input Parameters:
546ec9796c4SHansol Suh . tao - the Tao context
547ec9796c4SHansol Suh . X - input vector
548ec9796c4SHansol Suh . ptr - optional user-defined context, as set by TaoSetObjectiveGradient()
549ec9796c4SHansol Suh
550ec9796c4SHansol Suh Output Parameters:
551ec9796c4SHansol Suh . G - vector containing the newly evaluated gradient
552ec9796c4SHansol Suh . f - function value
553ec9796c4SHansol Suh
554ec9796c4SHansol Suh Note:
555ec9796c4SHansol Suh Some optimization methods ask for the function and the gradient evaluation
556ec9796c4SHansol Suh at the same time. Evaluating both at once may be more efficient that
557ec9796c4SHansol Suh evaluating each separately.
558ec9796c4SHansol Suh */
FormObjectiveGradient(Tao tao,Vec X,PetscReal * f,Vec G,void * ptr)559ec9796c4SHansol Suh static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560ec9796c4SHansol Suh {
561ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr;
562ec9796c4SHansol Suh PetscReal f_local = 0.0;
563ec9796c4SHansol Suh PetscScalar *g;
564ec9796c4SHansol Suh const PetscScalar *x;
565ec9796c4SHansol Suh const PetscScalar *o = NULL;
566ec9796c4SHansol Suh PetscMemType memtype_x, memtype_g;
567ec9796c4SHansol Suh
568ec9796c4SHansol Suh PetscFunctionBeginUser;
569ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) {
577ec9796c4SHansol Suh PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578*6a210b70SBarry Smith PetscCallMPI(MPIU_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
580ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581ec9796c4SHansol Suh PetscScalar *_fvec;
582ec9796c4SHansol Suh PetscScalar f_scalar;
583ec9796c4SHansol Suh cupmStream_t *stream;
584ec9796c4SHansol Suh PetscDeviceContext dctx;
585ec9796c4SHansol Suh
586ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589ec9796c4SHansol Suh PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591ec9796c4SHansol Suh PetscCall(VecSum(user->fvector, &f_scalar));
592ec9796c4SHansol Suh *f = PetscRealPart(f_scalar);
593ec9796c4SHansol Suh #endif
594d8b4a066SPierre Jolivet } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
595ec9796c4SHansol Suh
596ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599ec9796c4SHansol Suh PetscCall(VecZeroEntries(G));
600ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
604ec9796c4SHansol Suh }
605ec9796c4SHansol Suh
606ec9796c4SHansol Suh /* ------------------------------------------------------------------- */
607ec9796c4SHansol Suh /*
608ec9796c4SHansol Suh FormHessian - Evaluates Hessian matrix.
609ec9796c4SHansol Suh
610ec9796c4SHansol Suh Input Parameters:
611ec9796c4SHansol Suh . tao - the Tao context
612ec9796c4SHansol Suh . x - input vector
613ec9796c4SHansol Suh . ptr - optional user-defined context, as set by TaoSetHessian()
614ec9796c4SHansol Suh
615ec9796c4SHansol Suh Output Parameters:
616ec9796c4SHansol Suh . H - Hessian matrix
617ec9796c4SHansol Suh
618ec9796c4SHansol Suh Note: Providing the Hessian may not be necessary. Only some solvers
619ec9796c4SHansol Suh require this matrix.
620ec9796c4SHansol Suh */
FormHessian(Tao tao,Vec X,Mat H,Mat Hpre,void * ptr)621ec9796c4SHansol Suh static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622ec9796c4SHansol Suh {
623ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr;
624ec9796c4SHansol Suh PetscScalar *h;
625ec9796c4SHansol Suh const PetscScalar *x;
626ec9796c4SHansol Suh const PetscScalar *o = NULL;
627ec9796c4SHansol Suh PetscMemType memtype_x, memtype_h;
628ec9796c4SHansol Suh
629ec9796c4SHansol Suh PetscFunctionBeginUser;
630ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) {
637ec9796c4SHansol Suh PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
639ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640ec9796c4SHansol Suh cupmStream_t *stream;
641ec9796c4SHansol Suh PetscDeviceContext dctx;
642ec9796c4SHansol Suh
643ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645ec9796c4SHansol Suh PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646ec9796c4SHansol Suh #endif
647d8b4a066SPierre Jolivet } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
648ec9796c4SHansol Suh
649ec9796c4SHansol Suh PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));
651ec9796c4SHansol Suh
652ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
654ec9796c4SHansol Suh
655ec9796c4SHansol Suh if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
657ec9796c4SHansol Suh }
658ec9796c4SHansol Suh
TestLMVM(Tao tao)659ec9796c4SHansol Suh static PetscErrorCode TestLMVM(Tao tao)
660ec9796c4SHansol Suh {
661ec9796c4SHansol Suh KSP ksp;
662ec9796c4SHansol Suh PC pc;
663ec9796c4SHansol Suh PetscBool is_lmvm;
664ec9796c4SHansol Suh
665ec9796c4SHansol Suh PetscFunctionBegin;
666ec9796c4SHansol Suh PetscCall(TaoGetKSP(tao, &ksp));
667ec9796c4SHansol Suh if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668ec9796c4SHansol Suh PetscCall(KSPGetPC(ksp, &pc));
669ec9796c4SHansol Suh PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670ec9796c4SHansol Suh if (is_lmvm) {
671ec9796c4SHansol Suh Mat M;
672ec9796c4SHansol Suh Vec in, out, out2;
673ec9796c4SHansol Suh PetscReal mult_solve_dist;
674ec9796c4SHansol Suh Vec x;
675ec9796c4SHansol Suh
676ec9796c4SHansol Suh PetscCall(PCLMVMGetMatLMVM(pc, &M));
677ec9796c4SHansol Suh PetscCall(TaoGetSolution(tao, &x));
678ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &in));
679ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &out));
680ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &out2));
681ec9796c4SHansol Suh PetscCall(VecSetRandom(in, NULL));
682ec9796c4SHansol Suh PetscCall(MatMult(M, in, out));
683ec9796c4SHansol Suh PetscCall(MatSolve(M, out, out2));
684ec9796c4SHansol Suh
685ec9796c4SHansol Suh PetscCall(VecAXPY(out2, -1.0, in));
686ec9796c4SHansol Suh PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687ec9796c4SHansol Suh if (mult_solve_dist < 1.e-11) {
688ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689ec9796c4SHansol Suh } else if (mult_solve_dist < 1.e-6) {
690ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691ec9796c4SHansol Suh } else {
692ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693ec9796c4SHansol Suh }
694ec9796c4SHansol Suh PetscCall(VecDestroy(&in));
695ec9796c4SHansol Suh PetscCall(VecDestroy(&out));
696ec9796c4SHansol Suh PetscCall(VecDestroy(&out2));
697ec9796c4SHansol Suh }
698ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
699ec9796c4SHansol Suh }
700ec9796c4SHansol Suh
RosenbrockMain(void)701ec9796c4SHansol Suh static PetscErrorCode RosenbrockMain(void)
702ec9796c4SHansol Suh {
703ec9796c4SHansol Suh Vec x; /* solution vector */
704ec9796c4SHansol Suh Vec g; /* gradient vector */
705ec9796c4SHansol Suh Mat H; /* Hessian matrix */
706ec9796c4SHansol Suh Tao tao; /* Tao solver context */
707ec9796c4SHansol Suh AppCtx user; /* user-defined application context */
708ec9796c4SHansol Suh PetscLogStage solve;
709ec9796c4SHansol Suh
710ec9796c4SHansol Suh /* Initialize TAO and PETSc */
711ec9796c4SHansol Suh PetscFunctionBegin;
712ec9796c4SHansol Suh PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));
713ec9796c4SHansol Suh
714ec9796c4SHansol Suh PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715ec9796c4SHansol Suh PetscCall(CreateHessian(user, &H));
716ec9796c4SHansol Suh PetscCall(CreateVectors(user, H, &x, &g));
717ec9796c4SHansol Suh
718ec9796c4SHansol Suh /* The TAO code begins here */
719ec9796c4SHansol Suh
720ec9796c4SHansol Suh PetscCall(TaoCreate(user->comm, &tao));
721ec9796c4SHansol Suh PetscCall(VecZeroEntries(x));
722ec9796c4SHansol Suh PetscCall(TaoSetSolution(tao, x));
723ec9796c4SHansol Suh
724ec9796c4SHansol Suh /* Set routines for function, gradient, hessian evaluation */
725ec9796c4SHansol Suh PetscCall(TaoSetObjective(tao, FormObjective, user));
726ec9796c4SHansol Suh PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727ec9796c4SHansol Suh PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728ec9796c4SHansol Suh PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));
729ec9796c4SHansol Suh
730ec9796c4SHansol Suh PetscCall(TaoSetFromOptions(tao));
731ec9796c4SHansol Suh
732ec9796c4SHansol Suh /* SOLVE THE APPLICATION */
733ec9796c4SHansol Suh PetscCall(PetscLogStagePush(solve));
734ec9796c4SHansol Suh PetscCall(TaoSolve(tao));
735ec9796c4SHansol Suh PetscCall(PetscLogStagePop());
736ec9796c4SHansol Suh
737ec9796c4SHansol Suh if (user->test_lmvm) PetscCall(TestLMVM(tao));
738ec9796c4SHansol Suh
739ec9796c4SHansol Suh PetscCall(TaoDestroy(&tao));
740ec9796c4SHansol Suh PetscCall(VecDestroy(&g));
741ec9796c4SHansol Suh PetscCall(VecDestroy(&x));
742ec9796c4SHansol Suh PetscCall(MatDestroy(&H));
743ec9796c4SHansol Suh PetscCall(AppCtxDestroy(&user));
744ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS);
745ec9796c4SHansol Suh }
746