xref: /petsc/src/ts/impls/implicit/irk/irk.c (revision 6a5217c03994f2d95bb2e6dbd8bed42381aeb015)
1 /*
2   Code for timestepping with implicit Runge-Kutta method
3 
4   Notes:
5   The general system is written as
6 
7   F(t,U,Udot) = 0
8 
9 */
10 #include <petsc/private/tsimpl.h>                /*I   "petscts.h"   I*/
11 #include <petscdm.h>
12 #include <petscdt.h>
13 
14 static TSIRKType         TSIRKDefault = TSIRKGAUSS;
15 static PetscBool         TSIRKRegisterAllCalled;
16 static PetscBool         TSIRKPackageInitialized;
17 static PetscFunctionList TSIRKList;
18 
19 struct _IRKTableau{
20   PetscReal   *A,*b,*c;
21   PetscScalar *A_inv,*A_inv_rowsum,*I_s;
22   PetscReal   *binterp;   /* Dense output formula */
23 };
24 
25 typedef struct _IRKTableau *IRKTableau;
26 
27 typedef struct {
28   char         *method_name;
29   PetscInt     order;            /* Classical approximation order of the method */
30   PetscInt     nstages;          /* Number of stages */
31   PetscBool    stiffly_accurate;
32   PetscInt     pinterp;          /* Interpolation order */
33   IRKTableau   tableau;
34   Vec          U0;               /* Backup vector */
35   Vec          Z;                /* Combined stage vector */
36   Vec          *Y;               /* States computed during the step */
37   Vec          Ydot;             /* Work vector holding time derivatives during residual evaluation */
38   Vec          U;                /* U is used to compute Ydot = shift(Y-U) */
39   Vec          *YdotI;           /* Work vectors to hold the residual evaluation */
40   Mat          TJ;               /* KAIJ matrix for the Jacobian of the combined system */
41   PetscScalar  *work;            /* Scalar work */
42   TSStepStatus status;
43   PetscBool    rebuild_completion;
44   PetscReal    ccfl;
45 } TS_IRK;
46 
47 /*@C
48    TSIRKTableauCreate - create the tableau for TSIRK and provide the entries
49 
50    Not Collective
51 
52    Input Parameters:
53 +  ts - timestepping context
54 .  nstages - number of stages, this is the dimension of the matrices below
55 .  A - stage coefficients (dimension nstages*nstages, row-major)
56 .  b - step completion table (dimension nstages)
57 .  c - abscissa (dimension nstages)
58 .  binterp - coefficients of the interpolation formula (dimension nstages)
59 .  A_inv - inverse of A (dimension nstages*nstages, row-major)
60 .  A_inv_rowsum - row sum of the inverse of A (dimension nstages)
61 -  I_s - identity matrix (dimension nstages*nstages)
62 
63    Level: advanced
64 
65 .seealso: TSIRK, TSIRKRegister()
66 @*/
67 PetscErrorCode TSIRKTableauCreate(TS ts,PetscInt nstages,const PetscReal *A,const PetscReal *b,const PetscReal *c,const PetscReal *binterp,const PetscScalar *A_inv,const PetscScalar *A_inv_rowsum,const PetscScalar *I_s)
68 {
69   TS_IRK         *irk = (TS_IRK*)ts->data;
70   IRKTableau     tab = irk->tableau;
71 
72   PetscFunctionBegin;
73   irk->order = nstages;
74   PetscCall(PetscMalloc3(PetscSqr(nstages),&tab->A,PetscSqr(nstages),&tab->A_inv,PetscSqr(nstages),&tab->I_s));
75   PetscCall(PetscMalloc4(nstages,&tab->b,nstages,&tab->c,nstages,&tab->binterp,nstages,&tab->A_inv_rowsum));
76   PetscCall(PetscArraycpy(tab->A,A,PetscSqr(nstages)));
77   PetscCall(PetscArraycpy(tab->b,b,nstages));
78   PetscCall(PetscArraycpy(tab->c,c,nstages));
79   /* optional coefficient arrays */
80   if (binterp) {
81     PetscCall(PetscArraycpy(tab->binterp,binterp,nstages));
82   }
83   if (A_inv) {
84     PetscCall(PetscArraycpy(tab->A_inv,A_inv,PetscSqr(nstages)));
85   }
86   if (A_inv_rowsum) {
87     PetscCall(PetscArraycpy(tab->A_inv_rowsum,A_inv_rowsum,nstages));
88   }
89   if (I_s) {
90     PetscCall(PetscArraycpy(tab->I_s,I_s,PetscSqr(nstages)));
91   }
92   PetscFunctionReturn(0);
93 }
94 
95 /* Arrays should be freed with PetscFree3(A,b,c) */
96 static PetscErrorCode TSIRKCreate_Gauss(TS ts)
97 {
98   PetscInt       nstages;
99   PetscReal      *gauss_A_real,*gauss_b,*b,*gauss_c;
100   PetscScalar    *gauss_A,*gauss_A_inv,*gauss_A_inv_rowsum,*I_s;
101   PetscScalar    *G0,*G1;
102   PetscInt       i,j;
103   Mat            G0mat,G1mat,Amat;
104 
105   PetscFunctionBegin;
106   PetscCall(TSIRKGetNumStages(ts,&nstages));
107   PetscCall(PetscMalloc3(PetscSqr(nstages),&gauss_A_real,nstages,&gauss_b,nstages,&gauss_c));
108   PetscCall(PetscMalloc4(PetscSqr(nstages),&gauss_A,PetscSqr(nstages),&gauss_A_inv,nstages,&gauss_A_inv_rowsum,PetscSqr(nstages),&I_s));
109   PetscCall(PetscMalloc3(nstages,&b,PetscSqr(nstages),&G0,PetscSqr(nstages),&G1));
110   PetscCall(PetscDTGaussQuadrature(nstages,0.,1.,gauss_c,b));
111   for (i=0; i<nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */
112 
113   /* A^T = G0^{-1} G1 */
114   for (i=0; i<nstages; i++) {
115     for (j=0; j<nstages; j++) {
116       G0[i*nstages+j] = PetscPowRealInt(gauss_c[i],j);
117       G1[i*nstages+j] = PetscPowRealInt(gauss_c[i],j+1)/(j+1);
118     }
119   }
120   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
121   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,G0,&G0mat));
122   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,G1,&G1mat));
123   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF,nstages,nstages,gauss_A,&Amat));
124   PetscCall(MatLUFactor(G0mat,NULL,NULL,NULL));
125   PetscCall(MatMatSolve(G0mat,G1mat,Amat));
126   PetscCall(MatTranspose(Amat,MAT_INPLACE_MATRIX,&Amat));
127   for (i=0; i<nstages; i++)
128     for (j=0; j<nstages; j++)
129       gauss_A_real[i*nstages+j] = PetscRealPart(gauss_A[i*nstages+j]);
130 
131   PetscCall(MatDestroy(&G0mat));
132   PetscCall(MatDestroy(&G1mat));
133   PetscCall(MatDestroy(&Amat));
134   PetscCall(PetscFree3(b,G0,G1));
135 
136   {/* Invert A */
137     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
138      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
139      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
140     Mat               A_baij;
141     PetscInt          idxm[1]={0},idxn[1]={0};
142     const PetscScalar *A_inv;
143 
144     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF,nstages,nstages,nstages,1,NULL,&A_baij));
145     PetscCall(MatSetOption(A_baij,MAT_ROW_ORIENTED,PETSC_FALSE));
146     PetscCall(MatSetValuesBlocked(A_baij,1,idxm,1,idxn,gauss_A,INSERT_VALUES));
147     PetscCall(MatAssemblyBegin(A_baij,MAT_FINAL_ASSEMBLY));
148     PetscCall(MatAssemblyEnd(A_baij,MAT_FINAL_ASSEMBLY));
149     PetscCall(MatInvertBlockDiagonal(A_baij,&A_inv));
150     PetscCall(PetscMemcpy(gauss_A_inv,A_inv,nstages*nstages*sizeof(PetscScalar)));
151     PetscCall(MatDestroy(&A_baij));
152   }
153 
154   /* Compute row sums A_inv_rowsum and identity I_s */
155   for (i=0; i<nstages; i++) {
156     gauss_A_inv_rowsum[i] = 0;
157     for (j=0; j<nstages; j++) {
158       gauss_A_inv_rowsum[i] += gauss_A_inv[i+nstages*j];
159       I_s[i+nstages*j] = 1.*(i == j);
160     }
161   }
162   PetscCall(TSIRKTableauCreate(ts,nstages,gauss_A_real,gauss_b,gauss_c,NULL,gauss_A_inv,gauss_A_inv_rowsum,I_s));
163   PetscCall(PetscFree3(gauss_A_real,gauss_b,gauss_c));
164   PetscCall(PetscFree4(gauss_A,gauss_A_inv,gauss_A_inv_rowsum,I_s));
165   PetscFunctionReturn(0);
166 }
167 
168 /*@C
169    TSIRKRegister -  adds a TSIRK implementation
170 
171    Not Collective
172 
173    Input Parameters:
174 +  sname - name of user-defined IRK scheme
175 -  function - function to create method context
176 
177    Notes:
178    TSIRKRegister() may be called multiple times to add several user-defined families.
179 
180    Sample usage:
181 .vb
182    TSIRKRegister("my_scheme",MySchemeCreate);
183 .ve
184 
185    Then, your scheme can be chosen with the procedural interface via
186 $     TSIRKSetType(ts,"my_scheme")
187    or at runtime via the option
188 $     -ts_irk_type my_scheme
189 
190    Level: advanced
191 
192 .seealso: TSIRKRegisterAll()
193 @*/
194 PetscErrorCode TSIRKRegister(const char sname[],PetscErrorCode (*function)(TS))
195 {
196   PetscFunctionBegin;
197   PetscCall(TSIRKInitializePackage());
198   PetscCall(PetscFunctionListAdd(&TSIRKList,sname,function));
199   PetscFunctionReturn(0);
200 }
201 
202 /*@C
203   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in TSIRK
204 
205   Not Collective, but should be called by all processes which will need the schemes to be registered
206 
207   Level: advanced
208 
209 .seealso:  TSIRKRegisterDestroy()
210 @*/
211 PetscErrorCode TSIRKRegisterAll(void)
212 {
213   PetscFunctionBegin;
214   if (TSIRKRegisterAllCalled) PetscFunctionReturn(0);
215   TSIRKRegisterAllCalled = PETSC_TRUE;
216 
217   PetscCall(TSIRKRegister(TSIRKGAUSS,TSIRKCreate_Gauss));
218   PetscFunctionReturn(0);
219 }
220 
221 /*@C
222    TSIRKRegisterDestroy - Frees the list of schemes that were registered by TSIRKRegister().
223 
224    Not Collective
225 
226    Level: advanced
227 
228 .seealso: TSIRKRegister(), TSIRKRegisterAll()
229 @*/
230 PetscErrorCode TSIRKRegisterDestroy(void)
231 {
232   PetscFunctionBegin;
233   TSIRKRegisterAllCalled = PETSC_FALSE;
234   PetscFunctionReturn(0);
235 }
236 
237 /*@C
238   TSIRKInitializePackage - This function initializes everything in the TSIRK package. It is called
239   from TSInitializePackage().
240 
241   Level: developer
242 
243 .seealso: PetscInitialize()
244 @*/
245 PetscErrorCode TSIRKInitializePackage(void)
246 {
247   PetscFunctionBegin;
248   if (TSIRKPackageInitialized) PetscFunctionReturn(0);
249   TSIRKPackageInitialized = PETSC_TRUE;
250   PetscCall(TSIRKRegisterAll());
251   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
252   PetscFunctionReturn(0);
253 }
254 
255 /*@C
256   TSIRKFinalizePackage - This function destroys everything in the TSIRK package. It is
257   called from PetscFinalize().
258 
259   Level: developer
260 
261 .seealso: PetscFinalize()
262 @*/
263 PetscErrorCode TSIRKFinalizePackage(void)
264 {
265   PetscFunctionBegin;
266   PetscCall(PetscFunctionListDestroy(&TSIRKList));
267   TSIRKPackageInitialized = PETSC_FALSE;
268   PetscFunctionReturn(0);
269 }
270 
271 /*
272  This function can be called before or after ts->vec_sol has been updated.
273 */
274 static PetscErrorCode TSEvaluateStep_IRK(TS ts,PetscInt order,Vec U,PetscBool *done)
275 {
276   TS_IRK         *irk = (TS_IRK*)ts->data;
277   IRKTableau     tab = irk->tableau;
278   Vec            *YdotI = irk->YdotI;
279   PetscScalar    *w = irk->work;
280   PetscReal      h;
281   PetscInt       j;
282 
283   PetscFunctionBegin;
284   switch (irk->status) {
285   case TS_STEP_INCOMPLETE:
286   case TS_STEP_PENDING:
287     h = ts->time_step; break;
288   case TS_STEP_COMPLETE:
289     h = ts->ptime - ts->ptime_prev; break;
290   default: SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_PLIB,"Invalid TSStepStatus");
291   }
292 
293   PetscCall(VecCopy(ts->vec_sol,U));
294   for (j=0; j<irk->nstages; j++) w[j] = h*tab->b[j];
295   PetscCall(VecMAXPY(U,irk->nstages,w,YdotI));
296   PetscFunctionReturn(0);
297 }
298 
299 static PetscErrorCode TSRollBack_IRK(TS ts)
300 {
301   TS_IRK         *irk = (TS_IRK*)ts->data;
302 
303   PetscFunctionBegin;
304   PetscCall(VecCopy(irk->U0,ts->vec_sol));
305   PetscFunctionReturn(0);
306 }
307 
308 static PetscErrorCode TSStep_IRK(TS ts)
309 {
310   TS_IRK          *irk = (TS_IRK*)ts->data;
311   IRKTableau      tab = irk->tableau;
312   PetscScalar     *A_inv = tab->A_inv,*A_inv_rowsum = tab->A_inv_rowsum;
313   const PetscInt  nstages = irk->nstages;
314   SNES            snes;
315   PetscInt        i,j,its,lits,bs;
316   TSAdapt         adapt;
317   PetscInt        rejections = 0;
318   PetscBool       accept = PETSC_TRUE;
319   PetscReal       next_time_step = ts->time_step;
320 
321   PetscFunctionBegin;
322   if (!ts->steprollback) {
323     PetscCall(VecCopy(ts->vec_sol,irk->U0));
324   }
325   PetscCall(VecGetBlockSize(ts->vec_sol,&bs));
326   for (i=0; i<nstages; i++) {
327     PetscCall(VecStrideScatter(ts->vec_sol,i*bs,irk->Z,INSERT_VALUES));
328   }
329 
330   irk->status = TS_STEP_INCOMPLETE;
331   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
332     PetscCall(VecCopy(ts->vec_sol,irk->U));
333     PetscCall(TSGetSNES(ts,&snes));
334     PetscCall(SNESSolve(snes,NULL,irk->Z));
335     PetscCall(SNESGetIterationNumber(snes,&its));
336     PetscCall(SNESGetLinearSolveIterations(snes,&lits));
337     ts->snes_its += its; ts->ksp_its += lits;
338     PetscCall(VecStrideGatherAll(irk->Z,irk->Y,INSERT_VALUES));
339     for (i=0; i<nstages; i++) {
340       PetscCall(VecZeroEntries(irk->YdotI[i]));
341       for (j=0; j<nstages; j++) {
342         PetscCall(VecAXPY(irk->YdotI[i],A_inv[i+j*nstages]/ts->time_step,irk->Y[j]));
343       }
344       PetscCall(VecAXPY(irk->YdotI[i],-A_inv_rowsum[i]/ts->time_step,irk->U));
345     }
346     irk->status = TS_STEP_INCOMPLETE;
347     PetscCall(TSEvaluateStep_IRK(ts,irk->order,ts->vec_sol,NULL));
348     irk->status = TS_STEP_PENDING;
349     PetscCall(TSGetAdapt(ts,&adapt));
350     PetscCall(TSAdaptChoose(adapt,ts,ts->time_step,NULL,&next_time_step,&accept));
351     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
352     if (!accept) {
353       PetscCall(TSRollBack_IRK(ts));
354       ts->time_step = next_time_step;
355       goto reject_step;
356     }
357 
358     ts->ptime += ts->time_step;
359     ts->time_step = next_time_step;
360     break;
361   reject_step:
362     ts->reject++; accept = PETSC_FALSE;
363     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
364       ts->reason = TS_DIVERGED_STEP_REJECTED;
365       PetscCall(PetscInfo(ts,"Step=%D, step rejections %D greater than current TS allowed, stopping solve\n",ts->steps,rejections));
366     }
367   }
368   PetscFunctionReturn(0);
369 }
370 
371 static PetscErrorCode TSInterpolate_IRK(TS ts,PetscReal itime,Vec U)
372 {
373   TS_IRK          *irk = (TS_IRK*)ts->data;
374   PetscInt        nstages = irk->nstages,pinterp = irk->pinterp,i,j;
375   PetscReal       h;
376   PetscReal       tt,t;
377   PetscScalar     *bt;
378   const PetscReal *B = irk->tableau->binterp;
379 
380   PetscFunctionBegin;
381   PetscCheck(B,PetscObjectComm((PetscObject)ts),PETSC_ERR_SUP,"TSIRK %s does not have an interpolation formula",irk->method_name);
382   switch (irk->status) {
383   case TS_STEP_INCOMPLETE:
384   case TS_STEP_PENDING:
385     h = ts->time_step;
386     t = (itime - ts->ptime)/h;
387     break;
388   case TS_STEP_COMPLETE:
389     h = ts->ptime - ts->ptime_prev;
390     t = (itime - ts->ptime)/h + 1; /* In the interval [0,1] */
391     break;
392   default: SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_PLIB,"Invalid TSStepStatus");
393   }
394   PetscCall(PetscMalloc1(nstages,&bt));
395   for (i=0; i<nstages; i++) bt[i] = 0;
396   for (j=0,tt=t; j<pinterp; j++,tt*=t) {
397     for (i=0; i<nstages; i++) {
398       bt[i] += h * B[i*pinterp+j] * tt;
399     }
400   }
401   PetscCall(VecMAXPY(U,nstages,bt,irk->YdotI));
402   PetscFunctionReturn(0);
403 }
404 
405 static PetscErrorCode TSIRKTableauReset(TS ts)
406 {
407   TS_IRK         *irk = (TS_IRK*)ts->data;
408   IRKTableau     tab = irk->tableau;
409 
410   PetscFunctionBegin;
411   if (!tab) PetscFunctionReturn(0);
412   PetscCall(PetscFree3(tab->A,tab->A_inv,tab->I_s));
413   PetscCall(PetscFree4(tab->b,tab->c,tab->binterp,tab->A_inv_rowsum));
414   PetscFunctionReturn(0);
415 }
416 
417 static PetscErrorCode TSReset_IRK(TS ts)
418 {
419   TS_IRK         *irk = (TS_IRK*)ts->data;
420 
421   PetscFunctionBegin;
422   PetscCall(TSIRKTableauReset(ts));
423   if (irk->tableau) {
424     PetscCall(PetscFree(irk->tableau));
425   }
426   if (irk->method_name) {
427     PetscCall(PetscFree(irk->method_name));
428   }
429   if (irk->work) {
430     PetscCall(PetscFree(irk->work));
431   }
432   PetscCall(VecDestroyVecs(irk->nstages,&irk->Y));
433   PetscCall(VecDestroyVecs(irk->nstages,&irk->YdotI));
434   PetscCall(VecDestroy(&irk->Ydot));
435   PetscCall(VecDestroy(&irk->Z));
436   PetscCall(VecDestroy(&irk->U));
437   PetscCall(VecDestroy(&irk->U0));
438   PetscCall(MatDestroy(&irk->TJ));
439   PetscFunctionReturn(0);
440 }
441 
442 static PetscErrorCode TSIRKGetVecs(TS ts,DM dm,Vec *U)
443 {
444   TS_IRK         *irk = (TS_IRK*)ts->data;
445 
446   PetscFunctionBegin;
447   if (U) {
448     if (dm && dm != ts->dm) {
449       PetscCall(DMGetNamedGlobalVector(dm,"TSIRK_U",U));
450     } else *U = irk->U;
451   }
452   PetscFunctionReturn(0);
453 }
454 
455 static PetscErrorCode TSIRKRestoreVecs(TS ts,DM dm,Vec *U)
456 {
457   PetscFunctionBegin;
458   if (U) {
459     if (dm && dm != ts->dm) {
460       PetscCall(DMRestoreNamedGlobalVector(dm,"TSIRK_U",U));
461     }
462   }
463   PetscFunctionReturn(0);
464 }
465 
466 /*
467   This defines the nonlinear equations that is to be solved with SNES
468     G[e\otimes t + C*dt, Z, Zdot] = 0
469     Zdot = (In \otimes S)*Z - (In \otimes Se) U
470   where S = 1/(dt*A)
471 */
472 static PetscErrorCode SNESTSFormFunction_IRK(SNES snes,Vec ZC,Vec FC,TS ts)
473 {
474   TS_IRK            *irk = (TS_IRK*)ts->data;
475   IRKTableau        tab  = irk->tableau;
476   const PetscInt    nstages = irk->nstages;
477   const PetscReal   *c = tab->c;
478   const PetscScalar *A_inv = tab->A_inv,*A_inv_rowsum = tab->A_inv_rowsum;
479   DM                dm,dmsave;
480   Vec               U,*YdotI = irk->YdotI,Ydot = irk->Ydot,*Y = irk->Y;
481   PetscReal         h = ts->time_step;
482   PetscInt          i,j;
483 
484   PetscFunctionBegin;
485   PetscCall(SNESGetDM(snes,&dm));
486   PetscCall(TSIRKGetVecs(ts,dm,&U));
487   PetscCall(VecStrideGatherAll(ZC,Y,INSERT_VALUES));
488   dmsave = ts->dm;
489   ts->dm = dm;
490   for (i=0; i<nstages; i++) {
491     PetscCall(VecZeroEntries(Ydot));
492     for (j=0; j<nstages; j++) {
493       PetscCall(VecAXPY(Ydot,A_inv[j*nstages+i]/h,Y[j]));
494     }
495     PetscCall(VecAXPY(Ydot,-A_inv_rowsum[i]/h,U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
496     PetscCall(TSComputeIFunction(ts,ts->ptime+ts->time_step*c[i],Y[i],Ydot,YdotI[i],PETSC_FALSE));
497   }
498   PetscCall(VecStrideScatterAll(YdotI,FC,INSERT_VALUES));
499   ts->dm = dmsave;
500   PetscCall(TSIRKRestoreVecs(ts,dm,&U));
501   PetscFunctionReturn(0);
502 }
503 
504 /*
505    For explicit ODE, the Jacobian is
506      JC = I_n \otimes S - J \otimes I_s
507    For DAE, the Jacobian is
508      JC = M_n \otimes S - J \otimes I_s
509 */
510 static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes,Vec ZC,Mat JC,Mat JCpre,TS ts)
511 {
512   TS_IRK          *irk = (TS_IRK*)ts->data;
513   IRKTableau      tab  = irk->tableau;
514   const PetscInt  nstages = irk->nstages;
515   const PetscReal *c = tab->c;
516   DM              dm,dmsave;
517   Vec             *Y = irk->Y,Ydot = irk->Ydot;
518   Mat             J;
519   PetscScalar     *S;
520   PetscInt        i,j,bs;
521 
522   PetscFunctionBegin;
523   PetscCall(SNESGetDM(snes,&dm));
524   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
525   dmsave = ts->dm;
526   ts->dm = dm;
527   PetscCall(VecGetBlockSize(Y[nstages-1],&bs));
528   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
529     PetscCall(VecStrideGather(ZC,(nstages-1)*bs,Y[nstages-1],INSERT_VALUES));
530     PetscCall(MatKAIJGetAIJ(JC,&J));
531     PetscCall(TSComputeIJacobian(ts,ts->ptime+ts->time_step*c[nstages-1],Y[nstages-1],Ydot,0,J,J,PETSC_FALSE));
532     PetscCall(MatKAIJGetS(JC,NULL,NULL,&S));
533     for (i=0; i<nstages; i++)
534       for (j=0; j<nstages; j++)
535         S[i+nstages*j] = tab->A_inv[i+nstages*j]/ts->time_step;
536     PetscCall(MatKAIJRestoreS(JC,&S));
537   } else SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_SUP,"TSIRK %s does not support implicit formula",irk->method_name); /* TODO: need the mass matrix for DAE  */
538   ts->dm = dmsave;
539   PetscFunctionReturn(0);
540 }
541 
542 static PetscErrorCode DMCoarsenHook_TSIRK(DM fine,DM coarse,void *ctx)
543 {
544   PetscFunctionBegin;
545   PetscFunctionReturn(0);
546 }
547 
548 static PetscErrorCode DMRestrictHook_TSIRK(DM fine,Mat restrct,Vec rscale,Mat inject,DM coarse,void *ctx)
549 {
550   TS             ts = (TS)ctx;
551   Vec            U,U_c;
552 
553   PetscFunctionBegin;
554   PetscCall(TSIRKGetVecs(ts,fine,&U));
555   PetscCall(TSIRKGetVecs(ts,coarse,&U_c));
556   PetscCall(MatRestrict(restrct,U,U_c));
557   PetscCall(VecPointwiseMult(U_c,rscale,U_c));
558   PetscCall(TSIRKRestoreVecs(ts,fine,&U));
559   PetscCall(TSIRKRestoreVecs(ts,coarse,&U_c));
560   PetscFunctionReturn(0);
561 }
562 
563 static PetscErrorCode DMSubDomainHook_TSIRK(DM dm,DM subdm,void *ctx)
564 {
565   PetscFunctionBegin;
566   PetscFunctionReturn(0);
567 }
568 
569 static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm,VecScatter gscat,VecScatter lscat,DM subdm,void *ctx)
570 {
571   TS             ts = (TS)ctx;
572   Vec            U,U_c;
573 
574   PetscFunctionBegin;
575   PetscCall(TSIRKGetVecs(ts,dm,&U));
576   PetscCall(TSIRKGetVecs(ts,subdm,&U_c));
577 
578   PetscCall(VecScatterBegin(gscat,U,U_c,INSERT_VALUES,SCATTER_FORWARD));
579   PetscCall(VecScatterEnd(gscat,U,U_c,INSERT_VALUES,SCATTER_FORWARD));
580 
581   PetscCall(TSIRKRestoreVecs(ts,dm,&U));
582   PetscCall(TSIRKRestoreVecs(ts,subdm,&U_c));
583   PetscFunctionReturn(0);
584 }
585 
586 static PetscErrorCode TSSetUp_IRK(TS ts)
587 {
588   TS_IRK         *irk = (TS_IRK*)ts->data;
589   IRKTableau     tab = irk->tableau;
590   DM             dm;
591   Mat            J;
592   Vec            R;
593   const PetscInt nstages = irk->nstages;
594   PetscInt       vsize,bs;
595 
596   PetscFunctionBegin;
597   if (!irk->work) {
598     PetscCall(PetscMalloc1(irk->nstages,&irk->work));
599   }
600   if (!irk->Y) {
601     PetscCall(VecDuplicateVecs(ts->vec_sol,irk->nstages,&irk->Y));
602   }
603   if (!irk->YdotI) {
604     PetscCall(VecDuplicateVecs(ts->vec_sol,irk->nstages,&irk->YdotI));
605   }
606   if (!irk->Ydot) {
607     PetscCall(VecDuplicate(ts->vec_sol,&irk->Ydot));
608   }
609   if (!irk->U) {
610     PetscCall(VecDuplicate(ts->vec_sol,&irk->U));
611   }
612   if (!irk->U0) {
613     PetscCall(VecDuplicate(ts->vec_sol,&irk->U0));
614   }
615   if (!irk->Z) {
616     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol),&irk->Z));
617     PetscCall(VecGetSize(ts->vec_sol,&vsize));
618     PetscCall(VecSetSizes(irk->Z,PETSC_DECIDE,vsize*irk->nstages));
619     PetscCall(VecGetBlockSize(ts->vec_sol,&bs));
620     PetscCall(VecSetBlockSize(irk->Z,irk->nstages*bs));
621     PetscCall(VecSetFromOptions(irk->Z));
622   }
623   PetscCall(TSGetDM(ts,&dm));
624   PetscCall(DMCoarsenHookAdd(dm,DMCoarsenHook_TSIRK,DMRestrictHook_TSIRK,ts));
625   PetscCall(DMSubDomainHookAdd(dm,DMSubDomainHook_TSIRK,DMSubDomainRestrictHook_TSIRK,ts));
626 
627   PetscCall(TSGetSNES(ts,&ts->snes));
628   PetscCall(VecDuplicate(irk->Z,&R));
629   PetscCall(SNESSetFunction(ts->snes,R,SNESTSFormFunction,ts));
630   PetscCall(TSGetIJacobian(ts,&J,NULL,NULL,NULL));
631   if (!irk->TJ) {
632     /* Create the KAIJ matrix for solving the stages */
633     PetscCall(MatCreateKAIJ(J,nstages,nstages,tab->A_inv,tab->I_s,&irk->TJ));
634   }
635   PetscCall(SNESSetJacobian(ts->snes,irk->TJ,irk->TJ,SNESTSFormJacobian,ts));
636   PetscCall(VecDestroy(&R));
637   PetscFunctionReturn(0);
638 }
639 
640 static PetscErrorCode TSSetFromOptions_IRK(PetscOptionItems *PetscOptionsObject,TS ts)
641 {
642   TS_IRK         *irk = (TS_IRK*)ts->data;
643   char           tname[256] = TSIRKGAUSS;
644 
645   PetscFunctionBegin;
646   PetscCall(PetscOptionsHead(PetscOptionsObject,"IRK ODE solver options"));
647   {
648     PetscBool flg1,flg2;
649     PetscCall(PetscOptionsInt("-ts_irk_nstages","Stages of the IRK method","TSIRKSetNumStages",irk->nstages,&irk->nstages,&flg1));
650     PetscCall(PetscOptionsFList("-ts_irk_type","Type of IRK method","TSIRKSetType",TSIRKList,irk->method_name[0] ? irk->method_name : tname,tname,sizeof(tname),&flg2));
651     if (flg1 ||flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
652       PetscCall(TSIRKSetType(ts,tname));
653     }
654   }
655   PetscCall(PetscOptionsTail());
656   PetscFunctionReturn(0);
657 }
658 
659 static PetscErrorCode TSView_IRK(TS ts,PetscViewer viewer)
660 {
661   TS_IRK         *irk = (TS_IRK*)ts->data;
662   PetscBool      iascii;
663 
664   PetscFunctionBegin;
665   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii));
666   if (iascii) {
667     IRKTableau    tab = irk->tableau;
668     TSIRKType irktype;
669     char          buf[512];
670 
671     PetscCall(TSIRKGetType(ts,&irktype));
672     PetscCall(PetscViewerASCIIPrintf(viewer,"  IRK type %s\n",irktype));
673     PetscCall(PetscFormatRealArray(buf,sizeof(buf),"% 8.6f",irk->nstages,tab->c));
674     PetscCall(PetscViewerASCIIPrintf(viewer,"  Abscissa       c = %s\n",buf));
675     PetscCall(PetscViewerASCIIPrintf(viewer,"Stiffly accurate: %s\n",irk->stiffly_accurate ? "yes" : "no"));
676     PetscCall(PetscFormatRealArray(buf,sizeof(buf),"% 8.6f",PetscSqr(irk->nstages),tab->A));
677     PetscCall(PetscViewerASCIIPrintf(viewer,"  A coefficients       A = %s\n",buf));
678   }
679   PetscFunctionReturn(0);
680 }
681 
682 static PetscErrorCode TSLoad_IRK(TS ts,PetscViewer viewer)
683 {
684   SNES           snes;
685   TSAdapt        adapt;
686 
687   PetscFunctionBegin;
688   PetscCall(TSGetAdapt(ts,&adapt));
689   PetscCall(TSAdaptLoad(adapt,viewer));
690   PetscCall(TSGetSNES(ts,&snes));
691   PetscCall(SNESLoad(snes,viewer));
692   /* function and Jacobian context for SNES when used with TS is always ts object */
693   PetscCall(SNESSetFunction(snes,NULL,NULL,ts));
694   PetscCall(SNESSetJacobian(snes,NULL,NULL,NULL,ts));
695   PetscFunctionReturn(0);
696 }
697 
698 /*@C
699   TSIRKSetType - Set the type of IRK scheme
700 
701   Logically collective
702 
703   Input Parameters:
704 +  ts - timestepping context
705 -  irktype - type of IRK scheme
706 
707   Options Database:
708 .  -ts_irk_type <gauss> - set irk type
709 
710   Level: intermediate
711 
712 .seealso: TSIRKGetType(), TSIRK, TSIRKType, TSIRKGAUSS
713 @*/
714 PetscErrorCode TSIRKSetType(TS ts,TSIRKType irktype)
715 {
716   PetscFunctionBegin;
717   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
718   PetscValidCharPointer(irktype,2);
719   PetscTryMethod(ts,"TSIRKSetType_C",(TS,TSIRKType),(ts,irktype));
720   PetscFunctionReturn(0);
721 }
722 
723 /*@C
724   TSIRKGetType - Get the type of IRK IMEX scheme
725 
726   Logically collective
727 
728   Input Parameter:
729 .  ts - timestepping context
730 
731   Output Parameter:
732 .  irktype - type of IRK-IMEX scheme
733 
734   Level: intermediate
735 
736 .seealso: TSIRKGetType()
737 @*/
738 PetscErrorCode TSIRKGetType(TS ts,TSIRKType *irktype)
739 {
740   PetscFunctionBegin;
741   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
742   PetscUseMethod(ts,"TSIRKGetType_C",(TS,TSIRKType*),(ts,irktype));
743   PetscFunctionReturn(0);
744 }
745 
746 /*@C
747   TSIRKSetNumStages - Set the number of stages of IRK scheme
748 
749   Logically collective
750 
751   Input Parameters:
752 +  ts - timestepping context
753 -  nstages - number of stages of IRK scheme
754 
755   Options Database:
756 .  -ts_irk_nstages <int> - set number of stages
757 
758   Level: intermediate
759 
760 .seealso: TSIRKGetNumStages(), TSIRK
761 @*/
762 PetscErrorCode TSIRKSetNumStages(TS ts,PetscInt nstages)
763 {
764   PetscFunctionBegin;
765   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
766   PetscTryMethod(ts,"TSIRKSetNumStages_C",(TS,PetscInt),(ts,nstages));
767   PetscFunctionReturn(0);
768 }
769 
770 /*@C
771   TSIRKGetNumStages - Get the number of stages of IRK scheme
772 
773   Logically collective
774 
775   Input Parameters:
776 +  ts - timestepping context
777 -  nstages - number of stages of IRK scheme
778 
779   Level: intermediate
780 
781 .seealso: TSIRKSetNumStages(), TSIRK
782 @*/
783 PetscErrorCode TSIRKGetNumStages(TS ts,PetscInt *nstages)
784 {
785   PetscFunctionBegin;
786   PetscValidHeaderSpecific(ts,TS_CLASSID,1);
787   PetscValidIntPointer(nstages,2);
788   PetscTryMethod(ts,"TSIRKGetNumStages_C",(TS,PetscInt*),(ts,nstages));
789   PetscFunctionReturn(0);
790 }
791 
792 static PetscErrorCode TSIRKGetType_IRK(TS ts,TSIRKType *irktype)
793 {
794   TS_IRK *irk = (TS_IRK*)ts->data;
795 
796   PetscFunctionBegin;
797   *irktype = irk->method_name;
798   PetscFunctionReturn(0);
799 }
800 
801 static PetscErrorCode TSIRKSetType_IRK(TS ts,TSIRKType irktype)
802 {
803   TS_IRK         *irk = (TS_IRK*)ts->data;
804   PetscErrorCode (*irkcreate)(TS);
805 
806   PetscFunctionBegin;
807   if (irk->method_name) {
808     PetscCall(PetscFree(irk->method_name));
809     PetscCall(TSIRKTableauReset(ts));
810   }
811   PetscCall(PetscFunctionListFind(TSIRKList,irktype,&irkcreate));
812   PetscCheck(irkcreate,PETSC_COMM_SELF,PETSC_ERR_ARG_UNKNOWN_TYPE,"Unknown TSIRK type \"%s\" given",irktype);
813   PetscCall((*irkcreate)(ts));
814   PetscCall(PetscStrallocpy(irktype,&irk->method_name));
815   PetscFunctionReturn(0);
816 }
817 
818 static PetscErrorCode TSIRKSetNumStages_IRK(TS ts,PetscInt nstages)
819 {
820   TS_IRK *irk = (TS_IRK*)ts->data;
821 
822   PetscFunctionBegin;
823   PetscCheck(nstages>0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"input argument, %d, out of range",nstages);
824   irk->nstages = nstages;
825   PetscFunctionReturn(0);
826 }
827 
828 static PetscErrorCode TSIRKGetNumStages_IRK(TS ts,PetscInt *nstages)
829 {
830   TS_IRK *irk = (TS_IRK*)ts->data;
831 
832   PetscFunctionBegin;
833   PetscValidIntPointer(nstages,2);
834   *nstages = irk->nstages;
835   PetscFunctionReturn(0);
836 }
837 
838 static PetscErrorCode TSDestroy_IRK(TS ts)
839 {
840   PetscFunctionBegin;
841   PetscCall(TSReset_IRK(ts));
842   if (ts->dm) {
843     PetscCall(DMCoarsenHookRemove(ts->dm,DMCoarsenHook_TSIRK,DMRestrictHook_TSIRK,ts));
844     PetscCall(DMSubDomainHookRemove(ts->dm,DMSubDomainHook_TSIRK,DMSubDomainRestrictHook_TSIRK,ts));
845   }
846   PetscCall(PetscFree(ts->data));
847   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetType_C",NULL));
848   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetType_C",NULL));
849   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetNumStages_C",NULL));
850   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetNumStages_C",NULL));
851   PetscFunctionReturn(0);
852 }
853 
854 /*MC
855       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes
856 
857   Notes:
858 
859   TSIRK uses the sparse Kronecker product matrix implementation of MATKAIJ to achieve good arithmetic intensity.
860 
861   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with -ts_irk_nstages or TSIRKSetNumStages().
862 
863   Level: beginner
864 
865 .seealso:  TSCreate(), TS, TSSetType(), TSIRKSetType(), TSIRKGetType(), TSIRKGAUSS, TSIRKRegister(), TSIRKSetNumStages()
866 
867 M*/
868 PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
869 {
870   TS_IRK         *irk;
871 
872   PetscFunctionBegin;
873   PetscCall(TSIRKInitializePackage());
874 
875   ts->ops->reset          = TSReset_IRK;
876   ts->ops->destroy        = TSDestroy_IRK;
877   ts->ops->view           = TSView_IRK;
878   ts->ops->load           = TSLoad_IRK;
879   ts->ops->setup          = TSSetUp_IRK;
880   ts->ops->step           = TSStep_IRK;
881   ts->ops->interpolate    = TSInterpolate_IRK;
882   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
883   ts->ops->rollback       = TSRollBack_IRK;
884   ts->ops->setfromoptions = TSSetFromOptions_IRK;
885   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
886   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;
887 
888   ts->usessnes = PETSC_TRUE;
889 
890   PetscCall(PetscNewLog(ts,&irk));
891   ts->data = (void*)irk;
892 
893   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetType_C",TSIRKSetType_IRK));
894   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetType_C",TSIRKGetType_IRK));
895   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKSetNumStages_C",TSIRKSetNumStages_IRK));
896   PetscCall(PetscObjectComposeFunction((PetscObject)ts,"TSIRKGetNumStages_C",TSIRKGetNumStages_IRK));
897   /* 3-stage IRK_Gauss is the default */
898   PetscCall(PetscNew(&irk->tableau));
899   irk->nstages = 3;
900   PetscCall(TSIRKSetType(ts,TSIRKDefault));
901   PetscFunctionReturn(0);
902 }
903