1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\ 2 equation with time dependent parameters using two approaches : \n\ 3 track : track only local sensitivities at each adjoint step \n\ 4 and accumulate them in a global array \n\ 5 global : track parameters at all timesteps together \n\ 6 Choose one of the two at runtime by -sa_method {track,global}. \n"; 7 8 /* 9 Simple example to demonstrate TSAdjoint capabilities for time dependent params 10 without integral cost terms using either a tracking or global method. 11 12 Modify the Van Der Pol Eq to : 13 [u1'] = [mu1(t)*u1] 14 [u2'] = [mu2(t)*((1-u1^2)*u2-u1)] 15 (with initial conditions & params independent) 16 17 Define uref to be solution with initial conditions (2,-2/3), mu=(1,1e3) 18 - u_ref : (1.5967,-1.02969) 19 20 Define const function as cost = 2-norm(u - u_ref); 21 22 Initialization for the adjoint TS : 23 - dcost/dy|final_time = 2*(u-u_ref)|final_time 24 - dcost/dp|final_time = 0 25 26 The tracking method only tracks local sensitivity at each time step 27 and accumulates these sensitivities in a global array. Since the structure 28 of the equations being solved at each time step does not change, the jacobian 29 wrt parameters is defined analogous to constant RHSJacobian for a liner 30 TSSolve and the size of the jacP is independent of the number of time 31 steps. Enable this mode of adjoint analysis by -sa_method track. 32 33 The global method combines the parameters at all timesteps and tracks them 34 together. Thus, the columns of the jacP matrix are filled dependent upon the 35 time step. Also, the dimensions of the jacP matrix now depend upon the number 36 of time steps. Enable this mode of adjoint analysis by -sa_method global. 37 38 Since the equations here have parameters at predefined time steps, this 39 example should be run with non adaptive time stepping solvers only. This 40 can be ensured by -ts_adapt_type none (which is the default behavior only 41 for certain TS solvers like TSCN. If using an explicit TS like TSRK, 42 please be sure to add the aforementioned option to disable adaptive 43 timestepping.) 44 */ 45 46 /* 47 Include "petscts.h" so that we can use TS solvers. Note that this file 48 automatically includes: 49 petscsys.h - base PETSc routines petscvec.h - vectors 50 petscmat.h - matrices 51 petscis.h - index sets petscksp.h - Krylov subspace methods 52 petscviewer.h - viewers petscpc.h - preconditioners 53 petscksp.h - linear solvers petscsnes.h - nonlinear solvers 54 */ 55 #include <petscts.h> 56 57 extern PetscErrorCode RHSFunction(TS, PetscReal, Vec, Vec, void *); 58 extern PetscErrorCode RHSJacobian(TS, PetscReal, Vec, Mat, Mat, void *); 59 extern PetscErrorCode RHSJacobianP_track(TS, PetscReal, Vec, Mat, void *); 60 extern PetscErrorCode RHSJacobianP_global(TS, PetscReal, Vec, Mat, void *); 61 extern PetscErrorCode Monitor(TS, PetscInt, PetscReal, Vec, void *); 62 extern PetscErrorCode AdjointMonitor(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *); 63 64 /* 65 User-defined application context - contains data needed by the 66 application-provided call-back routines. 67 */ 68 69 typedef struct { 70 /*------------- Forward solve data structures --------------*/ 71 PetscInt max_steps; /* number of steps to run ts for */ 72 PetscReal final_time; /* final time to integrate to*/ 73 PetscReal time_step; /* ts integration time step */ 74 Vec mu1; /* time dependent params */ 75 Vec mu2; /* time dependent params */ 76 Vec U; /* solution vector */ 77 Mat A; /* Jacobian matrix */ 78 79 /*------------- Adjoint solve data structures --------------*/ 80 Mat Jacp; /* JacobianP matrix */ 81 Vec lambda; /* adjoint variable */ 82 Vec mup; /* adjoint variable */ 83 84 /*------------- Global accumation vecs for monitor based tracking --------------*/ 85 Vec sens_mu1; /* global sensitivity accumulation */ 86 Vec sens_mu2; /* global sensitivity accumulation */ 87 PetscInt adj_idx; /* to keep track of adjoint solve index */ 88 } AppCtx; 89 90 typedef enum { 91 SA_TRACK, 92 SA_GLOBAL 93 } SAMethod; 94 static const char *const SAMethods[] = {"TRACK", "GLOBAL", "SAMethod", "SA_", 0}; 95 96 /* ----------------------- Explicit form of the ODE -------------------- */ 97 98 PetscErrorCode RHSFunction(TS ts, PetscReal t, Vec U, Vec F, PetscCtx ctx) 99 { 100 AppCtx *user = (AppCtx *)ctx; 101 PetscScalar *f; 102 PetscInt curr_step; 103 const PetscScalar *u; 104 const PetscScalar *mu1; 105 const PetscScalar *mu2; 106 107 PetscFunctionBeginUser; 108 PetscCall(TSGetStepNumber(ts, &curr_step)); 109 PetscCall(VecGetArrayRead(U, &u)); 110 PetscCall(VecGetArrayRead(user->mu1, &mu1)); 111 PetscCall(VecGetArrayRead(user->mu2, &mu2)); 112 PetscCall(VecGetArray(F, &f)); 113 f[0] = mu1[curr_step] * u[1]; 114 f[1] = mu2[curr_step] * ((1. - u[0] * u[0]) * u[1] - u[0]); 115 PetscCall(VecRestoreArrayRead(U, &u)); 116 PetscCall(VecRestoreArrayRead(user->mu1, &mu1)); 117 PetscCall(VecRestoreArrayRead(user->mu2, &mu2)); 118 PetscCall(VecRestoreArray(F, &f)); 119 PetscFunctionReturn(PETSC_SUCCESS); 120 } 121 122 PetscErrorCode RHSJacobian(TS ts, PetscReal t, Vec U, Mat A, Mat B, PetscCtx ctx) 123 { 124 AppCtx *user = (AppCtx *)ctx; 125 PetscInt rowcol[] = {0, 1}; 126 PetscScalar J[2][2]; 127 PetscInt curr_step; 128 const PetscScalar *u; 129 const PetscScalar *mu1; 130 const PetscScalar *mu2; 131 132 PetscFunctionBeginUser; 133 PetscCall(TSGetStepNumber(ts, &curr_step)); 134 PetscCall(VecGetArrayRead(user->mu1, &mu1)); 135 PetscCall(VecGetArrayRead(user->mu2, &mu2)); 136 PetscCall(VecGetArrayRead(U, &u)); 137 J[0][0] = 0; 138 J[1][0] = -mu2[curr_step] * (2.0 * u[1] * u[0] + 1.); 139 J[0][1] = mu1[curr_step]; 140 J[1][1] = mu2[curr_step] * (1.0 - u[0] * u[0]); 141 PetscCall(MatSetValues(A, 2, rowcol, 2, rowcol, &J[0][0], INSERT_VALUES)); 142 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 143 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 144 PetscCall(VecRestoreArrayRead(U, &u)); 145 PetscCall(VecRestoreArrayRead(user->mu1, &mu1)); 146 PetscCall(VecRestoreArrayRead(user->mu2, &mu2)); 147 PetscFunctionReturn(PETSC_SUCCESS); 148 } 149 150 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */ 151 152 PetscErrorCode RHSJacobianP_track(TS ts, PetscReal t, Vec U, Mat A, PetscCtx ctx) 153 { 154 PetscInt row[] = {0, 1}, col[] = {0, 1}; 155 PetscScalar J[2][2]; 156 const PetscScalar *u; 157 158 PetscFunctionBeginUser; 159 PetscCall(VecGetArrayRead(U, &u)); 160 J[0][0] = u[1]; 161 J[1][0] = 0; 162 J[0][1] = 0; 163 J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0]; 164 PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES)); 165 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 166 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 167 PetscCall(VecRestoreArrayRead(U, &u)); 168 PetscFunctionReturn(PETSC_SUCCESS); 169 } 170 171 /* ------------------ Jacobian wrt parameters for global method ------------------ */ 172 173 PetscErrorCode RHSJacobianP_global(TS ts, PetscReal t, Vec U, Mat A, PetscCtx ctx) 174 { 175 PetscInt row[] = {0, 1}, col[] = {0, 1}; 176 PetscScalar J[2][2]; 177 const PetscScalar *u; 178 PetscInt curr_step; 179 180 PetscFunctionBeginUser; 181 PetscCall(TSGetStepNumber(ts, &curr_step)); 182 PetscCall(VecGetArrayRead(U, &u)); 183 J[0][0] = u[1]; 184 J[1][0] = 0; 185 J[0][1] = 0; 186 J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0]; 187 col[0] = curr_step * 2; 188 col[1] = curr_step * 2 + 1; 189 PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES)); 190 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 191 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 192 PetscCall(VecRestoreArrayRead(U, &u)); 193 PetscFunctionReturn(PETSC_SUCCESS); 194 } 195 196 /* Dump solution to console if called */ 197 PetscErrorCode Monitor(TS ts, PetscInt step, PetscReal t, Vec U, PetscCtx ctx) 198 { 199 PetscFunctionBeginUser; 200 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution at time %e is \n", (double)t)); 201 PetscCall(VecView(U, PETSC_VIEWER_STDOUT_WORLD)); 202 PetscFunctionReturn(PETSC_SUCCESS); 203 } 204 205 /* Customized adjoint monitor to keep track of local 206 sensitivities by storing them in a global sensitivity array. 207 Note : This routine is only used for the tracking method. */ 208 PetscErrorCode AdjointMonitor(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, PetscCtx ctx) 209 { 210 AppCtx *user = (AppCtx *)ctx; 211 PetscInt curr_step; 212 PetscScalar *sensmu1_glob; 213 PetscScalar *sensmu2_glob; 214 const PetscScalar *sensmu_loc; 215 216 PetscFunctionBeginUser; 217 PetscCall(TSGetStepNumber(ts, &curr_step)); 218 /* Note that we skip the first call to the monitor in the adjoint 219 solve since the sensitivities are already set (during 220 initialization of adjoint vectors). 221 We also note that each indvidial TSAdjointSolve calls the monitor 222 twice, once at the step it is integrating from and once at the step 223 it integrates to. Only the second call is useful for transferring 224 local sensitivities to the global array. */ 225 if (curr_step == user->adj_idx) { 226 PetscFunctionReturn(PETSC_SUCCESS); 227 } else { 228 PetscCall(VecGetArrayRead(*mu, &sensmu_loc)); 229 PetscCall(VecGetArray(user->sens_mu1, &sensmu1_glob)); 230 PetscCall(VecGetArray(user->sens_mu2, &sensmu2_glob)); 231 sensmu1_glob[curr_step] = sensmu_loc[0]; 232 sensmu2_glob[curr_step] = sensmu_loc[1]; 233 PetscCall(VecRestoreArray(user->sens_mu1, &sensmu1_glob)); 234 PetscCall(VecRestoreArray(user->sens_mu2, &sensmu2_glob)); 235 PetscCall(VecRestoreArrayRead(*mu, &sensmu_loc)); 236 PetscFunctionReturn(PETSC_SUCCESS); 237 } 238 } 239 240 int main(int argc, char **argv) 241 { 242 TS ts; 243 AppCtx user; 244 PetscScalar *x_ptr, *y_ptr, *u_ptr; 245 PetscMPIInt size; 246 PetscBool monitor = PETSC_FALSE; 247 SAMethod sa = SA_GLOBAL; 248 249 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 250 Initialize program 251 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 252 PetscFunctionBeginUser; 253 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 254 PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size)); 255 PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!"); 256 257 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 258 Set runtime options 259 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 260 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "SA analysis options.", ""); 261 { 262 PetscCall(PetscOptionsGetBool(NULL, NULL, "-monitor", &monitor, NULL)); 263 PetscCall(PetscOptionsEnum("-sa_method", "Sensitivity analysis method (track or global)", "", SAMethods, (PetscEnum)sa, (PetscEnum *)&sa, NULL)); 264 } 265 PetscOptionsEnd(); 266 267 user.final_time = 0.1; 268 user.max_steps = 5; 269 user.time_step = user.final_time / user.max_steps; 270 271 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 272 Create necessary matrix and vectors for forward solve. 273 Create Jacp matrix for adjoint solve. 274 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 275 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu1)); 276 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu2)); 277 PetscCall(VecSet(user.mu1, 1.25)); 278 PetscCall(VecSet(user.mu2, 1.0e2)); 279 280 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 281 For tracking method : create the global sensitivity array to 282 accumulate sensitivity with respect to parameters at each step. 283 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 284 if (sa == SA_TRACK) { 285 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu1)); 286 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu2)); 287 } 288 289 PetscCall(MatCreate(PETSC_COMM_WORLD, &user.A)); 290 PetscCall(MatSetSizes(user.A, PETSC_DECIDE, PETSC_DECIDE, 2, 2)); 291 PetscCall(MatSetFromOptions(user.A)); 292 PetscCall(MatSetUp(user.A)); 293 PetscCall(MatCreateVecs(user.A, &user.U, NULL)); 294 295 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 296 Note that the dimensions of the Jacp matrix depend upon the 297 sensitivity analysis method being used ! 298 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 299 PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp)); 300 if (sa == SA_TRACK) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, 2)); 301 if (sa == SA_GLOBAL) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, user.max_steps * 2)); 302 PetscCall(MatSetFromOptions(user.Jacp)); 303 PetscCall(MatSetUp(user.Jacp)); 304 305 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 306 Create timestepping solver context 307 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 308 PetscCall(TSCreate(PETSC_COMM_WORLD, &ts)); 309 PetscCall(TSSetEquationType(ts, TS_EQ_ODE_EXPLICIT)); 310 PetscCall(TSSetType(ts, TSCN)); 311 312 PetscCall(TSSetRHSFunction(ts, NULL, RHSFunction, &user)); 313 PetscCall(TSSetRHSJacobian(ts, user.A, user.A, RHSJacobian, &user)); 314 if (sa == SA_TRACK) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_track, &user)); 315 if (sa == SA_GLOBAL) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_global, &user)); 316 317 PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP)); 318 PetscCall(TSSetMaxTime(ts, user.final_time)); 319 PetscCall(TSSetTimeStep(ts, user.final_time / user.max_steps)); 320 321 if (monitor) PetscCall(TSMonitorSet(ts, Monitor, &user, NULL)); 322 if (sa == SA_TRACK) PetscCall(TSAdjointMonitorSet(ts, AdjointMonitor, &user, NULL)); 323 324 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 325 Set initial conditions 326 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 327 PetscCall(VecGetArray(user.U, &x_ptr)); 328 x_ptr[0] = 2.0; 329 x_ptr[1] = -2.0 / 3.0; 330 PetscCall(VecRestoreArray(user.U, &x_ptr)); 331 332 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 333 Save trajectory of solution so that TSAdjointSolve() may be used 334 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 335 PetscCall(TSSetSaveTrajectory(ts)); 336 337 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 338 Set runtime options 339 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 340 PetscCall(TSSetFromOptions(ts)); 341 342 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 343 Execute forward model and print solution. 344 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 345 PetscCall(TSSolve(ts, user.U)); 346 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution of forward TS :\n")); 347 PetscCall(VecView(user.U, PETSC_VIEWER_STDOUT_WORLD)); 348 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Forward TS solve successful! Adjoint run begins!\n")); 349 350 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 351 Adjoint model starts here! Create adjoint vectors. 352 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 353 PetscCall(MatCreateVecs(user.A, &user.lambda, NULL)); 354 PetscCall(MatCreateVecs(user.Jacp, &user.mup, NULL)); 355 356 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 357 Set initial conditions for the adjoint vector 358 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 359 PetscCall(VecGetArray(user.U, &u_ptr)); 360 PetscCall(VecGetArray(user.lambda, &y_ptr)); 361 y_ptr[0] = 2 * (u_ptr[0] - 1.5967); 362 y_ptr[1] = 2 * (u_ptr[1] - -(1.02969)); 363 PetscCall(VecRestoreArray(user.lambda, &y_ptr)); 364 PetscCall(VecRestoreArray(user.U, &y_ptr)); 365 PetscCall(VecSet(user.mup, 0)); 366 367 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 368 Set number of cost functions. 369 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 370 PetscCall(TSSetCostGradients(ts, 1, &user.lambda, &user.mup)); 371 372 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 373 The adjoint vector mup has to be reset for each adjoint step when 374 using the tracking method as we want to treat the parameters at each 375 time step one at a time and prevent accumulation of the sensitivities 376 from parameters at previous time steps. 377 This is not necessary for the global method as each time dependent 378 parameter is treated as an independent parameter. 379 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 380 if (sa == SA_TRACK) { 381 for (user.adj_idx = user.max_steps; user.adj_idx > 0; user.adj_idx--) { 382 PetscCall(VecSet(user.mup, 0)); 383 PetscCall(TSAdjointSetSteps(ts, 1)); 384 PetscCall(TSAdjointSolve(ts)); 385 } 386 } 387 if (sa == SA_GLOBAL) PetscCall(TSAdjointSolve(ts)); 388 389 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 390 Display adjoint sensitivities wrt parameters and initial conditions 391 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 392 if (sa == SA_TRACK) { 393 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu1: d[cost]/d[mu1]\n")); 394 PetscCall(VecView(user.sens_mu1, PETSC_VIEWER_STDOUT_WORLD)); 395 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu2: d[cost]/d[mu2]\n")); 396 PetscCall(VecView(user.sens_mu2, PETSC_VIEWER_STDOUT_WORLD)); 397 } 398 399 if (sa == SA_GLOBAL) { 400 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n")); 401 PetscCall(VecView(user.mup, PETSC_VIEWER_STDOUT_WORLD)); 402 } 403 404 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n")); 405 PetscCall(VecView(user.lambda, PETSC_VIEWER_STDOUT_WORLD)); 406 407 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 408 Free work space! 409 All PETSc objects should be destroyed when they are no longer needed. 410 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 411 PetscCall(MatDestroy(&user.A)); 412 PetscCall(MatDestroy(&user.Jacp)); 413 PetscCall(VecDestroy(&user.U)); 414 PetscCall(VecDestroy(&user.lambda)); 415 PetscCall(VecDestroy(&user.mup)); 416 PetscCall(VecDestroy(&user.mu1)); 417 PetscCall(VecDestroy(&user.mu2)); 418 if (sa == SA_TRACK) { 419 PetscCall(VecDestroy(&user.sens_mu1)); 420 PetscCall(VecDestroy(&user.sens_mu2)); 421 } 422 PetscCall(TSDestroy(&ts)); 423 PetscCall(PetscFinalize()); 424 return 0; 425 } 426 427 /*TEST 428 429 test: 430 requires: !complex 431 suffix : track 432 args : -sa_method track 433 434 test: 435 requires: !complex 436 suffix : global 437 args : -sa_method global 438 439 TEST*/ 440