1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\ 2 equation with time dependent parameters using two approaches : \n\ 3 track : track only local sensitivities at each adjoint step \n\ 4 and accumulate them in a global array \n\ 5 global : track parameters at all timesteps together \n\ 6 Choose one of the two at runtime by -sa_method {track,global}. \n"; 7 8 /* 9 Simple example to demonstrate TSAdjoint capabilities for time dependent params 10 without integral cost terms using either a tracking or global method. 11 12 Modify the Van Der Pol Eq to : 13 [u1'] = [mu1(t)*u1] 14 [u2'] = [mu2(t)*((1-u1^2)*u2-u1)] 15 (with initial conditions & params independent) 16 17 Define uref to be solution with initail conditions (2,-2/3), mu=(1,1e3) 18 - u_ref : (1.5967,-1.02969) 19 20 Define const function as cost = 2-norm(u - u_ref); 21 22 Initialization for the adjoint TS : 23 - dcost/dy|final_time = 2*(u-u_ref)|final_time 24 - dcost/dp|final_time = 0 25 26 The tracking method only tracks local sensitivity at each time step 27 and accumulates these sensitivities in a global array. Since the structure 28 of the equations being solved at each time step does not change, the jacobian 29 wrt parameters is defined analogous to constant RHSJacobian for a liner 30 TSSolve and the size of the jacP is independent of the number of time 31 steps. Enable this mode of adjoint analysis by -sa_method track. 32 33 The global method combines the parameters at all timesteps and tracks them 34 together. Thus, the columns of the jacP matrix are filled dependent upon the 35 time step. Also, the dimensions of the jacP matrix now depend upon the number 36 of time steps. Enable this mode of adjoint analysis by -sa_method global. 37 38 Since the equations here have parameters at predefined time steps, this 39 example should be run with non adaptive time stepping solvers only. This 40 can be ensured by -ts_adapt_type none (which is the default behavior only 41 for certain TS solvers like TSCN. If using an explicit TS like TSRK, 42 please be sure to add the aforementioned option to disable adaptive 43 timestepping.) 44 */ 45 46 /* 47 Include "petscts.h" so that we can use TS solvers. Note that this file 48 automatically includes: 49 petscsys.h - base PETSc routines petscvec.h - vectors 50 petscmat.h - matrices 51 petscis.h - index sets petscksp.h - Krylov subspace methods 52 petscviewer.h - viewers petscpc.h - preconditioners 53 petscksp.h - linear solvers petscsnes.h - nonlinear solvers 54 */ 55 #include <petscts.h> 56 57 extern PetscErrorCode RHSFunction(TS, PetscReal, Vec, Vec, void *); 58 extern PetscErrorCode RHSJacobian(TS, PetscReal, Vec, Mat, Mat, void *); 59 extern PetscErrorCode RHSJacobianP_track(TS, PetscReal, Vec, Mat, void *); 60 extern PetscErrorCode RHSJacobianP_global(TS, PetscReal, Vec, Mat, void *); 61 extern PetscErrorCode Monitor(TS, PetscInt, PetscReal, Vec, void *); 62 extern PetscErrorCode AdjointMonitor(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *); 63 64 /* 65 User-defined application context - contains data needed by the 66 application-provided call-back routines. 67 */ 68 69 typedef struct { 70 /*------------- Forward solve data structures --------------*/ 71 PetscInt max_steps; /* number of steps to run ts for */ 72 PetscReal final_time; /* final time to integrate to*/ 73 PetscReal time_step; /* ts integration time step */ 74 Vec mu1; /* time dependent params */ 75 Vec mu2; /* time dependent params */ 76 Vec U; /* solution vector */ 77 Mat A; /* Jacobian matrix */ 78 79 /*------------- Adjoint solve data structures --------------*/ 80 Mat Jacp; /* JacobianP matrix */ 81 Vec lambda; /* adjoint variable */ 82 Vec mup; /* adjoint variable */ 83 84 /*------------- Global accumation vecs for monitor based tracking --------------*/ 85 Vec sens_mu1; /* global sensitivity accumulation */ 86 Vec sens_mu2; /* global sensitivity accumulation */ 87 PetscInt adj_idx; /* to keep track of adjoint solve index */ 88 } AppCtx; 89 90 typedef enum { 91 SA_TRACK, 92 SA_GLOBAL 93 } SAMethod; 94 static const char *const SAMethods[] = {"TRACK", "GLOBAL", "SAMethod", "SA_", 0}; 95 96 /* ----------------------- Explicit form of the ODE -------------------- */ 97 98 PetscErrorCode RHSFunction(TS ts, PetscReal t, Vec U, Vec F, void *ctx) { 99 AppCtx *user = (AppCtx *)ctx; 100 PetscScalar *f; 101 PetscInt curr_step; 102 const PetscScalar *u; 103 const PetscScalar *mu1; 104 const PetscScalar *mu2; 105 106 PetscFunctionBeginUser; 107 PetscCall(TSGetStepNumber(ts, &curr_step)); 108 PetscCall(VecGetArrayRead(U, &u)); 109 PetscCall(VecGetArrayRead(user->mu1, &mu1)); 110 PetscCall(VecGetArrayRead(user->mu2, &mu2)); 111 PetscCall(VecGetArray(F, &f)); 112 f[0] = mu1[curr_step] * u[1]; 113 f[1] = mu2[curr_step] * ((1. - u[0] * u[0]) * u[1] - u[0]); 114 PetscCall(VecRestoreArrayRead(U, &u)); 115 PetscCall(VecRestoreArrayRead(user->mu1, &mu1)); 116 PetscCall(VecRestoreArrayRead(user->mu2, &mu2)); 117 PetscCall(VecRestoreArray(F, &f)); 118 PetscFunctionReturn(0); 119 } 120 121 PetscErrorCode RHSJacobian(TS ts, PetscReal t, Vec U, Mat A, Mat B, void *ctx) { 122 AppCtx *user = (AppCtx *)ctx; 123 PetscInt rowcol[] = {0, 1}; 124 PetscScalar J[2][2]; 125 PetscInt curr_step; 126 const PetscScalar *u; 127 const PetscScalar *mu1; 128 const PetscScalar *mu2; 129 130 PetscFunctionBeginUser; 131 PetscCall(TSGetStepNumber(ts, &curr_step)); 132 PetscCall(VecGetArrayRead(user->mu1, &mu1)); 133 PetscCall(VecGetArrayRead(user->mu2, &mu2)); 134 PetscCall(VecGetArrayRead(U, &u)); 135 J[0][0] = 0; 136 J[1][0] = -mu2[curr_step] * (2.0 * u[1] * u[0] + 1.); 137 J[0][1] = mu1[curr_step]; 138 J[1][1] = mu2[curr_step] * (1.0 - u[0] * u[0]); 139 PetscCall(MatSetValues(A, 2, rowcol, 2, rowcol, &J[0][0], INSERT_VALUES)); 140 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 141 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 142 PetscCall(VecRestoreArrayRead(U, &u)); 143 PetscCall(VecRestoreArrayRead(user->mu1, &mu1)); 144 PetscCall(VecRestoreArrayRead(user->mu2, &mu2)); 145 PetscFunctionReturn(0); 146 } 147 148 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */ 149 150 PetscErrorCode RHSJacobianP_track(TS ts, PetscReal t, Vec U, Mat A, void *ctx) { 151 PetscInt row[] = {0, 1}, col[] = {0, 1}; 152 PetscScalar J[2][2]; 153 const PetscScalar *u; 154 155 PetscFunctionBeginUser; 156 PetscCall(VecGetArrayRead(U, &u)); 157 J[0][0] = u[1]; 158 J[1][0] = 0; 159 J[0][1] = 0; 160 J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0]; 161 PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES)); 162 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 163 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 164 PetscCall(VecRestoreArrayRead(U, &u)); 165 PetscFunctionReturn(0); 166 } 167 168 /* ------------------ Jacobian wrt parameters for global method ------------------ */ 169 170 PetscErrorCode RHSJacobianP_global(TS ts, PetscReal t, Vec U, Mat A, void *ctx) { 171 PetscInt row[] = {0, 1}, col[] = {0, 1}; 172 PetscScalar J[2][2]; 173 const PetscScalar *u; 174 PetscInt curr_step; 175 176 PetscFunctionBeginUser; 177 PetscCall(TSGetStepNumber(ts, &curr_step)); 178 PetscCall(VecGetArrayRead(U, &u)); 179 J[0][0] = u[1]; 180 J[1][0] = 0; 181 J[0][1] = 0; 182 J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0]; 183 col[0] = (curr_step)*2; 184 col[1] = (curr_step)*2 + 1; 185 PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES)); 186 PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY)); 187 PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY)); 188 PetscCall(VecRestoreArrayRead(U, &u)); 189 PetscFunctionReturn(0); 190 } 191 192 /* Dump solution to console if called */ 193 PetscErrorCode Monitor(TS ts, PetscInt step, PetscReal t, Vec U, void *ctx) { 194 PetscFunctionBeginUser; 195 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution at time %e is \n", (double)t)); 196 PetscCall(VecView(U, PETSC_VIEWER_STDOUT_WORLD)); 197 PetscFunctionReturn(0); 198 } 199 200 /* Customized adjoint monitor to keep track of local 201 sensitivities by storing them in a global sensitivity array. 202 Note : This routine is only used for the tracking method. */ 203 PetscErrorCode AdjointMonitor(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, void *ctx) { 204 AppCtx *user = (AppCtx *)ctx; 205 PetscInt curr_step; 206 PetscScalar *sensmu1_glob; 207 PetscScalar *sensmu2_glob; 208 const PetscScalar *sensmu_loc; 209 210 PetscFunctionBeginUser; 211 PetscCall(TSGetStepNumber(ts, &curr_step)); 212 /* Note that we skip the first call to the monitor in the adjoint 213 solve since the sensitivities are already set (during 214 initialization of adjoint vectors). 215 We also note that each indvidial TSAdjointSolve calls the monitor 216 twice, once at the step it is integrating from and once at the step 217 it integrates to. Only the second call is useful for transferring 218 local sensitivities to the global array. */ 219 if (curr_step == user->adj_idx) { 220 PetscFunctionReturn(0); 221 } else { 222 PetscCall(VecGetArrayRead(*mu, &sensmu_loc)); 223 PetscCall(VecGetArray(user->sens_mu1, &sensmu1_glob)); 224 PetscCall(VecGetArray(user->sens_mu2, &sensmu2_glob)); 225 sensmu1_glob[curr_step] = sensmu_loc[0]; 226 sensmu2_glob[curr_step] = sensmu_loc[1]; 227 PetscCall(VecRestoreArray(user->sens_mu1, &sensmu1_glob)); 228 PetscCall(VecRestoreArray(user->sens_mu2, &sensmu2_glob)); 229 PetscCall(VecRestoreArrayRead(*mu, &sensmu_loc)); 230 PetscFunctionReturn(0); 231 } 232 } 233 234 int main(int argc, char **argv) { 235 TS ts; 236 AppCtx user; 237 PetscScalar *x_ptr, *y_ptr, *u_ptr; 238 PetscMPIInt size; 239 PetscBool monitor = PETSC_FALSE; 240 SAMethod sa = SA_GLOBAL; 241 242 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 243 Initialize program 244 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 245 PetscFunctionBeginUser; 246 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 247 PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size)); 248 PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!"); 249 250 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 251 Set runtime options 252 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 253 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "SA analysis options.", ""); 254 { 255 PetscCall(PetscOptionsGetBool(NULL, NULL, "-monitor", &monitor, NULL)); 256 PetscCall(PetscOptionsEnum("-sa_method", "Sensitivity analysis method (track or global)", "", SAMethods, (PetscEnum)sa, (PetscEnum *)&sa, NULL)); 257 } 258 PetscOptionsEnd(); 259 260 user.final_time = 0.1; 261 user.max_steps = 5; 262 user.time_step = user.final_time / user.max_steps; 263 264 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 265 Create necessary matrix and vectors for forward solve. 266 Create Jacp matrix for adjoint solve. 267 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 268 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu1)); 269 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu2)); 270 PetscCall(VecSet(user.mu1, 1.25)); 271 PetscCall(VecSet(user.mu2, 1.0e2)); 272 273 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 274 For tracking method : create the global sensitivity array to 275 accumulate sensitivity with respect to parameters at each step. 276 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 277 if (sa == SA_TRACK) { 278 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu1)); 279 PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu2)); 280 } 281 282 PetscCall(MatCreate(PETSC_COMM_WORLD, &user.A)); 283 PetscCall(MatSetSizes(user.A, PETSC_DECIDE, PETSC_DECIDE, 2, 2)); 284 PetscCall(MatSetFromOptions(user.A)); 285 PetscCall(MatSetUp(user.A)); 286 PetscCall(MatCreateVecs(user.A, &user.U, NULL)); 287 288 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 289 Note that the dimensions of the Jacp matrix depend upon the 290 sensitivity analysis method being used ! 291 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 292 PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp)); 293 if (sa == SA_TRACK) { PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, 2)); } 294 if (sa == SA_GLOBAL) { PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, user.max_steps * 2)); } 295 PetscCall(MatSetFromOptions(user.Jacp)); 296 PetscCall(MatSetUp(user.Jacp)); 297 298 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 299 Create timestepping solver context 300 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 301 PetscCall(TSCreate(PETSC_COMM_WORLD, &ts)); 302 PetscCall(TSSetEquationType(ts, TS_EQ_ODE_EXPLICIT)); 303 PetscCall(TSSetType(ts, TSCN)); 304 305 PetscCall(TSSetRHSFunction(ts, NULL, RHSFunction, &user)); 306 PetscCall(TSSetRHSJacobian(ts, user.A, user.A, RHSJacobian, &user)); 307 if (sa == SA_TRACK) { PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_track, &user)); } 308 if (sa == SA_GLOBAL) { PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_global, &user)); } 309 310 PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP)); 311 PetscCall(TSSetMaxTime(ts, user.final_time)); 312 PetscCall(TSSetTimeStep(ts, user.final_time / user.max_steps)); 313 314 if (monitor) { PetscCall(TSMonitorSet(ts, Monitor, &user, NULL)); } 315 if (sa == SA_TRACK) { PetscCall(TSAdjointMonitorSet(ts, AdjointMonitor, &user, NULL)); } 316 317 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 318 Set initial conditions 319 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 320 PetscCall(VecGetArray(user.U, &x_ptr)); 321 x_ptr[0] = 2.0; 322 x_ptr[1] = -2.0 / 3.0; 323 PetscCall(VecRestoreArray(user.U, &x_ptr)); 324 325 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 326 Save trajectory of solution so that TSAdjointSolve() may be used 327 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 328 PetscCall(TSSetSaveTrajectory(ts)); 329 330 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 331 Set runtime options 332 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 333 PetscCall(TSSetFromOptions(ts)); 334 335 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 336 Execute forward model and print solution. 337 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 338 PetscCall(TSSolve(ts, user.U)); 339 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution of forward TS :\n")); 340 PetscCall(VecView(user.U, PETSC_VIEWER_STDOUT_WORLD)); 341 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Forward TS solve successful! Adjoint run begins!\n")); 342 343 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 344 Adjoint model starts here! Create adjoint vectors. 345 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 346 PetscCall(MatCreateVecs(user.A, &user.lambda, NULL)); 347 PetscCall(MatCreateVecs(user.Jacp, &user.mup, NULL)); 348 349 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 350 Set initial conditions for the adjoint vector 351 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 352 PetscCall(VecGetArray(user.U, &u_ptr)); 353 PetscCall(VecGetArray(user.lambda, &y_ptr)); 354 y_ptr[0] = 2 * (u_ptr[0] - 1.5967); 355 y_ptr[1] = 2 * (u_ptr[1] - -(1.02969)); 356 PetscCall(VecRestoreArray(user.lambda, &y_ptr)); 357 PetscCall(VecRestoreArray(user.U, &y_ptr)); 358 PetscCall(VecSet(user.mup, 0)); 359 360 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 361 Set number of cost functions. 362 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 363 PetscCall(TSSetCostGradients(ts, 1, &user.lambda, &user.mup)); 364 365 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 366 The adjoint vector mup has to be reset for each adjoint step when 367 using the tracking method as we want to treat the parameters at each 368 time step one at a time and prevent accumulation of the sensitivities 369 from parameters at previous time steps. 370 This is not necessary for the global method as each time dependent 371 parameter is treated as an independent parameter. 372 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 373 if (sa == SA_TRACK) { 374 for (user.adj_idx = user.max_steps; user.adj_idx > 0; user.adj_idx--) { 375 PetscCall(VecSet(user.mup, 0)); 376 PetscCall(TSAdjointSetSteps(ts, 1)); 377 PetscCall(TSAdjointSolve(ts)); 378 } 379 } 380 if (sa == SA_GLOBAL) { PetscCall(TSAdjointSolve(ts)); } 381 382 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 383 Dispaly adjoint sensitivities wrt parameters and initial conditions 384 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 385 if (sa == SA_TRACK) { 386 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu1: d[cost]/d[mu1]\n")); 387 PetscCall(VecView(user.sens_mu1, PETSC_VIEWER_STDOUT_WORLD)); 388 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu2: d[cost]/d[mu2]\n")); 389 PetscCall(VecView(user.sens_mu2, PETSC_VIEWER_STDOUT_WORLD)); 390 } 391 392 if (sa == SA_GLOBAL) { 393 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n")); 394 PetscCall(VecView(user.mup, PETSC_VIEWER_STDOUT_WORLD)); 395 } 396 397 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n")); 398 PetscCall(VecView(user.lambda, PETSC_VIEWER_STDOUT_WORLD)); 399 400 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 401 Free work space! 402 All PETSc objects should be destroyed when they are no longer needed. 403 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 404 PetscCall(MatDestroy(&user.A)); 405 PetscCall(MatDestroy(&user.Jacp)); 406 PetscCall(VecDestroy(&user.U)); 407 PetscCall(VecDestroy(&user.lambda)); 408 PetscCall(VecDestroy(&user.mup)); 409 PetscCall(VecDestroy(&user.mu1)); 410 PetscCall(VecDestroy(&user.mu2)); 411 if (sa == SA_TRACK) { 412 PetscCall(VecDestroy(&user.sens_mu1)); 413 PetscCall(VecDestroy(&user.sens_mu2)); 414 } 415 PetscCall(TSDestroy(&ts)); 416 PetscCall(PetscFinalize()); 417 return (0); 418 } 419 420 /*TEST 421 422 test: 423 requires: !complex 424 suffix : track 425 args : -sa_method track 426 427 test: 428 requires: !complex 429 suffix : global 430 args : -sa_method global 431 432 TEST*/ 433