1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\ 2 equation with time dependent parameters using two approaches : \n\ 3 track : track only local sensitivities at each adjoint step \n\ 4 and accumulate them in a global array \n\ 5 global : track parameters at all timesteps together \n\ 6 Choose one of the two at runtime by -sa_method {track,global}. \n"; 7 8 /* 9 Simple example to demonstrate TSAdjoint capabilities for time dependent params 10 without integral cost terms using either a tracking or global method. 11 12 Modify the Van Der Pol Eq to : 13 [u1'] = [mu1(t)*u1] 14 [u2'] = [mu2(t)*((1-u1^2)*u2-u1)] 15 (with initial conditions & params independent) 16 17 Define uref to be solution with initail conditions (2,-2/3), mu=(1,1e3) 18 - u_ref : (1.5967,-1.02969) 19 20 Define const function as cost = 2-norm(u - u_ref); 21 22 Initialization for the adjoint TS : 23 - dcost/dy|final_time = 2*(u-u_ref)|final_time 24 - dcost/dp|final_time = 0 25 26 The tracking method only tracks local sensitivity at each time step 27 and accumulates these sensitivities in a global array. Since the structure 28 of the equations being solved at each time step does not change, the jacobian 29 wrt parameters is defined analogous to constant RHSJacobian for a liner 30 TSSolve and the size of the jacP is independent of the number of time 31 steps. Enable this mode of adjoint analysis by -sa_method track. 32 33 The global method combines the parameters at all timesteps and tracks them 34 together. Thus, the columns of the jacP matrix are filled dependent upon the 35 time step. Also, the dimensions of the jacP matrix now depend upon the number 36 of time steps. Enable this mode of adjoint analysis by -sa_method global. 37 38 Since the equations here have parameters at predefined time steps, this 39 example should be run with non adaptive time stepping solvers only. This 40 can be ensured by -ts_adapt_type none (which is the default behavior only 41 for certain TS solvers like TSCN. If using an explicit TS like TSRK, 42 please be sure to add the aforementioned option to disable adaptive 43 timestepping.) 44 */ 45 46 /* 47 Include "petscts.h" so that we can use TS solvers. Note that this file 48 automatically includes: 49 petscsys.h - base PETSc routines petscvec.h - vectors 50 petscmat.h - matrices 51 petscis.h - index sets petscksp.h - Krylov subspace methods 52 petscviewer.h - viewers petscpc.h - preconditioners 53 petscksp.h - linear solvers petscsnes.h - nonlinear solvers 54 */ 55 #include <petscts.h> 56 57 extern PetscErrorCode RHSFunction(TS ,PetscReal ,Vec ,Vec ,void *); 58 extern PetscErrorCode RHSJacobian(TS ,PetscReal ,Vec ,Mat ,Mat ,void *); 59 extern PetscErrorCode RHSJacobianP_track(TS ,PetscReal ,Vec ,Mat ,void *); 60 extern PetscErrorCode RHSJacobianP_global(TS ,PetscReal ,Vec ,Mat ,void *); 61 extern PetscErrorCode Monitor(TS ,PetscInt ,PetscReal ,Vec ,void *); 62 extern PetscErrorCode AdjointMonitor(TS ,PetscInt ,PetscReal ,Vec ,PetscInt ,Vec *, Vec *,void *); 63 64 /* 65 User-defined application context - contains data needed by the 66 application-provided call-back routines. 67 */ 68 69 typedef struct { 70 /*------------- Forward solve data structures --------------*/ 71 PetscInt max_steps; /* number of steps to run ts for */ 72 PetscReal final_time; /* final time to integrate to*/ 73 PetscReal time_step; /* ts integration time step */ 74 Vec mu1; /* time dependent params */ 75 Vec mu2; /* time dependent params */ 76 Vec U; /* solution vector */ 77 Mat A; /* Jacobian matrix */ 78 79 /*------------- Adjoint solve data structures --------------*/ 80 Mat Jacp; /* JacobianP matrix */ 81 Vec lambda; /* adjoint variable */ 82 Vec mup; /* adjoint variable */ 83 84 /*------------- Global accumation vecs for monitor based tracking --------------*/ 85 Vec sens_mu1; /* global sensitivity accumulation */ 86 Vec sens_mu2; /* global sensitivity accumulation */ 87 PetscInt adj_idx; /* to keep track of adjoint solve index */ 88 } AppCtx; 89 90 typedef enum {SA_TRACK, SA_GLOBAL} SAMethod; 91 static const char *const SAMethods[] = {"TRACK","GLOBAL","SAMethod","SA_",0}; 92 93 /* ----------------------- Explicit form of the ODE -------------------- */ 94 95 PetscErrorCode RHSFunction(TS ts,PetscReal t,Vec U,Vec F,void *ctx) 96 { 97 AppCtx *user = (AppCtx*) ctx; 98 PetscScalar *f; 99 PetscInt curr_step; 100 const PetscScalar *u; 101 const PetscScalar *mu1; 102 const PetscScalar *mu2; 103 104 PetscFunctionBeginUser; 105 PetscCall(TSGetStepNumber(ts,&curr_step)); 106 PetscCall(VecGetArrayRead(U,&u)); 107 PetscCall(VecGetArrayRead(user->mu1,&mu1)); 108 PetscCall(VecGetArrayRead(user->mu2,&mu2)); 109 PetscCall(VecGetArray(F,&f)); 110 f[0] = mu1[curr_step]*u[1]; 111 f[1] = mu2[curr_step]*((1.-u[0]*u[0])*u[1]-u[0]); 112 PetscCall(VecRestoreArrayRead(U,&u)); 113 PetscCall(VecRestoreArrayRead(user->mu1,&mu1)); 114 PetscCall(VecRestoreArrayRead(user->mu2,&mu2)); 115 PetscCall(VecRestoreArray(F,&f)); 116 PetscFunctionReturn(0); 117 } 118 119 PetscErrorCode RHSJacobian(TS ts,PetscReal t,Vec U,Mat A,Mat B,void *ctx) 120 { 121 AppCtx *user = (AppCtx*) ctx; 122 PetscInt rowcol[] = {0,1}; 123 PetscScalar J[2][2]; 124 PetscInt curr_step; 125 const PetscScalar *u; 126 const PetscScalar *mu1; 127 const PetscScalar *mu2; 128 129 PetscFunctionBeginUser; 130 PetscCall(TSGetStepNumber(ts,&curr_step)); 131 PetscCall(VecGetArrayRead(user->mu1,&mu1)); 132 PetscCall(VecGetArrayRead(user->mu2,&mu2)); 133 PetscCall(VecGetArrayRead(U,&u)); 134 J[0][0] = 0; 135 J[1][0] = -mu2[curr_step]*(2.0*u[1]*u[0]+1.); 136 J[0][1] = mu1[curr_step]; 137 J[1][1] = mu2[curr_step]*(1.0-u[0]*u[0]); 138 PetscCall(MatSetValues(A,2,rowcol,2,rowcol,&J[0][0],INSERT_VALUES)); 139 PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY)); 140 PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY)); 141 PetscCall(VecRestoreArrayRead(U,&u)); 142 PetscCall(VecRestoreArrayRead(user->mu1,&mu1)); 143 PetscCall(VecRestoreArrayRead(user->mu2,&mu2)); 144 PetscFunctionReturn(0); 145 } 146 147 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */ 148 149 PetscErrorCode RHSJacobianP_track(TS ts,PetscReal t,Vec U,Mat A,void *ctx) 150 { 151 PetscInt row[] = {0,1},col[] = {0,1}; 152 PetscScalar J[2][2]; 153 const PetscScalar *u; 154 155 PetscFunctionBeginUser; 156 PetscCall(VecGetArrayRead(U,&u)); 157 J[0][0] = u[1]; 158 J[1][0] = 0; 159 J[0][1] = 0; 160 J[1][1] = (1.-u[0]*u[0])*u[1]-u[0]; 161 PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES)); 162 PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY)); 163 PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY)); 164 PetscCall(VecRestoreArrayRead(U,&u)); 165 PetscFunctionReturn(0); 166 } 167 168 /* ------------------ Jacobian wrt parameters for global method ------------------ */ 169 170 PetscErrorCode RHSJacobianP_global(TS ts,PetscReal t,Vec U,Mat A,void *ctx) 171 { 172 PetscInt row[] = {0,1},col[] = {0,1}; 173 PetscScalar J[2][2]; 174 const PetscScalar *u; 175 PetscInt curr_step; 176 177 PetscFunctionBeginUser; 178 PetscCall(TSGetStepNumber(ts,&curr_step)); 179 PetscCall(VecGetArrayRead(U,&u)); 180 J[0][0] = u[1]; 181 J[1][0] = 0; 182 J[0][1] = 0; 183 J[1][1] = (1.-u[0]*u[0])*u[1]-u[0]; 184 col[0] = (curr_step)*2; 185 col[1] = (curr_step)*2+1; 186 PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES)); 187 PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY)); 188 PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY)); 189 PetscCall(VecRestoreArrayRead(U,&u)); 190 PetscFunctionReturn(0); 191 } 192 193 /* Dump solution to console if called */ 194 PetscErrorCode Monitor(TS ts,PetscInt step,PetscReal t,Vec U,void *ctx) 195 { 196 PetscFunctionBeginUser; 197 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution at time %e is \n", t)); 198 PetscCall(VecView(U,PETSC_VIEWER_STDOUT_WORLD)); 199 PetscFunctionReturn(0); 200 } 201 202 /* Customized adjoint monitor to keep track of local 203 sensitivities by storing them in a global sensitivity array. 204 Note : This routine is only used for the tracking method. */ 205 PetscErrorCode AdjointMonitor(TS ts,PetscInt steps,PetscReal time,Vec u,PetscInt numcost,Vec *lambda, Vec *mu,void *ctx) 206 { 207 AppCtx *user = (AppCtx*) ctx; 208 PetscInt curr_step; 209 PetscScalar *sensmu1_glob; 210 PetscScalar *sensmu2_glob; 211 const PetscScalar *sensmu_loc; 212 213 PetscFunctionBeginUser; 214 PetscCall(TSGetStepNumber(ts,&curr_step)); 215 /* Note that we skip the first call to the monitor in the adjoint 216 solve since the sensitivities are already set (during 217 initialization of adjoint vectors). 218 We also note that each indvidial TSAdjointSolve calls the monitor 219 twice, once at the step it is integrating from and once at the step 220 it integrates to. Only the second call is useful for transferring 221 local sensitivities to the global array. */ 222 if (curr_step == user->adj_idx) { 223 PetscFunctionReturn(0); 224 } else { 225 PetscCall(VecGetArrayRead(*mu,&sensmu_loc)); 226 PetscCall(VecGetArray(user->sens_mu1,&sensmu1_glob)); 227 PetscCall(VecGetArray(user->sens_mu2,&sensmu2_glob)); 228 sensmu1_glob[curr_step] = sensmu_loc[0]; 229 sensmu2_glob[curr_step] = sensmu_loc[1]; 230 PetscCall(VecRestoreArray(user->sens_mu1,&sensmu1_glob)); 231 PetscCall(VecRestoreArray(user->sens_mu2,&sensmu2_glob)); 232 PetscCall(VecRestoreArrayRead(*mu,&sensmu_loc)); 233 PetscFunctionReturn(0); 234 } 235 } 236 237 int main(int argc,char **argv) 238 { 239 TS ts; 240 AppCtx user; 241 PetscScalar *x_ptr,*y_ptr,*u_ptr; 242 PetscMPIInt size; 243 PetscBool monitor = PETSC_FALSE; 244 SAMethod sa = SA_GLOBAL; 245 PetscErrorCode ierr; 246 247 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 248 Initialize program 249 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 250 PetscCall(PetscInitialize(&argc,&argv,NULL,help)); 251 PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD,&size)); 252 PetscCheck(size == 1,PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!"); 253 254 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 255 Set runtime options 256 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 257 ierr = PetscOptionsBegin(PETSC_COMM_WORLD,NULL,"SA analysis options.","");PetscCall(ierr);{ 258 PetscCall(PetscOptionsGetBool(NULL,NULL,"-monitor",&monitor,NULL)); 259 PetscCall(PetscOptionsEnum("-sa_method","Sensitivity analysis method (track or global)","",SAMethods,(PetscEnum)sa,(PetscEnum*)&sa,NULL)); 260 } 261 ierr = PetscOptionsEnd();PetscCall(ierr); 262 263 user.final_time = 0.1; 264 user.max_steps = 5; 265 user.time_step = user.final_time/user.max_steps; 266 267 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 268 Create necessary matrix and vectors for forward solve. 269 Create Jacp matrix for adjoint solve. 270 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 271 PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu1)); 272 PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu2)); 273 PetscCall(VecSet(user.mu1,1.25)); 274 PetscCall(VecSet(user.mu2,1.0e2)); 275 276 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 277 For tracking method : create the global sensitivity array to 278 accumulate sensitivity with respect to parameters at each step. 279 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 280 if (sa == SA_TRACK) { 281 PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu1)); 282 PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu2)); 283 } 284 285 PetscCall(MatCreate(PETSC_COMM_WORLD,&user.A)); 286 PetscCall(MatSetSizes(user.A,PETSC_DECIDE,PETSC_DECIDE,2,2)); 287 PetscCall(MatSetFromOptions(user.A)); 288 PetscCall(MatSetUp(user.A)); 289 PetscCall(MatCreateVecs(user.A,&user.U,NULL)); 290 291 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 292 Note that the dimensions of the Jacp matrix depend upon the 293 sensitivity analysis method being used ! 294 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 295 PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jacp)); 296 if (sa == SA_TRACK) { 297 PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,2)); 298 } 299 if (sa == SA_GLOBAL) { 300 PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,user.max_steps*2)); 301 } 302 PetscCall(MatSetFromOptions(user.Jacp)); 303 PetscCall(MatSetUp(user.Jacp)); 304 305 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 306 Create timestepping solver context 307 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 308 PetscCall(TSCreate(PETSC_COMM_WORLD,&ts)); 309 PetscCall(TSSetEquationType(ts,TS_EQ_ODE_EXPLICIT)); 310 PetscCall(TSSetType(ts,TSCN)); 311 312 PetscCall(TSSetRHSFunction(ts,NULL,RHSFunction,&user)); 313 PetscCall(TSSetRHSJacobian(ts,user.A,user.A,RHSJacobian,&user)); 314 if (sa == SA_TRACK) { 315 PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_track,&user)); 316 } 317 if (sa == SA_GLOBAL) { 318 PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_global,&user)); 319 } 320 321 PetscCall(TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP)); 322 PetscCall(TSSetMaxTime(ts,user.final_time)); 323 PetscCall(TSSetTimeStep(ts,user.final_time/user.max_steps)); 324 325 if (monitor) { 326 PetscCall(TSMonitorSet(ts,Monitor,&user,NULL)); 327 } 328 if (sa == SA_TRACK) { 329 PetscCall(TSAdjointMonitorSet(ts,AdjointMonitor,&user,NULL)); 330 } 331 332 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 333 Set initial conditions 334 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 335 PetscCall(VecGetArray(user.U,&x_ptr)); 336 x_ptr[0] = 2.0; 337 x_ptr[1] = -2.0/3.0; 338 PetscCall(VecRestoreArray(user.U,&x_ptr)); 339 340 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 341 Save trajectory of solution so that TSAdjointSolve() may be used 342 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 343 PetscCall(TSSetSaveTrajectory(ts)); 344 345 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 346 Set runtime options 347 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 348 PetscCall(TSSetFromOptions(ts)); 349 350 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 351 Execute forward model and print solution. 352 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 353 PetscCall(TSSolve(ts,user.U)); 354 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution of forward TS :\n")); 355 PetscCall(VecView(user.U,PETSC_VIEWER_STDOUT_WORLD)); 356 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Forward TS solve successful! Adjoint run begins!\n")); 357 358 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 359 Adjoint model starts here! Create adjoint vectors. 360 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 361 PetscCall(MatCreateVecs(user.A,&user.lambda,NULL)); 362 PetscCall(MatCreateVecs(user.Jacp,&user.mup,NULL)); 363 364 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 365 Set initial conditions for the adjoint vector 366 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 367 PetscCall(VecGetArray(user.U,&u_ptr)); 368 PetscCall(VecGetArray(user.lambda,&y_ptr)); 369 y_ptr[0] = 2*(u_ptr[0] - 1.5967); 370 y_ptr[1] = 2*(u_ptr[1] - -(1.02969)); 371 PetscCall(VecRestoreArray(user.lambda,&y_ptr)); 372 PetscCall(VecRestoreArray(user.U,&y_ptr)); 373 PetscCall(VecSet(user.mup,0)); 374 375 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 376 Set number of cost functions. 377 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 378 PetscCall(TSSetCostGradients(ts,1,&user.lambda,&user.mup)); 379 380 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 381 The adjoint vector mup has to be reset for each adjoint step when 382 using the tracking method as we want to treat the parameters at each 383 time step one at a time and prevent accumulation of the sensitivities 384 from parameters at previous time steps. 385 This is not necessary for the global method as each time dependent 386 parameter is treated as an independent parameter. 387 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 388 if (sa == SA_TRACK) { 389 for (user.adj_idx=user.max_steps; user.adj_idx>0; user.adj_idx--) { 390 PetscCall(VecSet(user.mup,0)); 391 PetscCall(TSAdjointSetSteps(ts, 1)); 392 PetscCall(TSAdjointSolve(ts)); 393 } 394 } 395 if (sa == SA_GLOBAL) { 396 PetscCall(TSAdjointSolve(ts)); 397 } 398 399 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 400 Dispaly adjoint sensitivities wrt parameters and initial conditions 401 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 402 if (sa == SA_TRACK) { 403 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt mu1: d[cost]/d[mu1]\n")); 404 PetscCall(VecView(user.sens_mu1,PETSC_VIEWER_STDOUT_WORLD)); 405 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt mu2: d[cost]/d[mu2]\n")); 406 PetscCall(VecView(user.sens_mu2,PETSC_VIEWER_STDOUT_WORLD)); 407 } 408 409 if (sa == SA_GLOBAL) { 410 ierr = PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt params: d[cost]/d[p], where p refers to \n\ 411 the interlaced vector made by combining mu1,mu2\n");PetscCall(ierr); 412 PetscCall(VecView(user.mup,PETSC_VIEWER_STDOUT_WORLD)); 413 } 414 415 PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n")); 416 PetscCall(VecView(user.lambda,PETSC_VIEWER_STDOUT_WORLD)); 417 418 /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 419 Free work space! 420 All PETSc objects should be destroyed when they are no longer needed. 421 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ 422 PetscCall(MatDestroy(&user.A)); 423 PetscCall(MatDestroy(&user.Jacp)); 424 PetscCall(VecDestroy(&user.U)); 425 PetscCall(VecDestroy(&user.lambda)); 426 PetscCall(VecDestroy(&user.mup)); 427 PetscCall(VecDestroy(&user.mu1)); 428 PetscCall(VecDestroy(&user.mu2)); 429 if (sa == SA_TRACK) { 430 PetscCall(VecDestroy(&user.sens_mu1)); 431 PetscCall(VecDestroy(&user.sens_mu2)); 432 } 433 PetscCall(TSDestroy(&ts)); 434 PetscCall(PetscFinalize()); 435 return(ierr); 436 } 437 438 /*TEST 439 440 test: 441 requires: !complex 442 suffix : track 443 args : -sa_method track 444 445 test: 446 requires: !complex 447 suffix : global 448 args : -sa_method global 449 450 TEST*/ 451