xref: /petsc/src/ts/trajectory/impls/memory/trajmemory.c (revision 6780aa0c1f91fa337b07d1829de7f2ab9d39bc1a)
1 #define PRINTWHATTODO
2 #include <petsc/private/tsimpl.h>        /*I "petscts.h"  I*/
3 #include <petscsys.h>
4 
5 extern int wrap_revolve(int* check,int* capo,int* fine,int *snaps_in,int* info,int* rank);
6 
7 typedef struct _StackElement {
8   PetscInt  stepnum;
9   Vec       X;
10   Vec       *Y;
11   PetscReal time;
12   PetscReal timeprev;
13 } *StackElement;
14 
15 typedef struct _RevolveCTX {
16   PetscBool    reverseonestep;
17   PetscInt     snaps_in;
18   PetscInt     stepsleft;
19   PetscInt     check;
20   PetscInt     oldcapo;
21   PetscInt     capo;
22   PetscInt     fine;
23   PetscInt     info;
24 } RevolveCTX;
25 
26 typedef struct _Stack {
27   PetscBool    userevolve;
28   RevolveCTX   *rctx;
29   PetscInt     top;         /* The top of the stack */
30   PetscInt     max_cps; /* The maximum stack size */
31   PetscInt     numY;
32   MPI_Comm     comm;
33   StackElement *stack;      /* The storage */
34 } Stack;
35 
36 static PetscErrorCode StackCreate(MPI_Comm,Stack *,PetscInt,PetscInt);
37 static PetscErrorCode StackDestroy(Stack*);
38 static PetscErrorCode StackPush(Stack*,StackElement);
39 static PetscErrorCode StackPop(Stack*,StackElement*);
40 static PetscErrorCode StackTop(Stack*,StackElement*);
41 
42 #ifdef PRINTWHATTODO
43 static void printwhattodo(PetscInt whattodo,RevolveCTX *rctx)
44 {
45   switch(whattodo) {
46     case 1:
47       PetscPrintf(PETSC_COMM_WORLD,"Advance from %D to %D.\n",rctx->oldcapo,rctx->capo);
48       break;
49     case 2:
50       PetscPrintf(PETSC_COMM_WORLD,"Store in checkpoint number %D\n",rctx->check);
51       break;
52     case 3:
53       PetscPrintf(PETSC_COMM_WORLD,"First turn: Initialize adjoints and reverse first step.\n");
54       break;
55     case 4:
56       PetscPrintf(PETSC_COMM_WORLD,"Forward and reverse one step.\n");
57       break;
58     case 5:
59       PetscPrintf(PETSC_COMM_WORLD,"Restore in checkpoint number %D\n",rctx->check);
60       break;
61     case -1:
62       PetscPrintf(PETSC_COMM_WORLD,"Error!");
63       break;
64   }
65 }
66 #endif
67 
68 #undef __FUNCT__
69 #define __FUNCT__ "StackCreate"
70 static PetscErrorCode StackCreate(MPI_Comm comm,Stack *s,PetscInt size,PetscInt ny)
71 {
72   PetscErrorCode ierr;
73 
74   PetscFunctionBegin;
75   s->top         = -1;
76   s->max_cps = size;
77   s->comm        = comm;
78   s->numY        = ny;
79 
80   ierr = PetscMalloc1(s->max_cps*sizeof(StackElement),&s->stack);CHKERRQ(ierr);
81   ierr = PetscMemzero(s->stack,s->max_cps*sizeof(StackElement));CHKERRQ(ierr);
82   PetscFunctionReturn(0);
83 }
84 
85 #undef __FUNCT__
86 #define __FUNCT__ "StackDestroy"
87 static PetscErrorCode StackDestroy(Stack *s)
88 {
89   PetscInt       i;
90   PetscErrorCode ierr;
91 
92   PetscFunctionBegin;
93   if (s->top>-1) {
94     for (i=0;i<=s->top;i++) {
95       ierr = VecDestroy(&s->stack[i]->X);CHKERRQ(ierr);
96       ierr = VecDestroyVecs(s->numY,&s->stack[i]->Y);CHKERRQ(ierr);
97       ierr = PetscFree(s->stack[i]);CHKERRQ(ierr);
98     }
99   }
100   ierr = PetscFree(s->stack);CHKERRQ(ierr);
101   if (s->userevolve) {
102     ierr = PetscFree(s->rctx);CHKERRQ(ierr);
103   }
104   ierr = PetscFree(s);CHKERRQ(ierr);
105   PetscFunctionReturn(0);
106 }
107 
108 #undef __FUNCT__
109 #define __FUNCT__ "StackPush"
110 static PetscErrorCode StackPush(Stack *s,StackElement e)
111 {
112   PetscFunctionBegin;
113   if (s->top+1 >= s->max_cps) SETERRQ1(s->comm,PETSC_ERR_MEMC,"Maximum stack size (%D) exceeded",s->max_cps);
114   s->stack[++s->top] = e;
115   PetscFunctionReturn(0);
116 }
117 
118 #undef __FUNCT__
119 #define __FUNCT__ "StackPop"
120 static PetscErrorCode StackPop(Stack *s,StackElement *e)
121 {
122   PetscFunctionBegin;
123   if (s->top == -1) SETERRQ(s->comm,PETSC_ERR_MEMC,"Emptry stack");
124   *e = s->stack[s->top--];
125   PetscFunctionReturn(0);
126 }
127 
128 #undef __FUNCT__
129 #define __FUNCT__ "StackTop"
130 static PetscErrorCode StackTop(Stack *s,StackElement *e)
131 {
132   PetscFunctionBegin;
133   *e = s->stack[s->top];
134   PetscFunctionReturn(0);
135 }
136 
137 #undef __FUNCT__
138 #define __FUNCT__ "TSTrajectorySetMaxCheckpoints_Memory"
139 PetscErrorCode TSTrajectorySetMaxCheckpoints_Memory(TSTrajectory tj,PetscInt max_cps)
140 {
141   Stack      *s = (Stack*)tj->data;
142   PetscFunctionBegin;
143   s->max_cps = max_cps;
144   PetscFunctionReturn(0);
145 }
146 
147 #undef __FUNCT__
148 #define __FUNCT__ "TSTrajectorySetFromOptions_Memory"
149 PetscErrorCode TSTrajectorySetFromOptions_Memory(PetscOptions *PetscOptionsObject,TSTrajectory tj)
150 {
151   Stack     *s = (Stack*)tj->data;
152   PetscErrorCode ierr;
153 
154   PetscFunctionBegin;
155   ierr = PetscOptionsHead(PetscOptionsObject,"Memory based TS trajectory options");CHKERRQ(ierr);
156   {
157     ierr = PetscOptionsInt("-tstrajectory_max_cps","Maximum number of checkpoints","TSTrajectorySetMaxCheckpoints_Memory",s->max_cps,&s->max_cps,NULL);CHKERRQ(ierr);
158   }
159   ierr = PetscOptionsTail();CHKERRQ(ierr);
160   PetscFunctionReturn(0);
161 }
162 
163 #undef __FUNCT__
164 #define __FUNCT__ "TSTrajectorySet_Memory"
165 PetscErrorCode TSTrajectorySet_Memory(TSTrajectory tj,TS ts,PetscInt stepnum,PetscReal time,Vec X)
166 {
167   PetscInt       ns,i;
168   Vec            *Y;
169   PetscReal      timeprev;
170   StackElement   e;
171   Stack          *s = (Stack*)tj->data;
172   RevolveCTX     *rctx;
173   PetscInt       whattodo,rank;
174   PetscErrorCode ierr;
175 
176   PetscFunctionBegin;
177   if (stepnum<s->top) SETERRQ(s->comm,PETSC_ERR_MEMC,"Illegal modification of a non-top stack element");
178 
179   if (s->userevolve) {
180     rctx = s->rctx;
181     if (rctx->reverseonestep) PetscFunctionReturn(0);
182     if (rctx->stepsleft==0) { /* let the controller determine what to do next */
183       rctx->capo = stepnum;
184       rctx->oldcapo = rctx->capo;
185       whattodo = wrap_revolve(&rctx->check,&rctx->capo,&rctx->fine,&rctx->snaps_in,&rctx->info,&rank);
186 #ifdef PRINTWHATTODO
187       printwhattodo(whattodo,rctx);
188 #endif
189       if (whattodo==-1) SETERRQ(s->comm,PETSC_ERR_MEMC,"Error in the controller");
190       if (whattodo==1) {
191         rctx->stepsleft = rctx->capo-rctx->oldcapo-1;
192         PetscFunctionReturn(0); /* do not need to checkpoint */
193       }
194       if (whattodo==3 || whattodo==4) {
195         rctx->reverseonestep = PETSC_TRUE;
196         PetscFunctionReturn(0);
197       }
198       if (whattodo==5) {
199         rctx->oldcapo = rctx->capo;
200         whattodo = wrap_revolve(&rctx->check,&rctx->capo,&rctx->fine,&rctx->snaps_in,&rctx->info,&rank); /* must return 1*/
201 #ifdef PRINTWHATTODO
202         printwhattodo(whattodo,rctx);
203 #endif
204         rctx->stepsleft = rctx->capo-rctx->oldcapo;
205         PetscFunctionReturn(0);
206       }
207       if (whattodo==2) {
208         rctx->oldcapo = rctx->capo;
209         whattodo = wrap_revolve(&rctx->check,&rctx->capo,&rctx->fine,&rctx->snaps_in,&rctx->info,&rank); /* must return 1*/
210 #ifdef PRINTWHATTODO
211         printwhattodo(whattodo,rctx);
212 #endif
213         rctx->stepsleft = rctx->capo-rctx->oldcapo-1;
214       }
215     } else { /* advance s->stepsleft time steps without checkpointing */
216       rctx->stepsleft--;
217       PetscFunctionReturn(0);
218     }
219   }
220 
221   /* checkpoint to memmory */
222   if (stepnum==s->top) { /* overwrite the top checkpoint */
223     ierr = StackTop(s,&e);
224     ierr = VecCopy(X,e->X);CHKERRQ(ierr);
225     ierr = TSGetStages(ts,&ns,&Y);CHKERRQ(ierr);
226     for (i=0;i<ns;i++) {
227       ierr = VecCopy(Y[i],e->Y[i]);CHKERRQ(ierr);
228     }
229     e->stepnum  = stepnum;
230     e->time     = time;
231     ierr        = TSGetPrevTime(ts,&timeprev);CHKERRQ(ierr);
232     e->timeprev = timeprev;
233   } else {
234     ierr = PetscCalloc1(1,&e);
235     ierr = VecDuplicate(X,&e->X);CHKERRQ(ierr);
236     ierr = VecCopy(X,e->X);CHKERRQ(ierr);
237     ierr = TSGetStages(ts,&ns,&Y);CHKERRQ(ierr);
238     ierr = VecDuplicateVecs(Y[0],ns,&e->Y);CHKERRQ(ierr);
239     for (i=0;i<ns;i++) {
240       ierr = VecCopy(Y[i],e->Y[i]);CHKERRQ(ierr);
241     }
242     e->stepnum  = stepnum;
243     e->time     = time;
244     if (stepnum == 0) {
245       e->timeprev = e->time - ts->time_step; /* for consistency */
246     } else {
247       ierr        = TSGetPrevTime(ts,&timeprev);CHKERRQ(ierr);
248       e->timeprev = timeprev;
249     }
250     ierr        = StackPush(s,e);CHKERRQ(ierr);
251   }
252   PetscFunctionReturn(0);
253 }
254 
255 #undef __FUNCT__
256 #define __FUNCT__ "TSTrajectoryGet_Memory"
257 PetscErrorCode TSTrajectoryGet_Memory(TSTrajectory tj,TS ts,PetscInt stepnum,PetscReal *t)
258 {
259   Vec            *Y;
260   PetscInt       nr,i;
261   StackElement   e;
262   Stack          *s = (Stack*)tj->data;
263   RevolveCTX     *rctx;
264   PetscReal      stepsize;
265   PetscInt       whattodo,rank;
266   PetscErrorCode ierr;
267 
268   PetscFunctionBegin;
269   if (s->userevolve) rctx = s->rctx;
270   if (s->userevolve && rctx->reverseonestep) {
271     ierr = TSGetTimeStep(ts,&stepsize);CHKERRQ(ierr);
272     ierr = TSSetTimeStep(ts,-stepsize);CHKERRQ(ierr); /* go backward */
273     rctx->reverseonestep = PETSC_FALSE;
274     PetscFunctionReturn(0);
275   }
276 
277   /* restore a checkpoint */
278   ierr = StackTop(s,&e);CHKERRQ(ierr);
279   ierr = VecCopy(e->X,ts->vec_sol);CHKERRQ(ierr);
280   ierr = TSGetStages(ts,&nr,&Y);CHKERRQ(ierr);
281   for (i=0;i<nr ;i++) {
282     ierr = VecCopy(e->Y[i],Y[i]);CHKERRQ(ierr);
283   }
284   *t = e->time;
285 
286   if (e->stepnum < stepnum) { /* need recomputation */
287     rctx->capo = stepnum;
288     whattodo = wrap_revolve(&rctx->check,&rctx->capo,&rctx->fine,&rctx->snaps_in,&rctx->info,&rank);
289 #ifdef PRINTWHATTODO
290     printwhattodo(whattodo,rctx);
291 #endif
292     ierr = TSSetTimeStep(ts,(*t)-e->timeprev);CHKERRQ(ierr);
293     /* reset ts context */
294     PetscInt steps = ts->steps;
295     ts->steps      = e->stepnum;
296     ts->ptime      = e->time;
297     ts->ptime_prev = e->timeprev;
298     for (i=e->stepnum;i<stepnum;i++) { /* assume fixed step size */
299       ierr = TSTrajectorySet(ts->trajectory,ts,ts->steps,ts->ptime,ts->vec_sol);CHKERRQ(ierr);
300       ierr = TSMonitor(ts,ts->steps,ts->ptime,ts->vec_sol);CHKERRQ(ierr);
301       ierr = TSStep(ts);CHKERRQ(ierr);
302       if (ts->event) {
303         ierr = TSEventMonitor(ts);CHKERRQ(ierr);
304       }
305       if (!ts->steprollback) {
306         ierr = TSPostStep(ts);CHKERRQ(ierr);
307       }
308     }
309     /* reverseonestep must be true after the for loop */
310     ts->steps = steps;
311     ts->total_steps = stepnum;
312     ierr = TSGetTimeStep(ts,&stepsize);CHKERRQ(ierr);
313     ierr = TSSetTimeStep(ts,-stepsize);CHKERRQ(ierr); /* go backward */
314     if (stepnum-e->stepnum==1) {
315       ierr = StackPop(s,&e);CHKERRQ(ierr);
316       ierr = VecDestroy(&e->X);CHKERRQ(ierr);
317       ierr = VecDestroyVecs(s->numY,&e->Y);CHKERRQ(ierr);
318       ierr = PetscFree(e);CHKERRQ(ierr);
319     }
320     rctx->reverseonestep = PETSC_FALSE;
321   } else if (e->stepnum == stepnum) {
322     ierr = TSSetTimeStep(ts,-(*t)+e->timeprev);CHKERRQ(ierr); /* go backward */
323     ierr = StackPop(s,&e);CHKERRQ(ierr);
324     ierr = VecDestroy(&e->X);CHKERRQ(ierr);
325     ierr = VecDestroyVecs(s->numY,&e->Y);CHKERRQ(ierr);
326     ierr = PetscFree(e);CHKERRQ(ierr);
327   } else {
328     SETERRQ2(s->comm,PETSC_ERR_ARG_OUTOFRANGE,"The current step no. is %D, but the step number at top of the stack is %D",stepnum,e->stepnum);
329   }
330 
331   PetscFunctionReturn(0);
332 }
333 
334 #undef __FUNCT__
335 #define __FUNCT__ "TSTrajectoryDestroy_Memory"
336 PETSC_EXTERN PetscErrorCode TSTrajectoryDestroy_Memory(TSTrajectory tj)
337 {
338   Stack          *s = (Stack*)tj->data;
339   PetscErrorCode ierr;
340 
341   PetscFunctionBegin;
342   ierr = StackDestroy(s);CHKERRQ(ierr);
343   PetscFunctionReturn(0);
344 }
345 
346 /*MC
347       TSTRAJECTORYMEMORY - Stores each solution of the ODE/ADE in memory
348 
349   Level: intermediate
350 
351 .seealso:  TSTrajectoryCreate(), TS, TSTrajectorySetType()
352 
353 M*/
354 #undef __FUNCT__
355 #define __FUNCT__ "TSTrajectoryCreate_Memory"
356 PETSC_EXTERN PetscErrorCode TSTrajectoryCreate_Memory(TSTrajectory tj,TS ts)
357 {
358   PetscInt       nr,maxsteps;
359   Stack          *s;
360   RevolveCTX     *rctx;
361   PetscErrorCode ierr;
362 
363   PetscFunctionBegin;
364   tj->ops->set            = TSTrajectorySet_Memory;
365   tj->ops->get            = TSTrajectoryGet_Memory;
366   tj->ops->destroy        = TSTrajectoryDestroy_Memory;
367   tj->ops->setfromoptions = TSTrajectorySetFromOptions_Memory;
368 
369   ierr = PetscCalloc1(1,&s);CHKERRQ(ierr);
370   s->max_cps = 3; /* will be provided by users */
371   ierr = TSGetStages(ts,&nr,PETSC_IGNORE);CHKERRQ(ierr);
372 
373   maxsteps = PetscMin(ts->max_steps,(PetscInt)(ceil(ts->max_time/ts->time_step)));
374   if (s->max_cps-1<maxsteps) { /* Need to use revolve */
375     s->userevolve  = PETSC_TRUE;
376     ierr = PetscCalloc1(1,&rctx);CHKERRQ(ierr);
377     s->rctx = rctx;
378     rctx->snaps_in       = s->max_cps; /* for theta methods snaps_in=2*max_cps */
379     rctx->reverseonestep = PETSC_FALSE;
380     rctx->check          = -1;
381     rctx->oldcapo        = 0;
382     rctx->capo           = 0;
383     rctx->fine           = maxsteps;
384     rctx->info           = 2;
385     ierr = StackCreate(PetscObjectComm((PetscObject)ts),s,s->max_cps,nr);CHKERRQ(ierr);
386   } else { /* Enough space for checkpointing all time steps */
387     s->userevolve = PETSC_FALSE;
388     ierr = StackCreate(PetscObjectComm((PetscObject)ts),s,ts->max_steps+1,nr);CHKERRQ(ierr);
389   }
390   tj->data = s;
391   PetscFunctionReturn(0);
392 }
393