xref: /petsc/src/ts/tutorials/ex23fwdadj.c (revision daad07d386296cdcbb87925ef5f1432ee4a24ec4)
1 static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a paramerized mass matrice.\n";
2 
3 /*
4   This example solves the simple ODE
5     c x' = b x, x(0) = a,
6   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
7 
8 */
9 
10 #include <petscts.h>
11 
12 typedef struct _n_User *User;
13 struct _n_User {
14   PetscReal a;
15   PetscReal b;
16   PetscReal c;
17   /* Sensitivity analysis support */
18   PetscInt  steps;
19   PetscReal ftime;
20   Mat       Jac;                    /* Jacobian matrix */
21   Mat       Jacp;                   /* JacobianP matrix */
22   Vec       x;
23   Mat       sp;                     /* forward sensitivity variables */
24   Vec       lambda[1];              /* adjoint sensitivity variables */
25   Vec       mup[1];                 /* adjoint sensitivity variables */
26   PetscInt  der;
27 };
28 
29 static PetscErrorCode IFunction(TS ts,PetscReal t,Vec X,Vec Xdot,Vec F,void *ctx)
30 {
31   User              user = (User)ctx;
32   const PetscScalar *x,*xdot;
33   PetscScalar       *f;
34 
35   PetscFunctionBeginUser;
36   PetscCall(VecGetArrayRead(X,&x));
37   PetscCall(VecGetArrayRead(Xdot,&xdot));
38   PetscCall(VecGetArrayWrite(F,&f));
39   f[0] = user->c*xdot[0] - user->b*x[0];
40   PetscCall(VecRestoreArrayRead(X,&x));
41   PetscCall(VecRestoreArrayRead(Xdot,&xdot));
42   PetscCall(VecRestoreArrayWrite(F,&f));
43   PetscFunctionReturn(0);
44 }
45 
46 static PetscErrorCode IJacobian(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal a,Mat A,Mat B,void *ctx)
47 {
48   User              user     = (User)ctx;
49   PetscInt          rowcol[] = {0};
50   PetscScalar       J[1][1];
51   const PetscScalar *x;
52 
53   PetscFunctionBeginUser;
54   PetscCall(VecGetArrayRead(X,&x));
55   J[0][0] = user->c*a - user->b*1.0;
56   PetscCall(MatSetValues(B,1,rowcol,1,rowcol,&J[0][0],INSERT_VALUES));
57   PetscCall(VecRestoreArrayRead(X,&x));
58 
59   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
60   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
61   if (A != B) {
62     PetscCall(MatAssemblyBegin(B,MAT_FINAL_ASSEMBLY));
63     PetscCall(MatAssemblyEnd(B,MAT_FINAL_ASSEMBLY));
64   }
65   PetscFunctionReturn(0);
66 }
67 
68 static PetscErrorCode IJacobianP(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal shift,Mat A,void *ctx)
69 {
70   User              user = (User)ctx;
71   PetscInt          row[] = {0},col[]={0};
72   PetscScalar       J[1][1];
73   const PetscScalar *x,*xdot;
74   PetscReal         dt;
75 
76   PetscFunctionBeginUser;
77   PetscCall(VecGetArrayRead(X,&x));
78   PetscCall(VecGetArrayRead(Xdot,&xdot));
79   PetscCall(TSGetTimeStep(ts,&dt));
80   if (user->der == 1) J[0][0] = xdot[0];
81   if (user->der == 2) J[0][0] = -x[0];
82   PetscCall(MatSetValues(A,1,row,1,col,&J[0][0],INSERT_VALUES));
83   PetscCall(VecRestoreArrayRead(X,&x));
84 
85   PetscCall(MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY));
86   PetscCall(MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY));
87   PetscFunctionReturn(0);
88 }
89 
90 int main(int argc,char **argv)
91 {
92   TS             ts;
93   PetscScalar    *x_ptr;
94   PetscMPIInt    size;
95   struct _n_User user;
96   PetscInt       rows,cols;
97 
98   PetscCall(PetscInitialize(&argc,&argv,NULL,help));
99 
100   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD,&size));
101   PetscCheck(size == 1,PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!");
102 
103   user.a           = 2.0;
104   user.b           = 4.0;
105   user.c           = 3.0;
106   user.steps       = 0;
107   user.ftime       = 1.0;
108   user.der         = 1;
109   PetscCall(PetscOptionsGetInt(NULL,NULL,"-der",&user.der,NULL));
110 
111   rows = 1;
112   cols = 1;
113   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jac));
114   PetscCall(MatSetSizes(user.Jac,PETSC_DECIDE,PETSC_DECIDE,1,1));
115   PetscCall(MatSetFromOptions(user.Jac));
116   PetscCall(MatSetUp(user.Jac));
117   PetscCall(MatCreateVecs(user.Jac,&user.x,NULL));
118 
119   PetscCall(TSCreate(PETSC_COMM_WORLD,&ts));
120   PetscCall(TSSetType(ts,TSBEULER));
121   PetscCall(TSSetIFunction(ts,NULL,IFunction,&user));
122   PetscCall(TSSetIJacobian(ts,user.Jac,user.Jac,IJacobian,&user));
123   PetscCall(TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP));
124   PetscCall(TSSetMaxTime(ts,user.ftime));
125 
126   PetscCall(VecGetArrayWrite(user.x,&x_ptr));
127   x_ptr[0] = user.a;
128   PetscCall(VecRestoreArrayWrite(user.x,&x_ptr));
129   PetscCall(TSSetTimeStep(ts,0.001));
130 
131   /* Set up forward sensitivity */
132   PetscCall(MatCreate(PETSC_COMM_WORLD,&user.Jacp));
133   PetscCall(MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,rows,cols));
134   PetscCall(MatSetFromOptions(user.Jacp));
135   PetscCall(MatSetUp(user.Jacp));
136   PetscCall(MatCreateDense(PETSC_COMM_WORLD,PETSC_DECIDE,PETSC_DECIDE,rows,cols,NULL,&user.sp));
137   PetscCall(MatZeroEntries(user.sp));
138   PetscCall(TSForwardSetSensitivities(ts,cols,user.sp));
139   PetscCall(TSSetIJacobianP(ts,user.Jacp,IJacobianP,&user));
140 
141   PetscCall(TSSetSaveTrajectory(ts));
142   PetscCall(TSSetFromOptions(ts));
143 
144   PetscCall(TSSolve(ts,user.x));
145   PetscCall(TSGetSolveTime(ts,&user.ftime));
146   PetscCall(TSGetStepNumber(ts,&user.steps));
147   PetscCall(VecGetArray(user.x,&x_ptr));
148   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n ode solution %g\n",(double)PetscRealPart(x_ptr[0])));
149   PetscCall(VecRestoreArray(user.x,&x_ptr));
150   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical solution %g\n",(double)(user.a*PetscExpReal(user.b/user.c*user.ftime))));
151 
152   if (user.der == 1) {
153     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. c %g\n",(double)(-user.a*user.ftime*user.b/(user.c*user.c)*PetscExpReal(user.b/user.c*user.ftime))));
154   }
155   if (user.der == 2) {
156     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. b %g\n",(double)(user.a*user.ftime/user.c*PetscExpReal(user.b/user.c*user.ftime))));
157   }
158   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n forward sensitivity:\n"));
159   PetscCall(MatView(user.sp,PETSC_VIEWER_STDOUT_WORLD));
160 
161   PetscCall(MatCreateVecs(user.Jac,&user.lambda[0],NULL));
162   /* Set initial conditions for the adjoint integration */
163   PetscCall(VecGetArrayWrite(user.lambda[0],&x_ptr));
164   x_ptr[0] = 1.0;
165   PetscCall(VecRestoreArrayWrite(user.lambda[0],&x_ptr));
166   PetscCall(MatCreateVecs(user.Jacp,&user.mup[0],NULL));
167   PetscCall(VecGetArrayWrite(user.mup[0],&x_ptr));
168   x_ptr[0] = 0.0;
169   PetscCall(VecRestoreArrayWrite(user.mup[0],&x_ptr));
170 
171   PetscCall(TSSetCostGradients(ts,1,user.lambda,user.mup));
172   PetscCall(TSAdjointSolve(ts));
173 
174   PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n adjoint sensitivity:\n"));
175   PetscCall(VecView(user.mup[0],PETSC_VIEWER_STDOUT_WORLD));
176 
177   PetscCall(MatDestroy(&user.Jac));
178   PetscCall(MatDestroy(&user.sp));
179   PetscCall(MatDestroy(&user.Jacp));
180   PetscCall(VecDestroy(&user.x));
181   PetscCall(VecDestroy(&user.lambda[0]));
182   PetscCall(VecDestroy(&user.mup[0]));
183   PetscCall(TSDestroy(&ts));
184 
185   PetscCall(PetscFinalize());
186   return 0;
187 }
188 
189 /*TEST
190 
191     test:
192       args: -ts_type beuler
193 
194     test:
195       suffix: 2
196       args: -ts_type cn
197       output_file: output/ex23fwdadj_1.out
198 
199 TEST*/
200