1 #pragma once 2 3 #include <petsctao.h> 4 #include <petscsf.h> 5 #include <petscdevice.h> 6 #include <petscdevice_cupm.h> 7 8 /* 9 User-defined application context - contains data needed by the 10 application-provided call-back routines that evaluate the function, 11 gradient, and hessian. 12 */ 13 14 typedef struct _Rosenbrock { 15 PetscInt bs; // each block of bs variables is one chained multidimensional rosenbrock problem 16 PetscInt i_start, i_end; 17 PetscInt c_start, c_end; 18 PetscReal alpha; // condition parameter 19 } Rosenbrock; 20 21 typedef struct _AppCtx *AppCtx; 22 struct _AppCtx { 23 MPI_Comm comm; 24 PetscInt n; /* dimension */ 25 PetscInt n_local; 26 PetscInt n_local_comp; 27 Rosenbrock problem; 28 Vec Hvalues; /* vector for writing COO values of this MPI process */ 29 Vec gvalues; /* vector for writing gradient values of this mpi process */ 30 Vec fvector; 31 PetscSF off_process_scatter; 32 PetscSF gscatter; 33 Vec off_process_values; /* buffer for off-process values if chained */ 34 PetscBool test_lmvm; 35 PetscLogEvent event_f, event_g, event_fg; 36 }; 37 38 /* -------------- User-defined routines ---------- */ 39 40 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2) 41 { 42 PetscScalar d = x_2 - x_1 * x_1; 43 PetscScalar e = 1.0 - x_1; 44 return alpha * d * d + e * e; 45 } 46 47 static const PetscLogDouble RosenbrockObjectiveFlops = 7.0; 48 49 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2]) 50 { 51 PetscScalar d = x_2 - x_1 * x_1; 52 PetscScalar e = 1.0 - x_1; 53 PetscScalar g2 = alpha * d * 2.0; 54 55 g[0] = -2.0 * x_1 * g2 - 2.0 * e; 56 g[1] = g2; 57 } 58 59 static const PetscInt RosenbrockGradientFlops = 9.0; 60 61 static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2]) 62 { 63 PetscScalar d = x_2 - x_1 * x_1; 64 PetscScalar e = 1.0 - x_1; 65 PetscScalar ad = alpha * d; 66 PetscScalar g2 = ad * 2.0; 67 68 g[0] = -2.0 * x_1 * g2 - 2.0 * e; 69 g[1] = g2; 70 return ad * d + e * e; 71 } 72 73 static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0; 74 75 static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4]) 76 { 77 PetscScalar d = x_2 - x_1 * x_1; 78 PetscScalar g2 = alpha * d * 2.0; 79 PetscScalar h2 = -4.0 * alpha * x_1; 80 81 h[0] = -2.0 * (g2 + x_1 * h2) + 2.0; 82 h[1] = h[2] = h2; 83 h[3] = 2.0 * alpha; 84 } 85 86 static const PetscLogDouble RosenbrockHessianFlops = 11.0; 87 88 static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx) 89 { 90 AppCtx user; 91 PetscDeviceContext dctx; 92 93 PetscFunctionBegin; 94 PetscCall(PetscNew(ctx)); 95 user = *ctx; 96 user->comm = PETSC_COMM_WORLD; 97 98 /* Initialize problem parameters */ 99 user->n = 2; 100 user->problem.alpha = 99.0; 101 user->problem.bs = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock 102 user->test_lmvm = PETSC_FALSE; 103 /* Check for command line arguments to override defaults */ 104 PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL); 105 PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL)); 106 PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL)); 107 PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL)); 108 PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL)); 109 PetscOptionsEnd(); 110 PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs); 111 PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n); 112 PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f)); 113 PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g)); 114 PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg)); 115 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 116 PetscCall(PetscDeviceContextSetUp(dctx)); 117 PetscFunctionReturn(PETSC_SUCCESS); 118 } 119 120 static PetscErrorCode AppCtxDestroy(AppCtx *ctx) 121 { 122 AppCtx user; 123 124 PetscFunctionBegin; 125 user = *ctx; 126 *ctx = NULL; 127 PetscCall(VecDestroy(&user->Hvalues)); 128 PetscCall(VecDestroy(&user->gvalues)); 129 PetscCall(VecDestroy(&user->fvector)); 130 PetscCall(VecDestroy(&user->off_process_values)); 131 PetscCall(PetscSFDestroy(&user->off_process_scatter)); 132 PetscCall(PetscSFDestroy(&user->gscatter)); 133 PetscCall(PetscFree(user)); 134 PetscFunctionReturn(PETSC_SUCCESS); 135 } 136 137 static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian) 138 { 139 Mat H; 140 PetscLayout layout; 141 PetscInt i_start, i_end, n_local_comp, nnz_local; 142 PetscInt c_start, c_end; 143 PetscInt *coo_i; 144 PetscInt *coo_j; 145 PetscInt bs = user->problem.bs; 146 VecType vec_type; 147 148 PetscFunctionBegin; 149 /* Partition the optimization variables and the computations. 150 There are (bs - 1) contributions to the objective function for every (bs) 151 degrees of freedom. */ 152 PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout)); 153 PetscCall(PetscLayoutSetUp(layout)); 154 PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end)); 155 user->problem.i_start = i_start; 156 user->problem.i_end = i_end; 157 user->n_local = i_end - i_start; 158 user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs); 159 user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs); 160 user->n_local_comp = n_local_comp = c_end - c_start; 161 162 PetscCall(MatCreate(user->comm, Hessian)); 163 H = *Hessian; 164 PetscCall(MatSetLayouts(H, layout, layout)); 165 PetscCall(PetscLayoutDestroy(&layout)); 166 PetscCall(MatSetType(H, MATAIJ)); 167 PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE)); 168 PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE)); 169 PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE)); 170 PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE)); 171 PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE)); 172 PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */ 173 174 nnz_local = n_local_comp * 4; 175 PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j)); 176 /* Instead of having one computation thread per row of the matrix, 177 this example uses one thread per contribution to the objective 178 function. Each contribution to the objective function relates 179 two adjacent degrees of freedom, so each contribution to 180 the objective function adds a 2x2 block into the matrix. 181 We describe these 2x2 blocks in COO format. */ 182 for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) { 183 PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1); 184 185 coo_i[k + 0] = i; 186 coo_i[k + 1] = i; 187 coo_i[k + 2] = i + 1; 188 coo_i[k + 3] = i + 1; 189 190 coo_j[k + 0] = i; 191 coo_j[k + 1] = i + 1; 192 coo_j[k + 2] = i; 193 coo_j[k + 3] = i + 1; 194 } 195 PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j)); 196 PetscCall(PetscFree2(coo_i, coo_j)); 197 198 PetscCall(MatGetVecType(H, &vec_type)); 199 PetscCall(VecCreate(user->comm, &user->Hvalues)); 200 PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE)); 201 PetscCall(VecSetType(user->Hvalues, vec_type)); 202 203 // vector to collect contributions to the objective 204 PetscCall(VecCreate(user->comm, &user->fvector)); 205 PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE)); 206 PetscCall(VecSetType(user->fvector, vec_type)); 207 208 { /* If we are using a device (such as a GPU), run some computations that will 209 warm up its linear algebra runtime before the problem we actually want 210 to profile */ 211 212 PetscMemType memtype; 213 const PetscScalar *a; 214 215 PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype)); 216 PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a)); 217 218 if (memtype == PETSC_MEMTYPE_DEVICE) { 219 PetscLogStage warmup; 220 Mat A, AtA; 221 Vec x, b; 222 PetscInt warmup_size = 1000; 223 PetscDeviceContext dctx; 224 225 PetscCall(PetscLogStageRegister("Device Warmup", &warmup)); 226 PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE)); 227 228 PetscCall(PetscLogStagePush(warmup)); 229 PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A)); 230 PetscCall(MatSetRandom(A, NULL)); 231 PetscCall(MatCreateVecs(A, &x, &b)); 232 PetscCall(VecSetRandom(x, NULL)); 233 234 PetscCall(MatMult(A, x, b)); 235 PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &AtA)); 236 PetscCall(MatShift(AtA, (PetscScalar)warmup_size)); 237 PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE)); 238 PetscCall(MatCholeskyFactor(AtA, NULL, NULL)); 239 PetscCall(MatDestroy(&AtA)); 240 PetscCall(VecDestroy(&b)); 241 PetscCall(VecDestroy(&x)); 242 PetscCall(MatDestroy(&A)); 243 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 244 PetscCall(PetscDeviceContextSynchronize(dctx)); 245 PetscCall(PetscLogStagePop()); 246 } 247 } 248 PetscFunctionReturn(PETSC_SUCCESS); 249 } 250 251 static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient) 252 { 253 VecType vec_type; 254 PetscInt n_coo, *coo_i, i_start, i_end; 255 Vec x; 256 PetscInt n_recv; 257 PetscSFNode recv; 258 PetscLayout layout; 259 PetscInt c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs; 260 261 PetscFunctionBegin; 262 PetscCall(MatCreateVecs(H, solution, gradient)); 263 x = *solution; 264 PetscCall(VecGetOwnershipRange(x, &i_start, &i_end)); 265 PetscCall(VecGetType(x, &vec_type)); 266 // create scatter for communicating values 267 PetscCall(VecGetLayout(x, &layout)); 268 n_recv = 0; 269 if (user->n_local_comp && i_end < user->n) { 270 PetscMPIInt rank; 271 PetscInt index; 272 273 n_recv = 1; 274 PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index)); 275 recv.rank = rank; 276 recv.index = index; 277 } 278 PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter)); 279 PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES)); 280 PetscCall(VecCreate(user->comm, &user->off_process_values)); 281 PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE)); 282 PetscCall(VecSetType(user->off_process_values, vec_type)); 283 PetscCall(VecZeroEntries(user->off_process_values)); 284 285 // create COO data for writing the gradient 286 n_coo = user->n_local_comp * 2; 287 PetscCall(PetscMalloc1(n_coo, &coo_i)); 288 for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) { 289 PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1)); 290 291 coo_i[k + 0] = i; 292 coo_i[k + 1] = i + 1; 293 } 294 PetscCall(PetscSFCreate(user->comm, &user->gscatter)); 295 PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i)); 296 PetscCall(PetscSFSetUp(user->gscatter)); 297 PetscCall(PetscFree(coo_i)); 298 PetscCall(VecCreate(user->comm, &user->gvalues)); 299 PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE)); 300 PetscCall(VecSetType(user->gvalues, vec_type)); 301 PetscFunctionReturn(PETSC_SUCCESS); 302 } 303 304 #if PetscDefined(USING_CUPMCC) 305 306 #if PetscDefined(USING_NVCC) 307 typedef cudaStream_t cupmStream_t; 308 #define PetscCUPMLaunch(...) \ 309 do { \ 310 __VA_ARGS__; \ 311 PetscCallCUDA(cudaGetLastError()); \ 312 } while (0) 313 #elif PetscDefined(USING_HCC) 314 #define PetscCUPMLaunch(...) \ 315 do { \ 316 __VA_ARGS__; \ 317 PetscCallHIP(hipGetLastError()); \ 318 } while (0) 319 typedef hipStream_t cupmStream_t; 320 #endif 321 322 // x: on-process optimization variables 323 // o: buffer that contains the next optimization variable after the variables on this process 324 template <typename T> 325 PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept 326 { 327 PetscInt idx = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid 328 PetscInt num_threads = gridDim.x * blockDim.x; 329 330 for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) { 331 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 332 PetscScalar x_a = x[i - r.i_start]; 333 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 334 335 func(k, x_a, x_b); 336 } 337 return; 338 } 339 340 PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[]) 341 { 342 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); }); 343 } 344 345 PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 346 { 347 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); }); 348 } 349 350 PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[]) 351 { 352 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); }); 353 } 354 355 PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 356 { 357 rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); }); 358 } 359 360 static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[]) 361 { 362 PetscInt n_comp = r.c_end - r.c_start; 363 364 PetscFunctionBegin; 365 if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec)); 366 PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp)); 367 PetscFunctionReturn(PETSC_SUCCESS); 368 } 369 370 static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 371 { 372 PetscInt n_comp = r.c_end - r.c_start; 373 374 PetscFunctionBegin; 375 if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g)); 376 PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp)); 377 PetscFunctionReturn(PETSC_SUCCESS); 378 } 379 380 static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[]) 381 { 382 PetscInt n_comp = r.c_end - r.c_start; 383 384 PetscFunctionBegin; 385 if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g)); 386 PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp)); 387 PetscFunctionReturn(PETSC_SUCCESS); 388 } 389 390 static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 391 { 392 PetscInt n_comp = r.c_end - r.c_start; 393 394 PetscFunctionBegin; 395 if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h)); 396 PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp)); 397 PetscFunctionReturn(PETSC_SUCCESS); 398 } 399 #endif 400 401 static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f) 402 { 403 PetscReal _f = 0.0; 404 405 PetscFunctionBegin; 406 for (PetscInt c = r.c_start; c < r.c_end; c++) { 407 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 408 PetscScalar x_a = x[i - r.i_start]; 409 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 410 411 _f += RosenbrockObjective(r.alpha, x_a, x_b); 412 } 413 *f = _f; 414 PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start))); 415 PetscFunctionReturn(PETSC_SUCCESS); 416 } 417 418 static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 419 { 420 PetscFunctionBegin; 421 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 422 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 423 PetscScalar x_a = x[i - r.i_start]; 424 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 425 426 RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); 427 } 428 PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start))); 429 PetscFunctionReturn(PETSC_SUCCESS); 430 } 431 432 static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[]) 433 { 434 PetscReal _f = 0.0; 435 436 PetscFunctionBegin; 437 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 438 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 439 PetscScalar x_a = x[i - r.i_start]; 440 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 441 442 _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); 443 } 444 *f = _f; 445 PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start))); 446 PetscFunctionReturn(PETSC_SUCCESS); 447 } 448 449 static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 450 { 451 PetscFunctionBegin; 452 for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 453 PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 454 PetscScalar x_a = x[i - r.i_start]; 455 PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 456 457 RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); 458 } 459 PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start))); 460 PetscFunctionReturn(PETSC_SUCCESS); 461 } 462 463 /* -------------------------------------------------------------------- */ 464 465 static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr) 466 { 467 AppCtx user = (AppCtx)ptr; 468 PetscReal f_local = 0.0; 469 const PetscScalar *x; 470 const PetscScalar *o = NULL; 471 PetscMemType memtype_x; 472 473 PetscFunctionBeginUser; 474 PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL)); 475 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 476 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 477 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 478 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 479 if (memtype_x == PETSC_MEMTYPE_HOST) { 480 PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local)); 481 PetscCallMPI(MPI_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm)); 482 #if PetscDefined(USING_CUPMCC) 483 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 484 PetscScalar *_fvec; 485 PetscScalar f_scalar; 486 cupmStream_t *stream; 487 PetscDeviceContext dctx; 488 489 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 490 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 491 PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL)); 492 PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec)); 493 PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec)); 494 PetscCall(VecSum(user->fvector, &f_scalar)); 495 *f = PetscRealPart(f_scalar); 496 #endif 497 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x); 498 PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 499 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 500 PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL)); 501 PetscFunctionReturn(PETSC_SUCCESS); 502 } 503 504 static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr) 505 { 506 AppCtx user = (AppCtx)ptr; 507 PetscScalar *g; 508 const PetscScalar *x; 509 const PetscScalar *o = NULL; 510 PetscMemType memtype_x, memtype_g; 511 512 PetscFunctionBeginUser; 513 PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL)); 514 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 515 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 516 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 517 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 518 PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g)); 519 PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype"); 520 if (memtype_x == PETSC_MEMTYPE_HOST) { 521 PetscCall(RosenbrockGradient_Host(user->problem, x, o, g)); 522 #if PetscDefined(USING_CUPMCC) 523 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 524 cupmStream_t *stream; 525 PetscDeviceContext dctx; 526 527 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 528 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 529 PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g)); 530 #endif 531 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x); 532 PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g)); 533 PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 534 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 535 PetscCall(VecZeroEntries(G)); 536 PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 537 PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 538 PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL)); 539 PetscFunctionReturn(PETSC_SUCCESS); 540 } 541 542 /* 543 FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X). 544 545 Input Parameters: 546 . tao - the Tao context 547 . X - input vector 548 . ptr - optional user-defined context, as set by TaoSetObjectiveGradient() 549 550 Output Parameters: 551 . G - vector containing the newly evaluated gradient 552 . f - function value 553 554 Note: 555 Some optimization methods ask for the function and the gradient evaluation 556 at the same time. Evaluating both at once may be more efficient that 557 evaluating each separately. 558 */ 559 static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr) 560 { 561 AppCtx user = (AppCtx)ptr; 562 PetscReal f_local = 0.0; 563 PetscScalar *g; 564 const PetscScalar *x; 565 const PetscScalar *o = NULL; 566 PetscMemType memtype_x, memtype_g; 567 568 PetscFunctionBeginUser; 569 PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL)); 570 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 571 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 572 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 573 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 574 PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g)); 575 PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype"); 576 if (memtype_x == PETSC_MEMTYPE_HOST) { 577 PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g)); 578 PetscCallMPI(MPI_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD)); 579 #if PetscDefined(USING_CUPMCC) 580 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 581 PetscScalar *_fvec; 582 PetscScalar f_scalar; 583 cupmStream_t *stream; 584 PetscDeviceContext dctx; 585 586 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 587 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 588 PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL)); 589 PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g)); 590 PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec)); 591 PetscCall(VecSum(user->fvector, &f_scalar)); 592 *f = PetscRealPart(f_scalar); 593 #endif 594 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x); 595 596 PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g)); 597 PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 598 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 599 PetscCall(VecZeroEntries(G)); 600 PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 601 PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 602 PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL)); 603 PetscFunctionReturn(PETSC_SUCCESS); 604 } 605 606 /* ------------------------------------------------------------------- */ 607 /* 608 FormHessian - Evaluates Hessian matrix. 609 610 Input Parameters: 611 . tao - the Tao context 612 . x - input vector 613 . ptr - optional user-defined context, as set by TaoSetHessian() 614 615 Output Parameters: 616 . H - Hessian matrix 617 618 Note: Providing the Hessian may not be necessary. Only some solvers 619 require this matrix. 620 */ 621 static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr) 622 { 623 AppCtx user = (AppCtx)ptr; 624 PetscScalar *h; 625 const PetscScalar *x; 626 const PetscScalar *o = NULL; 627 PetscMemType memtype_x, memtype_h; 628 629 PetscFunctionBeginUser; 630 PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 631 PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 632 PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 633 PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 634 PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h)); 635 PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype"); 636 if (memtype_x == PETSC_MEMTYPE_HOST) { 637 PetscCall(RosenbrockHessian_Host(user->problem, x, o, h)); 638 #if PetscDefined(USING_CUPMCC) 639 } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 640 cupmStream_t *stream; 641 PetscDeviceContext dctx; 642 643 PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 644 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 645 PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h)); 646 #endif 647 } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x); 648 649 PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES)); 650 PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h)); 651 652 PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 653 PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 654 655 if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN)); 656 PetscFunctionReturn(PETSC_SUCCESS); 657 } 658 659 static PetscErrorCode TestLMVM(Tao tao) 660 { 661 KSP ksp; 662 PC pc; 663 PetscBool is_lmvm; 664 665 PetscFunctionBegin; 666 PetscCall(TaoGetKSP(tao, &ksp)); 667 if (!ksp) PetscFunctionReturn(PETSC_SUCCESS); 668 PetscCall(KSPGetPC(ksp, &pc)); 669 PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm)); 670 if (is_lmvm) { 671 Mat M; 672 Vec in, out, out2; 673 PetscReal mult_solve_dist; 674 Vec x; 675 676 PetscCall(PCLMVMGetMatLMVM(pc, &M)); 677 PetscCall(TaoGetSolution(tao, &x)); 678 PetscCall(VecDuplicate(x, &in)); 679 PetscCall(VecDuplicate(x, &out)); 680 PetscCall(VecDuplicate(x, &out2)); 681 PetscCall(VecSetRandom(in, NULL)); 682 PetscCall(MatMult(M, in, out)); 683 PetscCall(MatSolve(M, out, out2)); 684 685 PetscCall(VecAXPY(out2, -1.0, in)); 686 PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist)); 687 if (mult_solve_dist < 1.e-11) { 688 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n")); 689 } else if (mult_solve_dist < 1.e-6) { 690 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n")); 691 } else { 692 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist)); 693 } 694 PetscCall(VecDestroy(&in)); 695 PetscCall(VecDestroy(&out)); 696 PetscCall(VecDestroy(&out2)); 697 } 698 PetscFunctionReturn(PETSC_SUCCESS); 699 } 700 701 static PetscErrorCode RosenbrockMain(void) 702 { 703 Vec x; /* solution vector */ 704 Vec g; /* gradient vector */ 705 Mat H; /* Hessian matrix */ 706 Tao tao; /* Tao solver context */ 707 AppCtx user; /* user-defined application context */ 708 PetscLogStage solve; 709 710 /* Initialize TAO and PETSc */ 711 PetscFunctionBegin; 712 PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve)); 713 714 PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user)); 715 PetscCall(CreateHessian(user, &H)); 716 PetscCall(CreateVectors(user, H, &x, &g)); 717 718 /* The TAO code begins here */ 719 720 PetscCall(TaoCreate(user->comm, &tao)); 721 PetscCall(VecZeroEntries(x)); 722 PetscCall(TaoSetSolution(tao, x)); 723 724 /* Set routines for function, gradient, hessian evaluation */ 725 PetscCall(TaoSetObjective(tao, FormObjective, user)); 726 PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user)); 727 PetscCall(TaoSetGradient(tao, g, FormGradient, user)); 728 PetscCall(TaoSetHessian(tao, H, H, FormHessian, user)); 729 730 PetscCall(TaoSetFromOptions(tao)); 731 732 /* SOLVE THE APPLICATION */ 733 PetscCall(PetscLogStagePush(solve)); 734 PetscCall(TaoSolve(tao)); 735 PetscCall(PetscLogStagePop()); 736 737 if (user->test_lmvm) PetscCall(TestLMVM(tao)); 738 739 PetscCall(TaoDestroy(&tao)); 740 PetscCall(VecDestroy(&g)); 741 PetscCall(VecDestroy(&x)); 742 PetscCall(MatDestroy(&H)); 743 PetscCall(AppCtxDestroy(&user)); 744 PetscFunctionReturn(PETSC_SUCCESS); 745 } 746