xref: /petsc/src/ts/impls/implicit/irk/irk.c (revision 0baf8eba40dbc839082666f9f7396a225d6f663c)
1 /*
2   Code for timestepping with implicit Runge-Kutta method
3 
4   Notes:
5   The general system is written as
6 
7   F(t,U,Udot) = 0
8 
9 */
10 #include <petsc/private/tsimpl.h> /*I   "petscts.h"   I*/
11 #include <petscdm.h>
12 #include <petscdt.h>
13 
14 static TSIRKType         TSIRKDefault = TSIRKGAUSS;
15 static PetscBool         TSIRKRegisterAllCalled;
16 static PetscBool         TSIRKPackageInitialized;
17 static PetscFunctionList TSIRKList;
18 
19 struct _IRKTableau {
20   PetscReal   *A, *b, *c;
21   PetscScalar *A_inv, *A_inv_rowsum, *I_s;
22   PetscReal   *binterp; /* Dense output formula */
23 };
24 
25 typedef struct _IRKTableau *IRKTableau;
26 
27 typedef struct {
28   char        *method_name;
29   PetscInt     order;   /* Classical approximation order of the method */
30   PetscInt     nstages; /* Number of stages */
31   PetscBool    stiffly_accurate;
32   PetscInt     pinterp; /* Interpolation order */
33   IRKTableau   tableau;
34   Vec          U0;    /* Backup vector */
35   Vec          Z;     /* Combined stage vector */
36   Vec         *Y;     /* States computed during the step */
37   Vec          Ydot;  /* Work vector holding time derivatives during residual evaluation */
38   Vec          U;     /* U is used to compute Ydot = shift(Y-U) */
39   Vec         *YdotI; /* Work vectors to hold the residual evaluation */
40   Mat          TJ;    /* KAIJ matrix for the Jacobian of the combined system */
41   PetscScalar *work;  /* Scalar work */
42   TSStepStatus status;
43   PetscBool    rebuild_completion;
44   PetscReal    ccfl;
45 } TS_IRK;
46 
47 /*@C
48   TSIRKTableauCreate - create the tableau for `TSIRK` and provide the entries
49 
50   Not Collective
51 
52   Input Parameters:
53 + ts           - timestepping context
54 . nstages      - number of stages, this is the dimension of the matrices below
55 . A            - stage coefficients (dimension nstages*nstages, row-major)
56 . b            - step completion table (dimension nstages)
57 . c            - abscissa (dimension nstages)
58 . binterp      - coefficients of the interpolation formula (dimension nstages)
59 . A_inv        - inverse of A (dimension nstages*nstages, row-major)
60 . A_inv_rowsum - row sum of the inverse of A (dimension nstages)
61 - I_s          - identity matrix (dimension nstages*nstages)
62 
63   Level: advanced
64 
65 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`
66 @*/
67 PetscErrorCode TSIRKTableauCreate(TS ts, PetscInt nstages, const PetscReal *A, const PetscReal *b, const PetscReal *c, const PetscReal *binterp, const PetscScalar *A_inv, const PetscScalar *A_inv_rowsum, const PetscScalar *I_s)
68 {
69   TS_IRK    *irk = (TS_IRK *)ts->data;
70   IRKTableau tab = irk->tableau;
71 
72   PetscFunctionBegin;
73   irk->order = nstages;
74   PetscCall(PetscMalloc3(PetscSqr(nstages), &tab->A, PetscSqr(nstages), &tab->A_inv, PetscSqr(nstages), &tab->I_s));
75   PetscCall(PetscMalloc4(nstages, &tab->b, nstages, &tab->c, nstages, &tab->binterp, nstages, &tab->A_inv_rowsum));
76   PetscCall(PetscArraycpy(tab->A, A, PetscSqr(nstages)));
77   PetscCall(PetscArraycpy(tab->b, b, nstages));
78   PetscCall(PetscArraycpy(tab->c, c, nstages));
79   /* optional coefficient arrays */
80   if (binterp) PetscCall(PetscArraycpy(tab->binterp, binterp, nstages));
81   if (A_inv) PetscCall(PetscArraycpy(tab->A_inv, A_inv, PetscSqr(nstages)));
82   if (A_inv_rowsum) PetscCall(PetscArraycpy(tab->A_inv_rowsum, A_inv_rowsum, nstages));
83   if (I_s) PetscCall(PetscArraycpy(tab->I_s, I_s, PetscSqr(nstages)));
84   PetscFunctionReturn(PETSC_SUCCESS);
85 }
86 
87 /* Arrays should be freed with PetscFree3(A,b,c) */
88 static PetscErrorCode TSIRKCreate_Gauss(TS ts)
89 {
90   PetscInt     nstages;
91   PetscReal   *gauss_A_real, *gauss_b, *b, *gauss_c;
92   PetscScalar *gauss_A, *gauss_A_inv, *gauss_A_inv_rowsum, *I_s;
93   PetscScalar *G0, *G1;
94   PetscInt     i, j;
95   Mat          G0mat, G1mat, Amat;
96 
97   PetscFunctionBegin;
98   PetscCall(TSIRKGetNumStages(ts, &nstages));
99   PetscCall(PetscMalloc3(PetscSqr(nstages), &gauss_A_real, nstages, &gauss_b, nstages, &gauss_c));
100   PetscCall(PetscMalloc4(PetscSqr(nstages), &gauss_A, PetscSqr(nstages), &gauss_A_inv, nstages, &gauss_A_inv_rowsum, PetscSqr(nstages), &I_s));
101   PetscCall(PetscMalloc3(nstages, &b, PetscSqr(nstages), &G0, PetscSqr(nstages), &G1));
102   PetscCall(PetscDTGaussQuadrature(nstages, 0., 1., gauss_c, b));
103   for (i = 0; i < nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */
104 
105   /* A^T = G0^{-1} G1 */
106   for (i = 0; i < nstages; i++) {
107     for (j = 0; j < nstages; j++) {
108       G0[i * nstages + j] = PetscPowRealInt(gauss_c[i], j);
109       G1[i * nstages + j] = PetscPowRealInt(gauss_c[i], j + 1) / (j + 1);
110     }
111   }
112   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
113   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G0, &G0mat));
114   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G1, &G1mat));
115   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, gauss_A, &Amat));
116   PetscCall(MatLUFactor(G0mat, NULL, NULL, NULL));
117   PetscCall(MatMatSolve(G0mat, G1mat, Amat));
118   PetscCall(MatTranspose(Amat, MAT_INPLACE_MATRIX, &Amat));
119   for (i = 0; i < nstages; i++)
120     for (j = 0; j < nstages; j++) gauss_A_real[i * nstages + j] = PetscRealPart(gauss_A[i * nstages + j]);
121 
122   PetscCall(MatDestroy(&G0mat));
123   PetscCall(MatDestroy(&G1mat));
124   PetscCall(MatDestroy(&Amat));
125   PetscCall(PetscFree3(b, G0, G1));
126 
127   { /* Invert A */
128     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
129      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
130      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
131     Mat                A_baij;
132     PetscInt           idxm[1] = {0}, idxn[1] = {0};
133     const PetscScalar *A_inv;
134 
135     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF, nstages, nstages, nstages, 1, NULL, &A_baij));
136     PetscCall(MatSetOption(A_baij, MAT_ROW_ORIENTED, PETSC_FALSE));
137     PetscCall(MatSetValuesBlocked(A_baij, 1, idxm, 1, idxn, gauss_A, INSERT_VALUES));
138     PetscCall(MatAssemblyBegin(A_baij, MAT_FINAL_ASSEMBLY));
139     PetscCall(MatAssemblyEnd(A_baij, MAT_FINAL_ASSEMBLY));
140     PetscCall(MatInvertBlockDiagonal(A_baij, &A_inv));
141     PetscCall(PetscMemcpy(gauss_A_inv, A_inv, nstages * nstages * sizeof(PetscScalar)));
142     PetscCall(MatDestroy(&A_baij));
143   }
144 
145   /* Compute row sums A_inv_rowsum and identity I_s */
146   for (i = 0; i < nstages; i++) {
147     gauss_A_inv_rowsum[i] = 0;
148     for (j = 0; j < nstages; j++) {
149       gauss_A_inv_rowsum[i] += gauss_A_inv[i + nstages * j];
150       I_s[i + nstages * j] = 1. * (i == j);
151     }
152   }
153   PetscCall(TSIRKTableauCreate(ts, nstages, gauss_A_real, gauss_b, gauss_c, NULL, gauss_A_inv, gauss_A_inv_rowsum, I_s));
154   PetscCall(PetscFree3(gauss_A_real, gauss_b, gauss_c));
155   PetscCall(PetscFree4(gauss_A, gauss_A_inv, gauss_A_inv_rowsum, I_s));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 /*@C
160   TSIRKRegister -  adds a `TSIRK` implementation
161 
162   Not Collective, No Fortran Support
163 
164   Input Parameters:
165 + sname    - name of user-defined IRK scheme
166 - function - function to create method context
167 
168   Level: advanced
169 
170   Note:
171   `TSIRKRegister()` may be called multiple times to add several user-defined families.
172 
173   Example Usage:
174 .vb
175    TSIRKRegister("my_scheme", MySchemeCreate);
176 .ve
177 
178   Then, your scheme can be chosen with the procedural interface via
179 $     TSIRKSetType(ts, "my_scheme")
180   or at runtime via the option
181 $     -ts_irk_type my_scheme
182 
183 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterAll()`
184 @*/
185 PetscErrorCode TSIRKRegister(const char sname[], PetscErrorCode (*function)(TS))
186 {
187   PetscFunctionBegin;
188   PetscCall(TSIRKInitializePackage());
189   PetscCall(PetscFunctionListAdd(&TSIRKList, sname, function));
190   PetscFunctionReturn(PETSC_SUCCESS);
191 }
192 
193 /*@C
194   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in `TSIRK`
195 
196   Not Collective, but should be called by all processes which will need the schemes to be registered
197 
198   Level: advanced
199 
200 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterDestroy()`
201 @*/
202 PetscErrorCode TSIRKRegisterAll(void)
203 {
204   PetscFunctionBegin;
205   if (TSIRKRegisterAllCalled) PetscFunctionReturn(PETSC_SUCCESS);
206   TSIRKRegisterAllCalled = PETSC_TRUE;
207 
208   PetscCall(TSIRKRegister(TSIRKGAUSS, TSIRKCreate_Gauss));
209   PetscFunctionReturn(PETSC_SUCCESS);
210 }
211 
212 /*@C
213   TSIRKRegisterDestroy - Frees the list of schemes that were registered by `TSIRKRegister()`.
214 
215   Not Collective
216 
217   Level: advanced
218 
219 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`, `TSIRKRegisterAll()`
220 @*/
221 PetscErrorCode TSIRKRegisterDestroy(void)
222 {
223   PetscFunctionBegin;
224   TSIRKRegisterAllCalled = PETSC_FALSE;
225   PetscFunctionReturn(PETSC_SUCCESS);
226 }
227 
228 /*@C
229   TSIRKInitializePackage - This function initializes everything in the `TSIRK` package. It is called
230   from `TSInitializePackage()`.
231 
232   Level: developer
233 
234 .seealso: [](ch_ts), `TSIRK`, `PetscInitialize()`, `TSIRKFinalizePackage()`, `TSInitializePackage()`
235 @*/
236 PetscErrorCode TSIRKInitializePackage(void)
237 {
238   PetscFunctionBegin;
239   if (TSIRKPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
240   TSIRKPackageInitialized = PETSC_TRUE;
241   PetscCall(TSIRKRegisterAll());
242   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
243   PetscFunctionReturn(PETSC_SUCCESS);
244 }
245 
246 /*@C
247   TSIRKFinalizePackage - This function destroys everything in the `TSIRK` package. It is
248   called from `PetscFinalize()`.
249 
250   Level: developer
251 
252 .seealso: [](ch_ts), `TSIRK`, `PetscFinalize()`, `TSInitializePackage()`
253 @*/
254 PetscErrorCode TSIRKFinalizePackage(void)
255 {
256   PetscFunctionBegin;
257   PetscCall(PetscFunctionListDestroy(&TSIRKList));
258   TSIRKPackageInitialized = PETSC_FALSE;
259   PetscFunctionReturn(PETSC_SUCCESS);
260 }
261 
262 /*
263  This function can be called before or after ts->vec_sol has been updated.
264 */
265 static PetscErrorCode TSEvaluateStep_IRK(TS ts, PetscInt order, Vec U, PetscBool *done)
266 {
267   TS_IRK      *irk   = (TS_IRK *)ts->data;
268   IRKTableau   tab   = irk->tableau;
269   Vec         *YdotI = irk->YdotI;
270   PetscScalar *w     = irk->work;
271   PetscReal    h;
272   PetscInt     j;
273 
274   PetscFunctionBegin;
275   switch (irk->status) {
276   case TS_STEP_INCOMPLETE:
277   case TS_STEP_PENDING:
278     h = ts->time_step;
279     break;
280   case TS_STEP_COMPLETE:
281     h = ts->ptime - ts->ptime_prev;
282     break;
283   default:
284     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
285   }
286 
287   PetscCall(VecCopy(ts->vec_sol, U));
288   for (j = 0; j < irk->nstages; j++) w[j] = h * tab->b[j];
289   PetscCall(VecMAXPY(U, irk->nstages, w, YdotI));
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 static PetscErrorCode TSRollBack_IRK(TS ts)
294 {
295   TS_IRK *irk = (TS_IRK *)ts->data;
296 
297   PetscFunctionBegin;
298   PetscCall(VecCopy(irk->U0, ts->vec_sol));
299   PetscFunctionReturn(PETSC_SUCCESS);
300 }
301 
302 static PetscErrorCode TSStep_IRK(TS ts)
303 {
304   TS_IRK        *irk   = (TS_IRK *)ts->data;
305   IRKTableau     tab   = irk->tableau;
306   PetscScalar   *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
307   const PetscInt nstages = irk->nstages;
308   SNES           snes;
309   PetscInt       i, j, its, lits, bs;
310   TSAdapt        adapt;
311   PetscInt       rejections     = 0;
312   PetscBool      accept         = PETSC_TRUE;
313   PetscReal      next_time_step = ts->time_step;
314 
315   PetscFunctionBegin;
316   if (!ts->steprollback) PetscCall(VecCopy(ts->vec_sol, irk->U0));
317   PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
318   for (i = 0; i < nstages; i++) PetscCall(VecStrideScatter(ts->vec_sol, i * bs, irk->Z, INSERT_VALUES));
319 
320   irk->status = TS_STEP_INCOMPLETE;
321   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
322     PetscCall(VecCopy(ts->vec_sol, irk->U));
323     PetscCall(TSGetSNES(ts, &snes));
324     PetscCall(SNESSolve(snes, NULL, irk->Z));
325     PetscCall(SNESGetIterationNumber(snes, &its));
326     PetscCall(SNESGetLinearSolveIterations(snes, &lits));
327     ts->snes_its += its;
328     ts->ksp_its += lits;
329     PetscCall(VecStrideGatherAll(irk->Z, irk->Y, INSERT_VALUES));
330     for (i = 0; i < nstages; i++) {
331       PetscCall(VecZeroEntries(irk->YdotI[i]));
332       for (j = 0; j < nstages; j++) PetscCall(VecAXPY(irk->YdotI[i], A_inv[i + j * nstages] / ts->time_step, irk->Y[j]));
333       PetscCall(VecAXPY(irk->YdotI[i], -A_inv_rowsum[i] / ts->time_step, irk->U));
334     }
335     irk->status = TS_STEP_INCOMPLETE;
336     PetscCall(TSEvaluateStep_IRK(ts, irk->order, ts->vec_sol, NULL));
337     irk->status = TS_STEP_PENDING;
338     PetscCall(TSGetAdapt(ts, &adapt));
339     PetscCall(TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept));
340     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
341     if (!accept) {
342       PetscCall(TSRollBack_IRK(ts));
343       ts->time_step = next_time_step;
344       goto reject_step;
345     }
346 
347     ts->ptime += ts->time_step;
348     ts->time_step = next_time_step;
349     break;
350   reject_step:
351     ts->reject++;
352     accept = PETSC_FALSE;
353     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
354       ts->reason = TS_DIVERGED_STEP_REJECTED;
355       PetscCall(PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections));
356     }
357   }
358   PetscFunctionReturn(PETSC_SUCCESS);
359 }
360 
361 static PetscErrorCode TSInterpolate_IRK(TS ts, PetscReal itime, Vec U)
362 {
363   TS_IRK          *irk     = (TS_IRK *)ts->data;
364   PetscInt         nstages = irk->nstages, pinterp = irk->pinterp, i, j;
365   PetscReal        h;
366   PetscReal        tt, t;
367   PetscScalar     *bt;
368   const PetscReal *B = irk->tableau->binterp;
369 
370   PetscFunctionBegin;
371   PetscCheck(B, PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not have an interpolation formula", irk->method_name);
372   switch (irk->status) {
373   case TS_STEP_INCOMPLETE:
374   case TS_STEP_PENDING:
375     h = ts->time_step;
376     t = (itime - ts->ptime) / h;
377     break;
378   case TS_STEP_COMPLETE:
379     h = ts->ptime - ts->ptime_prev;
380     t = (itime - ts->ptime) / h + 1; /* In the interval [0,1] */
381     break;
382   default:
383     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
384   }
385   PetscCall(PetscMalloc1(nstages, &bt));
386   for (i = 0; i < nstages; i++) bt[i] = 0;
387   for (j = 0, tt = t; j < pinterp; j++, tt *= t) {
388     for (i = 0; i < nstages; i++) bt[i] += h * B[i * pinterp + j] * tt;
389   }
390   PetscCall(VecMAXPY(U, nstages, bt, irk->YdotI));
391   PetscFunctionReturn(PETSC_SUCCESS);
392 }
393 
394 static PetscErrorCode TSIRKTableauReset(TS ts)
395 {
396   TS_IRK    *irk = (TS_IRK *)ts->data;
397   IRKTableau tab = irk->tableau;
398 
399   PetscFunctionBegin;
400   if (!tab) PetscFunctionReturn(PETSC_SUCCESS);
401   PetscCall(PetscFree3(tab->A, tab->A_inv, tab->I_s));
402   PetscCall(PetscFree4(tab->b, tab->c, tab->binterp, tab->A_inv_rowsum));
403   PetscFunctionReturn(PETSC_SUCCESS);
404 }
405 
406 static PetscErrorCode TSReset_IRK(TS ts)
407 {
408   TS_IRK *irk = (TS_IRK *)ts->data;
409 
410   PetscFunctionBegin;
411   PetscCall(TSIRKTableauReset(ts));
412   if (irk->tableau) PetscCall(PetscFree(irk->tableau));
413   if (irk->method_name) PetscCall(PetscFree(irk->method_name));
414   if (irk->work) PetscCall(PetscFree(irk->work));
415   PetscCall(VecDestroyVecs(irk->nstages, &irk->Y));
416   PetscCall(VecDestroyVecs(irk->nstages, &irk->YdotI));
417   PetscCall(VecDestroy(&irk->Ydot));
418   PetscCall(VecDestroy(&irk->Z));
419   PetscCall(VecDestroy(&irk->U));
420   PetscCall(VecDestroy(&irk->U0));
421   PetscCall(MatDestroy(&irk->TJ));
422   PetscFunctionReturn(PETSC_SUCCESS);
423 }
424 
425 static PetscErrorCode TSIRKGetVecs(TS ts, DM dm, Vec *U)
426 {
427   TS_IRK *irk = (TS_IRK *)ts->data;
428 
429   PetscFunctionBegin;
430   if (U) {
431     if (dm && dm != ts->dm) {
432       PetscCall(DMGetNamedGlobalVector(dm, "TSIRK_U", U));
433     } else *U = irk->U;
434   }
435   PetscFunctionReturn(PETSC_SUCCESS);
436 }
437 
438 static PetscErrorCode TSIRKRestoreVecs(TS ts, DM dm, Vec *U)
439 {
440   PetscFunctionBegin;
441   if (U) {
442     if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSIRK_U", U));
443   }
444   PetscFunctionReturn(PETSC_SUCCESS);
445 }
446 
447 /*
448   This defines the nonlinear equations that is to be solved with SNES
449     G[e\otimes t + C*dt, Z, Zdot] = 0
450     Zdot = (In \otimes S)*Z - (In \otimes Se) U
451   where S = 1/(dt*A)
452 */
453 static PetscErrorCode SNESTSFormFunction_IRK(SNES snes, Vec ZC, Vec FC, TS ts)
454 {
455   TS_IRK            *irk     = (TS_IRK *)ts->data;
456   IRKTableau         tab     = irk->tableau;
457   const PetscInt     nstages = irk->nstages;
458   const PetscReal   *c       = tab->c;
459   const PetscScalar *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
460   DM                 dm, dmsave;
461   Vec                U, *YdotI = irk->YdotI, Ydot = irk->Ydot, *Y = irk->Y;
462   PetscReal          h = ts->time_step;
463   PetscInt           i, j;
464 
465   PetscFunctionBegin;
466   PetscCall(SNESGetDM(snes, &dm));
467   PetscCall(TSIRKGetVecs(ts, dm, &U));
468   PetscCall(VecStrideGatherAll(ZC, Y, INSERT_VALUES));
469   dmsave = ts->dm;
470   ts->dm = dm;
471   for (i = 0; i < nstages; i++) {
472     PetscCall(VecZeroEntries(Ydot));
473     for (j = 0; j < nstages; j++) PetscCall(VecAXPY(Ydot, A_inv[j * nstages + i] / h, Y[j]));
474     PetscCall(VecAXPY(Ydot, -A_inv_rowsum[i] / h, U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
475     PetscCall(TSComputeIFunction(ts, ts->ptime + ts->time_step * c[i], Y[i], Ydot, YdotI[i], PETSC_FALSE));
476   }
477   PetscCall(VecStrideScatterAll(YdotI, FC, INSERT_VALUES));
478   ts->dm = dmsave;
479   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
480   PetscFunctionReturn(PETSC_SUCCESS);
481 }
482 
483 /*
484    For explicit ODE, the Jacobian is
485      JC = I_n \otimes S - J \otimes I_s
486    For DAE, the Jacobian is
487      JC = M_n \otimes S - J \otimes I_s
488 */
489 static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes, Vec ZC, Mat JC, Mat JCpre, TS ts)
490 {
491   TS_IRK          *irk     = (TS_IRK *)ts->data;
492   IRKTableau       tab     = irk->tableau;
493   const PetscInt   nstages = irk->nstages;
494   const PetscReal *c       = tab->c;
495   DM               dm, dmsave;
496   Vec             *Y = irk->Y, Ydot = irk->Ydot;
497   Mat              J;
498   PetscScalar     *S;
499   PetscInt         i, j, bs;
500 
501   PetscFunctionBegin;
502   PetscCall(SNESGetDM(snes, &dm));
503   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
504   dmsave = ts->dm;
505   ts->dm = dm;
506   PetscCall(VecGetBlockSize(Y[nstages - 1], &bs));
507   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
508     PetscCall(VecStrideGather(ZC, (nstages - 1) * bs, Y[nstages - 1], INSERT_VALUES));
509     PetscCall(MatKAIJGetAIJ(JC, &J));
510     PetscCall(TSComputeIJacobian(ts, ts->ptime + ts->time_step * c[nstages - 1], Y[nstages - 1], Ydot, 0, J, J, PETSC_FALSE));
511     PetscCall(MatKAIJGetS(JC, NULL, NULL, &S));
512     for (i = 0; i < nstages; i++)
513       for (j = 0; j < nstages; j++) S[i + nstages * j] = tab->A_inv[i + nstages * j] / ts->time_step;
514     PetscCall(MatKAIJRestoreS(JC, &S));
515   } else SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not support implicit formula", irk->method_name); /* TODO: need the mass matrix for DAE  */
516   ts->dm = dmsave;
517   PetscFunctionReturn(PETSC_SUCCESS);
518 }
519 
520 static PetscErrorCode DMCoarsenHook_TSIRK(DM fine, DM coarse, void *ctx)
521 {
522   PetscFunctionBegin;
523   PetscFunctionReturn(PETSC_SUCCESS);
524 }
525 
526 static PetscErrorCode DMRestrictHook_TSIRK(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, void *ctx)
527 {
528   TS  ts = (TS)ctx;
529   Vec U, U_c;
530 
531   PetscFunctionBegin;
532   PetscCall(TSIRKGetVecs(ts, fine, &U));
533   PetscCall(TSIRKGetVecs(ts, coarse, &U_c));
534   PetscCall(MatRestrict(restrct, U, U_c));
535   PetscCall(VecPointwiseMult(U_c, rscale, U_c));
536   PetscCall(TSIRKRestoreVecs(ts, fine, &U));
537   PetscCall(TSIRKRestoreVecs(ts, coarse, &U_c));
538   PetscFunctionReturn(PETSC_SUCCESS);
539 }
540 
541 static PetscErrorCode DMSubDomainHook_TSIRK(DM dm, DM subdm, void *ctx)
542 {
543   PetscFunctionBegin;
544   PetscFunctionReturn(PETSC_SUCCESS);
545 }
546 
547 static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, void *ctx)
548 {
549   TS  ts = (TS)ctx;
550   Vec U, U_c;
551 
552   PetscFunctionBegin;
553   PetscCall(TSIRKGetVecs(ts, dm, &U));
554   PetscCall(TSIRKGetVecs(ts, subdm, &U_c));
555 
556   PetscCall(VecScatterBegin(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
557   PetscCall(VecScatterEnd(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
558 
559   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
560   PetscCall(TSIRKRestoreVecs(ts, subdm, &U_c));
561   PetscFunctionReturn(PETSC_SUCCESS);
562 }
563 
564 static PetscErrorCode TSSetUp_IRK(TS ts)
565 {
566   TS_IRK        *irk = (TS_IRK *)ts->data;
567   IRKTableau     tab = irk->tableau;
568   DM             dm;
569   Mat            J;
570   Vec            R;
571   const PetscInt nstages = irk->nstages;
572   PetscInt       vsize, bs;
573 
574   PetscFunctionBegin;
575   if (!irk->work) PetscCall(PetscMalloc1(irk->nstages, &irk->work));
576   if (!irk->Y) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->Y));
577   if (!irk->YdotI) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->YdotI));
578   if (!irk->Ydot) PetscCall(VecDuplicate(ts->vec_sol, &irk->Ydot));
579   if (!irk->U) PetscCall(VecDuplicate(ts->vec_sol, &irk->U));
580   if (!irk->U0) PetscCall(VecDuplicate(ts->vec_sol, &irk->U0));
581   if (!irk->Z) {
582     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol), &irk->Z));
583     PetscCall(VecGetSize(ts->vec_sol, &vsize));
584     PetscCall(VecSetSizes(irk->Z, PETSC_DECIDE, vsize * irk->nstages));
585     PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
586     PetscCall(VecSetBlockSize(irk->Z, irk->nstages * bs));
587     PetscCall(VecSetFromOptions(irk->Z));
588   }
589   PetscCall(TSGetDM(ts, &dm));
590   PetscCall(DMCoarsenHookAdd(dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
591   PetscCall(DMSubDomainHookAdd(dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
592 
593   PetscCall(TSGetSNES(ts, &ts->snes));
594   PetscCall(VecDuplicate(irk->Z, &R));
595   PetscCall(SNESSetFunction(ts->snes, R, SNESTSFormFunction, ts));
596   PetscCall(TSGetIJacobian(ts, &J, NULL, NULL, NULL));
597   if (!irk->TJ) {
598     /* Create the KAIJ matrix for solving the stages */
599     PetscCall(MatCreateKAIJ(J, nstages, nstages, tab->A_inv, tab->I_s, &irk->TJ));
600   }
601   PetscCall(SNESSetJacobian(ts->snes, irk->TJ, irk->TJ, SNESTSFormJacobian, ts));
602   PetscCall(VecDestroy(&R));
603   PetscFunctionReturn(PETSC_SUCCESS);
604 }
605 
606 static PetscErrorCode TSSetFromOptions_IRK(TS ts, PetscOptionItems PetscOptionsObject)
607 {
608   TS_IRK *irk        = (TS_IRK *)ts->data;
609   char    tname[256] = TSIRKGAUSS;
610 
611   PetscFunctionBegin;
612   PetscOptionsHeadBegin(PetscOptionsObject, "IRK ODE solver options");
613   {
614     PetscBool flg1, flg2;
615     PetscCall(PetscOptionsInt("-ts_irk_nstages", "Stages of the IRK method", "TSIRKSetNumStages", irk->nstages, &irk->nstages, &flg1));
616     PetscCall(PetscOptionsFList("-ts_irk_type", "Type of IRK method", "TSIRKSetType", TSIRKList, irk->method_name[0] ? irk->method_name : tname, tname, sizeof(tname), &flg2));
617     if (flg1 || flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
618       PetscCall(TSIRKSetType(ts, tname));
619     }
620   }
621   PetscOptionsHeadEnd();
622   PetscFunctionReturn(PETSC_SUCCESS);
623 }
624 
625 static PetscErrorCode TSView_IRK(TS ts, PetscViewer viewer)
626 {
627   TS_IRK   *irk = (TS_IRK *)ts->data;
628   PetscBool iascii;
629 
630   PetscFunctionBegin;
631   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
632   if (iascii) {
633     IRKTableau tab = irk->tableau;
634     TSIRKType  irktype;
635     char       buf[512];
636 
637     PetscCall(TSIRKGetType(ts, &irktype));
638     PetscCall(PetscViewerASCIIPrintf(viewer, "  IRK type %s\n", irktype));
639     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", irk->nstages, tab->c));
640     PetscCall(PetscViewerASCIIPrintf(viewer, "  Abscissa       c = %s\n", buf));
641     PetscCall(PetscViewerASCIIPrintf(viewer, "Stiffly accurate: %s\n", irk->stiffly_accurate ? "yes" : "no"));
642     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", PetscSqr(irk->nstages), tab->A));
643     PetscCall(PetscViewerASCIIPrintf(viewer, "  A coefficients       A = %s\n", buf));
644   }
645   PetscFunctionReturn(PETSC_SUCCESS);
646 }
647 
648 static PetscErrorCode TSLoad_IRK(TS ts, PetscViewer viewer)
649 {
650   SNES    snes;
651   TSAdapt adapt;
652 
653   PetscFunctionBegin;
654   PetscCall(TSGetAdapt(ts, &adapt));
655   PetscCall(TSAdaptLoad(adapt, viewer));
656   PetscCall(TSGetSNES(ts, &snes));
657   PetscCall(SNESLoad(snes, viewer));
658   /* function and Jacobian context for SNES when used with TS is always ts object */
659   PetscCall(SNESSetFunction(snes, NULL, NULL, ts));
660   PetscCall(SNESSetJacobian(snes, NULL, NULL, NULL, ts));
661   PetscFunctionReturn(PETSC_SUCCESS);
662 }
663 
664 /*@
665   TSIRKSetType - Set the type of `TSIRK` scheme to use
666 
667   Logically Collective
668 
669   Input Parameters:
670 + ts      - timestepping context
671 - irktype - type of `TSIRK` scheme
672 
673   Options Database Key:
674 . -ts_irk_type <gauss> - set irk type
675 
676   Level: intermediate
677 
678 .seealso: [](ch_ts), `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
679 @*/
680 PetscErrorCode TSIRKSetType(TS ts, TSIRKType irktype)
681 {
682   PetscFunctionBegin;
683   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
684   PetscAssertPointer(irktype, 2);
685   PetscTryMethod(ts, "TSIRKSetType_C", (TS, TSIRKType), (ts, irktype));
686   PetscFunctionReturn(PETSC_SUCCESS);
687 }
688 
689 /*@
690   TSIRKGetType - Get the type of `TSIRK` IMEX scheme being used
691 
692   Logically Collective
693 
694   Input Parameter:
695 . ts - timestepping context
696 
697   Output Parameter:
698 . irktype - type of `TSIRK` IMEX scheme
699 
700   Level: intermediate
701 
702 .seealso: [](ch_ts), `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
703 @*/
704 PetscErrorCode TSIRKGetType(TS ts, TSIRKType *irktype)
705 {
706   PetscFunctionBegin;
707   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
708   PetscUseMethod(ts, "TSIRKGetType_C", (TS, TSIRKType *), (ts, irktype));
709   PetscFunctionReturn(PETSC_SUCCESS);
710 }
711 
712 /*@
713   TSIRKSetNumStages - Set the number of stages of `TSIRK` scheme to use
714 
715   Logically Collective
716 
717   Input Parameters:
718 + ts      - timestepping context
719 - nstages - number of stages of `TSIRK` scheme
720 
721   Options Database Key:
722 . -ts_irk_nstages <int> - set number of stages
723 
724   Level: intermediate
725 
726 .seealso: [](ch_ts), `TSIRKGetNumStages()`, `TSIRK`
727 @*/
728 PetscErrorCode TSIRKSetNumStages(TS ts, PetscInt nstages)
729 {
730   PetscFunctionBegin;
731   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
732   PetscTryMethod(ts, "TSIRKSetNumStages_C", (TS, PetscInt), (ts, nstages));
733   PetscFunctionReturn(PETSC_SUCCESS);
734 }
735 
736 /*@
737   TSIRKGetNumStages - Get the number of stages of `TSIRK` scheme
738 
739   Logically Collective
740 
741   Input Parameters:
742 + ts      - timestepping context
743 - nstages - number of stages of `TSIRK` scheme
744 
745   Level: intermediate
746 
747 .seealso: [](ch_ts), `TSIRKSetNumStages()`, `TSIRK`
748 @*/
749 PetscErrorCode TSIRKGetNumStages(TS ts, PetscInt *nstages)
750 {
751   PetscFunctionBegin;
752   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
753   PetscAssertPointer(nstages, 2);
754   PetscTryMethod(ts, "TSIRKGetNumStages_C", (TS, PetscInt *), (ts, nstages));
755   PetscFunctionReturn(PETSC_SUCCESS);
756 }
757 
758 static PetscErrorCode TSIRKGetType_IRK(TS ts, TSIRKType *irktype)
759 {
760   TS_IRK *irk = (TS_IRK *)ts->data;
761 
762   PetscFunctionBegin;
763   *irktype = irk->method_name;
764   PetscFunctionReturn(PETSC_SUCCESS);
765 }
766 
767 static PetscErrorCode TSIRKSetType_IRK(TS ts, TSIRKType irktype)
768 {
769   TS_IRK *irk = (TS_IRK *)ts->data;
770   PetscErrorCode (*irkcreate)(TS);
771 
772   PetscFunctionBegin;
773   if (irk->method_name) {
774     PetscCall(PetscFree(irk->method_name));
775     PetscCall(TSIRKTableauReset(ts));
776   }
777   PetscCall(PetscFunctionListFind(TSIRKList, irktype, &irkcreate));
778   PetscCheck(irkcreate, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown TSIRK type \"%s\" given", irktype);
779   PetscCall((*irkcreate)(ts));
780   PetscCall(PetscStrallocpy(irktype, &irk->method_name));
781   PetscFunctionReturn(PETSC_SUCCESS);
782 }
783 
784 static PetscErrorCode TSIRKSetNumStages_IRK(TS ts, PetscInt nstages)
785 {
786   TS_IRK *irk = (TS_IRK *)ts->data;
787 
788   PetscFunctionBegin;
789   PetscCheck(nstages > 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "input argument, %" PetscInt_FMT ", out of range", nstages);
790   irk->nstages = nstages;
791   PetscFunctionReturn(PETSC_SUCCESS);
792 }
793 
794 static PetscErrorCode TSIRKGetNumStages_IRK(TS ts, PetscInt *nstages)
795 {
796   TS_IRK *irk = (TS_IRK *)ts->data;
797 
798   PetscFunctionBegin;
799   PetscAssertPointer(nstages, 2);
800   *nstages = irk->nstages;
801   PetscFunctionReturn(PETSC_SUCCESS);
802 }
803 
804 static PetscErrorCode TSDestroy_IRK(TS ts)
805 {
806   PetscFunctionBegin;
807   PetscCall(TSReset_IRK(ts));
808   if (ts->dm) {
809     PetscCall(DMCoarsenHookRemove(ts->dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
810     PetscCall(DMSubDomainHookRemove(ts->dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
811   }
812   PetscCall(PetscFree(ts->data));
813   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", NULL));
814   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", NULL));
815   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", NULL));
816   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", NULL));
817   PetscFunctionReturn(PETSC_SUCCESS);
818 }
819 
820 /*MC
821       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes
822 
823   Level: beginner
824 
825   Notes:
826   `TSIRK` uses the sparse Kronecker product matrix implementation of `MATKAIJ` to achieve good arithmetic intensity.
827 
828   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s
829   when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with
830   -ts_irk_nstages or `TSIRKSetNumStages()`.
831 
832 .seealso: [](ch_ts), `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`, `TSType`
833 M*/
834 PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
835 {
836   TS_IRK *irk;
837 
838   PetscFunctionBegin;
839   PetscCall(TSIRKInitializePackage());
840 
841   ts->ops->reset          = TSReset_IRK;
842   ts->ops->destroy        = TSDestroy_IRK;
843   ts->ops->view           = TSView_IRK;
844   ts->ops->load           = TSLoad_IRK;
845   ts->ops->setup          = TSSetUp_IRK;
846   ts->ops->step           = TSStep_IRK;
847   ts->ops->interpolate    = TSInterpolate_IRK;
848   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
849   ts->ops->rollback       = TSRollBack_IRK;
850   ts->ops->setfromoptions = TSSetFromOptions_IRK;
851   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
852   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;
853 
854   ts->usessnes = PETSC_TRUE;
855 
856   PetscCall(PetscNew(&irk));
857   ts->data = (void *)irk;
858 
859   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", TSIRKSetType_IRK));
860   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", TSIRKGetType_IRK));
861   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", TSIRKSetNumStages_IRK));
862   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", TSIRKGetNumStages_IRK));
863   /* 3-stage IRK_Gauss is the default */
864   PetscCall(PetscNew(&irk->tableau));
865   irk->nstages = 3;
866   PetscCall(TSIRKSetType(ts, TSIRKDefault));
867   PetscFunctionReturn(PETSC_SUCCESS);
868 }
869