xref: /petsc/src/ts/trajectory/impls/memory/trajmemory.c (revision dd6eb164c85a5ce51aadd4dcb0bbe037d7ed934a)
1 #define PRINTWHATTODO
2 #include <petsc/private/tsimpl.h>        /*I "petscts.h"  I*/
3 #include <petscsys.h>
4 
5 extern int wrap_revolve(int* check,int* capo,int* fine,int *snaps_in,int* info,int* rank);
6 
7 typedef struct _StackElement {
8   PetscInt  stepnum;
9   Vec       X;
10   Vec       *Y;
11   PetscReal time;
12   PetscReal timeprev;
13 } *StackElement;
14 
15 typedef struct _Stack {
16    PetscBool    usecontroller;
17    PetscBool    reverseonestep;
18    PetscInt     stepsleft;
19    PetscInt     check;
20    PetscInt     oldcapo;
21    PetscInt     capo;
22    PetscInt     fine;
23    PetscInt     info;
24    PetscInt     top;         /* The top of the stack */
25    PetscInt     maxelements; /* The maximum stack size */
26    PetscInt     numY;
27    MPI_Comm     comm;
28    StackElement *stack;      /* The storage */
29  } Stack;
30 
31 static PetscErrorCode StackCreate(MPI_Comm,Stack *,PetscInt,PetscInt);
32 static PetscErrorCode StackDestroy(Stack*);
33 static PetscErrorCode StackPush(Stack*,StackElement);
34 static PetscErrorCode StackPop(Stack*,StackElement*);
35 static PetscErrorCode StackTop(Stack*,StackElement*);
36 
37 #ifdef PRINTWHATTODO
38 static void printwhattodo(PetscInt whattodo,Stack *s)
39 {
40   switch(whattodo) {
41     case 1:
42       PetscPrintf(PETSC_COMM_WORLD,"Advance from %D to %D.\n",s->oldcapo,s->capo);
43       break;
44     case 2:
45       PetscPrintf(PETSC_COMM_WORLD,"Store in checkpoint number %D\n",s->check);
46       break;
47     case 3:
48       PetscPrintf(PETSC_COMM_WORLD,"First turn: Initialize adjoints and reverse first step.\n");
49       break;
50     case 4:
51       PetscPrintf(PETSC_COMM_WORLD,"Forward and reverse one step.\n");
52       break;
53     case 5:
54       PetscPrintf(PETSC_COMM_WORLD,"Restore in checkpoint number %D\n",s->check);
55       break;
56     case -1:
57       PetscPrintf(PETSC_COMM_WORLD,"Error!");
58       break;
59   }
60 }
61 #endif
62 
63 #undef __FUNCT__
64 #define __FUNCT__ "StackCreate"
65 static PetscErrorCode StackCreate(MPI_Comm comm,Stack *s,PetscInt size,PetscInt ny)
66 {
67   PetscErrorCode ierr;
68 
69   PetscFunctionBegin;
70   s->top         = -1;
71   s->maxelements = size;
72   s->comm        = comm;
73   s->numY        = ny;
74 
75   ierr = PetscMalloc1(s->maxelements*sizeof(StackElement),&s->stack);CHKERRQ(ierr);
76   ierr = PetscMemzero(s->stack,s->maxelements*sizeof(StackElement));CHKERRQ(ierr);
77   PetscFunctionReturn(0);
78 }
79 
80 #undef __FUNCT__
81 #define __FUNCT__ "StackDestroy"
82 static PetscErrorCode StackDestroy(Stack *s)
83 {
84   PetscInt       i;
85   PetscErrorCode ierr;
86 
87   PetscFunctionBegin;
88   if (s->top>-1) {
89     for (i=0;i<=s->top;i++) {
90       ierr = VecDestroy(&s->stack[i]->X);CHKERRQ(ierr);
91       ierr = VecDestroyVecs(s->numY,&s->stack[i]->Y);CHKERRQ(ierr);
92       ierr = PetscFree(s->stack[i]);CHKERRQ(ierr);
93     }
94   }
95   ierr = PetscFree(s->stack);CHKERRQ(ierr);
96   ierr = PetscFree(s);CHKERRQ(ierr);
97   PetscFunctionReturn(0);
98 }
99 
100 #undef __FUNCT__
101 #define __FUNCT__ "StackPush"
102 static PetscErrorCode StackPush(Stack *s,StackElement e)
103 {
104   PetscFunctionBegin;
105   if (s->top+1 >= s->maxelements) SETERRQ1(s->comm,PETSC_ERR_MEMC,"Maximum stack size (%D) exceeded",s->maxelements);
106   s->stack[++s->top] = e;
107   PetscFunctionReturn(0);
108 }
109 
110 #undef __FUNCT__
111 #define __FUNCT__ "StackPop"
112 static PetscErrorCode StackPop(Stack *s,StackElement *e)
113 {
114   PetscFunctionBegin;
115   if (s->top == -1) SETERRQ(s->comm,PETSC_ERR_MEMC,"Emptry stack");
116   *e = s->stack[s->top--];
117   PetscFunctionReturn(0);
118 }
119 
120 #undef __FUNCT__
121 #define __FUNCT__ "StackTop"
122 static PetscErrorCode StackTop(Stack *s,StackElement *e)
123 {
124   PetscFunctionBegin;
125   *e = s->stack[s->top];
126   PetscFunctionReturn(0);
127 }
128 
129 #undef __FUNCT__
130 #define __FUNCT__ "TSTrajectorySet_Memory"
131 PetscErrorCode TSTrajectorySet_Memory(TSTrajectory tj,TS ts,PetscInt stepnum,PetscReal time,Vec X)
132 {
133   PetscInt       ns,i;
134   Vec            *Y;
135   PetscReal      timeprev;
136   StackElement   e;
137   Stack          *s = (Stack*)tj->data;
138   PetscInt       whattodo,rank;
139   PetscErrorCode ierr;
140 
141   PetscFunctionBegin;
142   if (stepnum<s->top) SETERRQ(s->comm,PETSC_ERR_MEMC,"Illegal modification of a non-top stack element");
143 
144   if (s->usecontroller) {
145     if (s->reverseonestep) PetscFunctionReturn(0);
146     if (s->stepsleft==0) { /* let the controller determine what to do next */
147       s->capo = stepnum;
148       s->oldcapo = s->capo;
149       whattodo = wrap_revolve(&s->check,&s->capo,&s->fine,&s->maxelements,&s->info,&rank);
150 #ifdef PRINTWHATTODO
151       printwhattodo(whattodo,s);
152 #endif
153       if (whattodo==-1) SETERRQ(s->comm,PETSC_ERR_MEMC,"Error in the controller");
154       if (whattodo==1) {
155         s->stepsleft = s->capo-s->oldcapo-1;
156         PetscFunctionReturn(0); /* do not need to checkpoint */
157       }
158       if (whattodo==3 || whattodo==4) {
159         s->reverseonestep = PETSC_TRUE;
160         PetscFunctionReturn(0);
161       }
162       if (whattodo==5) {
163         s->oldcapo = s->capo;
164         whattodo = wrap_revolve(&s->check,&s->capo,&s->fine,&s->maxelements,&s->info,&rank); /* must return 1*/
165 #ifdef PRINTWHATTODO
166         printwhattodo(whattodo,s);
167 #endif
168         s->stepsleft = s->capo-s->oldcapo;
169         PetscFunctionReturn(0);
170       }
171       if (whattodo==2) {
172         s->oldcapo = s->capo;
173         whattodo = wrap_revolve(&s->check,&s->capo,&s->fine,&s->maxelements,&s->info,&rank); /* must return 1*/
174 #ifdef PRINTWHATTODO
175         printwhattodo(whattodo,s);
176 #endif
177         s->stepsleft = s->capo-s->oldcapo-1;
178       }
179     } else { /* advance s->stepsleft time steps without checkpointing */
180       s->stepsleft--;
181       PetscFunctionReturn(0);
182     }
183   }
184 
185   /* checkpoint to memmory */
186   if (stepnum==s->top) { /* overwrite the top checkpoint */
187     ierr = StackTop(s,&e);
188     ierr = VecCopy(X,e->X);CHKERRQ(ierr);
189     ierr = TSGetStages(ts,&ns,&Y);CHKERRQ(ierr);
190     for (i=0;i<ns;i++) {
191       ierr = VecCopy(Y[i],e->Y[i]);CHKERRQ(ierr);
192     }
193     e->stepnum  = stepnum;
194     e->time     = time;
195     ierr        = TSGetPrevTime(ts,&timeprev);CHKERRQ(ierr);
196     e->timeprev = timeprev;
197   } else {
198     ierr = PetscCalloc1(1,&e);
199     ierr = VecDuplicate(X,&e->X);CHKERRQ(ierr);
200     ierr = VecCopy(X,e->X);CHKERRQ(ierr);
201     ierr = TSGetStages(ts,&ns,&Y);CHKERRQ(ierr);
202     ierr = VecDuplicateVecs(Y[0],ns,&e->Y);CHKERRQ(ierr);
203     for (i=0;i<ns;i++) {
204       ierr = VecCopy(Y[i],e->Y[i]);CHKERRQ(ierr);
205     }
206     e->stepnum  = stepnum;
207     e->time     = time;
208     if (stepnum == 0) {
209       e->timeprev = e->time - ts->time_step; /* for consistency */
210     } else {
211       ierr        = TSGetPrevTime(ts,&timeprev);CHKERRQ(ierr);
212       e->timeprev = timeprev;
213     }
214     ierr        = StackPush(s,e);CHKERRQ(ierr);
215   }
216   PetscFunctionReturn(0);
217 }
218 
219 #undef __FUNCT__
220 #define __FUNCT__ "TSTrajectoryGet_Memory"
221 PetscErrorCode TSTrajectoryGet_Memory(TSTrajectory tj,TS ts,PetscInt stepnum,PetscReal *t)
222 {
223   Vec            *Y;
224   PetscInt       nr,i;
225   StackElement   e;
226   Stack          *s = (Stack*)tj->data;
227   PetscReal      stepsize;
228   PetscInt       whattodo,rank;
229   PetscErrorCode ierr;
230 
231   PetscFunctionBegin;
232   if (s->usecontroller && s->reverseonestep) {
233     ierr = TSGetTimeStep(ts,&stepsize);CHKERRQ(ierr);
234     ierr = TSSetTimeStep(ts,-stepsize);CHKERRQ(ierr); /* go backward */
235     s->reverseonestep = PETSC_FALSE;
236     PetscFunctionReturn(0);
237   }
238 
239   /* restore a checkpoint */
240   ierr = StackTop(s,&e);CHKERRQ(ierr);
241   ierr = VecCopy(e->X,ts->vec_sol);CHKERRQ(ierr);
242   ierr = TSGetStages(ts,&nr,&Y);CHKERRQ(ierr);
243   for (i=0;i<nr ;i++) {
244     ierr = VecCopy(e->Y[i],Y[i]);CHKERRQ(ierr);
245   }
246   *t = e->time;
247 
248   if (e->stepnum < stepnum) { /* need recomputation */
249     s->capo = stepnum;
250     whattodo = wrap_revolve(&s->check,&s->capo,&s->fine,&s->maxelements,&s->info,&rank);
251 #ifdef PRINTWHATTODO
252     printwhattodo(whattodo,s);
253 #endif
254     ierr = TSSetTimeStep(ts,(*t)-e->timeprev);CHKERRQ(ierr);
255     /* reset ts context */
256     PetscInt steps = ts->steps;
257     ts->steps      = e->stepnum;
258     ts->ptime      = e->time;
259     ts->ptime_prev = e->timeprev;
260     for (i=e->stepnum;i<stepnum;i++) { /* assume fixed step size */
261       ierr = TSTrajectorySet(ts->trajectory,ts,ts->steps,ts->ptime,ts->vec_sol);CHKERRQ(ierr);
262       ierr = TSMonitor(ts,ts->steps,ts->ptime,ts->vec_sol);CHKERRQ(ierr);
263       ierr = TSStep(ts);CHKERRQ(ierr);
264       if (ts->event) {
265         ierr = TSEventMonitor(ts);CHKERRQ(ierr);
266       }
267       if (!ts->steprollback) {
268         ierr = TSPostStep(ts);CHKERRQ(ierr);
269       }
270     }
271     /* reverseonestep must be true after the for loop */
272     ts->steps = steps;
273     ts->total_steps = stepnum;
274     ierr = TSGetTimeStep(ts,&stepsize);CHKERRQ(ierr);
275     ierr = TSSetTimeStep(ts,-stepsize);CHKERRQ(ierr); /* go backward */
276     if (stepnum-e->stepnum==1) {
277       ierr = StackPop(s,&e);CHKERRQ(ierr);
278       ierr = VecDestroy(&e->X);CHKERRQ(ierr);
279       ierr = VecDestroyVecs(s->numY,&e->Y);CHKERRQ(ierr);
280       ierr = PetscFree(e);CHKERRQ(ierr);
281     }
282     s->reverseonestep = PETSC_FALSE;
283   } else if (e->stepnum == stepnum) {
284     ierr = TSSetTimeStep(ts,-(*t)+e->timeprev);CHKERRQ(ierr); /* go backward */
285     ierr = StackPop(s,&e);CHKERRQ(ierr);
286     ierr = VecDestroy(&e->X);CHKERRQ(ierr);
287     ierr = VecDestroyVecs(s->numY,&e->Y);CHKERRQ(ierr);
288     ierr = PetscFree(e);CHKERRQ(ierr);
289   } else {
290     SETERRQ2(s->comm,PETSC_ERR_ARG_OUTOFRANGE,"The current step no. is %D, but the step number at top of the stack is %D",stepnum,e->stepnum);
291   }
292 
293   PetscFunctionReturn(0);
294 }
295 
296 #undef __FUNCT__
297 #define __FUNCT__ "TSTrajectoryDestroy_Memory"
298 PETSC_EXTERN PetscErrorCode TSTrajectoryDestroy_Memory(TSTrajectory tj)
299 {
300   Stack          *s = (Stack*)tj->data;
301   PetscErrorCode ierr;
302 
303   PetscFunctionBegin;
304   ierr = StackDestroy(s);CHKERRQ(ierr);
305   PetscFunctionReturn(0);
306 }
307 
308 /*MC
309       TSTRAJECTORYMEMORY - Stores each solution of the ODE/ADE in memory
310 
311   Level: intermediate
312 
313 .seealso:  TSTrajectoryCreate(), TS, TSTrajectorySetType()
314 
315 M*/
316 #undef __FUNCT__
317 #define __FUNCT__ "TSTrajectoryCreate_Memory"
318 PETSC_EXTERN PetscErrorCode TSTrajectoryCreate_Memory(TSTrajectory tj,TS ts)
319 {
320   PetscInt       nr,maxsteps;
321   Stack          *s;
322   PetscErrorCode ierr;
323 
324   PetscFunctionBegin;
325   tj->ops->set     = TSTrajectorySet_Memory;
326   tj->ops->get     = TSTrajectoryGet_Memory;
327   tj->ops->destroy = TSTrajectoryDestroy_Memory;
328 
329   ierr = PetscCalloc1(1,&s);
330   s->maxelements = 3; /* will be provided by users */
331   ierr = TSGetStages(ts,&nr,PETSC_IGNORE);CHKERRQ(ierr);
332 
333   maxsteps = PetscMin(ts->max_steps,(PetscInt)(ceil(ts->max_time/ts->time_step)));
334   if (s->maxelements-1<maxsteps) { /* Need to use a controller */
335     s->usecontroller  = PETSC_TRUE;
336     s->reverseonestep = PETSC_FALSE;
337     s->check          = -1;
338     s->oldcapo        = 0;
339     s->capo           = 0;
340     s->fine           = maxsteps;
341     s->info           = 2;
342     ierr = StackCreate(PetscObjectComm((PetscObject)ts),s,s->maxelements,nr);CHKERRQ(ierr);
343   } else { /* Enough space for checkpointing all time steps */
344     s->usecontroller = PETSC_FALSE;
345     ierr = StackCreate(PetscObjectComm((PetscObject)ts),s,ts->max_steps+1,nr);CHKERRQ(ierr);
346   }
347   tj->data = s;
348   PetscFunctionReturn(0);
349 }
350