xref: /petsc/src/ts/tutorials/ex23fwdadj.c (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1 static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a paramerized mass matrice.\n";
2 
3 /*
4   This example solves the simple ODE
5     c x' = b x, x(0) = a,
6   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
7 
8 */
9 
10 #include <petscts.h>
11 
12 typedef struct _n_User *User;
13 struct _n_User {
14   PetscReal a;
15   PetscReal b;
16   PetscReal c;
17   /* Sensitivity analysis support */
18   PetscInt  steps;
19   PetscReal ftime;
20   Mat       Jac;                    /* Jacobian matrix */
21   Mat       Jacp;                   /* JacobianP matrix */
22   Vec       x;
23   Mat       sp;                     /* forward sensitivity variables */
24   Vec       lambda[1];              /* adjoint sensitivity variables */
25   Vec       mup[1];                 /* adjoint sensitivity variables */
26   PetscInt  der;
27 };
28 
29 static PetscErrorCode IFunction(TS ts,PetscReal t,Vec X,Vec Xdot,Vec F,void *ctx)
30 {
31   PetscErrorCode    ierr;
32   User              user = (User)ctx;
33   const PetscScalar *x,*xdot;
34   PetscScalar       *f;
35 
36   PetscFunctionBeginUser;
37   ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
38   ierr = VecGetArrayRead(Xdot,&xdot);CHKERRQ(ierr);
39   ierr = VecGetArrayWrite(F,&f);CHKERRQ(ierr);
40   f[0] = user->c*xdot[0] - user->b*x[0];
41   ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
42   ierr = VecRestoreArrayRead(Xdot,&xdot);CHKERRQ(ierr);
43   ierr = VecRestoreArrayWrite(F,&f);CHKERRQ(ierr);
44   PetscFunctionReturn(0);
45 }
46 
47 static PetscErrorCode IJacobian(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal a,Mat A,Mat B,void *ctx)
48 {
49   PetscErrorCode    ierr;
50   User              user     = (User)ctx;
51   PetscInt          rowcol[] = {0};
52   PetscScalar       J[1][1];
53   const PetscScalar *x;
54 
55   PetscFunctionBeginUser;
56   ierr    = VecGetArrayRead(X,&x);CHKERRQ(ierr);
57   J[0][0] = user->c*a - user->b*1.0;
58   ierr    = MatSetValues(B,1,rowcol,1,rowcol,&J[0][0],INSERT_VALUES);CHKERRQ(ierr);
59   ierr    = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
60 
61   ierr = MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
62   ierr = MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
63   if (A != B) {
64     ierr = MatAssemblyBegin(B,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
65     ierr = MatAssemblyEnd(B,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
66   }
67   PetscFunctionReturn(0);
68 }
69 
70 static PetscErrorCode IJacobianP(TS ts,PetscReal t,Vec X,Vec Xdot,PetscReal shift,Mat A,void *ctx)
71 {
72   User              user = (User)ctx;
73   PetscInt          row[] = {0},col[]={0};
74   PetscScalar       J[1][1];
75   const PetscScalar *x,*xdot;
76   PetscReal         dt;
77   PetscErrorCode    ierr;
78 
79   PetscFunctionBeginUser;
80   ierr    = VecGetArrayRead(X,&x);CHKERRQ(ierr);
81   ierr    = VecGetArrayRead(Xdot,&xdot);CHKERRQ(ierr);
82   ierr    = TSGetTimeStep(ts,&dt);CHKERRQ(ierr);
83   if (user->der == 1) J[0][0] = xdot[0];
84   if (user->der == 2) J[0][0] = -x[0];
85   ierr    = MatSetValues(A,1,row,1,col,&J[0][0],INSERT_VALUES);CHKERRQ(ierr);
86   ierr    = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
87 
88   ierr = MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
89   ierr = MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
90   PetscFunctionReturn(0);
91 }
92 
93 int main(int argc,char **argv)
94 {
95   TS             ts;
96   PetscScalar    *x_ptr;
97   PetscMPIInt    size;
98   struct _n_User user;
99   PetscInt       rows,cols;
100   PetscErrorCode ierr;
101 
102   ierr = PetscInitialize(&argc,&argv,NULL,help);if (ierr) return ierr;
103 
104   ierr = MPI_Comm_size(PETSC_COMM_WORLD,&size);CHKERRMPI(ierr);
105   if (size != 1) SETERRQ(PETSC_COMM_WORLD,PETSC_ERR_WRONG_MPI_SIZE,"This is a uniprocessor example only!");
106 
107   user.a           = 2.0;
108   user.b           = 4.0;
109   user.c           = 3.0;
110   user.steps       = 0;
111   user.ftime       = 1.0;
112   user.der         = 1;
113   ierr = PetscOptionsGetInt(NULL,NULL,"-der",&user.der,NULL);CHKERRQ(ierr);
114 
115   rows = 1;
116   cols = 1;
117   ierr = MatCreate(PETSC_COMM_WORLD,&user.Jac);CHKERRQ(ierr);
118   ierr = MatSetSizes(user.Jac,PETSC_DECIDE,PETSC_DECIDE,1,1);CHKERRQ(ierr);
119   ierr = MatSetFromOptions(user.Jac);CHKERRQ(ierr);
120   ierr = MatSetUp(user.Jac);CHKERRQ(ierr);
121   ierr = MatCreateVecs(user.Jac,&user.x,NULL);CHKERRQ(ierr);
122 
123   ierr = TSCreate(PETSC_COMM_WORLD,&ts);CHKERRQ(ierr);
124   ierr = TSSetType(ts,TSBEULER);CHKERRQ(ierr);
125   ierr = TSSetIFunction(ts,NULL,IFunction,&user);CHKERRQ(ierr);
126   ierr = TSSetIJacobian(ts,user.Jac,user.Jac,IJacobian,&user);CHKERRQ(ierr);
127   ierr = TSSetExactFinalTime(ts,TS_EXACTFINALTIME_MATCHSTEP);CHKERRQ(ierr);
128   ierr = TSSetMaxTime(ts,user.ftime);CHKERRQ(ierr);
129 
130   ierr = VecGetArrayWrite(user.x,&x_ptr);CHKERRQ(ierr);
131   x_ptr[0] = user.a;
132   ierr = VecRestoreArrayWrite(user.x,&x_ptr);CHKERRQ(ierr);
133   ierr = TSSetTimeStep(ts,0.001);CHKERRQ(ierr);
134 
135   /* Set up forward sensitivity */
136   ierr = MatCreate(PETSC_COMM_WORLD,&user.Jacp);CHKERRQ(ierr);
137   ierr = MatSetSizes(user.Jacp,PETSC_DECIDE,PETSC_DECIDE,rows,cols);CHKERRQ(ierr);
138   ierr = MatSetFromOptions(user.Jacp);CHKERRQ(ierr);
139   ierr = MatSetUp(user.Jacp);CHKERRQ(ierr);
140   ierr = MatCreateDense(PETSC_COMM_WORLD,PETSC_DECIDE,PETSC_DECIDE,rows,cols,NULL,&user.sp);CHKERRQ(ierr);
141   ierr = MatZeroEntries(user.sp);CHKERRQ(ierr);
142   ierr = TSForwardSetSensitivities(ts,cols,user.sp);CHKERRQ(ierr);
143   ierr = TSSetIJacobianP(ts,user.Jacp,IJacobianP,&user);CHKERRQ(ierr);
144 
145   ierr = TSSetSaveTrajectory(ts);CHKERRQ(ierr);
146   ierr = TSSetFromOptions(ts);CHKERRQ(ierr);
147 
148   ierr = TSSolve(ts,user.x);CHKERRQ(ierr);
149   ierr = TSGetSolveTime(ts,&user.ftime);CHKERRQ(ierr);
150   ierr = TSGetStepNumber(ts,&user.steps);CHKERRQ(ierr);
151   ierr = VecGetArray(user.x,&x_ptr);CHKERRQ(ierr);
152   ierr = PetscPrintf(PETSC_COMM_WORLD,"\n ode solution %g\n",(double)PetscRealPart(x_ptr[0]));CHKERRQ(ierr);
153   ierr = VecRestoreArray(user.x,&x_ptr);CHKERRQ(ierr);
154   ierr = PetscPrintf(PETSC_COMM_WORLD,"\n analytical solution %g\n",(double)user.a*PetscExpReal(user.b/user.c*user.ftime));CHKERRQ(ierr);
155 
156   if (user.der == 1) {
157     ierr = PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. c %g\n",(double)-user.a*user.ftime*user.b/(user.c*user.c)*PetscExpReal(user.b/user.c*user.ftime));CHKERRQ(ierr);
158   }
159   if (user.der == 2) {
160     ierr = PetscPrintf(PETSC_COMM_WORLD,"\n analytical derivative w.r.t. b %g\n",user.a*user.ftime/user.c*PetscExpReal(user.b/user.c*user.ftime));CHKERRQ(ierr);
161   }
162   ierr = PetscPrintf(PETSC_COMM_WORLD,"\n forward sensitivity:\n");CHKERRQ(ierr);
163   ierr = MatView(user.sp,PETSC_VIEWER_STDOUT_WORLD);CHKERRQ(ierr);
164 
165   ierr = MatCreateVecs(user.Jac,&user.lambda[0],NULL);CHKERRQ(ierr);
166   /* Set initial conditions for the adjoint integration */
167   ierr = VecGetArrayWrite(user.lambda[0],&x_ptr);CHKERRQ(ierr);
168   x_ptr[0] = 1.0;
169   ierr = VecRestoreArrayWrite(user.lambda[0],&x_ptr);CHKERRQ(ierr);
170   ierr = MatCreateVecs(user.Jacp,&user.mup[0],NULL);CHKERRQ(ierr);
171   ierr = VecGetArrayWrite(user.mup[0],&x_ptr);CHKERRQ(ierr);
172   x_ptr[0] = 0.0;
173   ierr = VecRestoreArrayWrite(user.mup[0],&x_ptr);CHKERRQ(ierr);
174 
175   ierr = TSSetCostGradients(ts,1,user.lambda,user.mup);CHKERRQ(ierr);
176   ierr = TSAdjointSolve(ts);CHKERRQ(ierr);
177 
178   ierr = PetscPrintf(PETSC_COMM_WORLD,"\n adjoint sensitivity:\n");CHKERRQ(ierr);
179   ierr = VecView(user.mup[0],PETSC_VIEWER_STDOUT_WORLD);CHKERRQ(ierr);
180 
181   ierr = MatDestroy(&user.Jac);CHKERRQ(ierr);
182   ierr = MatDestroy(&user.sp);CHKERRQ(ierr);
183   ierr = MatDestroy(&user.Jacp);CHKERRQ(ierr);
184   ierr = VecDestroy(&user.x);CHKERRQ(ierr);
185   ierr = VecDestroy(&user.lambda[0]);CHKERRQ(ierr);
186   ierr = VecDestroy(&user.mup[0]);CHKERRQ(ierr);
187   ierr = TSDestroy(&ts);CHKERRQ(ierr);
188 
189   ierr = PetscFinalize();
190   return(ierr);
191 }
192 
193 /*TEST
194 
195     test:
196       args: -ts_type beuler
197 
198     test:
199       suffix: 2
200       args: -ts_type cn
201       output_file: output/ex23fwdadj_1.out
202 
203 TEST*/
204