xref: /petsc/src/ts/tutorials/ex20td.c (revision f97672e55eacc8688507b9471cd7ec2664d7f203)
1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\
2 equation with time dependent parameters using two approaches :  \n\
3 track  : track only local sensitivities at each adjoint step \n\
4          and accumulate them in a global array \n\
5 global : track parameters at all timesteps together \n\
6 Choose one of the two at runtime by -sa_method {track,global}. \n";
7 
8 /*
9    Simple example to demonstrate TSAdjoint capabilities for time dependent params
10    without integral cost terms using either a tracking or global method.
11 
12    Modify the Van Der Pol Eq to :
13    [u1'] = [mu1(t)*u1]
14    [u2'] = [mu2(t)*((1-u1^2)*u2-u1)]
15    (with initial conditions & params independent)
16 
17    Define uref to be solution with initail conditions (2,-2/3), mu=(1,1e3)
18    - u_ref : (1.5967,-1.02969)
19 
20    Define const function as cost = 2-norm(u - u_ref);
21 
22    Initialization for the adjoint TS :
23    - dcost/dy|final_time = 2*(u-u_ref)|final_time
24    - dcost/dp|final_time = 0
25 
26    The tracking method only tracks local sensitivity at each time step
27    and accumulates these sensitivities in a global array. Since the structure
28    of the equations being solved at each time step does not change, the jacobian
29    wrt parameters is defined analogous to constant RHSJacobian for a liner
30    TSSolve and the size of the jacP is independent of the number of time
31    steps. Enable this mode of adjoint analysis by -sa_method track.
32 
33    The global method combines the parameters at all timesteps and tracks them
34    together. Thus, the columns of the jacP matrix are filled dependent upon the
35    time step. Also, the dimensions of the jacP matrix now depend upon the number
36    of time steps. Enable this mode of adjoint analysis by -sa_method global.
37 
38    Since the equations here have parameters at predefined time steps, this
39    example should be run with non adaptive time stepping solvers only. This
40    can be ensured by -ts_adapt_type none (which is the default behavior only
41    for certain TS solvers like TSCN. If using an explicit TS like TSRK,
42    please be sure to add the aforementioned option to disable adaptive
43    timestepping.)
44 */
45 
46 /*
47    Include "petscts.h" so that we can use TS solvers.  Note that this file
48    automatically includes:
49      petscsys.h    - base PETSc routines   petscvec.h  - vectors
50      petscmat.h    - matrices
51      petscis.h     - index sets            petscksp.h  - Krylov subspace methods
52      petscviewer.h - viewers               petscpc.h   - preconditioners
53      petscksp.h    - linear solvers        petscsnes.h - nonlinear solvers
54 */
55 #include <petscts.h>
56 
57 extern PetscErrorCode RHSFunction(TS ,PetscReal ,Vec ,Vec ,void *);
58 extern PetscErrorCode RHSJacobian(TS ,PetscReal ,Vec ,Mat ,Mat ,void *);
59 extern PetscErrorCode RHSJacobianP_track(TS ,PetscReal ,Vec ,Mat ,void *);
60 extern PetscErrorCode RHSJacobianP_global(TS ,PetscReal ,Vec ,Mat ,void *);
61 extern PetscErrorCode Monitor(TS ,PetscInt ,PetscReal ,Vec ,void *);
62 extern PetscErrorCode AdjointMonitor(TS ,PetscInt ,PetscReal ,Vec ,PetscInt ,Vec *, Vec *,void *);
63 
64 /*
65    User-defined application context - contains data needed by the
66    application-provided call-back routines.
67 */
68 
69 typedef struct {
70  /*------------- Forward solve data structures --------------*/
71   PetscInt  max_steps;     /* number of steps to run ts for */
72   PetscReal final_time;    /* final time to integrate to*/
73   PetscReal time_step;     /* ts integration time step */
74   Vec       mu1;           /* time dependent params */
75   Vec       mu2;           /* time dependent params */
76   Vec       U;             /* solution vector */
77   Mat       A;             /* Jacobian matrix */
78 
79   /*------------- Adjoint solve data structures --------------*/
80   Mat       Jacp;          /* JacobianP matrix */
81   Vec       lambda;        /* adjoint variable */
82   Vec       mup;           /* adjoint variable */
83 
84   /*------------- Global accumation vecs for monitor based tracking --------------*/
85   Vec       sens_mu1;      /* global sensitivity accumulation */
86   Vec       sens_mu2;      /* global sensitivity accumulation */
87   PetscInt  adj_idx;       /* to keep track of adjoint solve index */
88 } AppCtx;
89 
90 typedef enum {SA_TRACK, SA_GLOBAL} SAMethod;
91 static const char *const SAMethods[] = {"TRACK","GLOBAL","SAMethod","SA_",0};
92 
93 /* ----------------------- Explicit form of the ODE  -------------------- */
94 
95 PetscErrorCode RHSFunction(TS ts,PetscReal t,Vec U,Vec F,void *ctx)
96 {
97   AppCtx            *user = (AppCtx*) ctx;
98   PetscScalar       *f;
99   PetscInt          curr_step;
100   const PetscScalar *u;
101   const PetscScalar *mu1;
102   const PetscScalar *mu2;
103 
104   PetscFunctionBeginUser;
105   PetscCall(TSGetStepNumber(ts,&curr_step));
106   PetscCall(VecGetArrayRead(U,&u));
107   PetscCall(VecGetArrayRead(user->mu1,&mu1));
108   PetscCall(VecGetArrayRead(user->mu2,&mu2));
109   PetscCall(VecGetArray(F,&f));
110   f[0] = mu1[curr_step]*u[1];
111   f[1] = mu2[curr_step]*((1.-u[0]*u[0])*u[1]-u[0]);
112   PetscCall(VecRestoreArrayRead(U,&u));
113   PetscCall(VecRestoreArrayRead(user->mu1,&mu1));
114   PetscCall(VecRestoreArrayRead(user->mu2,&mu2));
115   PetscCall(VecRestoreArray(F,&f));
116   PetscFunctionReturn(0);
117 }
118 
119 PetscErrorCode RHSJacobian(TS ts,PetscReal t,Vec U,Mat A,Mat B,void *ctx)
120 {
121   AppCtx            *user = (AppCtx*) ctx;
122   PetscInt          rowcol[] = {0,1};
123   PetscScalar       J[2][2];
124   PetscInt          curr_step;
125   const PetscScalar *u;
126   const PetscScalar *mu1;
127   const PetscScalar *mu2;
128 
129   PetscFunctionBeginUser;
130   PetscCall(TSGetStepNumber(ts,&curr_step));
131   PetscCall(VecGetArrayRead(user->mu1,&mu1));
132   PetscCall(VecGetArrayRead(user->mu2,&mu2));
133   PetscCall(VecGetArrayRead(U,&u));
134   J[0][0] = 0;
135   J[1][0] = -mu2[curr_step]*(2.0*u[1]*u[0]+1.);
136   J[0][1] = mu1[curr_step];
137   J[1][1] = mu2[curr_step]*(1.0-u[0]*u[0]);
138   PetscCall(MatSetValues(A,2,rowcol,2,rowcol,&J[0][0],INSERT_VALUES));
139   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
140   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
141   PetscCall(VecRestoreArrayRead(U,&u));
142   PetscCall(VecRestoreArrayRead(user->mu1,&mu1));
143   PetscCall(VecRestoreArrayRead(user->mu2,&mu2));
144   PetscFunctionReturn(0);
145 }
146 
147 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */
148 
149 PetscErrorCode RHSJacobianP_track(TS ts,PetscReal t,Vec U,Mat A,void *ctx)
150 {
151   PetscInt          row[] = {0,1},col[] = {0,1};
152   PetscScalar       J[2][2];
153   const PetscScalar *u;
154 
155   PetscFunctionBeginUser;
156   PetscCall(VecGetArrayRead(U,&u));
157   J[0][0] = u[1];
158   J[1][0] = 0;
159   J[0][1] = 0;
160   J[1][1] = (1.-u[0]*u[0])*u[1]-u[0];
161   PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES));
162   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
163   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
164   PetscCall(VecRestoreArrayRead(U,&u));
165   PetscFunctionReturn(0);
166 }
167 
168 /* ------------------ Jacobian wrt parameters for global method ------------------ */
169 
170 PetscErrorCode RHSJacobianP_global(TS ts,PetscReal t,Vec U,Mat A,void *ctx)
171 {
172   PetscInt          row[] = {0,1},col[] = {0,1};
173   PetscScalar       J[2][2];
174   const PetscScalar *u;
175   PetscInt          curr_step;
176 
177   PetscFunctionBeginUser;
178   PetscCall(TSGetStepNumber(ts,&curr_step));
179   PetscCall(VecGetArrayRead(U,&u));
180   J[0][0] = u[1];
181   J[1][0] = 0;
182   J[0][1] = 0;
183   J[1][1] = (1.-u[0]*u[0])*u[1]-u[0];
184   col[0] = (curr_step)*2;
185   col[1] = (curr_step)*2+1;
186   PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES));
187   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
188   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
189   PetscCall(VecRestoreArrayRead(U,&u));
190   PetscFunctionReturn(0);
191 }
192 
193 /* Dump solution to console if called */
194 PetscErrorCode Monitor(TS ts,PetscInt step,PetscReal t,Vec U,void *ctx)
195 {
196   PetscFunctionBeginUser;
197   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution at time %e is \n", (double)t));
198   PetscCall(VecView(U,PETSC_VIEWER_STDOUT_WORLD));
199   PetscFunctionReturn(0);
200 }
201 
202 /* Customized adjoint monitor to keep track of local
203    sensitivities by storing them in a global sensitivity array.
204    Note : This routine is only used for the tracking method. */
205 PetscErrorCode AdjointMonitor(TS ts,PetscInt steps,PetscReal time,Vec u,PetscInt numcost,Vec *lambda, Vec *mu,void *ctx)
206 {
207   AppCtx            *user = (AppCtx*) ctx;
208   PetscInt          curr_step;
209   PetscScalar       *sensmu1_glob;
210   PetscScalar       *sensmu2_glob;
211   const PetscScalar *sensmu_loc;
212 
213   PetscFunctionBeginUser;
214   PetscCall(TSGetStepNumber(ts,&curr_step));
215   /* Note that we skip the first call to the monitor in the adjoint
216      solve since the sensitivities are already set (during
217      initialization of adjoint vectors).
218      We also note that each indvidial TSAdjointSolve calls the monitor
219      twice, once at the step it is integrating from and once at the step
220      it integrates to. Only the second call is useful for transferring
221      local sensitivities to the global array. */
222   if (curr_step == user->adj_idx) {
223     PetscFunctionReturn(0);
224   } else {
225     PetscCall(VecGetArrayRead(*mu,&sensmu_loc));
226     PetscCall(VecGetArray(user->sens_mu1,&sensmu1_glob));
227     PetscCall(VecGetArray(user->sens_mu2,&sensmu2_glob));
228     sensmu1_glob[curr_step] = sensmu_loc[0];
229     sensmu2_glob[curr_step] = sensmu_loc[1];
230     PetscCall(VecRestoreArray(user->sens_mu1,&sensmu1_glob));
231     PetscCall(VecRestoreArray(user->sens_mu2,&sensmu2_glob));
232     PetscCall(VecRestoreArrayRead(*mu,&sensmu_loc));
233     PetscFunctionReturn(0);
234   }
235 }
236 
237 int main(int argc,char **argv)
238 {
239   TS             ts;
240   AppCtx         user;
241   PetscScalar    *x_ptr,*y_ptr,*u_ptr;
242   PetscMPIInt    size;
243   PetscBool      monitor = PETSC_FALSE;
244   SAMethod       sa = SA_GLOBAL;
245 
246   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
247      Initialize program
248      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
249   PetscCall(PetscInitialize(&argc,&argv,NULL,help));
250   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD,&size));
251   PetscCheck(size == 1,PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!");
252 
253   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
254      Set runtime options
255      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
256   PetscOptionsBegin(PETSC_COMM_WORLD,NULL,"SA analysis options.","");{
257   PetscCall(PetscOptionsGetBool(NULL,NULL,"-monitor",&monitor,NULL));
258   PetscCall(PetscOptionsEnum("-sa_method","Sensitivity analysis method (track or global)","",SAMethods,(PetscEnum)sa,(PetscEnum*)&sa,NULL));
259   }
260   PetscOptionsEnd();
261 
262   user.final_time = 0.1;
263   user.max_steps  = 5;
264   user.time_step  = user.final_time/user.max_steps;
265 
266   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
267      Create necessary matrix and vectors for forward solve.
268      Create Jacp matrix for adjoint solve.
269      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
270   PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu1));
271   PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu2));
272   PetscCall(VecSet(user.mu1,1.25));
273   PetscCall(VecSet(user.mu2,1.0e2));
274 
275   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
276       For tracking method : create the global sensitivity array to
277       accumulate sensitivity with respect to parameters at each step.
278      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
279   if (sa == SA_TRACK) {
280     PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu1));
281     PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu2));
282   }
283 
284   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.A));
285   PetscCall(MatSetSizes(user.A,PETSC_DECIDE,PETSC_DECIDE,2,2));
286   PetscCall(MatSetFromOptions(user.A));
287   PetscCall(MatSetUp(user.A));
288   PetscCall(MatCreateVecs(user.A,&user.U,NULL));
289 
290   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
291       Note that the dimensions of the Jacp matrix depend upon the
292       sensitivity analysis method being used !
293      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
294   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jacp));
295   if (sa == SA_TRACK) {
296     PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,2));
297   }
298   if (sa == SA_GLOBAL) {
299     PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,user.max_steps*2));
300   }
301   PetscCall(MatSetFromOptions(user.Jacp));
302   PetscCall(MatSetUp(user.Jacp));
303 
304   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
305      Create timestepping solver context
306      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
307   PetscCall(TSCreate(PETSC_COMM_WORLD,&ts));
308   PetscCall(TSSetEquationType(ts,TS_EQ_ODE_EXPLICIT));
309   PetscCall(TSSetType(ts,TSCN));
310 
311   PetscCall(TSSetRHSFunction(ts,NULL,RHSFunction,&user));
312   PetscCall(TSSetRHSJacobian(ts,user.A,user.A,RHSJacobian,&user));
313   if (sa == SA_TRACK) {
314     PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_track,&user));
315   }
316   if (sa == SA_GLOBAL) {
317     PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_global,&user));
318   }
319 
320   PetscCall(TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP));
321   PetscCall(TSSetMaxTime(ts,user.final_time));
322   PetscCall(TSSetTimeStep(ts,user.final_time/user.max_steps));
323 
324   if (monitor) {
325     PetscCall(TSMonitorSet(ts,Monitor,&user,NULL));
326   }
327   if (sa == SA_TRACK) {
328     PetscCall(TSAdjointMonitorSet(ts,AdjointMonitor,&user,NULL));
329   }
330 
331   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
332      Set initial conditions
333      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
334   PetscCall(VecGetArray(user.U,&x_ptr));
335   x_ptr[0] = 2.0;
336   x_ptr[1] = -2.0/3.0;
337   PetscCall(VecRestoreArray(user.U,&x_ptr));
338 
339   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
340      Save trajectory of solution so that TSAdjointSolve() may be used
341      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
342   PetscCall(TSSetSaveTrajectory(ts));
343 
344   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
345      Set runtime options
346      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
347   PetscCall(TSSetFromOptions(ts));
348 
349   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
350      Execute forward model and print solution.
351      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
352   PetscCall(TSSolve(ts,user.U));
353   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution of forward TS :\n"));
354   PetscCall(VecView(user.U,PETSC_VIEWER_STDOUT_WORLD));
355   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Forward TS solve successfull! Adjoint run begins!\n"));
356 
357   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
358      Adjoint model starts here! Create adjoint vectors.
359      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
360   PetscCall(MatCreateVecs(user.A,&user.lambda,NULL));
361   PetscCall(MatCreateVecs(user.Jacp,&user.mup,NULL));
362 
363   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
364      Set initial conditions for the adjoint vector
365      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
366   PetscCall(VecGetArray(user.U,&u_ptr));
367   PetscCall(VecGetArray(user.lambda,&y_ptr));
368   y_ptr[0] = 2*(u_ptr[0] - 1.5967);
369   y_ptr[1] = 2*(u_ptr[1] - -(1.02969));
370   PetscCall(VecRestoreArray(user.lambda,&y_ptr));
371   PetscCall(VecRestoreArray(user.U,&y_ptr));
372   PetscCall(VecSet(user.mup,0));
373 
374   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
375      Set number of cost functions.
376      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
377   PetscCall(TSSetCostGradients(ts,1,&user.lambda,&user.mup));
378 
379   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
380      The adjoint vector mup has to be reset for each adjoint step when
381      using the tracking method as we want to treat the parameters at each
382      time step one at a time and prevent accumulation of the sensitivities
383      from parameters at previous time steps.
384      This is not necessary for the global method as each time dependent
385      parameter is treated as an independent parameter.
386    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
387   if (sa == SA_TRACK) {
388     for (user.adj_idx=user.max_steps; user.adj_idx>0; user.adj_idx--) {
389       PetscCall(VecSet(user.mup,0));
390       PetscCall(TSAdjointSetSteps(ts, 1));
391       PetscCall(TSAdjointSolve(ts));
392     }
393   }
394   if (sa == SA_GLOBAL) {
395     PetscCall(TSAdjointSolve(ts));
396   }
397 
398   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
399      Dispaly adjoint sensitivities wrt parameters and initial conditions
400      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
401   if (sa == SA_TRACK) {
402     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  mu1: d[cost]/d[mu1]\n"));
403     PetscCall(VecView(user.sens_mu1,PETSC_VIEWER_STDOUT_WORLD));
404     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  mu2: d[cost]/d[mu2]\n"));
405     PetscCall(VecView(user.sens_mu2,PETSC_VIEWER_STDOUT_WORLD));
406   }
407 
408   if (sa == SA_GLOBAL) {
409     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n"));
410     PetscCall(VecView(user.mup,PETSC_VIEWER_STDOUT_WORLD));
411   }
412 
413   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n"));
414   PetscCall(VecView(user.lambda,PETSC_VIEWER_STDOUT_WORLD));
415 
416   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
417      Free work space!
418      All PETSc objects should be destroyed when they are no longer needed.
419      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
420   PetscCall(MatDestroy(&user.A));
421   PetscCall(MatDestroy(&user.Jacp));
422   PetscCall(VecDestroy(&user.U));
423   PetscCall(VecDestroy(&user.lambda));
424   PetscCall(VecDestroy(&user.mup));
425   PetscCall(VecDestroy(&user.mu1));
426   PetscCall(VecDestroy(&user.mu2));
427   if (sa == SA_TRACK) {
428     PetscCall(VecDestroy(&user.sens_mu1));
429     PetscCall(VecDestroy(&user.sens_mu2));
430   }
431   PetscCall(TSDestroy(&ts));
432   PetscCall(PetscFinalize());
433   return(0);
434 }
435 
436 /*TEST
437 
438   test:
439     requires: !complex
440     suffix : track
441     args : -sa_method track
442 
443   test:
444     requires: !complex
445     suffix : global
446     args : -sa_method global
447 
448 TEST*/
449