xref: /petsc/src/ts/impls/implicit/irk/irk.c (revision a69119a591a03a9d906b29c0a4e9802e4d7c9795)
1 /*
2   Code for timestepping with implicit Runge-Kutta method
3 
4   Notes:
5   The general system is written as
6 
7   F(t,U,Udot) = 0
8 
9 */
10 #include <petsc/private/tsimpl.h> /*I   "petscts.h"   I*/
11 #include <petscdm.h>
12 #include <petscdt.h>
13 
14 static TSIRKType         TSIRKDefault = TSIRKGAUSS;
15 static PetscBool         TSIRKRegisterAllCalled;
16 static PetscBool         TSIRKPackageInitialized;
17 static PetscFunctionList TSIRKList;
18 
19 struct _IRKTableau {
20   PetscReal   *A, *b, *c;
21   PetscScalar *A_inv, *A_inv_rowsum, *I_s;
22   PetscReal   *binterp; /* Dense output formula */
23 };
24 
25 typedef struct _IRKTableau *IRKTableau;
26 
27 typedef struct {
28   char        *method_name;
29   PetscInt     order;   /* Classical approximation order of the method */
30   PetscInt     nstages; /* Number of stages */
31   PetscBool    stiffly_accurate;
32   PetscInt     pinterp; /* Interpolation order */
33   IRKTableau   tableau;
34   Vec          U0;    /* Backup vector */
35   Vec          Z;     /* Combined stage vector */
36   Vec         *Y;     /* States computed during the step */
37   Vec          Ydot;  /* Work vector holding time derivatives during residual evaluation */
38   Vec          U;     /* U is used to compute Ydot = shift(Y-U) */
39   Vec         *YdotI; /* Work vectors to hold the residual evaluation */
40   Mat          TJ;    /* KAIJ matrix for the Jacobian of the combined system */
41   PetscScalar *work;  /* Scalar work */
42   TSStepStatus status;
43   PetscBool    rebuild_completion;
44   PetscReal    ccfl;
45 } TS_IRK;
46 
47 /*@C
48    TSIRKTableauCreate - create the tableau for TSIRK and provide the entries
49 
50    Not Collective
51 
52    Input Parameters:
53 +  ts - timestepping context
54 .  nstages - number of stages, this is the dimension of the matrices below
55 .  A - stage coefficients (dimension nstages*nstages, row-major)
56 .  b - step completion table (dimension nstages)
57 .  c - abscissa (dimension nstages)
58 .  binterp - coefficients of the interpolation formula (dimension nstages)
59 .  A_inv - inverse of A (dimension nstages*nstages, row-major)
60 .  A_inv_rowsum - row sum of the inverse of A (dimension nstages)
61 -  I_s - identity matrix (dimension nstages*nstages)
62 
63    Level: advanced
64 
65 .seealso: `TSIRK`, `TSIRKRegister()`
66 @*/
67 PetscErrorCode TSIRKTableauCreate(TS ts, PetscInt nstages, const PetscReal *A, const PetscReal *b, const PetscReal *c, const PetscReal *binterp, const PetscScalar *A_inv, const PetscScalar *A_inv_rowsum, const PetscScalar *I_s) {
68   TS_IRK    *irk = (TS_IRK *)ts->data;
69   IRKTableau tab = irk->tableau;
70 
71   PetscFunctionBegin;
72   irk->order = nstages;
73   PetscCall(PetscMalloc3(PetscSqr(nstages), &tab->A, PetscSqr(nstages), &tab->A_inv, PetscSqr(nstages), &tab->I_s));
74   PetscCall(PetscMalloc4(nstages, &tab->b, nstages, &tab->c, nstages, &tab->binterp, nstages, &tab->A_inv_rowsum));
75   PetscCall(PetscArraycpy(tab->A, A, PetscSqr(nstages)));
76   PetscCall(PetscArraycpy(tab->b, b, nstages));
77   PetscCall(PetscArraycpy(tab->c, c, nstages));
78   /* optional coefficient arrays */
79   if (binterp) PetscCall(PetscArraycpy(tab->binterp, binterp, nstages));
80   if (A_inv) PetscCall(PetscArraycpy(tab->A_inv, A_inv, PetscSqr(nstages)));
81   if (A_inv_rowsum) PetscCall(PetscArraycpy(tab->A_inv_rowsum, A_inv_rowsum, nstages));
82   if (I_s) PetscCall(PetscArraycpy(tab->I_s, I_s, PetscSqr(nstages)));
83   PetscFunctionReturn(0);
84 }
85 
86 /* Arrays should be freed with PetscFree3(A,b,c) */
87 static PetscErrorCode TSIRKCreate_Gauss(TS ts) {
88   PetscInt     nstages;
89   PetscReal   *gauss_A_real, *gauss_b, *b, *gauss_c;
90   PetscScalar *gauss_A, *gauss_A_inv, *gauss_A_inv_rowsum, *I_s;
91   PetscScalar *G0, *G1;
92   PetscInt     i, j;
93   Mat          G0mat, G1mat, Amat;
94 
95   PetscFunctionBegin;
96   PetscCall(TSIRKGetNumStages(ts, &nstages));
97   PetscCall(PetscMalloc3(PetscSqr(nstages), &gauss_A_real, nstages, &gauss_b, nstages, &gauss_c));
98   PetscCall(PetscMalloc4(PetscSqr(nstages), &gauss_A, PetscSqr(nstages), &gauss_A_inv, nstages, &gauss_A_inv_rowsum, PetscSqr(nstages), &I_s));
99   PetscCall(PetscMalloc3(nstages, &b, PetscSqr(nstages), &G0, PetscSqr(nstages), &G1));
100   PetscCall(PetscDTGaussQuadrature(nstages, 0., 1., gauss_c, b));
101   for (i = 0; i < nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */
102 
103   /* A^T = G0^{-1} G1 */
104   for (i = 0; i < nstages; i++) {
105     for (j = 0; j < nstages; j++) {
106       G0[i * nstages + j] = PetscPowRealInt(gauss_c[i], j);
107       G1[i * nstages + j] = PetscPowRealInt(gauss_c[i], j + 1) / (j + 1);
108     }
109   }
110   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
111   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G0, &G0mat));
112   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G1, &G1mat));
113   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, gauss_A, &Amat));
114   PetscCall(MatLUFactor(G0mat, NULL, NULL, NULL));
115   PetscCall(MatMatSolve(G0mat, G1mat, Amat));
116   PetscCall(MatTranspose(Amat, MAT_INPLACE_MATRIX, &Amat));
117   for (i = 0; i < nstages; i++)
118     for (j = 0; j < nstages; j++) gauss_A_real[i * nstages + j] = PetscRealPart(gauss_A[i * nstages + j]);
119 
120   PetscCall(MatDestroy(&G0mat));
121   PetscCall(MatDestroy(&G1mat));
122   PetscCall(MatDestroy(&Amat));
123   PetscCall(PetscFree3(b, G0, G1));
124 
125   { /* Invert A */
126     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
127      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
128      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
129     Mat                A_baij;
130     PetscInt           idxm[1] = {0}, idxn[1] = {0};
131     const PetscScalar *A_inv;
132 
133     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF, nstages, nstages, nstages, 1, NULL, &A_baij));
134     PetscCall(MatSetOption(A_baij, MAT_ROW_ORIENTED, PETSC_FALSE));
135     PetscCall(MatSetValuesBlocked(A_baij, 1, idxm, 1, idxn, gauss_A, INSERT_VALUES));
136     PetscCall(MatAssemblyBegin(A_baij, MAT_FINAL_ASSEMBLY));
137     PetscCall(MatAssemblyEnd(A_baij, MAT_FINAL_ASSEMBLY));
138     PetscCall(MatInvertBlockDiagonal(A_baij, &A_inv));
139     PetscCall(PetscMemcpy(gauss_A_inv, A_inv, nstages * nstages * sizeof(PetscScalar)));
140     PetscCall(MatDestroy(&A_baij));
141   }
142 
143   /* Compute row sums A_inv_rowsum and identity I_s */
144   for (i = 0; i < nstages; i++) {
145     gauss_A_inv_rowsum[i] = 0;
146     for (j = 0; j < nstages; j++) {
147       gauss_A_inv_rowsum[i] += gauss_A_inv[i + nstages * j];
148       I_s[i + nstages * j] = 1. * (i == j);
149     }
150   }
151   PetscCall(TSIRKTableauCreate(ts, nstages, gauss_A_real, gauss_b, gauss_c, NULL, gauss_A_inv, gauss_A_inv_rowsum, I_s));
152   PetscCall(PetscFree3(gauss_A_real, gauss_b, gauss_c));
153   PetscCall(PetscFree4(gauss_A, gauss_A_inv, gauss_A_inv_rowsum, I_s));
154   PetscFunctionReturn(0);
155 }
156 
157 /*@C
158    TSIRKRegister -  adds a TSIRK implementation
159 
160    Not Collective
161 
162    Input Parameters:
163 +  sname - name of user-defined IRK scheme
164 -  function - function to create method context
165 
166    Notes:
167    TSIRKRegister() may be called multiple times to add several user-defined families.
168 
169    Sample usage:
170 .vb
171    TSIRKRegister("my_scheme",MySchemeCreate);
172 .ve
173 
174    Then, your scheme can be chosen with the procedural interface via
175 $     TSIRKSetType(ts,"my_scheme")
176    or at runtime via the option
177 $     -ts_irk_type my_scheme
178 
179    Level: advanced
180 
181 .seealso: `TSIRKRegisterAll()`
182 @*/
183 PetscErrorCode TSIRKRegister(const char sname[], PetscErrorCode (*function)(TS)) {
184   PetscFunctionBegin;
185   PetscCall(TSIRKInitializePackage());
186   PetscCall(PetscFunctionListAdd(&TSIRKList, sname, function));
187   PetscFunctionReturn(0);
188 }
189 
190 /*@C
191   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in TSIRK
192 
193   Not Collective, but should be called by all processes which will need the schemes to be registered
194 
195   Level: advanced
196 
197 .seealso: `TSIRKRegisterDestroy()`
198 @*/
199 PetscErrorCode TSIRKRegisterAll(void) {
200   PetscFunctionBegin;
201   if (TSIRKRegisterAllCalled) PetscFunctionReturn(0);
202   TSIRKRegisterAllCalled = PETSC_TRUE;
203 
204   PetscCall(TSIRKRegister(TSIRKGAUSS, TSIRKCreate_Gauss));
205   PetscFunctionReturn(0);
206 }
207 
208 /*@C
209    TSIRKRegisterDestroy - Frees the list of schemes that were registered by TSIRKRegister().
210 
211    Not Collective
212 
213    Level: advanced
214 
215 .seealso: `TSIRKRegister()`, `TSIRKRegisterAll()`
216 @*/
217 PetscErrorCode TSIRKRegisterDestroy(void) {
218   PetscFunctionBegin;
219   TSIRKRegisterAllCalled = PETSC_FALSE;
220   PetscFunctionReturn(0);
221 }
222 
223 /*@C
224   TSIRKInitializePackage - This function initializes everything in the TSIRK package. It is called
225   from TSInitializePackage().
226 
227   Level: developer
228 
229 .seealso: `PetscInitialize()`
230 @*/
231 PetscErrorCode TSIRKInitializePackage(void) {
232   PetscFunctionBegin;
233   if (TSIRKPackageInitialized) PetscFunctionReturn(0);
234   TSIRKPackageInitialized = PETSC_TRUE;
235   PetscCall(TSIRKRegisterAll());
236   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
237   PetscFunctionReturn(0);
238 }
239 
240 /*@C
241   TSIRKFinalizePackage - This function destroys everything in the TSIRK package. It is
242   called from PetscFinalize().
243 
244   Level: developer
245 
246 .seealso: `PetscFinalize()`
247 @*/
248 PetscErrorCode TSIRKFinalizePackage(void) {
249   PetscFunctionBegin;
250   PetscCall(PetscFunctionListDestroy(&TSIRKList));
251   TSIRKPackageInitialized = PETSC_FALSE;
252   PetscFunctionReturn(0);
253 }
254 
255 /*
256  This function can be called before or after ts->vec_sol has been updated.
257 */
258 static PetscErrorCode TSEvaluateStep_IRK(TS ts, PetscInt order, Vec U, PetscBool *done) {
259   TS_IRK      *irk   = (TS_IRK *)ts->data;
260   IRKTableau   tab   = irk->tableau;
261   Vec         *YdotI = irk->YdotI;
262   PetscScalar *w     = irk->work;
263   PetscReal    h;
264   PetscInt     j;
265 
266   PetscFunctionBegin;
267   switch (irk->status) {
268   case TS_STEP_INCOMPLETE:
269   case TS_STEP_PENDING: h = ts->time_step; break;
270   case TS_STEP_COMPLETE: h = ts->ptime - ts->ptime_prev; break;
271   default: SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
272   }
273 
274   PetscCall(VecCopy(ts->vec_sol, U));
275   for (j = 0; j < irk->nstages; j++) w[j] = h * tab->b[j];
276   PetscCall(VecMAXPY(U, irk->nstages, w, YdotI));
277   PetscFunctionReturn(0);
278 }
279 
280 static PetscErrorCode TSRollBack_IRK(TS ts) {
281   TS_IRK *irk = (TS_IRK *)ts->data;
282 
283   PetscFunctionBegin;
284   PetscCall(VecCopy(irk->U0, ts->vec_sol));
285   PetscFunctionReturn(0);
286 }
287 
288 static PetscErrorCode TSStep_IRK(TS ts) {
289   TS_IRK        *irk   = (TS_IRK *)ts->data;
290   IRKTableau     tab   = irk->tableau;
291   PetscScalar   *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
292   const PetscInt nstages = irk->nstages;
293   SNES           snes;
294   PetscInt       i, j, its, lits, bs;
295   TSAdapt        adapt;
296   PetscInt       rejections     = 0;
297   PetscBool      accept         = PETSC_TRUE;
298   PetscReal      next_time_step = ts->time_step;
299 
300   PetscFunctionBegin;
301   if (!ts->steprollback) PetscCall(VecCopy(ts->vec_sol, irk->U0));
302   PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
303   for (i = 0; i < nstages; i++) PetscCall(VecStrideScatter(ts->vec_sol, i * bs, irk->Z, INSERT_VALUES));
304 
305   irk->status = TS_STEP_INCOMPLETE;
306   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
307     PetscCall(VecCopy(ts->vec_sol, irk->U));
308     PetscCall(TSGetSNES(ts, &snes));
309     PetscCall(SNESSolve(snes, NULL, irk->Z));
310     PetscCall(SNESGetIterationNumber(snes, &its));
311     PetscCall(SNESGetLinearSolveIterations(snes, &lits));
312     ts->snes_its += its;
313     ts->ksp_its += lits;
314     PetscCall(VecStrideGatherAll(irk->Z, irk->Y, INSERT_VALUES));
315     for (i = 0; i < nstages; i++) {
316       PetscCall(VecZeroEntries(irk->YdotI[i]));
317       for (j = 0; j < nstages; j++) PetscCall(VecAXPY(irk->YdotI[i], A_inv[i + j * nstages] / ts->time_step, irk->Y[j]));
318       PetscCall(VecAXPY(irk->YdotI[i], -A_inv_rowsum[i] / ts->time_step, irk->U));
319     }
320     irk->status = TS_STEP_INCOMPLETE;
321     PetscCall(TSEvaluateStep_IRK(ts, irk->order, ts->vec_sol, NULL));
322     irk->status = TS_STEP_PENDING;
323     PetscCall(TSGetAdapt(ts, &adapt));
324     PetscCall(TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept));
325     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
326     if (!accept) {
327       PetscCall(TSRollBack_IRK(ts));
328       ts->time_step = next_time_step;
329       goto reject_step;
330     }
331 
332     ts->ptime += ts->time_step;
333     ts->time_step = next_time_step;
334     break;
335   reject_step:
336     ts->reject++;
337     accept = PETSC_FALSE;
338     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
339       ts->reason = TS_DIVERGED_STEP_REJECTED;
340       PetscCall(PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections));
341     }
342   }
343   PetscFunctionReturn(0);
344 }
345 
346 static PetscErrorCode TSInterpolate_IRK(TS ts, PetscReal itime, Vec U) {
347   TS_IRK          *irk     = (TS_IRK *)ts->data;
348   PetscInt         nstages = irk->nstages, pinterp = irk->pinterp, i, j;
349   PetscReal        h;
350   PetscReal        tt, t;
351   PetscScalar     *bt;
352   const PetscReal *B = irk->tableau->binterp;
353 
354   PetscFunctionBegin;
355   PetscCheck(B, PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not have an interpolation formula", irk->method_name);
356   switch (irk->status) {
357   case TS_STEP_INCOMPLETE:
358   case TS_STEP_PENDING:
359     h = ts->time_step;
360     t = (itime - ts->ptime) / h;
361     break;
362   case TS_STEP_COMPLETE:
363     h = ts->ptime - ts->ptime_prev;
364     t = (itime - ts->ptime) / h + 1; /* In the interval [0,1] */
365     break;
366   default: SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
367   }
368   PetscCall(PetscMalloc1(nstages, &bt));
369   for (i = 0; i < nstages; i++) bt[i] = 0;
370   for (j = 0, tt = t; j < pinterp; j++, tt *= t) {
371     for (i = 0; i < nstages; i++) bt[i] += h * B[i * pinterp + j] * tt;
372   }
373   PetscCall(VecMAXPY(U, nstages, bt, irk->YdotI));
374   PetscFunctionReturn(0);
375 }
376 
377 static PetscErrorCode TSIRKTableauReset(TS ts) {
378   TS_IRK    *irk = (TS_IRK *)ts->data;
379   IRKTableau tab = irk->tableau;
380 
381   PetscFunctionBegin;
382   if (!tab) PetscFunctionReturn(0);
383   PetscCall(PetscFree3(tab->A, tab->A_inv, tab->I_s));
384   PetscCall(PetscFree4(tab->b, tab->c, tab->binterp, tab->A_inv_rowsum));
385   PetscFunctionReturn(0);
386 }
387 
388 static PetscErrorCode TSReset_IRK(TS ts) {
389   TS_IRK *irk = (TS_IRK *)ts->data;
390 
391   PetscFunctionBegin;
392   PetscCall(TSIRKTableauReset(ts));
393   if (irk->tableau) PetscCall(PetscFree(irk->tableau));
394   if (irk->method_name) PetscCall(PetscFree(irk->method_name));
395   if (irk->work) PetscCall(PetscFree(irk->work));
396   PetscCall(VecDestroyVecs(irk->nstages, &irk->Y));
397   PetscCall(VecDestroyVecs(irk->nstages, &irk->YdotI));
398   PetscCall(VecDestroy(&irk->Ydot));
399   PetscCall(VecDestroy(&irk->Z));
400   PetscCall(VecDestroy(&irk->U));
401   PetscCall(VecDestroy(&irk->U0));
402   PetscCall(MatDestroy(&irk->TJ));
403   PetscFunctionReturn(0);
404 }
405 
406 static PetscErrorCode TSIRKGetVecs(TS ts, DM dm, Vec *U) {
407   TS_IRK *irk = (TS_IRK *)ts->data;
408 
409   PetscFunctionBegin;
410   if (U) {
411     if (dm && dm != ts->dm) {
412       PetscCall(DMGetNamedGlobalVector(dm, "TSIRK_U", U));
413     } else *U = irk->U;
414   }
415   PetscFunctionReturn(0);
416 }
417 
418 static PetscErrorCode TSIRKRestoreVecs(TS ts, DM dm, Vec *U) {
419   PetscFunctionBegin;
420   if (U) {
421     if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSIRK_U", U));
422   }
423   PetscFunctionReturn(0);
424 }
425 
426 /*
427   This defines the nonlinear equations that is to be solved with SNES
428     G[e\otimes t + C*dt, Z, Zdot] = 0
429     Zdot = (In \otimes S)*Z - (In \otimes Se) U
430   where S = 1/(dt*A)
431 */
432 static PetscErrorCode SNESTSFormFunction_IRK(SNES snes, Vec ZC, Vec FC, TS ts) {
433   TS_IRK            *irk     = (TS_IRK *)ts->data;
434   IRKTableau         tab     = irk->tableau;
435   const PetscInt     nstages = irk->nstages;
436   const PetscReal   *c       = tab->c;
437   const PetscScalar *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
438   DM                 dm, dmsave;
439   Vec                U, *YdotI = irk->YdotI, Ydot = irk->Ydot, *Y = irk->Y;
440   PetscReal          h = ts->time_step;
441   PetscInt           i, j;
442 
443   PetscFunctionBegin;
444   PetscCall(SNESGetDM(snes, &dm));
445   PetscCall(TSIRKGetVecs(ts, dm, &U));
446   PetscCall(VecStrideGatherAll(ZC, Y, INSERT_VALUES));
447   dmsave = ts->dm;
448   ts->dm = dm;
449   for (i = 0; i < nstages; i++) {
450     PetscCall(VecZeroEntries(Ydot));
451     for (j = 0; j < nstages; j++) PetscCall(VecAXPY(Ydot, A_inv[j * nstages + i] / h, Y[j]));
452     PetscCall(VecAXPY(Ydot, -A_inv_rowsum[i] / h, U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
453     PetscCall(TSComputeIFunction(ts, ts->ptime + ts->time_step * c[i], Y[i], Ydot, YdotI[i], PETSC_FALSE));
454   }
455   PetscCall(VecStrideScatterAll(YdotI, FC, INSERT_VALUES));
456   ts->dm = dmsave;
457   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
458   PetscFunctionReturn(0);
459 }
460 
461 /*
462    For explicit ODE, the Jacobian is
463      JC = I_n \otimes S - J \otimes I_s
464    For DAE, the Jacobian is
465      JC = M_n \otimes S - J \otimes I_s
466 */
467 static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes, Vec ZC, Mat JC, Mat JCpre, TS ts) {
468   TS_IRK          *irk     = (TS_IRK *)ts->data;
469   IRKTableau       tab     = irk->tableau;
470   const PetscInt   nstages = irk->nstages;
471   const PetscReal *c       = tab->c;
472   DM               dm, dmsave;
473   Vec             *Y = irk->Y, Ydot = irk->Ydot;
474   Mat              J;
475   PetscScalar     *S;
476   PetscInt         i, j, bs;
477 
478   PetscFunctionBegin;
479   PetscCall(SNESGetDM(snes, &dm));
480   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
481   dmsave = ts->dm;
482   ts->dm = dm;
483   PetscCall(VecGetBlockSize(Y[nstages - 1], &bs));
484   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
485     PetscCall(VecStrideGather(ZC, (nstages - 1) * bs, Y[nstages - 1], INSERT_VALUES));
486     PetscCall(MatKAIJGetAIJ(JC, &J));
487     PetscCall(TSComputeIJacobian(ts, ts->ptime + ts->time_step * c[nstages - 1], Y[nstages - 1], Ydot, 0, J, J, PETSC_FALSE));
488     PetscCall(MatKAIJGetS(JC, NULL, NULL, &S));
489     for (i = 0; i < nstages; i++)
490       for (j = 0; j < nstages; j++) S[i + nstages * j] = tab->A_inv[i + nstages * j] / ts->time_step;
491     PetscCall(MatKAIJRestoreS(JC, &S));
492   } else SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not support implicit formula", irk->method_name); /* TODO: need the mass matrix for DAE  */
493   ts->dm = dmsave;
494   PetscFunctionReturn(0);
495 }
496 
497 static PetscErrorCode DMCoarsenHook_TSIRK(DM fine, DM coarse, void *ctx) {
498   PetscFunctionBegin;
499   PetscFunctionReturn(0);
500 }
501 
502 static PetscErrorCode DMRestrictHook_TSIRK(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, void *ctx) {
503   TS  ts = (TS)ctx;
504   Vec U, U_c;
505 
506   PetscFunctionBegin;
507   PetscCall(TSIRKGetVecs(ts, fine, &U));
508   PetscCall(TSIRKGetVecs(ts, coarse, &U_c));
509   PetscCall(MatRestrict(restrct, U, U_c));
510   PetscCall(VecPointwiseMult(U_c, rscale, U_c));
511   PetscCall(TSIRKRestoreVecs(ts, fine, &U));
512   PetscCall(TSIRKRestoreVecs(ts, coarse, &U_c));
513   PetscFunctionReturn(0);
514 }
515 
516 static PetscErrorCode DMSubDomainHook_TSIRK(DM dm, DM subdm, void *ctx) {
517   PetscFunctionBegin;
518   PetscFunctionReturn(0);
519 }
520 
521 static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, void *ctx) {
522   TS  ts = (TS)ctx;
523   Vec U, U_c;
524 
525   PetscFunctionBegin;
526   PetscCall(TSIRKGetVecs(ts, dm, &U));
527   PetscCall(TSIRKGetVecs(ts, subdm, &U_c));
528 
529   PetscCall(VecScatterBegin(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
530   PetscCall(VecScatterEnd(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
531 
532   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
533   PetscCall(TSIRKRestoreVecs(ts, subdm, &U_c));
534   PetscFunctionReturn(0);
535 }
536 
537 static PetscErrorCode TSSetUp_IRK(TS ts) {
538   TS_IRK        *irk = (TS_IRK *)ts->data;
539   IRKTableau     tab = irk->tableau;
540   DM             dm;
541   Mat            J;
542   Vec            R;
543   const PetscInt nstages = irk->nstages;
544   PetscInt       vsize, bs;
545 
546   PetscFunctionBegin;
547   if (!irk->work) PetscCall(PetscMalloc1(irk->nstages, &irk->work));
548   if (!irk->Y) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->Y));
549   if (!irk->YdotI) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->YdotI));
550   if (!irk->Ydot) PetscCall(VecDuplicate(ts->vec_sol, &irk->Ydot));
551   if (!irk->U) PetscCall(VecDuplicate(ts->vec_sol, &irk->U));
552   if (!irk->U0) PetscCall(VecDuplicate(ts->vec_sol, &irk->U0));
553   if (!irk->Z) {
554     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol), &irk->Z));
555     PetscCall(VecGetSize(ts->vec_sol, &vsize));
556     PetscCall(VecSetSizes(irk->Z, PETSC_DECIDE, vsize * irk->nstages));
557     PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
558     PetscCall(VecSetBlockSize(irk->Z, irk->nstages * bs));
559     PetscCall(VecSetFromOptions(irk->Z));
560   }
561   PetscCall(TSGetDM(ts, &dm));
562   PetscCall(DMCoarsenHookAdd(dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
563   PetscCall(DMSubDomainHookAdd(dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
564 
565   PetscCall(TSGetSNES(ts, &ts->snes));
566   PetscCall(VecDuplicate(irk->Z, &R));
567   PetscCall(SNESSetFunction(ts->snes, R, SNESTSFormFunction, ts));
568   PetscCall(TSGetIJacobian(ts, &J, NULL, NULL, NULL));
569   if (!irk->TJ) {
570     /* Create the KAIJ matrix for solving the stages */
571     PetscCall(MatCreateKAIJ(J, nstages, nstages, tab->A_inv, tab->I_s, &irk->TJ));
572   }
573   PetscCall(SNESSetJacobian(ts->snes, irk->TJ, irk->TJ, SNESTSFormJacobian, ts));
574   PetscCall(VecDestroy(&R));
575   PetscFunctionReturn(0);
576 }
577 
578 static PetscErrorCode TSSetFromOptions_IRK(TS ts, PetscOptionItems *PetscOptionsObject) {
579   TS_IRK *irk        = (TS_IRK *)ts->data;
580   char    tname[256] = TSIRKGAUSS;
581 
582   PetscFunctionBegin;
583   PetscOptionsHeadBegin(PetscOptionsObject, "IRK ODE solver options");
584   {
585     PetscBool flg1, flg2;
586     PetscCall(PetscOptionsInt("-ts_irk_nstages", "Stages of the IRK method", "TSIRKSetNumStages", irk->nstages, &irk->nstages, &flg1));
587     PetscCall(PetscOptionsFList("-ts_irk_type", "Type of IRK method", "TSIRKSetType", TSIRKList, irk->method_name[0] ? irk->method_name : tname, tname, sizeof(tname), &flg2));
588     if (flg1 || flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
589       PetscCall(TSIRKSetType(ts, tname));
590     }
591   }
592   PetscOptionsHeadEnd();
593   PetscFunctionReturn(0);
594 }
595 
596 static PetscErrorCode TSView_IRK(TS ts, PetscViewer viewer) {
597   TS_IRK   *irk = (TS_IRK *)ts->data;
598   PetscBool iascii;
599 
600   PetscFunctionBegin;
601   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
602   if (iascii) {
603     IRKTableau tab = irk->tableau;
604     TSIRKType  irktype;
605     char       buf[512];
606 
607     PetscCall(TSIRKGetType(ts, &irktype));
608     PetscCall(PetscViewerASCIIPrintf(viewer, "  IRK type %s\n", irktype));
609     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", irk->nstages, tab->c));
610     PetscCall(PetscViewerASCIIPrintf(viewer, "  Abscissa       c = %s\n", buf));
611     PetscCall(PetscViewerASCIIPrintf(viewer, "Stiffly accurate: %s\n", irk->stiffly_accurate ? "yes" : "no"));
612     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", PetscSqr(irk->nstages), tab->A));
613     PetscCall(PetscViewerASCIIPrintf(viewer, "  A coefficients       A = %s\n", buf));
614   }
615   PetscFunctionReturn(0);
616 }
617 
618 static PetscErrorCode TSLoad_IRK(TS ts, PetscViewer viewer) {
619   SNES    snes;
620   TSAdapt adapt;
621 
622   PetscFunctionBegin;
623   PetscCall(TSGetAdapt(ts, &adapt));
624   PetscCall(TSAdaptLoad(adapt, viewer));
625   PetscCall(TSGetSNES(ts, &snes));
626   PetscCall(SNESLoad(snes, viewer));
627   /* function and Jacobian context for SNES when used with TS is always ts object */
628   PetscCall(SNESSetFunction(snes, NULL, NULL, ts));
629   PetscCall(SNESSetJacobian(snes, NULL, NULL, NULL, ts));
630   PetscFunctionReturn(0);
631 }
632 
633 /*@C
634   TSIRKSetType - Set the type of IRK scheme
635 
636   Logically collective
637 
638   Input Parameters:
639 +  ts - timestepping context
640 -  irktype - type of IRK scheme
641 
642   Options Database:
643 .  -ts_irk_type <gauss> - set irk type
644 
645   Level: intermediate
646 
647 .seealso: `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
648 @*/
649 PetscErrorCode TSIRKSetType(TS ts, TSIRKType irktype) {
650   PetscFunctionBegin;
651   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
652   PetscValidCharPointer(irktype, 2);
653   PetscTryMethod(ts, "TSIRKSetType_C", (TS, TSIRKType), (ts, irktype));
654   PetscFunctionReturn(0);
655 }
656 
657 /*@C
658   TSIRKGetType - Get the type of IRK IMEX scheme
659 
660   Logically collective
661 
662   Input Parameter:
663 .  ts - timestepping context
664 
665   Output Parameter:
666 .  irktype - type of IRK-IMEX scheme
667 
668   Level: intermediate
669 
670 .seealso: `TSIRKGetType()`
671 @*/
672 PetscErrorCode TSIRKGetType(TS ts, TSIRKType *irktype) {
673   PetscFunctionBegin;
674   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
675   PetscUseMethod(ts, "TSIRKGetType_C", (TS, TSIRKType *), (ts, irktype));
676   PetscFunctionReturn(0);
677 }
678 
679 /*@C
680   TSIRKSetNumStages - Set the number of stages of IRK scheme
681 
682   Logically collective
683 
684   Input Parameters:
685 +  ts - timestepping context
686 -  nstages - number of stages of IRK scheme
687 
688   Options Database:
689 .  -ts_irk_nstages <int> - set number of stages
690 
691   Level: intermediate
692 
693 .seealso: `TSIRKGetNumStages()`, `TSIRK`
694 @*/
695 PetscErrorCode TSIRKSetNumStages(TS ts, PetscInt nstages) {
696   PetscFunctionBegin;
697   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
698   PetscTryMethod(ts, "TSIRKSetNumStages_C", (TS, PetscInt), (ts, nstages));
699   PetscFunctionReturn(0);
700 }
701 
702 /*@C
703   TSIRKGetNumStages - Get the number of stages of IRK scheme
704 
705   Logically collective
706 
707   Input Parameters:
708 +  ts - timestepping context
709 -  nstages - number of stages of IRK scheme
710 
711   Level: intermediate
712 
713 .seealso: `TSIRKSetNumStages()`, `TSIRK`
714 @*/
715 PetscErrorCode TSIRKGetNumStages(TS ts, PetscInt *nstages) {
716   PetscFunctionBegin;
717   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
718   PetscValidIntPointer(nstages, 2);
719   PetscTryMethod(ts, "TSIRKGetNumStages_C", (TS, PetscInt *), (ts, nstages));
720   PetscFunctionReturn(0);
721 }
722 
723 static PetscErrorCode TSIRKGetType_IRK(TS ts, TSIRKType *irktype) {
724   TS_IRK *irk = (TS_IRK *)ts->data;
725 
726   PetscFunctionBegin;
727   *irktype = irk->method_name;
728   PetscFunctionReturn(0);
729 }
730 
731 static PetscErrorCode TSIRKSetType_IRK(TS ts, TSIRKType irktype) {
732   TS_IRK *irk = (TS_IRK *)ts->data;
733   PetscErrorCode (*irkcreate)(TS);
734 
735   PetscFunctionBegin;
736   if (irk->method_name) {
737     PetscCall(PetscFree(irk->method_name));
738     PetscCall(TSIRKTableauReset(ts));
739   }
740   PetscCall(PetscFunctionListFind(TSIRKList, irktype, &irkcreate));
741   PetscCheck(irkcreate, PETSC_COMM_SELF, PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown TSIRK type \"%s\" given", irktype);
742   PetscCall((*irkcreate)(ts));
743   PetscCall(PetscStrallocpy(irktype, &irk->method_name));
744   PetscFunctionReturn(0);
745 }
746 
747 static PetscErrorCode TSIRKSetNumStages_IRK(TS ts, PetscInt nstages) {
748   TS_IRK *irk = (TS_IRK *)ts->data;
749 
750   PetscFunctionBegin;
751   PetscCheck(nstages > 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "input argument, %" PetscInt_FMT ", out of range", nstages);
752   irk->nstages = nstages;
753   PetscFunctionReturn(0);
754 }
755 
756 static PetscErrorCode TSIRKGetNumStages_IRK(TS ts, PetscInt *nstages) {
757   TS_IRK *irk = (TS_IRK *)ts->data;
758 
759   PetscFunctionBegin;
760   PetscValidIntPointer(nstages, 2);
761   *nstages = irk->nstages;
762   PetscFunctionReturn(0);
763 }
764 
765 static PetscErrorCode TSDestroy_IRK(TS ts) {
766   PetscFunctionBegin;
767   PetscCall(TSReset_IRK(ts));
768   if (ts->dm) {
769     PetscCall(DMCoarsenHookRemove(ts->dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
770     PetscCall(DMSubDomainHookRemove(ts->dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
771   }
772   PetscCall(PetscFree(ts->data));
773   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", NULL));
774   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", NULL));
775   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", NULL));
776   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", NULL));
777   PetscFunctionReturn(0);
778 }
779 
780 /*MC
781       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes
782 
783   Notes:
784 
785   TSIRK uses the sparse Kronecker product matrix implementation of MATKAIJ to achieve good arithmetic intensity.
786 
787   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with -ts_irk_nstages or TSIRKSetNumStages().
788 
789   Level: beginner
790 
791 .seealso: `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`
792 
793 M*/
794 PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts) {
795   TS_IRK *irk;
796 
797   PetscFunctionBegin;
798   PetscCall(TSIRKInitializePackage());
799 
800   ts->ops->reset          = TSReset_IRK;
801   ts->ops->destroy        = TSDestroy_IRK;
802   ts->ops->view           = TSView_IRK;
803   ts->ops->load           = TSLoad_IRK;
804   ts->ops->setup          = TSSetUp_IRK;
805   ts->ops->step           = TSStep_IRK;
806   ts->ops->interpolate    = TSInterpolate_IRK;
807   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
808   ts->ops->rollback       = TSRollBack_IRK;
809   ts->ops->setfromoptions = TSSetFromOptions_IRK;
810   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
811   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;
812 
813   ts->usessnes = PETSC_TRUE;
814 
815   PetscCall(PetscNewLog(ts, &irk));
816   ts->data = (void *)irk;
817 
818   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", TSIRKSetType_IRK));
819   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", TSIRKGetType_IRK));
820   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", TSIRKSetNumStages_IRK));
821   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", TSIRKGetNumStages_IRK));
822   /* 3-stage IRK_Gauss is the default */
823   PetscCall(PetscNew(&irk->tableau));
824   irk->nstages = 3;
825   PetscCall(TSIRKSetType(ts, TSIRKDefault));
826   PetscFunctionReturn(0);
827 }
828