xref: /petsc/src/ts/tutorials/ex20td.c (revision 607e733f3db3ee7f6f605a13295c517df8dbb9c9)
1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\
2 equation with time dependent parameters using two approaches :  \n\
3 track  : track only local sensitivities at each adjoint step \n\
4          and accumulate them in a global array \n\
5 global : track parameters at all timesteps together \n\
6 Choose one of the two at runtime by -sa_method {track,global}. \n";
7 
8 /*
9    Simple example to demonstrate TSAdjoint capabilities for time dependent params
10    without integral cost terms using either a tracking or global method.
11 
12    Modify the Van Der Pol Eq to :
13    [u1'] = [mu1(t)*u1]
14    [u2'] = [mu2(t)*((1-u1^2)*u2-u1)]
15    (with initial conditions & params independent)
16 
17    Define uref to be solution with initial conditions (2,-2/3), mu=(1,1e3)
18    - u_ref : (1.5967,-1.02969)
19 
20    Define const function as cost = 2-norm(u - u_ref);
21 
22    Initialization for the adjoint TS :
23    - dcost/dy|final_time = 2*(u-u_ref)|final_time
24    - dcost/dp|final_time = 0
25 
26    The tracking method only tracks local sensitivity at each time step
27    and accumulates these sensitivities in a global array. Since the structure
28    of the equations being solved at each time step does not change, the jacobian
29    wrt parameters is defined analogous to constant RHSJacobian for a liner
30    TSSolve and the size of the jacP is independent of the number of time
31    steps. Enable this mode of adjoint analysis by -sa_method track.
32 
33    The global method combines the parameters at all timesteps and tracks them
34    together. Thus, the columns of the jacP matrix are filled dependent upon the
35    time step. Also, the dimensions of the jacP matrix now depend upon the number
36    of time steps. Enable this mode of adjoint analysis by -sa_method global.
37 
38    Since the equations here have parameters at predefined time steps, this
39    example should be run with non adaptive time stepping solvers only. This
40    can be ensured by -ts_adapt_type none (which is the default behavior only
41    for certain TS solvers like TSCN. If using an explicit TS like TSRK,
42    please be sure to add the aforementioned option to disable adaptive
43    timestepping.)
44 */
45 
46 /*
47    Include "petscts.h" so that we can use TS solvers.  Note that this file
48    automatically includes:
49      petscsys.h    - base PETSc routines   petscvec.h  - vectors
50      petscmat.h    - matrices
51      petscis.h     - index sets            petscksp.h  - Krylov subspace methods
52      petscviewer.h - viewers               petscpc.h   - preconditioners
53      petscksp.h    - linear solvers        petscsnes.h - nonlinear solvers
54 */
55 #include <petscts.h>
56 
57 extern PetscErrorCode RHSFunction(TS, PetscReal, Vec, Vec, void *);
58 extern PetscErrorCode RHSJacobian(TS, PetscReal, Vec, Mat, Mat, void *);
59 extern PetscErrorCode RHSJacobianP_track(TS, PetscReal, Vec, Mat, void *);
60 extern PetscErrorCode RHSJacobianP_global(TS, PetscReal, Vec, Mat, void *);
61 extern PetscErrorCode Monitor(TS, PetscInt, PetscReal, Vec, void *);
62 extern PetscErrorCode AdjointMonitor(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *);
63 
64 /*
65    User-defined application context - contains data needed by the
66    application-provided call-back routines.
67 */
68 
69 typedef struct {
70   /*------------- Forward solve data structures --------------*/
71   PetscInt  max_steps;  /* number of steps to run ts for */
72   PetscReal final_time; /* final time to integrate to*/
73   PetscReal time_step;  /* ts integration time step */
74   Vec       mu1;        /* time dependent params */
75   Vec       mu2;        /* time dependent params */
76   Vec       U;          /* solution vector */
77   Mat       A;          /* Jacobian matrix */
78 
79   /*------------- Adjoint solve data structures --------------*/
80   Mat Jacp;   /* JacobianP matrix */
81   Vec lambda; /* adjoint variable */
82   Vec mup;    /* adjoint variable */
83 
84   /*------------- Global accumation vecs for monitor based tracking --------------*/
85   Vec      sens_mu1; /* global sensitivity accumulation */
86   Vec      sens_mu2; /* global sensitivity accumulation */
87   PetscInt adj_idx;  /* to keep track of adjoint solve index */
88 } AppCtx;
89 
90 typedef enum {
91   SA_TRACK,
92   SA_GLOBAL
93 } SAMethod;
94 static const char *const SAMethods[] = {"TRACK", "GLOBAL", "SAMethod", "SA_", 0};
95 
96 /* ----------------------- Explicit form of the ODE  -------------------- */
97 
98 PetscErrorCode RHSFunction(TS ts, PetscReal t, Vec U, Vec F, PetscCtx ctx)
99 {
100   AppCtx            *user = (AppCtx *)ctx;
101   PetscScalar       *f;
102   PetscInt           curr_step;
103   const PetscScalar *u;
104   const PetscScalar *mu1;
105   const PetscScalar *mu2;
106 
107   PetscFunctionBeginUser;
108   PetscCall(TSGetStepNumber(ts, &curr_step));
109   PetscCall(VecGetArrayRead(U, &u));
110   PetscCall(VecGetArrayRead(user->mu1, &mu1));
111   PetscCall(VecGetArrayRead(user->mu2, &mu2));
112   PetscCall(VecGetArray(F, &f));
113   f[0] = mu1[curr_step] * u[1];
114   f[1] = mu2[curr_step] * ((1. - u[0] * u[0]) * u[1] - u[0]);
115   PetscCall(VecRestoreArrayRead(U, &u));
116   PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
117   PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
118   PetscCall(VecRestoreArray(F, &f));
119   PetscFunctionReturn(PETSC_SUCCESS);
120 }
121 
122 PetscErrorCode RHSJacobian(TS ts, PetscReal t, Vec U, Mat A, Mat B, PetscCtx ctx)
123 {
124   AppCtx            *user     = (AppCtx *)ctx;
125   PetscInt           rowcol[] = {0, 1};
126   PetscScalar        J[2][2];
127   PetscInt           curr_step;
128   const PetscScalar *u;
129   const PetscScalar *mu1;
130   const PetscScalar *mu2;
131 
132   PetscFunctionBeginUser;
133   PetscCall(TSGetStepNumber(ts, &curr_step));
134   PetscCall(VecGetArrayRead(user->mu1, &mu1));
135   PetscCall(VecGetArrayRead(user->mu2, &mu2));
136   PetscCall(VecGetArrayRead(U, &u));
137   J[0][0] = 0;
138   J[1][0] = -mu2[curr_step] * (2.0 * u[1] * u[0] + 1.);
139   J[0][1] = mu1[curr_step];
140   J[1][1] = mu2[curr_step] * (1.0 - u[0] * u[0]);
141   PetscCall(MatSetValues(A, 2, rowcol, 2, rowcol, &J[0][0], INSERT_VALUES));
142   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
143   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
144   PetscCall(VecRestoreArrayRead(U, &u));
145   PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
146   PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
147   PetscFunctionReturn(PETSC_SUCCESS);
148 }
149 
150 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */
151 
152 PetscErrorCode RHSJacobianP_track(TS ts, PetscReal t, Vec U, Mat A, PetscCtx ctx)
153 {
154   PetscInt           row[] = {0, 1}, col[] = {0, 1};
155   PetscScalar        J[2][2];
156   const PetscScalar *u;
157 
158   PetscFunctionBeginUser;
159   PetscCall(VecGetArrayRead(U, &u));
160   J[0][0] = u[1];
161   J[1][0] = 0;
162   J[0][1] = 0;
163   J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
164   PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
165   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
166   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
167   PetscCall(VecRestoreArrayRead(U, &u));
168   PetscFunctionReturn(PETSC_SUCCESS);
169 }
170 
171 /* ------------------ Jacobian wrt parameters for global method ------------------ */
172 
173 PetscErrorCode RHSJacobianP_global(TS ts, PetscReal t, Vec U, Mat A, PetscCtx ctx)
174 {
175   PetscInt           row[] = {0, 1}, col[] = {0, 1};
176   PetscScalar        J[2][2];
177   const PetscScalar *u;
178   PetscInt           curr_step;
179 
180   PetscFunctionBeginUser;
181   PetscCall(TSGetStepNumber(ts, &curr_step));
182   PetscCall(VecGetArrayRead(U, &u));
183   J[0][0] = u[1];
184   J[1][0] = 0;
185   J[0][1] = 0;
186   J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
187   col[0]  = curr_step * 2;
188   col[1]  = curr_step * 2 + 1;
189   PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
190   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
191   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
192   PetscCall(VecRestoreArrayRead(U, &u));
193   PetscFunctionReturn(PETSC_SUCCESS);
194 }
195 
196 /* Dump solution to console if called */
197 PetscErrorCode Monitor(TS ts, PetscInt step, PetscReal t, Vec U, PetscCtx ctx)
198 {
199   PetscFunctionBeginUser;
200   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution at time %e is \n", (double)t));
201   PetscCall(VecView(U, PETSC_VIEWER_STDOUT_WORLD));
202   PetscFunctionReturn(PETSC_SUCCESS);
203 }
204 
205 /* Customized adjoint monitor to keep track of local
206    sensitivities by storing them in a global sensitivity array.
207    Note : This routine is only used for the tracking method. */
208 PetscErrorCode AdjointMonitor(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, PetscCtx ctx)
209 {
210   AppCtx            *user = (AppCtx *)ctx;
211   PetscInt           curr_step;
212   PetscScalar       *sensmu1_glob;
213   PetscScalar       *sensmu2_glob;
214   const PetscScalar *sensmu_loc;
215 
216   PetscFunctionBeginUser;
217   PetscCall(TSGetStepNumber(ts, &curr_step));
218   /* Note that we skip the first call to the monitor in the adjoint
219      solve since the sensitivities are already set (during
220      initialization of adjoint vectors).
221      We also note that each indvidial TSAdjointSolve calls the monitor
222      twice, once at the step it is integrating from and once at the step
223      it integrates to. Only the second call is useful for transferring
224      local sensitivities to the global array. */
225   if (curr_step == user->adj_idx) {
226     PetscFunctionReturn(PETSC_SUCCESS);
227   } else {
228     PetscCall(VecGetArrayRead(*mu, &sensmu_loc));
229     PetscCall(VecGetArray(user->sens_mu1, &sensmu1_glob));
230     PetscCall(VecGetArray(user->sens_mu2, &sensmu2_glob));
231     sensmu1_glob[curr_step] = sensmu_loc[0];
232     sensmu2_glob[curr_step] = sensmu_loc[1];
233     PetscCall(VecRestoreArray(user->sens_mu1, &sensmu1_glob));
234     PetscCall(VecRestoreArray(user->sens_mu2, &sensmu2_glob));
235     PetscCall(VecRestoreArrayRead(*mu, &sensmu_loc));
236     PetscFunctionReturn(PETSC_SUCCESS);
237   }
238 }
239 
240 int main(int argc, char **argv)
241 {
242   TS           ts;
243   AppCtx       user;
244   PetscScalar *x_ptr, *y_ptr, *u_ptr;
245   PetscMPIInt  size;
246   PetscBool    monitor = PETSC_FALSE;
247   SAMethod     sa      = SA_GLOBAL;
248 
249   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
250      Initialize program
251      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
252   PetscFunctionBeginUser;
253   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
254   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
255   PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
256 
257   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
258      Set runtime options
259      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
260   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "SA analysis options.", "");
261   {
262     PetscCall(PetscOptionsGetBool(NULL, NULL, "-monitor", &monitor, NULL));
263     PetscCall(PetscOptionsEnum("-sa_method", "Sensitivity analysis method (track or global)", "", SAMethods, (PetscEnum)sa, (PetscEnum *)&sa, NULL));
264   }
265   PetscOptionsEnd();
266 
267   user.final_time = 0.1;
268   user.max_steps  = 5;
269   user.time_step  = user.final_time / user.max_steps;
270 
271   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
272      Create necessary matrix and vectors for forward solve.
273      Create Jacp matrix for adjoint solve.
274      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
275   PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu1));
276   PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu2));
277   PetscCall(VecSet(user.mu1, 1.25));
278   PetscCall(VecSet(user.mu2, 1.0e2));
279 
280   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
281       For tracking method : create the global sensitivity array to
282       accumulate sensitivity with respect to parameters at each step.
283      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
284   if (sa == SA_TRACK) {
285     PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu1));
286     PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu2));
287   }
288 
289   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.A));
290   PetscCall(MatSetSizes(user.A, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
291   PetscCall(MatSetFromOptions(user.A));
292   PetscCall(MatSetUp(user.A));
293   PetscCall(MatCreateVecs(user.A, &user.U, NULL));
294 
295   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
296       Note that the dimensions of the Jacp matrix depend upon the
297       sensitivity analysis method being used !
298      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
299   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
300   if (sa == SA_TRACK) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
301   if (sa == SA_GLOBAL) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, user.max_steps * 2));
302   PetscCall(MatSetFromOptions(user.Jacp));
303   PetscCall(MatSetUp(user.Jacp));
304 
305   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
306      Create timestepping solver context
307      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
308   PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
309   PetscCall(TSSetEquationType(ts, TS_EQ_ODE_EXPLICIT));
310   PetscCall(TSSetType(ts, TSCN));
311 
312   PetscCall(TSSetRHSFunction(ts, NULL, RHSFunction, &user));
313   PetscCall(TSSetRHSJacobian(ts, user.A, user.A, RHSJacobian, &user));
314   if (sa == SA_TRACK) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_track, &user));
315   if (sa == SA_GLOBAL) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_global, &user));
316 
317   PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
318   PetscCall(TSSetMaxTime(ts, user.final_time));
319   PetscCall(TSSetTimeStep(ts, user.final_time / user.max_steps));
320 
321   if (monitor) PetscCall(TSMonitorSet(ts, Monitor, &user, NULL));
322   if (sa == SA_TRACK) PetscCall(TSAdjointMonitorSet(ts, AdjointMonitor, &user, NULL));
323 
324   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
325      Set initial conditions
326      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
327   PetscCall(VecGetArray(user.U, &x_ptr));
328   x_ptr[0] = 2.0;
329   x_ptr[1] = -2.0 / 3.0;
330   PetscCall(VecRestoreArray(user.U, &x_ptr));
331 
332   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
333      Save trajectory of solution so that TSAdjointSolve() may be used
334      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
335   PetscCall(TSSetSaveTrajectory(ts));
336 
337   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
338      Set runtime options
339      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
340   PetscCall(TSSetFromOptions(ts));
341 
342   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
343      Execute forward model and print solution.
344      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
345   PetscCall(TSSolve(ts, user.U));
346   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution of forward TS :\n"));
347   PetscCall(VecView(user.U, PETSC_VIEWER_STDOUT_WORLD));
348   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Forward TS solve successful! Adjoint run begins!\n"));
349 
350   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
351      Adjoint model starts here! Create adjoint vectors.
352      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
353   PetscCall(MatCreateVecs(user.A, &user.lambda, NULL));
354   PetscCall(MatCreateVecs(user.Jacp, &user.mup, NULL));
355 
356   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
357      Set initial conditions for the adjoint vector
358      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
359   PetscCall(VecGetArray(user.U, &u_ptr));
360   PetscCall(VecGetArray(user.lambda, &y_ptr));
361   y_ptr[0] = 2 * (u_ptr[0] - 1.5967);
362   y_ptr[1] = 2 * (u_ptr[1] - -(1.02969));
363   PetscCall(VecRestoreArray(user.lambda, &y_ptr));
364   PetscCall(VecRestoreArray(user.U, &y_ptr));
365   PetscCall(VecSet(user.mup, 0));
366 
367   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
368      Set number of cost functions.
369      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
370   PetscCall(TSSetCostGradients(ts, 1, &user.lambda, &user.mup));
371 
372   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
373      The adjoint vector mup has to be reset for each adjoint step when
374      using the tracking method as we want to treat the parameters at each
375      time step one at a time and prevent accumulation of the sensitivities
376      from parameters at previous time steps.
377      This is not necessary for the global method as each time dependent
378      parameter is treated as an independent parameter.
379    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
380   if (sa == SA_TRACK) {
381     for (user.adj_idx = user.max_steps; user.adj_idx > 0; user.adj_idx--) {
382       PetscCall(VecSet(user.mup, 0));
383       PetscCall(TSAdjointSetSteps(ts, 1));
384       PetscCall(TSAdjointSolve(ts));
385     }
386   }
387   if (sa == SA_GLOBAL) PetscCall(TSAdjointSolve(ts));
388 
389   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
390      Display adjoint sensitivities wrt parameters and initial conditions
391      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
392   if (sa == SA_TRACK) {
393     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  mu1: d[cost]/d[mu1]\n"));
394     PetscCall(VecView(user.sens_mu1, PETSC_VIEWER_STDOUT_WORLD));
395     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  mu2: d[cost]/d[mu2]\n"));
396     PetscCall(VecView(user.sens_mu2, PETSC_VIEWER_STDOUT_WORLD));
397   }
398 
399   if (sa == SA_GLOBAL) {
400     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n"));
401     PetscCall(VecView(user.mup, PETSC_VIEWER_STDOUT_WORLD));
402   }
403 
404   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n"));
405   PetscCall(VecView(user.lambda, PETSC_VIEWER_STDOUT_WORLD));
406 
407   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
408      Free work space!
409      All PETSc objects should be destroyed when they are no longer needed.
410      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
411   PetscCall(MatDestroy(&user.A));
412   PetscCall(MatDestroy(&user.Jacp));
413   PetscCall(VecDestroy(&user.U));
414   PetscCall(VecDestroy(&user.lambda));
415   PetscCall(VecDestroy(&user.mup));
416   PetscCall(VecDestroy(&user.mu1));
417   PetscCall(VecDestroy(&user.mu2));
418   if (sa == SA_TRACK) {
419     PetscCall(VecDestroy(&user.sens_mu1));
420     PetscCall(VecDestroy(&user.sens_mu2));
421   }
422   PetscCall(TSDestroy(&ts));
423   PetscCall(PetscFinalize());
424   return 0;
425 }
426 
427 /*TEST
428 
429   test:
430     requires: !complex
431     suffix : track
432     args : -sa_method track
433 
434   test:
435     requires: !complex
436     suffix : global
437     args : -sa_method global
438 
439 TEST*/
440