xref: /petsc/src/ts/tutorials/ex20td.c (revision 2e3d3ef9bcea2ec3187269b59709989dccbaffee)
1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\
2 equation with time dependent parameters using two approaches :  \n\
3 track  : track only local sensitivities at each adjoint step \n\
4          and accumulate them in a global array \n\
5 global : track parameters at all timesteps together \n\
6 Choose one of the two at runtime by -sa_method {track,global}. \n";
7 
8 /*
9    Simple example to demonstrate TSAdjoint capabilities for time dependent params
10    without integral cost terms using either a tracking or global method.
11 
12    Modify the Van Der Pol Eq to :
13    [u1'] = [mu1(t)*u1]
14    [u2'] = [mu2(t)*((1-u1^2)*u2-u1)]
15    (with initial conditions & params independent)
16 
17    Define uref to be solution with initail conditions (2,-2/3), mu=(1,1e3)
18    - u_ref : (1.5967,-1.02969)
19 
20    Define const function as cost = 2-norm(u - u_ref);
21 
22    Initialization for the adjoint TS :
23    - dcost/dy|final_time = 2*(u-u_ref)|final_time
24    - dcost/dp|final_time = 0
25 
26    The tracking method only tracks local sensitivity at each time step
27    and accumulates these sensitivities in a global array. Since the structure
28    of the equations being solved at each time step does not change, the jacobian
29    wrt parameters is defined analogous to constant RHSJacobian for a liner
30    TSSolve and the size of the jacP is independent of the number of time
31    steps. Enable this mode of adjoint analysis by -sa_method track.
32 
33    The global method combines the parameters at all timesteps and tracks them
34    together. Thus, the columns of the jacP matrix are filled dependent upon the
35    time step. Also, the dimensions of the jacP matrix now depend upon the number
36    of time steps. Enable this mode of adjoint analysis by -sa_method global.
37 
38    Since the equations here have parameters at predefined time steps, this
39    example should be run with non adaptive time stepping solvers only. This
40    can be ensured by -ts_adapt_type none (which is the default behavior only
41    for certain TS solvers like TSCN. If using an explicit TS like TSRK,
42    please be sure to add the aforementioned option to disable adaptive
43    timestepping.)
44 */
45 
46 /*
47    Include "petscts.h" so that we can use TS solvers.  Note that this file
48    automatically includes:
49      petscsys.h    - base PETSc routines   petscvec.h  - vectors
50      petscmat.h    - matrices
51      petscis.h     - index sets            petscksp.h  - Krylov subspace methods
52      petscviewer.h - viewers               petscpc.h   - preconditioners
53      petscksp.h    - linear solvers        petscsnes.h - nonlinear solvers
54 */
55 #include <petscts.h>
56 
57 extern PetscErrorCode RHSFunction(TS ,PetscReal ,Vec ,Vec ,void *);
58 extern PetscErrorCode RHSJacobian(TS ,PetscReal ,Vec ,Mat ,Mat ,void *);
59 extern PetscErrorCode RHSJacobianP_track(TS ,PetscReal ,Vec ,Mat ,void *);
60 extern PetscErrorCode RHSJacobianP_global(TS ,PetscReal ,Vec ,Mat ,void *);
61 extern PetscErrorCode Monitor(TS ,PetscInt ,PetscReal ,Vec ,void *);
62 extern PetscErrorCode AdjointMonitor(TS ,PetscInt ,PetscReal ,Vec ,PetscInt ,Vec *, Vec *,void *);
63 
64 /*
65    User-defined application context - contains data needed by the
66    application-provided call-back routines.
67 */
68 
69 typedef struct {
70  /*------------- Forward solve data structures --------------*/
71   PetscInt  max_steps;     /* number of steps to run ts for */
72   PetscReal final_time;    /* final time to integrate to*/
73   PetscReal time_step;     /* ts integration time step */
74   Vec       mu1;           /* time dependent params */
75   Vec       mu2;           /* time dependent params */
76   Vec       U;             /* solution vector */
77   Mat       A;             /* Jacobian matrix */
78 
79   /*------------- Adjoint solve data structures --------------*/
80   Mat       Jacp;          /* JacobianP matrix */
81   Vec       lambda;        /* adjoint variable */
82   Vec       mup;           /* adjoint variable */
83 
84   /*------------- Global accumation vecs for monitor based tracking --------------*/
85   Vec       sens_mu1;      /* global sensitivity accumulation */
86   Vec       sens_mu2;      /* global sensitivity accumulation */
87   PetscInt  adj_idx;       /* to keep track of adjoint solve index */
88 } AppCtx;
89 
90 typedef enum {SA_TRACK, SA_GLOBAL} SAMethod;
91 static const char *const SAMethods[] = {"TRACK","GLOBAL","SAMethod","SA_",0};
92 
93 /* ----------------------- Explicit form of the ODE  -------------------- */
94 
95 PetscErrorCode RHSFunction(TS ts,PetscReal t,Vec U,Vec F,void *ctx)
96 {
97   AppCtx            *user = (AppCtx*) ctx;
98   PetscScalar       *f;
99   PetscInt          curr_step;
100   const PetscScalar *u;
101   const PetscScalar *mu1;
102   const PetscScalar *mu2;
103 
104   PetscFunctionBeginUser;
105   PetscCall(TSGetStepNumber(ts,&curr_step));
106   PetscCall(VecGetArrayRead(U,&u));
107   PetscCall(VecGetArrayRead(user->mu1,&mu1));
108   PetscCall(VecGetArrayRead(user->mu2,&mu2));
109   PetscCall(VecGetArray(F,&f));
110   f[0] = mu1[curr_step]*u[1];
111   f[1] = mu2[curr_step]*((1.-u[0]*u[0])*u[1]-u[0]);
112   PetscCall(VecRestoreArrayRead(U,&u));
113   PetscCall(VecRestoreArrayRead(user->mu1,&mu1));
114   PetscCall(VecRestoreArrayRead(user->mu2,&mu2));
115   PetscCall(VecRestoreArray(F,&f));
116   PetscFunctionReturn(0);
117 }
118 
119 PetscErrorCode RHSJacobian(TS ts,PetscReal t,Vec U,Mat A,Mat B,void *ctx)
120 {
121   AppCtx            *user = (AppCtx*) ctx;
122   PetscInt          rowcol[] = {0,1};
123   PetscScalar       J[2][2];
124   PetscInt          curr_step;
125   const PetscScalar *u;
126   const PetscScalar *mu1;
127   const PetscScalar *mu2;
128 
129   PetscFunctionBeginUser;
130   PetscCall(TSGetStepNumber(ts,&curr_step));
131   PetscCall(VecGetArrayRead(user->mu1,&mu1));
132   PetscCall(VecGetArrayRead(user->mu2,&mu2));
133   PetscCall(VecGetArrayRead(U,&u));
134   J[0][0] = 0;
135   J[1][0] = -mu2[curr_step]*(2.0*u[1]*u[0]+1.);
136   J[0][1] = mu1[curr_step];
137   J[1][1] = mu2[curr_step]*(1.0-u[0]*u[0]);
138   PetscCall(MatSetValues(A,2,rowcol,2,rowcol,&J[0][0],INSERT_VALUES));
139   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
140   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
141   PetscCall(VecRestoreArrayRead(U,&u));
142   PetscCall(VecRestoreArrayRead(user->mu1,&mu1));
143   PetscCall(VecRestoreArrayRead(user->mu2,&mu2));
144   PetscFunctionReturn(0);
145 }
146 
147 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */
148 
149 PetscErrorCode RHSJacobianP_track(TS ts,PetscReal t,Vec U,Mat A,void *ctx)
150 {
151   PetscInt          row[] = {0,1},col[] = {0,1};
152   PetscScalar       J[2][2];
153   const PetscScalar *u;
154 
155   PetscFunctionBeginUser;
156   PetscCall(VecGetArrayRead(U,&u));
157   J[0][0] = u[1];
158   J[1][0] = 0;
159   J[0][1] = 0;
160   J[1][1] = (1.-u[0]*u[0])*u[1]-u[0];
161   PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES));
162   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
163   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
164   PetscCall(VecRestoreArrayRead(U,&u));
165   PetscFunctionReturn(0);
166 }
167 
168 /* ------------------ Jacobian wrt parameters for global method ------------------ */
169 
170 PetscErrorCode RHSJacobianP_global(TS ts,PetscReal t,Vec U,Mat A,void *ctx)
171 {
172   PetscInt          row[] = {0,1},col[] = {0,1};
173   PetscScalar       J[2][2];
174   const PetscScalar *u;
175   PetscInt          curr_step;
176 
177   PetscFunctionBeginUser;
178   PetscCall(TSGetStepNumber(ts,&curr_step));
179   PetscCall(VecGetArrayRead(U,&u));
180   J[0][0] = u[1];
181   J[1][0] = 0;
182   J[0][1] = 0;
183   J[1][1] = (1.-u[0]*u[0])*u[1]-u[0];
184   col[0] = (curr_step)*2;
185   col[1] = (curr_step)*2+1;
186   PetscCall(MatSetValues(A,2,row,2,col,&J[0][0],INSERT_VALUES));
187   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
188   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
189   PetscCall(VecRestoreArrayRead(U,&u));
190   PetscFunctionReturn(0);
191 }
192 
193 /* Dump solution to console if called */
194 PetscErrorCode Monitor(TS ts,PetscInt step,PetscReal t,Vec U,void *ctx)
195 {
196   PetscFunctionBeginUser;
197   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution at time %e is \n", t));
198   PetscCall(VecView(U,PETSC_VIEWER_STDOUT_WORLD));
199   PetscFunctionReturn(0);
200 }
201 
202 /* Customized adjoint monitor to keep track of local
203    sensitivities by storing them in a global sensitivity array.
204    Note : This routine is only used for the tracking method. */
205 PetscErrorCode AdjointMonitor(TS ts,PetscInt steps,PetscReal time,Vec u,PetscInt numcost,Vec *lambda, Vec *mu,void *ctx)
206 {
207   AppCtx            *user = (AppCtx*) ctx;
208   PetscInt          curr_step;
209   PetscScalar       *sensmu1_glob;
210   PetscScalar       *sensmu2_glob;
211   const PetscScalar *sensmu_loc;
212 
213   PetscFunctionBeginUser;
214   PetscCall(TSGetStepNumber(ts,&curr_step));
215   /* Note that we skip the first call to the monitor in the adjoint
216      solve since the sensitivities are already set (during
217      initialization of adjoint vectors).
218      We also note that each indvidial TSAdjointSolve calls the monitor
219      twice, once at the step it is integrating from and once at the step
220      it integrates to. Only the second call is useful for transferring
221      local sensitivities to the global array. */
222   if (curr_step == user->adj_idx) {
223     PetscFunctionReturn(0);
224   } else {
225     PetscCall(VecGetArrayRead(*mu,&sensmu_loc));
226     PetscCall(VecGetArray(user->sens_mu1,&sensmu1_glob));
227     PetscCall(VecGetArray(user->sens_mu2,&sensmu2_glob));
228     sensmu1_glob[curr_step] = sensmu_loc[0];
229     sensmu2_glob[curr_step] = sensmu_loc[1];
230     PetscCall(VecRestoreArray(user->sens_mu1,&sensmu1_glob));
231     PetscCall(VecRestoreArray(user->sens_mu2,&sensmu2_glob));
232     PetscCall(VecRestoreArrayRead(*mu,&sensmu_loc));
233     PetscFunctionReturn(0);
234   }
235 }
236 
237 int main(int argc,char **argv)
238 {
239   TS             ts;
240   AppCtx         user;
241   PetscScalar    *x_ptr,*y_ptr,*u_ptr;
242   PetscMPIInt    size;
243   PetscBool      monitor = PETSC_FALSE;
244   SAMethod       sa = SA_GLOBAL;
245   PetscErrorCode ierr;
246 
247   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
248      Initialize program
249      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
250   PetscCall(PetscInitialize(&argc,&argv,NULL,help));
251   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD,&size));
252   PetscCheck(size == 1,PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!");
253 
254   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
255      Set runtime options
256      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
257   ierr = PetscOptionsBegin(PETSC_COMM_WORLD,NULL,"SA analysis options.","");PetscCall(ierr);{
258   PetscCall(PetscOptionsGetBool(NULL,NULL,"-monitor",&monitor,NULL));
259   PetscCall(PetscOptionsEnum("-sa_method","Sensitivity analysis method (track or global)","",SAMethods,(PetscEnum)sa,(PetscEnum*)&sa,NULL));
260   }
261   ierr = PetscOptionsEnd();PetscCall(ierr);
262 
263   user.final_time = 0.1;
264   user.max_steps  = 5;
265   user.time_step  = user.final_time/user.max_steps;
266 
267   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
268      Create necessary matrix and vectors for forward solve.
269      Create Jacp matrix for adjoint solve.
270      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
271   PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu1));
272   PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.mu2));
273   PetscCall(VecSet(user.mu1,1.25));
274   PetscCall(VecSet(user.mu2,1.0e2));
275 
276   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
277       For tracking method : create the global sensitivity array to
278       accumulate sensitivity with respect to parameters at each step.
279      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
280   if (sa == SA_TRACK) {
281     PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu1));
282     PetscCall(VecCreateSeq(PETSC_COMM_WORLD,user.max_steps,&user.sens_mu2));
283   }
284 
285   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.A));
286   PetscCall(MatSetSizes(user.A,PETSC_DECIDE,PETSC_DECIDE,2,2));
287   PetscCall(MatSetFromOptions(user.A));
288   PetscCall(MatSetUp(user.A));
289   PetscCall(MatCreateVecs(user.A,&user.U,NULL));
290 
291   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
292       Note that the dimensions of the Jacp matrix depend upon the
293       sensitivity analysis method being used !
294      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
295   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jacp));
296   if (sa == SA_TRACK) {
297     PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,2));
298   }
299   if (sa == SA_GLOBAL) {
300     PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,2,user.max_steps*2));
301   }
302   PetscCall(MatSetFromOptions(user.Jacp));
303   PetscCall(MatSetUp(user.Jacp));
304 
305   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
306      Create timestepping solver context
307      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
308   PetscCall(TSCreate(PETSC_COMM_WORLD,&ts));
309   PetscCall(TSSetEquationType(ts,TS_EQ_ODE_EXPLICIT));
310   PetscCall(TSSetType(ts,TSCN));
311 
312   PetscCall(TSSetRHSFunction(ts,NULL,RHSFunction,&user));
313   PetscCall(TSSetRHSJacobian(ts,user.A,user.A,RHSJacobian,&user));
314   if (sa == SA_TRACK) {
315     PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_track,&user));
316   }
317   if (sa == SA_GLOBAL) {
318     PetscCall(TSSetRHSJacobianP(ts,user.Jacp,RHSJacobianP_global,&user));
319   }
320 
321   PetscCall(TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP));
322   PetscCall(TSSetMaxTime(ts,user.final_time));
323   PetscCall(TSSetTimeStep(ts,user.final_time/user.max_steps));
324 
325   if (monitor) {
326     PetscCall(TSMonitorSet(ts,Monitor,&user,NULL));
327   }
328   if (sa == SA_TRACK) {
329     PetscCall(TSAdjointMonitorSet(ts,AdjointMonitor,&user,NULL));
330   }
331 
332   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
333      Set initial conditions
334      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
335   PetscCall(VecGetArray(user.U,&x_ptr));
336   x_ptr[0] = 2.0;
337   x_ptr[1] = -2.0/3.0;
338   PetscCall(VecRestoreArray(user.U,&x_ptr));
339 
340   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
341      Save trajectory of solution so that TSAdjointSolve() may be used
342      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
343   PetscCall(TSSetSaveTrajectory(ts));
344 
345   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
346      Set runtime options
347      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
348   PetscCall(TSSetFromOptions(ts));
349 
350   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
351      Execute forward model and print solution.
352      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
353   PetscCall(TSSolve(ts,user.U));
354   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Solution of forward TS :\n"));
355   PetscCall(VecView(user.U,PETSC_VIEWER_STDOUT_WORLD));
356   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n Forward TS solve successfull! Adjoint run begins!\n"));
357 
358   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
359      Adjoint model starts here! Create adjoint vectors.
360      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
361   PetscCall(MatCreateVecs(user.A,&user.lambda,NULL));
362   PetscCall(MatCreateVecs(user.Jacp,&user.mup,NULL));
363 
364   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
365      Set initial conditions for the adjoint vector
366      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
367   PetscCall(VecGetArray(user.U,&u_ptr));
368   PetscCall(VecGetArray(user.lambda,&y_ptr));
369   y_ptr[0] = 2*(u_ptr[0] - 1.5967);
370   y_ptr[1] = 2*(u_ptr[1] - -(1.02969));
371   PetscCall(VecRestoreArray(user.lambda,&y_ptr));
372   PetscCall(VecRestoreArray(user.U,&y_ptr));
373   PetscCall(VecSet(user.mup,0));
374 
375   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
376      Set number of cost functions.
377      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
378   PetscCall(TSSetCostGradients(ts,1,&user.lambda,&user.mup));
379 
380   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
381      The adjoint vector mup has to be reset for each adjoint step when
382      using the tracking method as we want to treat the parameters at each
383      time step one at a time and prevent accumulation of the sensitivities
384      from parameters at previous time steps.
385      This is not necessary for the global method as each time dependent
386      parameter is treated as an independent parameter.
387    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
388   if (sa == SA_TRACK) {
389     for (user.adj_idx=user.max_steps; user.adj_idx>0; user.adj_idx--) {
390       PetscCall(VecSet(user.mup,0));
391       PetscCall(TSAdjointSetSteps(ts, 1));
392       PetscCall(TSAdjointSolve(ts));
393     }
394   }
395   if (sa == SA_GLOBAL) {
396     PetscCall(TSAdjointSolve(ts));
397   }
398 
399   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
400      Dispaly adjoint sensitivities wrt parameters and initial conditions
401      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
402   if (sa == SA_TRACK) {
403     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  mu1: d[cost]/d[mu1]\n"));
404     PetscCall(VecView(user.sens_mu1,PETSC_VIEWER_STDOUT_WORLD));
405     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  mu2: d[cost]/d[mu2]\n"));
406     PetscCall(VecView(user.sens_mu2,PETSC_VIEWER_STDOUT_WORLD));
407   }
408 
409   if (sa == SA_GLOBAL) {
410     ierr = PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt  params: d[cost]/d[p], where p refers to \n\
411                     the interlaced vector made by combining mu1,mu2\n");PetscCall(ierr);
412     PetscCall(VecView(user.mup,PETSC_VIEWER_STDOUT_WORLD));
413   }
414 
415   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n"));
416   PetscCall(VecView(user.lambda,PETSC_VIEWER_STDOUT_WORLD));
417 
418   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
419      Free work space!
420      All PETSc objects should be destroyed when they are no longer needed.
421      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
422   PetscCall(MatDestroy(&user.A));
423   PetscCall(MatDestroy(&user.Jacp));
424   PetscCall(VecDestroy(&user.U));
425   PetscCall(VecDestroy(&user.lambda));
426   PetscCall(VecDestroy(&user.mup));
427   PetscCall(VecDestroy(&user.mu1));
428   PetscCall(VecDestroy(&user.mu2));
429   if (sa == SA_TRACK) {
430     PetscCall(VecDestroy(&user.sens_mu1));
431     PetscCall(VecDestroy(&user.sens_mu2));
432   }
433   PetscCall(TSDestroy(&ts));
434   PetscCall(PetscFinalize());
435   return(ierr);
436 }
437 
438 /*TEST
439 
440   test:
441     requires: !complex
442     suffix : track
443     args : -sa_method track
444 
445   test:
446     requires: !complex
447     suffix : global
448     args : -sa_method global
449 
450 TEST*/
451