xref: /petsc/src/ts/tutorials/ex23fwdadj.c (revision a336c15037c72f93cd561f5a5e11e93175f2efd9)
1 static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a parametrized mass matrice.\n";
2 
3 /*
4   This example solves the simple ODE
5     c x' = b x, x(0) = a,
6   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
7 
8 */
9 
10 #include <petscts.h>
11 
12 typedef struct _n_User *User;
13 struct _n_User {
14   PetscReal a;
15   PetscReal b;
16   PetscReal c;
17   /* Sensitivity analysis support */
18   PetscInt  steps;
19   PetscReal ftime;
20   Mat       Jac;  /* Jacobian matrix */
21   Mat       Jacp; /* JacobianP matrix */
22   Vec       x;
23   Mat       sp;        /* forward sensitivity variables */
24   Vec       lambda[1]; /* adjoint sensitivity variables */
25   Vec       mup[1];    /* adjoint sensitivity variables */
26   PetscInt  der;
27 };
28 
29 static PetscErrorCode IFunction(TS ts, PetscReal t, Vec X, Vec Xdot, Vec F, PetscCtx ctx)
30 {
31   User               user = (User)ctx;
32   const PetscScalar *x, *xdot;
33   PetscScalar       *f;
34 
35   PetscFunctionBeginUser;
36   PetscCall(VecGetArrayRead(X, &x));
37   PetscCall(VecGetArrayRead(Xdot, &xdot));
38   PetscCall(VecGetArrayWrite(F, &f));
39   f[0] = user->c * xdot[0] - user->b * x[0];
40   PetscCall(VecRestoreArrayRead(X, &x));
41   PetscCall(VecRestoreArrayRead(Xdot, &xdot));
42   PetscCall(VecRestoreArrayWrite(F, &f));
43   PetscFunctionReturn(PETSC_SUCCESS);
44 }
45 
46 static PetscErrorCode IJacobian(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal a, Mat A, Mat B, PetscCtx ctx)
47 {
48   User               user     = (User)ctx;
49   PetscInt           rowcol[] = {0};
50   PetscScalar        J[1][1];
51   const PetscScalar *x;
52 
53   PetscFunctionBeginUser;
54   PetscCall(VecGetArrayRead(X, &x));
55   J[0][0] = user->c * a - user->b * 1.0;
56   PetscCall(MatSetValues(B, 1, rowcol, 1, rowcol, &J[0][0], INSERT_VALUES));
57   PetscCall(VecRestoreArrayRead(X, &x));
58 
59   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
60   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
61   if (A != B) {
62     PetscCall(MatAssemblyBegin(B, MAT_FINAL_ASSEMBLY));
63     PetscCall(MatAssemblyEnd(B, MAT_FINAL_ASSEMBLY));
64   }
65   PetscFunctionReturn(PETSC_SUCCESS);
66 }
67 
68 static PetscErrorCode IJacobianP(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal shift, Mat A, PetscCtx ctx)
69 {
70   User               user  = (User)ctx;
71   PetscInt           row[] = {0}, col[] = {0};
72   PetscScalar        J[1][1];
73   const PetscScalar *x, *xdot;
74   PetscReal          dt;
75 
76   PetscFunctionBeginUser;
77   PetscCall(VecGetArrayRead(X, &x));
78   PetscCall(VecGetArrayRead(Xdot, &xdot));
79   PetscCall(TSGetTimeStep(ts, &dt));
80   if (user->der == 1) J[0][0] = xdot[0];
81   if (user->der == 2) J[0][0] = -x[0];
82   PetscCall(MatSetValues(A, 1, row, 1, col, &J[0][0], INSERT_VALUES));
83   PetscCall(VecRestoreArrayRead(X, &x));
84 
85   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
86   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 int main(int argc, char **argv)
91 {
92   TS             ts;
93   PetscScalar   *x_ptr;
94   PetscMPIInt    size;
95   struct _n_User user;
96   PetscInt       rows, cols;
97 
98   PetscFunctionBeginUser;
99   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
100 
101   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
102   PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
103 
104   user.a     = 2.0;
105   user.b     = 4.0;
106   user.c     = 3.0;
107   user.steps = 0;
108   user.ftime = 1.0;
109   user.der   = 1;
110   PetscCall(PetscOptionsGetInt(NULL, NULL, "-der", &user.der, NULL));
111 
112   rows = 1;
113   cols = 1;
114   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jac));
115   PetscCall(MatSetSizes(user.Jac, PETSC_DECIDE, PETSC_DECIDE, 1, 1));
116   PetscCall(MatSetFromOptions(user.Jac));
117   PetscCall(MatSetUp(user.Jac));
118   PetscCall(MatCreateVecs(user.Jac, &user.x, NULL));
119 
120   PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
121   PetscCall(TSSetType(ts, TSBEULER));
122   PetscCall(TSSetIFunction(ts, NULL, IFunction, &user));
123   PetscCall(TSSetIJacobian(ts, user.Jac, user.Jac, IJacobian, &user));
124   PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
125   PetscCall(TSSetMaxTime(ts, user.ftime));
126 
127   PetscCall(VecGetArrayWrite(user.x, &x_ptr));
128   x_ptr[0] = user.a;
129   PetscCall(VecRestoreArrayWrite(user.x, &x_ptr));
130   PetscCall(TSSetTimeStep(ts, 0.001));
131 
132   /* Set up forward sensitivity */
133   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
134   PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, rows, cols));
135   PetscCall(MatSetFromOptions(user.Jacp));
136   PetscCall(MatSetUp(user.Jacp));
137   PetscCall(MatCreateDense(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, rows, cols, NULL, &user.sp));
138   PetscCall(MatZeroEntries(user.sp));
139   PetscCall(TSForwardSetSensitivities(ts, cols, user.sp));
140   PetscCall(TSSetIJacobianP(ts, user.Jacp, IJacobianP, &user));
141 
142   PetscCall(TSSetSaveTrajectory(ts));
143   PetscCall(TSSetFromOptions(ts));
144 
145   PetscCall(TSSolve(ts, user.x));
146   PetscCall(TSGetSolveTime(ts, &user.ftime));
147   PetscCall(TSGetStepNumber(ts, &user.steps));
148   PetscCall(VecGetArray(user.x, &x_ptr));
149   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n ode solution %g\n", (double)PetscRealPart(x_ptr[0])));
150   PetscCall(VecRestoreArray(user.x, &x_ptr));
151   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical solution %g\n", (double)(user.a * PetscExpReal(user.b / user.c * user.ftime))));
152 
153   if (user.der == 1) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. c %g\n", (double)(-user.a * user.ftime * user.b / (user.c * user.c) * PetscExpReal(user.b / user.c * user.ftime))));
154   if (user.der == 2) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. b %g\n", (double)(user.a * user.ftime / user.c * PetscExpReal(user.b / user.c * user.ftime))));
155   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n forward sensitivity:\n"));
156   PetscCall(MatView(user.sp, PETSC_VIEWER_STDOUT_WORLD));
157 
158   PetscCall(MatCreateVecs(user.Jac, &user.lambda[0], NULL));
159   /* Set initial conditions for the adjoint integration */
160   PetscCall(VecGetArrayWrite(user.lambda[0], &x_ptr));
161   x_ptr[0] = 1.0;
162   PetscCall(VecRestoreArrayWrite(user.lambda[0], &x_ptr));
163   PetscCall(MatCreateVecs(user.Jacp, &user.mup[0], NULL));
164   PetscCall(VecGetArrayWrite(user.mup[0], &x_ptr));
165   x_ptr[0] = 0.0;
166   PetscCall(VecRestoreArrayWrite(user.mup[0], &x_ptr));
167 
168   PetscCall(TSSetCostGradients(ts, 1, user.lambda, user.mup));
169   PetscCall(TSAdjointSolve(ts));
170 
171   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n adjoint sensitivity:\n"));
172   PetscCall(VecView(user.mup[0], PETSC_VIEWER_STDOUT_WORLD));
173 
174   PetscCall(MatDestroy(&user.Jac));
175   PetscCall(MatDestroy(&user.sp));
176   PetscCall(MatDestroy(&user.Jacp));
177   PetscCall(VecDestroy(&user.x));
178   PetscCall(VecDestroy(&user.lambda[0]));
179   PetscCall(VecDestroy(&user.mup[0]));
180   PetscCall(TSDestroy(&ts));
181 
182   PetscCall(PetscFinalize());
183   return 0;
184 }
185 
186 /*TEST
187 
188     test:
189       args: -ts_type beuler
190 
191     test:
192       suffix: 2
193       args: -ts_type cn
194       output_file: output/ex23fwdadj_1.out
195 
196 TEST*/
197