xref: /petsc/src/ts/impls/implicit/irk/irk.c (revision 792fecdfe9134cce4d631112660ddd34f063bc17)
1 /*
2   Code for timestepping with implicit Runge-Kutta method
3 
4   Notes:
5   The general system is written as
6 
7   F(t,U,Udot) = 0
8 
9 */
10 #include <petsc/private/tsimpl.h>                /*I   "petscts.h"   I*/
11 #include <petscdm.h>
12 #include <petscdt.h>
13 
14 static TSIRKType         TSIRKDefault = TSIRKGAUSS;
15 static PetscBool         TSIRKRegisterAllCalled;
16 static PetscBool         TSIRKPackageInitialized;
17 static PetscFunctionList TSIRKList;
18 
19 struct _IRKTableau{
20   PetscReal   *A,*b,*c;
21   PetscScalar *A_inv,*A_inv_rowsum,*I_s;
22   PetscReal   *binterp;   /* Dense output formula */
23 };
24 
25 typedef struct _IRKTableau *IRKTableau;
26 
27 typedef struct {
28   char         *method_name;
29   PetscInt     order;            /* Classical approximation order of the method */
30   PetscInt     nstages;          /* Number of stages */
31   PetscBool    stiffly_accurate;
32   PetscInt     pinterp;          /* Interpolation order */
33   IRKTableau   tableau;
34   Vec          U0;               /* Backup vector */
35   Vec          Z;                /* Combined stage vector */
36   Vec          *Y;               /* States computed during the step */
37   Vec          Ydot;             /* Work vector holding time derivatives during residual evaluation */
38   Vec          U;                /* U is used to compute Ydot = shift(Y-U) */
39   Vec          *YdotI;           /* Work vectors to hold the residual evaluation */
40   Mat          TJ;               /* KAIJ matrix for the Jacobian of the combined system */
41   PetscScalar  *work;            /* Scalar work */
42   TSStepStatus status;
43   PetscBool    rebuild_completion;
44   PetscReal    ccfl;
45 } TS_IRK;
46 
47 /*@C
48    TSIRKTableauCreate - create the tableau for TSIRK and provide the entries
49 
50    Not Collective
51 
52    Input Parameters:
53 +  ts - timestepping context
54 .  nstages - number of stages, this is the dimension of the matrices below
55 .  A - stage coefficients (dimension nstages*nstages, row-major)
56 .  b - step completion table (dimension nstages)
57 .  c - abscissa (dimension nstages)
58 .  binterp - coefficients of the interpolation formula (dimension nstages)
59 .  A_inv - inverse of A (dimension nstages*nstages, row-major)
60 .  A_inv_rowsum - row sum of the inverse of A (dimension nstages)
61 -  I_s - identity matrix (dimension nstages*nstages)
62 
63    Level: advanced
64 
65 .seealso: `TSIRK`, `TSIRKRegister()`
66 @*/
67 PetscErrorCode TSIRKTableauCreate(TS ts,PetscInt nstages,const PetscReal *A,const PetscReal *b,const PetscReal *c,const PetscReal *binterp,const PetscScalar *A_inv,const PetscScalar *A_inv_rowsum,const PetscScalar *I_s)
68 {
69   TS_IRK         *irk = (TS_IRK*)ts->data;
70   IRKTableau     tab = irk->tableau;
71 
72   PetscFunctionBegin;
73   irk->order = nstages;
74   PetscCall(PetscMalloc3(PetscSqr(nstages),&tab->A,PetscSqr(nstages),&tab->A_inv,PetscSqr(nstages),&tab->I_s));
75   PetscCall(PetscMalloc4(nstages,&tab->b,nstages,&tab->c,nstages,&tab->binterp,nstages,&tab->A_inv_rowsum));
76   PetscCall(PetscArraycpy(tab->A,A,PetscSqr(nstages)));
77   PetscCall(PetscArraycpy(tab->b,b,nstages));
78   PetscCall(PetscArraycpy(tab->c,c,nstages));
79   /* optional coefficient arrays */
80   if (binterp) PetscCall(PetscArraycpy(tab->binterp,binterp,nstages));
81   if (A_inv) PetscCall(PetscArraycpy(tab->A_inv,A_inv,PetscSqr(nstages)));
82   if (A_inv_rowsum) PetscCall(PetscArraycpy(tab->A_inv_rowsum,A_inv_rowsum,nstages));
83   if (I_s) PetscCall(PetscArraycpy(tab->I_s,I_s,PetscSqr(nstages)));
84   PetscFunctionReturn(0);
85 }
86 
87 /* Arrays should be freed with PetscFree3(A,b,c) */
88 static PetscErrorCode TSIRKCreate_Gauss(TS ts)
89 {
90   PetscInt       nstages;
91   PetscReal      *gauss_A_real,*gauss_b,*b,*gauss_c;
92   PetscScalar    *gauss_A,*gauss_A_inv,*gauss_A_inv_rowsum,*I_s;
93   PetscScalar    *G0,*G1;
94   PetscInt       i,j;
95   Mat            G0mat,G1mat,Amat;
96 
97   PetscFunctionBegin;
98   PetscCall(TSIRKGetNumStages(ts,&nstages));
99   PetscCall(PetscMalloc3(PetscSqr(nstages),&gauss_A_real,nstages,&gauss_b,nstages,&gauss_c));
100   PetscCall(PetscMalloc4(PetscSqr(nstages),&gauss_A,PetscSqr(nstages),&gauss_A_inv,nstages,&gauss_A_inv_rowsum,PetscSqr(nstages),&I_s));
101   PetscCall(PetscMalloc3(nstages,&b,PetscSqr(nstages),&G0,PetscSqr(nstages),&G1));
102   PetscCall(PetscDTGaussQuadrature(nstages,0.,1.,gauss_c,b));
103   for (i=0; i<nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */
104 
105   /* A^T = G0^{-1} G1 */
106   for (i=0; i<nstages; i++) {
107     for (j=0; j<nstages; j++) {
108       G0[i*nstages+j] = PetscPowRealInt(gauss_c[i],j);
109       G1[i*nstages+j] = PetscPowRealInt(gauss_c[i],j+1)/(j+1);
110     }
111   }
112   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
113   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,G0,&G0mat));
114   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,G1,&G1mat));
115   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,gauss_A,&Amat));
116   PetscCall(MatLUFactor(G0mat,NULL,NULL,NULL));
117   PetscCall(MatMatSolve(G0mat,G1mat,Amat));
118   PetscCall(MatTranspose(Amat,MAT_INPLACE_MATRIX,&Amat));
119   for (i=0; i<nstages; i++)
120     for (j=0; j<nstages; j++)
121       gauss_A_real[i*nstages+j] = PetscRealPart(gauss_A[i*nstages+j]);
122 
123   PetscCall(MatDestroy(&G0mat));
124   PetscCall(MatDestroy(&G1mat));
125   PetscCall(MatDestroy(&Amat));
126   PetscCall(PetscFree3(b,G0,G1));
127 
128   {/* Invert A */
129     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
130      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
131      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
132     Mat               A_baij;
133     PetscInt          idxm[1]={0},idxn[1]={0};
134     const PetscScalar *A_inv;
135 
136     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF,nstages,nstages,nstages,1,NULL,&A_baij));
137     PetscCall(MatSetOption(A_baij,MAT_ROW_ORIENTED,PETSC_FALSE));
138     PetscCall(MatSetValuesBlocked(A_baij,1,idxm,1,idxn,gauss_A,INSERT_VALUES));
139     PetscCall(MatAssemblyBegin(A_baij,MAT_FINAL_ASSEMBLY));
140     PetscCall(MatAssemblyEnd(A_baij,MAT_FINAL_ASSEMBLY));
141     PetscCall(MatInvertBlockDiagonal(A_baij,&A_inv));
142     PetscCall(PetscMemcpy(gauss_A_inv,A_inv,nstages*nstages*sizeof(PetscScalar)));
143     PetscCall(MatDestroy(&A_baij));
144   }
145 
146   /* Compute row sums A_inv_rowsum and identity I_s */
147   for (i=0; i<nstages; i++) {
148     gauss_A_inv_rowsum[i] = 0;
149     for (j=0; j<nstages; j++) {
150       gauss_A_inv_rowsum[i] += gauss_A_inv[i+nstages*j];
151       I_s[i+nstages*j] = 1.*(i == j);
152     }
153   }
154   PetscCall(TSIRKTableauCreate(ts,nstages,gauss_A_real,gauss_b,gauss_c,NULL,gauss_A_inv,gauss_A_inv_rowsum,I_s));
155   PetscCall(PetscFree3(gauss_A_real,gauss_b,gauss_c));
156   PetscCall(PetscFree4(gauss_A,gauss_A_inv,gauss_A_inv_rowsum,I_s));
157   PetscFunctionReturn(0);
158 }
159 
160 /*@C
161    TSIRKRegister -  adds a TSIRK implementation
162 
163    Not Collective
164 
165    Input Parameters:
166 +  sname - name of user-defined IRK scheme
167 -  function - function to create method context
168 
169    Notes:
170    TSIRKRegister() may be called multiple times to add several user-defined families.
171 
172    Sample usage:
173 .vb
174    TSIRKRegister("my_scheme",MySchemeCreate);
175 .ve
176 
177    Then, your scheme can be chosen with the procedural interface via
178 $     TSIRKSetType(ts,"my_scheme")
179    or at runtime via the option
180 $     -ts_irk_type my_scheme
181 
182    Level: advanced
183 
184 .seealso: `TSIRKRegisterAll()`
185 @*/
186 PetscErrorCode TSIRKRegister(const char sname[],PetscErrorCode (*function)(TS))
187 {
188   PetscFunctionBegin;
189   PetscCall(TSIRKInitializePackage());
190   PetscCall(PetscFunctionListAdd(&TSIRKList,sname,function));
191   PetscFunctionReturn(0);
192 }
193 
194 /*@C
195   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in TSIRK
196 
197   Not Collective, but should be called by all processes which will need the schemes to be registered
198 
199   Level: advanced
200 
201 .seealso: `TSIRKRegisterDestroy()`
202 @*/
203 PetscErrorCode TSIRKRegisterAll(void)
204 {
205   PetscFunctionBegin;
206   if (TSIRKRegisterAllCalled) PetscFunctionReturn(0);
207   TSIRKRegisterAllCalled = PETSC_TRUE;
208 
209   PetscCall(TSIRKRegister(TSIRKGAUSS,TSIRKCreate_Gauss));
210   PetscFunctionReturn(0);
211 }
212 
213 /*@C
214    TSIRKRegisterDestroy - Frees the list of schemes that were registered by TSIRKRegister().
215 
216    Not Collective
217 
218    Level: advanced
219 
220 .seealso: `TSIRKRegister()`, `TSIRKRegisterAll()`
221 @*/
222 PetscErrorCode TSIRKRegisterDestroy(void)
223 {
224   PetscFunctionBegin;
225   TSIRKRegisterAllCalled = PETSC_FALSE;
226   PetscFunctionReturn(0);
227 }
228 
229 /*@C
230   TSIRKInitializePackage - This function initializes everything in the TSIRK package. It is called
231   from TSInitializePackage().
232 
233   Level: developer
234 
235 .seealso: `PetscInitialize()`
236 @*/
237 PetscErrorCode TSIRKInitializePackage(void)
238 {
239   PetscFunctionBegin;
240   if (TSIRKPackageInitialized) PetscFunctionReturn(0);
241   TSIRKPackageInitialized = PETSC_TRUE;
242   PetscCall(TSIRKRegisterAll());
243   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
244   PetscFunctionReturn(0);
245 }
246 
247 /*@C
248   TSIRKFinalizePackage - This function destroys everything in the TSIRK package. It is
249   called from PetscFinalize().
250 
251   Level: developer
252 
253 .seealso: `PetscFinalize()`
254 @*/
255 PetscErrorCode TSIRKFinalizePackage(void)
256 {
257   PetscFunctionBegin;
258   PetscCall(PetscFunctionListDestroy(&TSIRKList));
259   TSIRKPackageInitialized = PETSC_FALSE;
260   PetscFunctionReturn(0);
261 }
262 
263 /*
264  This function can be called before or after ts->vec_sol has been updated.
265 */
266 static PetscErrorCode TSEvaluateStep_IRK(TS ts,PetscInt order,Vec U,PetscBool *done)
267 {
268   TS_IRK         *irk = (TS_IRK*)ts->data;
269   IRKTableau     tab = irk->tableau;
270   Vec            *YdotI = irk->YdotI;
271   PetscScalar    *w = irk->work;
272   PetscReal      h;
273   PetscInt       j;
274 
275   PetscFunctionBegin;
276   switch (irk->status) {
277   case TS_STEP_INCOMPLETE:
278   case TS_STEP_PENDING:
279     h = ts->time_step; break;
280   case TS_STEP_COMPLETE:
281     h = ts->ptime - ts->ptime_prev; break;
282   default: SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_PLIB,"Invalid TSStepStatus");
283   }
284 
285   PetscCall(VecCopy(ts->vec_sol,U));
286   for (j=0; j<irk->nstages; j++) w[j] = h*tab->b[j];
287   PetscCall(VecMAXPY(U,irk->nstages,w,YdotI));
288   PetscFunctionReturn(0);
289 }
290 
291 static PetscErrorCode TSRollBack_IRK(TS ts)
292 {
293   TS_IRK         *irk = (TS_IRK*)ts->data;
294 
295   PetscFunctionBegin;
296   PetscCall(VecCopy(irk->U0,ts->vec_sol));
297   PetscFunctionReturn(0);
298 }
299 
300 static PetscErrorCode TSStep_IRK(TS ts)
301 {
302   TS_IRK          *irk = (TS_IRK*)ts->data;
303   IRKTableau      tab = irk->tableau;
304   PetscScalar     *A_inv = tab->A_inv,*A_inv_rowsum = tab->A_inv_rowsum;
305   const PetscInt  nstages = irk->nstages;
306   SNES            snes;
307   PetscInt        i,j,its,lits,bs;
308   TSAdapt         adapt;
309   PetscInt        rejections = 0;
310   PetscBool       accept = PETSC_TRUE;
311   PetscReal       next_time_step = ts->time_step;
312 
313   PetscFunctionBegin;
314   if (!ts->steprollback) {
315     PetscCall(VecCopy(ts->vec_sol,irk->U0));
316   }
317   PetscCall(VecGetBlockSize(ts->vec_sol,&bs));
318   for (i=0; i<nstages; i++) {
319     PetscCall(VecStrideScatter(ts->vec_sol,i*bs,irk->Z,INSERT_VALUES));
320   }
321 
322   irk->status = TS_STEP_INCOMPLETE;
323   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
324     PetscCall(VecCopy(ts->vec_sol,irk->U));
325     PetscCall(TSGetSNES(ts,&snes));
326     PetscCall(SNESSolve(snes,NULL,irk->Z));
327     PetscCall(SNESGetIterationNumber(snes,&its));
328     PetscCall(SNESGetLinearSolveIterations(snes,&lits));
329     ts->snes_its += its; ts->ksp_its += lits;
330     PetscCall(VecStrideGatherAll(irk->Z,irk->Y,INSERT_VALUES));
331     for (i=0; i<nstages; i++) {
332       PetscCall(VecZeroEntries(irk->YdotI[i]));
333       for (j=0; j<nstages; j++) {
334         PetscCall(VecAXPY(irk->YdotI[i],A_inv[i+j*nstages]/ts->time_step,irk->Y[j]));
335       }
336       PetscCall(VecAXPY(irk->YdotI[i],-A_inv_rowsum[i]/ts->time_step,irk->U));
337     }
338     irk->status = TS_STEP_INCOMPLETE;
339     PetscCall(TSEvaluateStep_IRK(ts,irk->order,ts->vec_sol,NULL));
340     irk->status = TS_STEP_PENDING;
341     PetscCall(TSGetAdapt(ts,&adapt));
342     PetscCall(TSAdaptChoose(adapt,ts,ts->time_step,NULL,&next_time_step,&accept));
343     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
344     if (!accept) {
345       PetscCall(TSRollBack_IRK(ts));
346       ts->time_step = next_time_step;
347       goto reject_step;
348     }
349 
350     ts->ptime += ts->time_step;
351     ts->time_step = next_time_step;
352     break;
353   reject_step:
354     ts->reject++; accept = PETSC_FALSE;
355     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
356       ts->reason = TS_DIVERGED_STEP_REJECTED;
357       PetscCall(PetscInfo(ts,"Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n",ts->steps,rejections));
358     }
359   }
360   PetscFunctionReturn(0);
361 }
362 
363 static PetscErrorCode TSInterpolate_IRK(TS ts,PetscReal itime,Vec U)
364 {
365   TS_IRK          *irk = (TS_IRK*)ts->data;
366   PetscInt        nstages = irk->nstages,pinterp = irk->pinterp,i,j;
367   PetscReal       h;
368   PetscReal       tt,t;
369   PetscScalar     *bt;
370   const PetscReal *B = irk->tableau->binterp;
371 
372   PetscFunctionBegin;
373   PetscCheck(B,PetscObjectComm((PetscObject)ts),PETSC_ERR_SUP,"TSIRK %s does not have an interpolation formula",irk->method_name);
374   switch (irk->status) {
375   case TS_STEP_INCOMPLETE:
376   case TS_STEP_PENDING:
377     h = ts->time_step;
378     t = (itime - ts->ptime)/h;
379     break;
380   case TS_STEP_COMPLETE:
381     h = ts->ptime - ts->ptime_prev;
382     t = (itime - ts->ptime)/h + 1; /* In the interval [0,1] */
383     break;
384   default: SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_PLIB,"Invalid TSStepStatus");
385   }
386   PetscCall(PetscMalloc1(nstages,&bt));
387   for (i=0; i<nstages; i++) bt[i] = 0;
388   for (j=0,tt=t; j<pinterp; j++,tt*=t) {
389     for (i=0; i<nstages; i++) {
390       bt[i] += h * B[i*pinterp+j] * tt;
391     }
392   }
393   PetscCall(VecMAXPY(U,nstages,bt,irk->YdotI));
394   PetscFunctionReturn(0);
395 }
396 
397 static PetscErrorCode TSIRKTableauReset(TS ts)
398 {
399   TS_IRK         *irk = (TS_IRK*)ts->data;
400   IRKTableau     tab = irk->tableau;
401 
402   PetscFunctionBegin;
403   if (!tab) PetscFunctionReturn(0);
404   PetscCall(PetscFree3(tab->A,tab->A_inv,tab->I_s));
405   PetscCall(PetscFree4(tab->b,tab->c,tab->binterp,tab->A_inv_rowsum));
406   PetscFunctionReturn(0);
407 }
408 
409 static PetscErrorCode TSReset_IRK(TS ts)
410 {
411   TS_IRK         *irk = (TS_IRK*)ts->data;
412 
413   PetscFunctionBegin;
414   PetscCall(TSIRKTableauReset(ts));
415   if (irk->tableau) PetscCall(PetscFree(irk->tableau));
416   if (irk->method_name) PetscCall(PetscFree(irk->method_name));
417   if (irk->work) PetscCall(PetscFree(irk->work));
418   PetscCall(VecDestroyVecs(irk->nstages,&irk->Y));
419   PetscCall(VecDestroyVecs(irk->nstages,&irk->YdotI));
420   PetscCall(VecDestroy(&irk->Ydot));
421   PetscCall(VecDestroy(&irk->Z));
422   PetscCall(VecDestroy(&irk->U));
423   PetscCall(VecDestroy(&irk->U0));
424   PetscCall(MatDestroy(&irk->TJ));
425   PetscFunctionReturn(0);
426 }
427 
428 static PetscErrorCode TSIRKGetVecs(TS ts,DM dm,Vec *U)
429 {
430   TS_IRK         *irk = (TS_IRK*)ts->data;
431 
432   PetscFunctionBegin;
433   if (U) {
434     if (dm && dm != ts->dm) {
435       PetscCall(DMGetNamedGlobalVector(dm,"TSIRK_U",U));
436     } else *U = irk->U;
437   }
438   PetscFunctionReturn(0);
439 }
440 
441 static PetscErrorCode TSIRKRestoreVecs(TS ts,DM dm,Vec *U)
442 {
443   PetscFunctionBegin;
444   if (U) {
445     if (dm && dm != ts->dm) {
446       PetscCall(DMRestoreNamedGlobalVector(dm,"TSIRK_U",U));
447     }
448   }
449   PetscFunctionReturn(0);
450 }
451 
452 /*
453   This defines the nonlinear equations that is to be solved with SNES
454     G[e\otimes t + C*dt, Z, Zdot] = 0
455     Zdot = (In \otimes S)*Z - (In \otimes Se) U
456   where S = 1/(dt*A)
457 */
458 static PetscErrorCode SNESTSFormFunction_IRK(SNES snes,Vec ZC,Vec FC,TS ts)
459 {
460   TS_IRK            *irk = (TS_IRK*)ts->data;
461   IRKTableau        tab  = irk->tableau;
462   const PetscInt    nstages = irk->nstages;
463   const PetscReal   *c = tab->c;
464   const PetscScalar *A_inv = tab->A_inv,*A_inv_rowsum = tab->A_inv_rowsum;
465   DM                dm,dmsave;
466   Vec               U,*YdotI = irk->YdotI,Ydot = irk->Ydot,*Y = irk->Y;
467   PetscReal         h = ts->time_step;
468   PetscInt          i,j;
469 
470   PetscFunctionBegin;
471   PetscCall(SNESGetDM(snes,&dm));
472   PetscCall(TSIRKGetVecs(ts,dm,&U));
473   PetscCall(VecStrideGatherAll(ZC,Y,INSERT_VALUES));
474   dmsave = ts->dm;
475   ts->dm = dm;
476   for (i=0; i<nstages; i++) {
477     PetscCall(VecZeroEntries(Ydot));
478     for (j=0; j<nstages; j++) {
479       PetscCall(VecAXPY(Ydot,A_inv[j*nstages+i]/h,Y[j]));
480     }
481     PetscCall(VecAXPY(Ydot,-A_inv_rowsum[i]/h,U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
482     PetscCall(TSComputeIFunction(ts,ts->ptime+ts->time_step*c[i],Y[i],Ydot,YdotI[i],PETSC_FALSE));
483   }
484   PetscCall(VecStrideScatterAll(YdotI,FC,INSERT_VALUES));
485   ts->dm = dmsave;
486   PetscCall(TSIRKRestoreVecs(ts,dm,&U));
487   PetscFunctionReturn(0);
488 }
489 
490 /*
491    For explicit ODE, the Jacobian is
492      JC = I_n \otimes S - J \otimes I_s
493    For DAE, the Jacobian is
494      JC = M_n \otimes S - J \otimes I_s
495 */
496 static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes,Vec ZC,Mat JC,Mat JCpre,TS ts)
497 {
498   TS_IRK          *irk = (TS_IRK*)ts->data;
499   IRKTableau      tab  = irk->tableau;
500   const PetscInt  nstages = irk->nstages;
501   const PetscReal *c = tab->c;
502   DM              dm,dmsave;
503   Vec             *Y = irk->Y,Ydot = irk->Ydot;
504   Mat             J;
505   PetscScalar     *S;
506   PetscInt        i,j,bs;
507 
508   PetscFunctionBegin;
509   PetscCall(SNESGetDM(snes,&dm));
510   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
511   dmsave = ts->dm;
512   ts->dm = dm;
513   PetscCall(VecGetBlockSize(Y[nstages-1],&bs));
514   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
515     PetscCall(VecStrideGather(ZC,(nstages-1)*bs,Y[nstages-1],INSERT_VALUES));
516     PetscCall(MatKAIJGetAIJ(JC,&J));
517     PetscCall(TSComputeIJacobian(ts,ts->ptime+ts->time_step*c[nstages-1],Y[nstages-1],Ydot,0,J,J,PETSC_FALSE));
518     PetscCall(MatKAIJGetS(JC,NULL,NULL,&S));
519     for (i=0; i<nstages; i++)
520       for (j=0; j<nstages; j++)
521         S[i+nstages*j] = tab->A_inv[i+nstages*j]/ts->time_step;
522     PetscCall(MatKAIJRestoreS(JC,&S));
523   } else SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_SUP,"TSIRK %s does not support implicit formula",irk->method_name); /* TODO: need the mass matrix for DAE  */
524   ts->dm = dmsave;
525   PetscFunctionReturn(0);
526 }
527 
528 static PetscErrorCode DMCoarsenHook_TSIRK(DM fine,DM coarse,void *ctx)
529 {
530   PetscFunctionBegin;
531   PetscFunctionReturn(0);
532 }
533 
534 static PetscErrorCode DMRestrictHook_TSIRK(DM fine,Mat restrct,Vec rscale,Mat inject,DM coarse,void *ctx)
535 {
536   TS             ts = (TS)ctx;
537   Vec            U,U_c;
538 
539   PetscFunctionBegin;
540   PetscCall(TSIRKGetVecs(ts,fine,&U));
541   PetscCall(TSIRKGetVecs(ts,coarse,&U_c));
542   PetscCall(MatRestrict(restrct,U,U_c));
543   PetscCall(VecPointwiseMult(U_c,rscale,U_c));
544   PetscCall(TSIRKRestoreVecs(ts,fine,&U));
545   PetscCall(TSIRKRestoreVecs(ts,coarse,&U_c));
546   PetscFunctionReturn(0);
547 }
548 
549 static PetscErrorCode DMSubDomainHook_TSIRK(DM dm,DM subdm,void *ctx)
550 {
551   PetscFunctionBegin;
552   PetscFunctionReturn(0);
553 }
554 
555 static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm,VecScatter gscat,VecScatter lscat,DM subdm,void *ctx)
556 {
557   TS             ts = (TS)ctx;
558   Vec            U,U_c;
559 
560   PetscFunctionBegin;
561   PetscCall(TSIRKGetVecs(ts,dm,&U));
562   PetscCall(TSIRKGetVecs(ts,subdm,&U_c));
563 
564   PetscCall(VecScatterBegin(gscat,U,U_c,INSERT_VALUES,SCATTER_FORWARD));
565   PetscCall(VecScatterEnd(gscat,U,U_c,INSERT_VALUES,SCATTER_FORWARD));
566 
567   PetscCall(TSIRKRestoreVecs(ts,dm,&U));
568   PetscCall(TSIRKRestoreVecs(ts,subdm,&U_c));
569   PetscFunctionReturn(0);
570 }
571 
572 static PetscErrorCode TSSetUp_IRK(TS ts)
573 {
574   TS_IRK         *irk = (TS_IRK*)ts->data;
575   IRKTableau     tab = irk->tableau;
576   DM             dm;
577   Mat            J;
578   Vec            R;
579   const PetscInt nstages = irk->nstages;
580   PetscInt       vsize,bs;
581 
582   PetscFunctionBegin;
583   if (!irk->work) {
584     PetscCall(PetscMalloc1(irk->nstages,&irk->work));
585   }
586   if (!irk->Y) {
587     PetscCall(VecDuplicateVecs(ts->vec_sol,irk->nstages,&irk->Y));
588   }
589   if (!irk->YdotI) {
590     PetscCall(VecDuplicateVecs(ts->vec_sol,irk->nstages,&irk->YdotI));
591   }
592   if (!irk->Ydot) {
593     PetscCall(VecDuplicate(ts->vec_sol,&irk->Ydot));
594   }
595   if (!irk->U) {
596     PetscCall(VecDuplicate(ts->vec_sol,&irk->U));
597   }
598   if (!irk->U0) {
599     PetscCall(VecDuplicate(ts->vec_sol,&irk->U0));
600   }
601   if (!irk->Z) {
602     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol),&irk->Z));
603     PetscCall(VecGetSize(ts->vec_sol,&vsize));
604     PetscCall(VecSetSizes(irk->Z,PETSC_DECIDE,vsize*irk->nstages));
605     PetscCall(VecGetBlockSize(ts->vec_sol,&bs));
606     PetscCall(VecSetBlockSize(irk->Z,irk->nstages*bs));
607     PetscCall(VecSetFromOptions(irk->Z));
608   }
609   PetscCall(TSGetDM(ts,&dm));
610   PetscCall(DMCoarsenHookAdd(dm,DMCoarsenHook_TSIRK,DMRestrictHook_TSIRK,ts));
611   PetscCall(DMSubDomainHookAdd(dm,DMSubDomainHook_TSIRK,DMSubDomainRestrictHook_TSIRK,ts));
612 
613   PetscCall(TSGetSNES(ts,&ts->snes));
614   PetscCall(VecDuplicate(irk->Z,&R));
615   PetscCall(SNESSetFunction(ts->snes,R,SNESTSFormFunction,ts));
616   PetscCall(TSGetIJacobian(ts,&J,NULL,NULL,NULL));
617   if (!irk->TJ) {
618     /* Create the KAIJ matrix for solving the stages */
619     PetscCall(MatCreateKAIJ(J,nstages,nstages,tab->A_inv,tab->I_s,&irk->TJ));
620   }
621   PetscCall(SNESSetJacobian(ts->snes,irk->TJ,irk->TJ,SNESTSFormJacobian,ts));
622   PetscCall(VecDestroy(&R));
623   PetscFunctionReturn(0);
624 }
625 
626 static PetscErrorCode TSSetFromOptions_IRK(PetscOptionItems *PetscOptionsObject,TS ts)
627 {
628   TS_IRK         *irk = (TS_IRK*)ts->data;
629   char           tname[256] = TSIRKGAUSS;
630 
631   PetscFunctionBegin;
632   PetscOptionsHeadBegin(PetscOptionsObject,"IRK ODE solver options");
633   {
634     PetscBool flg1,flg2;
635     PetscCall(PetscOptionsInt("-ts_irk_nstages","Stages of the IRK method","TSIRKSetNumStages",irk->nstages,&irk->nstages,&flg1));
636     PetscCall(PetscOptionsFList("-ts_irk_type","Type of IRK method","TSIRKSetType",TSIRKList,irk->method_name[0] ? irk->method_name : tname,tname,sizeof(tname),&flg2));
637     if (flg1 ||flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
638       PetscCall(TSIRKSetType(ts,tname));
639     }
640   }
641   PetscOptionsHeadEnd();
642   PetscFunctionReturn(0);
643 }
644 
645 static PetscErrorCode TSView_IRK(TS ts,PetscViewer viewer)
646 {
647   TS_IRK         *irk = (TS_IRK*)ts->data;
648   PetscBool      iascii;
649 
650   PetscFunctionBegin;
651   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii));
652   if (iascii) {
653     IRKTableau    tab = irk->tableau;
654     TSIRKType irktype;
655     char          buf[512];
656 
657     PetscCall(TSIRKGetType(ts,&irktype));
658     PetscCall(PetscViewerASCIIPrintf(viewer,"  IRK type %s\n",irktype));
659     PetscCall(PetscFormatRealArray(buf,sizeof(buf),"% 8.6f",irk->nstages,tab->c));
660     PetscCall(PetscViewerASCIIPrintf(viewer,"  Abscissa       c = %s\n",buf));
661     PetscCall(PetscViewerASCIIPrintf(viewer,"Stiffly accurate: %s\n",irk->stiffly_accurate ? "yes" : "no"));
662     PetscCall(PetscFormatRealArray(buf,sizeof(buf),"% 8.6f",PetscSqr(irk->nstages),tab->A));
663     PetscCall(PetscViewerASCIIPrintf(viewer,"  A coefficients       A = %s\n",buf));
664   }
665   PetscFunctionReturn(0);
666 }
667 
668 static PetscErrorCode TSLoad_IRK(TS ts,PetscViewer viewer)
669 {
670   SNES           snes;
671   TSAdapt        adapt;
672 
673   PetscFunctionBegin;
674   PetscCall(TSGetAdapt(ts,&adapt));
675   PetscCall(TSAdaptLoad(adapt,viewer));
676   PetscCall(TSGetSNES(ts,&snes));
677   PetscCall(SNESLoad(snes,viewer));
678   /* function and Jacobian context for SNES when used with TS is always ts object */
679   PetscCall(SNESSetFunction(snes,NULL,NULL,ts));
680   PetscCall(SNESSetJacobian(snes,NULL,NULL,NULL,ts));
681   PetscFunctionReturn(0);
682 }
683 
684 /*@C
685   TSIRKSetType - Set the type of IRK scheme
686 
687   Logically collective
688 
689   Input Parameters:
690 +  ts - timestepping context
691 -  irktype - type of IRK scheme
692 
693   Options Database:
694 .  -ts_irk_type <gauss> - set irk type
695 
696   Level: intermediate
697 
698 .seealso: `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
699 @*/
700 PetscErrorCode TSIRKSetType(TS ts,TSIRKType irktype)
701 {
702   PetscFunctionBegin;
703   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
704   PetscValidCharPointer(irktype,2);
705   PetscTryMethod(ts,"TSIRKSetType_C",(TS,TSIRKType),(ts,irktype));
706   PetscFunctionReturn(0);
707 }
708 
709 /*@C
710   TSIRKGetType - Get the type of IRK IMEX scheme
711 
712   Logically collective
713 
714   Input Parameter:
715 .  ts - timestepping context
716 
717   Output Parameter:
718 .  irktype - type of IRK-IMEX scheme
719 
720   Level: intermediate
721 
722 .seealso: `TSIRKGetType()`
723 @*/
724 PetscErrorCode TSIRKGetType(TS ts,TSIRKType *irktype)
725 {
726   PetscFunctionBegin;
727   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
728   PetscUseMethod(ts,"TSIRKGetType_C",(TS,TSIRKType*),(ts,irktype));
729   PetscFunctionReturn(0);
730 }
731 
732 /*@C
733   TSIRKSetNumStages - Set the number of stages of IRK scheme
734 
735   Logically collective
736 
737   Input Parameters:
738 +  ts - timestepping context
739 -  nstages - number of stages of IRK scheme
740 
741   Options Database:
742 .  -ts_irk_nstages <int> - set number of stages
743 
744   Level: intermediate
745 
746 .seealso: `TSIRKGetNumStages()`, `TSIRK`
747 @*/
748 PetscErrorCode TSIRKSetNumStages(TS ts,PetscInt nstages)
749 {
750   PetscFunctionBegin;
751   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
752   PetscTryMethod(ts,"TSIRKSetNumStages_C",(TS,PetscInt),(ts,nstages));
753   PetscFunctionReturn(0);
754 }
755 
756 /*@C
757   TSIRKGetNumStages - Get the number of stages of IRK scheme
758 
759   Logically collective
760 
761   Input Parameters:
762 +  ts - timestepping context
763 -  nstages - number of stages of IRK scheme
764 
765   Level: intermediate
766 
767 .seealso: `TSIRKSetNumStages()`, `TSIRK`
768 @*/
769 PetscErrorCode TSIRKGetNumStages(TS ts,PetscInt *nstages)
770 {
771   PetscFunctionBegin;
772   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
773   PetscValidIntPointer(nstages,2);
774   PetscTryMethod(ts,"TSIRKGetNumStages_C",(TS,PetscInt*),(ts,nstages));
775   PetscFunctionReturn(0);
776 }
777 
778 static PetscErrorCode TSIRKGetType_IRK(TS ts,TSIRKType *irktype)
779 {
780   TS_IRK *irk = (TS_IRK*)ts->data;
781 
782   PetscFunctionBegin;
783   *irktype = irk->method_name;
784   PetscFunctionReturn(0);
785 }
786 
787 static PetscErrorCode TSIRKSetType_IRK(TS ts,TSIRKType irktype)
788 {
789   TS_IRK         *irk = (TS_IRK*)ts->data;
790   PetscErrorCode (*irkcreate)(TS);
791 
792   PetscFunctionBegin;
793   if (irk->method_name) {
794     PetscCall(PetscFree(irk->method_name));
795     PetscCall(TSIRKTableauReset(ts));
796   }
797   PetscCall(PetscFunctionListFind(TSIRKList,irktype,&irkcreate));
798   PetscCheck(irkcreate,PETSC_COMM_SELF,PETSC_ERR_ARG_UNKNOWN_TYPE,"Unknown TSIRK type \"%s\" given",irktype);
799   PetscCall((*irkcreate)(ts));
800   PetscCall(PetscStrallocpy(irktype,&irk->method_name));
801   PetscFunctionReturn(0);
802 }
803 
804 static PetscErrorCode TSIRKSetNumStages_IRK(TS ts,PetscInt nstages)
805 {
806   TS_IRK *irk = (TS_IRK*)ts->data;
807 
808   PetscFunctionBegin;
809   PetscCheck(nstages>0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"input argument, %" PetscInt_FMT ", out of range",nstages);
810   irk->nstages = nstages;
811   PetscFunctionReturn(0);
812 }
813 
814 static PetscErrorCode TSIRKGetNumStages_IRK(TS ts,PetscInt *nstages)
815 {
816   TS_IRK *irk = (TS_IRK*)ts->data;
817 
818   PetscFunctionBegin;
819   PetscValidIntPointer(nstages,2);
820   *nstages = irk->nstages;
821   PetscFunctionReturn(0);
822 }
823 
824 static PetscErrorCode TSDestroy_IRK(TS ts)
825 {
826   PetscFunctionBegin;
827   PetscCall(TSReset_IRK(ts));
828   if (ts->dm) {
829     PetscCall(DMCoarsenHookRemove(ts->dm,DMCoarsenHook_TSIRK,DMRestrictHook_TSIRK,ts));
830     PetscCall(DMSubDomainHookRemove(ts->dm,DMSubDomainHook_TSIRK,DMSubDomainRestrictHook_TSIRK,ts));
831   }
832   PetscCall(PetscFree(ts->data));
833   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetType_C",NULL));
834   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetType_C",NULL));
835   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetNumStages_C",NULL));
836   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetNumStages_C",NULL));
837   PetscFunctionReturn(0);
838 }
839 
840 /*MC
841       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes
842 
843   Notes:
844 
845   TSIRK uses the sparse Kronecker product matrix implementation of MATKAIJ to achieve good arithmetic intensity.
846 
847   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with -ts_irk_nstages or TSIRKSetNumStages().
848 
849   Level: beginner
850 
851 .seealso: `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`
852 
853 M*/
854 PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
855 {
856   TS_IRK         *irk;
857 
858   PetscFunctionBegin;
859   PetscCall(TSIRKInitializePackage());
860 
861   ts->ops->reset          = TSReset_IRK;
862   ts->ops->destroy        = TSDestroy_IRK;
863   ts->ops->view           = TSView_IRK;
864   ts->ops->load           = TSLoad_IRK;
865   ts->ops->setup          = TSSetUp_IRK;
866   ts->ops->step           = TSStep_IRK;
867   ts->ops->interpolate    = TSInterpolate_IRK;
868   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
869   ts->ops->rollback       = TSRollBack_IRK;
870   ts->ops->setfromoptions = TSSetFromOptions_IRK;
871   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
872   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;
873 
874   ts->usessnes = PETSC_TRUE;
875 
876   PetscCall(PetscNewLog(ts,&irk));
877   ts->data = (void*)irk;
878 
879   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetType_C",TSIRKSetType_IRK));
880   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetType_C",TSIRKGetType_IRK));
881   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetNumStages_C",TSIRKSetNumStages_IRK));
882   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetNumStages_C",TSIRKGetNumStages_IRK));
883   /* 3-stage IRK_Gauss is the default */
884   PetscCall(PetscNew(&irk->tableau));
885   irk->nstages = 3;
886   PetscCall(TSIRKSetType(ts,TSIRKDefault));
887   PetscFunctionReturn(0);
888 }
889