xref: /petsc/src/ts/tutorials/ex20td.c (revision 750b007cd8d816cecd9de99077bb0a703b4cf61a)
1 static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\
2 equation with time dependent parameters using two approaches :  \n\
3 track  : track only local sensitivities at each adjoint step \n\
4          and accumulate them in a global array \n\
5 global : track parameters at all timesteps together \n\
6 Choose one of the two at runtime by -sa_method {track,global}. \n";
7 
8 /*
9    Simple example to demonstrate TSAdjoint capabilities for time dependent params
10    without integral cost terms using either a tracking or global method.
11 
12    Modify the Van Der Pol Eq to :
13    [u1'] = [mu1(t)*u1]
14    [u2'] = [mu2(t)*((1-u1^2)*u2-u1)]
15    (with initial conditions & params independent)
16 
17    Define uref to be solution with initail conditions (2,-2/3), mu=(1,1e3)
18    - u_ref : (1.5967,-1.02969)
19 
20    Define const function as cost = 2-norm(u - u_ref);
21 
22    Initialization for the adjoint TS :
23    - dcost/dy|final_time = 2*(u-u_ref)|final_time
24    - dcost/dp|final_time = 0
25 
26    The tracking method only tracks local sensitivity at each time step
27    and accumulates these sensitivities in a global array. Since the structure
28    of the equations being solved at each time step does not change, the jacobian
29    wrt parameters is defined analogous to constant RHSJacobian for a liner
30    TSSolve and the size of the jacP is independent of the number of time
31    steps. Enable this mode of adjoint analysis by -sa_method track.
32 
33    The global method combines the parameters at all timesteps and tracks them
34    together. Thus, the columns of the jacP matrix are filled dependent upon the
35    time step. Also, the dimensions of the jacP matrix now depend upon the number
36    of time steps. Enable this mode of adjoint analysis by -sa_method global.
37 
38    Since the equations here have parameters at predefined time steps, this
39    example should be run with non adaptive time stepping solvers only. This
40    can be ensured by -ts_adapt_type none (which is the default behavior only
41    for certain TS solvers like TSCN. If using an explicit TS like TSRK,
42    please be sure to add the aforementioned option to disable adaptive
43    timestepping.)
44 */
45 
46 /*
47    Include "petscts.h" so that we can use TS solvers.  Note that this file
48    automatically includes:
49      petscsys.h    - base PETSc routines   petscvec.h  - vectors
50      petscmat.h    - matrices
51      petscis.h     - index sets            petscksp.h  - Krylov subspace methods
52      petscviewer.h - viewers               petscpc.h   - preconditioners
53      petscksp.h    - linear solvers        petscsnes.h - nonlinear solvers
54 */
55 #include <petscts.h>
56 
57 extern PetscErrorCode RHSFunction(TS, PetscReal, Vec, Vec, void *);
58 extern PetscErrorCode RHSJacobian(TS, PetscReal, Vec, Mat, Mat, void *);
59 extern PetscErrorCode RHSJacobianP_track(TS, PetscReal, Vec, Mat, void *);
60 extern PetscErrorCode RHSJacobianP_global(TS, PetscReal, Vec, Mat, void *);
61 extern PetscErrorCode Monitor(TS, PetscInt, PetscReal, Vec, void *);
62 extern PetscErrorCode AdjointMonitor(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *);
63 
64 /*
65    User-defined application context - contains data needed by the
66    application-provided call-back routines.
67 */
68 
69 typedef struct {
70   /*------------- Forward solve data structures --------------*/
71   PetscInt  max_steps;  /* number of steps to run ts for */
72   PetscReal final_time; /* final time to integrate to*/
73   PetscReal time_step;  /* ts integration time step */
74   Vec       mu1;        /* time dependent params */
75   Vec       mu2;        /* time dependent params */
76   Vec       U;          /* solution vector */
77   Mat       A;          /* Jacobian matrix */
78 
79   /*------------- Adjoint solve data structures --------------*/
80   Mat Jacp;   /* JacobianP matrix */
81   Vec lambda; /* adjoint variable */
82   Vec mup;    /* adjoint variable */
83 
84   /*------------- Global accumation vecs for monitor based tracking --------------*/
85   Vec      sens_mu1; /* global sensitivity accumulation */
86   Vec      sens_mu2; /* global sensitivity accumulation */
87   PetscInt adj_idx;  /* to keep track of adjoint solve index */
88 } AppCtx;
89 
90 typedef enum {
91   SA_TRACK,
92   SA_GLOBAL
93 } SAMethod;
94 static const char *const SAMethods[] = {"TRACK", "GLOBAL", "SAMethod", "SA_", 0};
95 
96 /* ----------------------- Explicit form of the ODE  -------------------- */
97 
98 PetscErrorCode RHSFunction(TS ts, PetscReal t, Vec U, Vec F, void *ctx) {
99   AppCtx            *user = (AppCtx *)ctx;
100   PetscScalar       *f;
101   PetscInt           curr_step;
102   const PetscScalar *u;
103   const PetscScalar *mu1;
104   const PetscScalar *mu2;
105 
106   PetscFunctionBeginUser;
107   PetscCall(TSGetStepNumber(ts, &curr_step));
108   PetscCall(VecGetArrayRead(U, &u));
109   PetscCall(VecGetArrayRead(user->mu1, &mu1));
110   PetscCall(VecGetArrayRead(user->mu2, &mu2));
111   PetscCall(VecGetArray(F, &f));
112   f[0] = mu1[curr_step] * u[1];
113   f[1] = mu2[curr_step] * ((1. - u[0] * u[0]) * u[1] - u[0]);
114   PetscCall(VecRestoreArrayRead(U, &u));
115   PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
116   PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
117   PetscCall(VecRestoreArray(F, &f));
118   PetscFunctionReturn(0);
119 }
120 
121 PetscErrorCode RHSJacobian(TS ts, PetscReal t, Vec U, Mat A, Mat B, void *ctx) {
122   AppCtx            *user     = (AppCtx *)ctx;
123   PetscInt           rowcol[] = {0, 1};
124   PetscScalar        J[2][2];
125   PetscInt           curr_step;
126   const PetscScalar *u;
127   const PetscScalar *mu1;
128   const PetscScalar *mu2;
129 
130   PetscFunctionBeginUser;
131   PetscCall(TSGetStepNumber(ts, &curr_step));
132   PetscCall(VecGetArrayRead(user->mu1, &mu1));
133   PetscCall(VecGetArrayRead(user->mu2, &mu2));
134   PetscCall(VecGetArrayRead(U, &u));
135   J[0][0] = 0;
136   J[1][0] = -mu2[curr_step] * (2.0 * u[1] * u[0] + 1.);
137   J[0][1] = mu1[curr_step];
138   J[1][1] = mu2[curr_step] * (1.0 - u[0] * u[0]);
139   PetscCall(MatSetValues(A, 2, rowcol, 2, rowcol, &J[0][0], INSERT_VALUES));
140   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
141   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
142   PetscCall(VecRestoreArrayRead(U, &u));
143   PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
144   PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
145   PetscFunctionReturn(0);
146 }
147 
148 /* ------------------ Jacobian wrt parameters for tracking method ------------------ */
149 
150 PetscErrorCode RHSJacobianP_track(TS ts, PetscReal t, Vec U, Mat A, void *ctx) {
151   PetscInt           row[] = {0, 1}, col[] = {0, 1};
152   PetscScalar        J[2][2];
153   const PetscScalar *u;
154 
155   PetscFunctionBeginUser;
156   PetscCall(VecGetArrayRead(U, &u));
157   J[0][0] = u[1];
158   J[1][0] = 0;
159   J[0][1] = 0;
160   J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
161   PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
162   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
163   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
164   PetscCall(VecRestoreArrayRead(U, &u));
165   PetscFunctionReturn(0);
166 }
167 
168 /* ------------------ Jacobian wrt parameters for global method ------------------ */
169 
170 PetscErrorCode RHSJacobianP_global(TS ts, PetscReal t, Vec U, Mat A, void *ctx) {
171   PetscInt           row[] = {0, 1}, col[] = {0, 1};
172   PetscScalar        J[2][2];
173   const PetscScalar *u;
174   PetscInt           curr_step;
175 
176   PetscFunctionBeginUser;
177   PetscCall(TSGetStepNumber(ts, &curr_step));
178   PetscCall(VecGetArrayRead(U, &u));
179   J[0][0] = u[1];
180   J[1][0] = 0;
181   J[0][1] = 0;
182   J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
183   col[0]  = (curr_step)*2;
184   col[1]  = (curr_step)*2 + 1;
185   PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
186   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
187   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
188   PetscCall(VecRestoreArrayRead(U, &u));
189   PetscFunctionReturn(0);
190 }
191 
192 /* Dump solution to console if called */
193 PetscErrorCode Monitor(TS ts, PetscInt step, PetscReal t, Vec U, void *ctx) {
194   PetscFunctionBeginUser;
195   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution at time %e is \n", (double)t));
196   PetscCall(VecView(U, PETSC_VIEWER_STDOUT_WORLD));
197   PetscFunctionReturn(0);
198 }
199 
200 /* Customized adjoint monitor to keep track of local
201    sensitivities by storing them in a global sensitivity array.
202    Note : This routine is only used for the tracking method. */
203 PetscErrorCode AdjointMonitor(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, void *ctx) {
204   AppCtx            *user = (AppCtx *)ctx;
205   PetscInt           curr_step;
206   PetscScalar       *sensmu1_glob;
207   PetscScalar       *sensmu2_glob;
208   const PetscScalar *sensmu_loc;
209 
210   PetscFunctionBeginUser;
211   PetscCall(TSGetStepNumber(ts, &curr_step));
212   /* Note that we skip the first call to the monitor in the adjoint
213      solve since the sensitivities are already set (during
214      initialization of adjoint vectors).
215      We also note that each indvidial TSAdjointSolve calls the monitor
216      twice, once at the step it is integrating from and once at the step
217      it integrates to. Only the second call is useful for transferring
218      local sensitivities to the global array. */
219   if (curr_step == user->adj_idx) {
220     PetscFunctionReturn(0);
221   } else {
222     PetscCall(VecGetArrayRead(*mu, &sensmu_loc));
223     PetscCall(VecGetArray(user->sens_mu1, &sensmu1_glob));
224     PetscCall(VecGetArray(user->sens_mu2, &sensmu2_glob));
225     sensmu1_glob[curr_step] = sensmu_loc[0];
226     sensmu2_glob[curr_step] = sensmu_loc[1];
227     PetscCall(VecRestoreArray(user->sens_mu1, &sensmu1_glob));
228     PetscCall(VecRestoreArray(user->sens_mu2, &sensmu2_glob));
229     PetscCall(VecRestoreArrayRead(*mu, &sensmu_loc));
230     PetscFunctionReturn(0);
231   }
232 }
233 
234 int main(int argc, char **argv) {
235   TS           ts;
236   AppCtx       user;
237   PetscScalar *x_ptr, *y_ptr, *u_ptr;
238   PetscMPIInt  size;
239   PetscBool    monitor = PETSC_FALSE;
240   SAMethod     sa      = SA_GLOBAL;
241 
242   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
243      Initialize program
244      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
245   PetscFunctionBeginUser;
246   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
247   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
248   PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
249 
250   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
251      Set runtime options
252      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
253   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "SA analysis options.", "");
254   {
255     PetscCall(PetscOptionsGetBool(NULL, NULL, "-monitor", &monitor, NULL));
256     PetscCall(PetscOptionsEnum("-sa_method", "Sensitivity analysis method (track or global)", "", SAMethods, (PetscEnum)sa, (PetscEnum *)&sa, NULL));
257   }
258   PetscOptionsEnd();
259 
260   user.final_time = 0.1;
261   user.max_steps  = 5;
262   user.time_step  = user.final_time / user.max_steps;
263 
264   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
265      Create necessary matrix and vectors for forward solve.
266      Create Jacp matrix for adjoint solve.
267      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
268   PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu1));
269   PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu2));
270   PetscCall(VecSet(user.mu1, 1.25));
271   PetscCall(VecSet(user.mu2, 1.0e2));
272 
273   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
274       For tracking method : create the global sensitivity array to
275       accumulate sensitivity with respect to parameters at each step.
276      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
277   if (sa == SA_TRACK) {
278     PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu1));
279     PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu2));
280   }
281 
282   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.A));
283   PetscCall(MatSetSizes(user.A, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
284   PetscCall(MatSetFromOptions(user.A));
285   PetscCall(MatSetUp(user.A));
286   PetscCall(MatCreateVecs(user.A, &user.U, NULL));
287 
288   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
289       Note that the dimensions of the Jacp matrix depend upon the
290       sensitivity analysis method being used !
291      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
292   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
293   if (sa == SA_TRACK) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
294   if (sa == SA_GLOBAL) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, user.max_steps * 2));
295   PetscCall(MatSetFromOptions(user.Jacp));
296   PetscCall(MatSetUp(user.Jacp));
297 
298   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
299      Create timestepping solver context
300      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
301   PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
302   PetscCall(TSSetEquationType(ts, TS_EQ_ODE_EXPLICIT));
303   PetscCall(TSSetType(ts, TSCN));
304 
305   PetscCall(TSSetRHSFunction(ts, NULL, RHSFunction, &user));
306   PetscCall(TSSetRHSJacobian(ts, user.A, user.A, RHSJacobian, &user));
307   if (sa == SA_TRACK) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_track, &user));
308   if (sa == SA_GLOBAL) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_global, &user));
309 
310   PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
311   PetscCall(TSSetMaxTime(ts, user.final_time));
312   PetscCall(TSSetTimeStep(ts, user.final_time / user.max_steps));
313 
314   if (monitor) PetscCall(TSMonitorSet(ts, Monitor, &user, NULL));
315   if (sa == SA_TRACK) PetscCall(TSAdjointMonitorSet(ts, AdjointMonitor, &user, NULL));
316 
317   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
318      Set initial conditions
319      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
320   PetscCall(VecGetArray(user.U, &x_ptr));
321   x_ptr[0] = 2.0;
322   x_ptr[1] = -2.0 / 3.0;
323   PetscCall(VecRestoreArray(user.U, &x_ptr));
324 
325   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
326      Save trajectory of solution so that TSAdjointSolve() may be used
327      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
328   PetscCall(TSSetSaveTrajectory(ts));
329 
330   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
331      Set runtime options
332      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
333   PetscCall(TSSetFromOptions(ts));
334 
335   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
336      Execute forward model and print solution.
337      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
338   PetscCall(TSSolve(ts, user.U));
339   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution of forward TS :\n"));
340   PetscCall(VecView(user.U, PETSC_VIEWER_STDOUT_WORLD));
341   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Forward TS solve successful! Adjoint run begins!\n"));
342 
343   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
344      Adjoint model starts here! Create adjoint vectors.
345      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
346   PetscCall(MatCreateVecs(user.A, &user.lambda, NULL));
347   PetscCall(MatCreateVecs(user.Jacp, &user.mup, NULL));
348 
349   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
350      Set initial conditions for the adjoint vector
351      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
352   PetscCall(VecGetArray(user.U, &u_ptr));
353   PetscCall(VecGetArray(user.lambda, &y_ptr));
354   y_ptr[0] = 2 * (u_ptr[0] - 1.5967);
355   y_ptr[1] = 2 * (u_ptr[1] - -(1.02969));
356   PetscCall(VecRestoreArray(user.lambda, &y_ptr));
357   PetscCall(VecRestoreArray(user.U, &y_ptr));
358   PetscCall(VecSet(user.mup, 0));
359 
360   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
361      Set number of cost functions.
362      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
363   PetscCall(TSSetCostGradients(ts, 1, &user.lambda, &user.mup));
364 
365   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
366      The adjoint vector mup has to be reset for each adjoint step when
367      using the tracking method as we want to treat the parameters at each
368      time step one at a time and prevent accumulation of the sensitivities
369      from parameters at previous time steps.
370      This is not necessary for the global method as each time dependent
371      parameter is treated as an independent parameter.
372    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
373   if (sa == SA_TRACK) {
374     for (user.adj_idx = user.max_steps; user.adj_idx > 0; user.adj_idx--) {
375       PetscCall(VecSet(user.mup, 0));
376       PetscCall(TSAdjointSetSteps(ts, 1));
377       PetscCall(TSAdjointSolve(ts));
378     }
379   }
380   if (sa == SA_GLOBAL) PetscCall(TSAdjointSolve(ts));
381 
382   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
383      Dispaly adjoint sensitivities wrt parameters and initial conditions
384      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
385   if (sa == SA_TRACK) {
386     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  mu1: d[cost]/d[mu1]\n"));
387     PetscCall(VecView(user.sens_mu1, PETSC_VIEWER_STDOUT_WORLD));
388     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  mu2: d[cost]/d[mu2]\n"));
389     PetscCall(VecView(user.sens_mu2, PETSC_VIEWER_STDOUT_WORLD));
390   }
391 
392   if (sa == SA_GLOBAL) {
393     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt  params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n"));
394     PetscCall(VecView(user.mup, PETSC_VIEWER_STDOUT_WORLD));
395   }
396 
397   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n"));
398   PetscCall(VecView(user.lambda, PETSC_VIEWER_STDOUT_WORLD));
399 
400   /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
401      Free work space!
402      All PETSc objects should be destroyed when they are no longer needed.
403      - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
404   PetscCall(MatDestroy(&user.A));
405   PetscCall(MatDestroy(&user.Jacp));
406   PetscCall(VecDestroy(&user.U));
407   PetscCall(VecDestroy(&user.lambda));
408   PetscCall(VecDestroy(&user.mup));
409   PetscCall(VecDestroy(&user.mu1));
410   PetscCall(VecDestroy(&user.mu2));
411   if (sa == SA_TRACK) {
412     PetscCall(VecDestroy(&user.sens_mu1));
413     PetscCall(VecDestroy(&user.sens_mu2));
414   }
415   PetscCall(TSDestroy(&ts));
416   PetscCall(PetscFinalize());
417   return (0);
418 }
419 
420 /*TEST
421 
422   test:
423     requires: !complex
424     suffix : track
425     args : -sa_method track
426 
427   test:
428     requires: !complex
429     suffix : global
430     args : -sa_method global
431 
432 TEST*/
433