xref: /petsc/src/ts/tutorials/ex23fwdadj.c (revision 76d901e46dda72c1afe96306c7cb4731c47d4e87)
1 static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a paramerized mass matrice.\n";
2 
3 /*
4   This example solves the simple ODE
5     c x' = b x, x(0) = a,
6   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
7 
8 */
9 
10 #include <petscts.h>
11 
12 typedef struct _n_User *User;
13 struct _n_User {
14   PetscReal a;
15   PetscReal b;
16   PetscReal c;
17   /* Sensitivity analysis support */
18   PetscInt  steps;
19   PetscReal ftime;
20   Mat       Jac;                    /* Jacobian matrix */
21   Mat       Jacp;                   /* JacobianP matrix */
22   Vec       x;
23   Mat       sp;                     /* forward sensitivity variables */
24   Vec       lambda[1];              /* adjoint sensitivity variables */
25   Vec       mup[1];                 /* adjoint sensitivity variables */
26   PetscInt  der;
27 };
28 
29 static PetscErrorCode IFunction(TS ts,PetscReal t,Vec X,Vec Xdot,Vec F,void *ctx)
30 {
31   User              user = (User)ctx;
32   const PetscScalar *x,*xdot;
33   PetscScalar       *f;
34 
35   PetscFunctionBeginUser;
36   PetscCall(VecGetArrayRead(X,&x));
37   PetscCall(VecGetArrayRead(Xdot,&xdot));
38   PetscCall(VecGetArrayWrite(F,&f));
39   f[0] = user->c*xdot[0] - user->b*x[0];
40   PetscCall(VecRestoreArrayRead(X,&x));
41   PetscCall(VecRestoreArrayRead(Xdot,&xdot));
42   PetscCall(VecRestoreArrayWrite(F,&f));
43   PetscFunctionReturn(0);
44 }
45 
46 static PetscErrorCode IJacobian(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal a,Mat A,Mat B,void *ctx)
47 {
48   User              user     = (User)ctx;
49   PetscInt          rowcol[] = {0};
50   PetscScalar       J[1][1];
51   const PetscScalar *x;
52 
53   PetscFunctionBeginUser;
54   PetscCall(VecGetArrayRead(X,&x));
55   J[0][0] = user->c*a - user->b*1.0;
56   PetscCall(MatSetValues(B,1,rowcol,1,rowcol,&J[0][0],INSERT_VALUES));
57   PetscCall(VecRestoreArrayRead(X,&x));
58 
59   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
60   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
61   if (A != B) {
62     PetscCall(MatAssemblyBegin(B,MAT_FINAL_ASSEMBLY));
63     PetscCall(MatAssemblyEnd(B,MAT_FINAL_ASSEMBLY));
64   }
65   PetscFunctionReturn(0);
66 }
67 
68 static PetscErrorCode IJacobianP(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal shift,Mat A,void *ctx)
69 {
70   User              user = (User)ctx;
71   PetscInt          row[] = {0},col[]={0};
72   PetscScalar       J[1][1];
73   const PetscScalar *x,*xdot;
74   PetscReal         dt;
75 
76   PetscFunctionBeginUser;
77   PetscCall(VecGetArrayRead(X,&x));
78   PetscCall(VecGetArrayRead(Xdot,&xdot));
79   PetscCall(TSGetTimeStep(ts,&dt));
80   if (user->der == 1) J[0][0] = xdot[0];
81   if (user->der == 2) J[0][0] = -x[0];
82   PetscCall(MatSetValues(A,1,row,1,col,&J[0][0],INSERT_VALUES));
83   PetscCall(VecRestoreArrayRead(X,&x));
84 
85   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
86   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
87   PetscFunctionReturn(0);
88 }
89 
90 int main(int argc,char **argv)
91 {
92   TS             ts;
93   PetscScalar    *x_ptr;
94   PetscMPIInt    size;
95   struct _n_User user;
96   PetscInt       rows,cols;
97 
98   PetscFunctionBeginUser;
99   PetscCall(PetscInitialize(&argc,&argv,NULL,help));
100 
101   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD,&size));
102   PetscCheck(size == 1,PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!");
103 
104   user.a           = 2.0;
105   user.b           = 4.0;
106   user.c           = 3.0;
107   user.steps       = 0;
108   user.ftime       = 1.0;
109   user.der         = 1;
110   PetscCall(PetscOptionsGetInt(NULL,NULL,"-der",&user.der,NULL));
111 
112   rows = 1;
113   cols = 1;
114   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jac));
115   PetscCall(MatSetSizes(user.Jac,PETSC_DECIDE,PETSC_DECIDE,1,1));
116   PetscCall(MatSetFromOptions(user.Jac));
117   PetscCall(MatSetUp(user.Jac));
118   PetscCall(MatCreateVecs(user.Jac,&user.x,NULL));
119 
120   PetscCall(TSCreate(PETSC_COMM_WORLD,&ts));
121   PetscCall(TSSetType(ts,TSBEULER));
122   PetscCall(TSSetIFunction(ts,NULL,IFunction,&user));
123   PetscCall(TSSetIJacobian(ts,user.Jac,user.Jac,IJacobian,&user));
124   PetscCall(TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP));
125   PetscCall(TSSetMaxTime(ts,user.ftime));
126 
127   PetscCall(VecGetArrayWrite(user.x,&x_ptr));
128   x_ptr[0] = user.a;
129   PetscCall(VecRestoreArrayWrite(user.x,&x_ptr));
130   PetscCall(TSSetTimeStep(ts,0.001));
131 
132   /* Set up forward sensitivity */
133   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jacp));
134   PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,rows,cols));
135   PetscCall(MatSetFromOptions(user.Jacp));
136   PetscCall(MatSetUp(user.Jacp));
137   PetscCall(MatCreateDense(PETSC_COMM_WORLD,PETSC_DECIDE,PETSC_DECIDE,rows,cols,NULL,&user.sp));
138   PetscCall(MatZeroEntries(user.sp));
139   PetscCall(TSForwardSetSensitivities(ts,cols,user.sp));
140   PetscCall(TSSetIJacobianP(ts,user.Jacp,IJacobianP,&user));
141 
142   PetscCall(TSSetSaveTrajectory(ts));
143   PetscCall(TSSetFromOptions(ts));
144 
145   PetscCall(TSSolve(ts,user.x));
146   PetscCall(TSGetSolveTime(ts,&user.ftime));
147   PetscCall(TSGetStepNumber(ts,&user.steps));
148   PetscCall(VecGetArray(user.x,&x_ptr));
149   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n ode solution %g\n",(double)PetscRealPart(x_ptr[0])));
150   PetscCall(VecRestoreArray(user.x,&x_ptr));
151   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical solution %g\n",(double)(user.a*PetscExpReal(user.b/user.c*user.ftime))));
152 
153   if (user.der == 1) {
154     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. c %g\n",(double)(-user.a*user.ftime*user.b/(user.c*user.c)*PetscExpReal(user.b/user.c*user.ftime))));
155   }
156   if (user.der == 2) {
157     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. b %g\n",(double)(user.a*user.ftime/user.c*PetscExpReal(user.b/user.c*user.ftime))));
158   }
159   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n forward sensitivity:\n"));
160   PetscCall(MatView(user.sp,PETSC_VIEWER_STDOUT_WORLD));
161 
162   PetscCall(MatCreateVecs(user.Jac,&user.lambda[0],NULL));
163   /* Set initial conditions for the adjoint integration */
164   PetscCall(VecGetArrayWrite(user.lambda[0],&x_ptr));
165   x_ptr[0] = 1.0;
166   PetscCall(VecRestoreArrayWrite(user.lambda[0],&x_ptr));
167   PetscCall(MatCreateVecs(user.Jacp,&user.mup[0],NULL));
168   PetscCall(VecGetArrayWrite(user.mup[0],&x_ptr));
169   x_ptr[0] = 0.0;
170   PetscCall(VecRestoreArrayWrite(user.mup[0],&x_ptr));
171 
172   PetscCall(TSSetCostGradients(ts,1,user.lambda,user.mup));
173   PetscCall(TSAdjointSolve(ts));
174 
175   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n adjoint sensitivity:\n"));
176   PetscCall(VecView(user.mup[0],PETSC_VIEWER_STDOUT_WORLD));
177 
178   PetscCall(MatDestroy(&user.Jac));
179   PetscCall(MatDestroy(&user.sp));
180   PetscCall(MatDestroy(&user.Jacp));
181   PetscCall(VecDestroy(&user.x));
182   PetscCall(VecDestroy(&user.lambda[0]));
183   PetscCall(VecDestroy(&user.mup[0]));
184   PetscCall(TSDestroy(&ts));
185 
186   PetscCall(PetscFinalize());
187   return 0;
188 }
189 
190 /*TEST
191 
192     test:
193       args: -ts_type beuler
194 
195     test:
196       suffix: 2
197       args: -ts_type cn
198       output_file: output/ex23fwdadj_1.out
199 
200 TEST*/
201