xref: /petsc/src/ts/impls/implicit/irk/irk.c (revision 174dc0c8cee294b82b85e4dd3b331b29396264fc)
1 /*
2   Code for timestepping with implicit Runge-Kutta method
3 
4   Notes:
5   The general system is written as
6 
7   F(t,U,Udot) = 0
8 
9 */
10 #include <petsc/private/tsimpl.h> /*I   "petscts.h"   I*/
11 #include <petscdm.h>
12 #include <petscdt.h>
13 
14 static TSIRKType         TSIRKDefault = TSIRKGAUSS;
15 static PetscBool         TSIRKRegisterAllCalled;
16 static PetscBool         TSIRKPackageInitialized;
17 static PetscFunctionList TSIRKList;
18 
19 struct _IRKTableau {
20   PetscReal   *A, *b, *c;
21   PetscScalar *A_inv, *A_inv_rowsum, *I_s;
22   PetscReal   *binterp; /* Dense output formula */
23 };
24 
25 typedef struct _IRKTableau *IRKTableau;
26 
27 typedef struct {
28   char        *method_name;
29   PetscInt     order;   /* Classical approximation order of the method */
30   PetscInt     nstages; /* Number of stages */
31   PetscBool    stiffly_accurate;
32   PetscInt     pinterp; /* Interpolation order */
33   IRKTableau   tableau;
34   Vec          U0;    /* Backup vector */
35   Vec          Z;     /* Combined stage vector */
36   Vec         *Y;     /* States computed during the step */
37   Vec          Ydot;  /* Work vector holding time derivatives during residual evaluation */
38   Vec          U;     /* U is used to compute Ydot = shift(Y-U) */
39   Vec         *YdotI; /* Work vectors to hold the residual evaluation */
40   Mat          TJ;    /* KAIJ matrix for the Jacobian of the combined system */
41   PetscScalar *work;  /* Scalar work */
42   TSStepStatus status;
43   PetscBool    rebuild_completion;
44   PetscReal    ccfl;
45 } TS_IRK;
46 
47 /*@C
48   TSIRKTableauCreate - create the tableau for `TSIRK` and provide the entries
49 
50   Not Collective
51 
52   Input Parameters:
53 + ts           - timestepping context
54 . nstages      - number of stages, this is the dimension of the matrices below
55 . A            - stage coefficients (dimension nstages*nstages, row-major)
56 . b            - step completion table (dimension nstages)
57 . c            - abscissa (dimension nstages)
58 . binterp      - coefficients of the interpolation formula (dimension nstages)
59 . A_inv        - inverse of A (dimension nstages*nstages, row-major)
60 . A_inv_rowsum - row sum of the inverse of A (dimension nstages)
61 - I_s          - identity matrix (dimension nstages*nstages)
62 
63   Level: advanced
64 
65 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`
66 @*/
67 PetscErrorCode TSIRKTableauCreate(TS ts, PetscInt nstages, const PetscReal *A, const PetscReal *b, const PetscReal *c, const PetscReal *binterp, const PetscScalar *A_inv, const PetscScalar *A_inv_rowsum, const PetscScalar *I_s)
68 {
69   TS_IRK    *irk = (TS_IRK *)ts->data;
70   IRKTableau tab = irk->tableau;
71 
72   PetscFunctionBegin;
73   irk->order = nstages;
74   PetscCall(PetscMalloc3(PetscSqr(nstages), &tab->A, PetscSqr(nstages), &tab->A_inv, PetscSqr(nstages), &tab->I_s));
75   PetscCall(PetscMalloc4(nstages, &tab->b, nstages, &tab->c, nstages, &tab->binterp, nstages, &tab->A_inv_rowsum));
76   PetscCall(PetscArraycpy(tab->A, A, PetscSqr(nstages)));
77   PetscCall(PetscArraycpy(tab->b, b, nstages));
78   PetscCall(PetscArraycpy(tab->c, c, nstages));
79   /* optional coefficient arrays */
80   if (binterp) PetscCall(PetscArraycpy(tab->binterp, binterp, nstages));
81   if (A_inv) PetscCall(PetscArraycpy(tab->A_inv, A_inv, PetscSqr(nstages)));
82   if (A_inv_rowsum) PetscCall(PetscArraycpy(tab->A_inv_rowsum, A_inv_rowsum, nstages));
83   if (I_s) PetscCall(PetscArraycpy(tab->I_s, I_s, PetscSqr(nstages)));
84   PetscFunctionReturn(PETSC_SUCCESS);
85 }
86 
87 /* Arrays should be freed with PetscFree3(A,b,c) */
88 static PetscErrorCode TSIRKCreate_Gauss(TS ts)
89 {
90   PetscInt     nstages;
91   PetscReal   *gauss_A_real, *gauss_b, *b, *gauss_c;
92   PetscScalar *gauss_A, *gauss_A_inv, *gauss_A_inv_rowsum, *I_s;
93   PetscScalar *G0, *G1;
94   PetscInt     i, j;
95   Mat          G0mat, G1mat, Amat;
96 
97   PetscFunctionBegin;
98   PetscCall(TSIRKGetNumStages(ts, &nstages));
99   PetscCall(PetscMalloc3(PetscSqr(nstages), &gauss_A_real, nstages, &gauss_b, nstages, &gauss_c));
100   PetscCall(PetscMalloc4(PetscSqr(nstages), &gauss_A, PetscSqr(nstages), &gauss_A_inv, nstages, &gauss_A_inv_rowsum, PetscSqr(nstages), &I_s));
101   PetscCall(PetscMalloc3(nstages, &b, PetscSqr(nstages), &G0, PetscSqr(nstages), &G1));
102   PetscCall(PetscDTGaussQuadrature(nstages, 0., 1., gauss_c, b));
103   for (i = 0; i < nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */
104 
105   /* A^T = G0^{-1} G1 */
106   for (i = 0; i < nstages; i++) {
107     for (j = 0; j < nstages; j++) {
108       G0[i * nstages + j] = PetscPowRealInt(gauss_c[i], j);
109       G1[i * nstages + j] = PetscPowRealInt(gauss_c[i], j + 1) / (j + 1);
110     }
111   }
112   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
113   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G0, &G0mat));
114   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G1, &G1mat));
115   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, gauss_A, &Amat));
116   PetscCall(MatLUFactor(G0mat, NULL, NULL, NULL));
117   PetscCall(MatMatSolve(G0mat, G1mat, Amat));
118   PetscCall(MatTranspose(Amat, MAT_INPLACE_MATRIX, &Amat));
119   for (i = 0; i < nstages; i++)
120     for (j = 0; j < nstages; j++) gauss_A_real[i * nstages + j] = PetscRealPart(gauss_A[i * nstages + j]);
121 
122   PetscCall(MatDestroy(&G0mat));
123   PetscCall(MatDestroy(&G1mat));
124   PetscCall(MatDestroy(&Amat));
125   PetscCall(PetscFree3(b, G0, G1));
126 
127   { /* Invert A */
128     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
129      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
130      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
131     Mat                A_baij;
132     PetscInt           idxm[1] = {0}, idxn[1] = {0};
133     const PetscScalar *A_inv;
134 
135     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF, nstages, nstages, nstages, 1, NULL, &A_baij));
136     PetscCall(MatSetOption(A_baij, MAT_ROW_ORIENTED, PETSC_FALSE));
137     PetscCall(MatSetValuesBlocked(A_baij, 1, idxm, 1, idxn, gauss_A, INSERT_VALUES));
138     PetscCall(MatAssemblyBegin(A_baij, MAT_FINAL_ASSEMBLY));
139     PetscCall(MatAssemblyEnd(A_baij, MAT_FINAL_ASSEMBLY));
140     PetscCall(MatInvertBlockDiagonal(A_baij, &A_inv));
141     PetscCall(PetscMemcpy(gauss_A_inv, A_inv, nstages * nstages * sizeof(PetscScalar)));
142     PetscCall(MatDestroy(&A_baij));
143   }
144 
145   /* Compute row sums A_inv_rowsum and identity I_s */
146   for (i = 0; i < nstages; i++) {
147     gauss_A_inv_rowsum[i] = 0;
148     for (j = 0; j < nstages; j++) {
149       gauss_A_inv_rowsum[i] += gauss_A_inv[i + nstages * j];
150       I_s[i + nstages * j] = 1. * (i == j);
151     }
152   }
153   PetscCall(TSIRKTableauCreate(ts, nstages, gauss_A_real, gauss_b, gauss_c, NULL, gauss_A_inv, gauss_A_inv_rowsum, I_s));
154   PetscCall(PetscFree3(gauss_A_real, gauss_b, gauss_c));
155   PetscCall(PetscFree4(gauss_A, gauss_A_inv, gauss_A_inv_rowsum, I_s));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 /*@C
160   TSIRKRegister -  adds a `TSIRK` implementation
161 
162   Not Collective, No Fortran Support
163 
164   Input Parameters:
165 + sname    - name of user-defined IRK scheme
166 - function - function to create method context
167 
168   Level: advanced
169 
170   Note:
171   `TSIRKRegister()` may be called multiple times to add several user-defined families.
172 
173   Example Usage:
174 .vb
175    TSIRKRegister("my_scheme", MySchemeCreate);
176 .ve
177 
178   Then, your scheme can be chosen with the procedural interface via
179 .vb
180   TSIRKSetType(ts, "my_scheme")
181 .ve
182   or at runtime via the option
183 .vb
184   -ts_irk_type my_scheme
185 .ve
186 
187 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterAll()`
188 @*/
189 PetscErrorCode TSIRKRegister(const char sname[], PetscErrorCode (*function)(TS))
190 {
191   PetscFunctionBegin;
192   PetscCall(TSIRKInitializePackage());
193   PetscCall(PetscFunctionListAdd(&TSIRKList, sname, function));
194   PetscFunctionReturn(PETSC_SUCCESS);
195 }
196 
197 /*@C
198   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in `TSIRK`
199 
200   Not Collective, but should be called by all processes which will need the schemes to be registered
201 
202   Level: advanced
203 
204 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterDestroy()`
205 @*/
206 PetscErrorCode TSIRKRegisterAll(void)
207 {
208   PetscFunctionBegin;
209   if (TSIRKRegisterAllCalled) PetscFunctionReturn(PETSC_SUCCESS);
210   TSIRKRegisterAllCalled = PETSC_TRUE;
211 
212   PetscCall(TSIRKRegister(TSIRKGAUSS, TSIRKCreate_Gauss));
213   PetscFunctionReturn(PETSC_SUCCESS);
214 }
215 
216 /*@C
217   TSIRKRegisterDestroy - Frees the list of schemes that were registered by `TSIRKRegister()`.
218 
219   Not Collective
220 
221   Level: advanced
222 
223 .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`, `TSIRKRegisterAll()`
224 @*/
225 PetscErrorCode TSIRKRegisterDestroy(void)
226 {
227   PetscFunctionBegin;
228   TSIRKRegisterAllCalled = PETSC_FALSE;
229   PetscFunctionReturn(PETSC_SUCCESS);
230 }
231 
232 /*@C
233   TSIRKInitializePackage - This function initializes everything in the `TSIRK` package. It is called
234   from `TSInitializePackage()`.
235 
236   Level: developer
237 
238 .seealso: [](ch_ts), `TSIRK`, `PetscInitialize()`, `TSIRKFinalizePackage()`, `TSInitializePackage()`
239 @*/
240 PetscErrorCode TSIRKInitializePackage(void)
241 {
242   PetscFunctionBegin;
243   if (TSIRKPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
244   TSIRKPackageInitialized = PETSC_TRUE;
245   PetscCall(TSIRKRegisterAll());
246   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
247   PetscFunctionReturn(PETSC_SUCCESS);
248 }
249 
250 /*@C
251   TSIRKFinalizePackage - This function destroys everything in the `TSIRK` package. It is
252   called from `PetscFinalize()`.
253 
254   Level: developer
255 
256 .seealso: [](ch_ts), `TSIRK`, `PetscFinalize()`, `TSInitializePackage()`
257 @*/
258 PetscErrorCode TSIRKFinalizePackage(void)
259 {
260   PetscFunctionBegin;
261   PetscCall(PetscFunctionListDestroy(&TSIRKList));
262   TSIRKPackageInitialized = PETSC_FALSE;
263   PetscFunctionReturn(PETSC_SUCCESS);
264 }
265 
266 /*
267  This function can be called before or after ts->vec_sol has been updated.
268 */
269 static PetscErrorCode TSEvaluateStep_IRK(TS ts, PetscInt order, Vec U, PetscBool *done)
270 {
271   TS_IRK      *irk   = (TS_IRK *)ts->data;
272   IRKTableau   tab   = irk->tableau;
273   Vec         *YdotI = irk->YdotI;
274   PetscScalar *w     = irk->work;
275   PetscReal    h;
276   PetscInt     j;
277 
278   PetscFunctionBegin;
279   switch (irk->status) {
280   case TS_STEP_INCOMPLETE:
281   case TS_STEP_PENDING:
282     h = ts->time_step;
283     break;
284   case TS_STEP_COMPLETE:
285     h = ts->ptime - ts->ptime_prev;
286     break;
287   default:
288     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
289   }
290 
291   PetscCall(VecCopy(ts->vec_sol, U));
292   for (j = 0; j < irk->nstages; j++) w[j] = h * tab->b[j];
293   PetscCall(VecMAXPY(U, irk->nstages, w, YdotI));
294   PetscFunctionReturn(PETSC_SUCCESS);
295 }
296 
297 static PetscErrorCode TSRollBack_IRK(TS ts)
298 {
299   TS_IRK *irk = (TS_IRK *)ts->data;
300 
301   PetscFunctionBegin;
302   PetscCall(VecCopy(irk->U0, ts->vec_sol));
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 static PetscErrorCode TSStep_IRK(TS ts)
307 {
308   TS_IRK        *irk   = (TS_IRK *)ts->data;
309   IRKTableau     tab   = irk->tableau;
310   PetscScalar   *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
311   const PetscInt nstages = irk->nstages;
312   SNES           snes;
313   PetscInt       i, j, its, lits, bs;
314   TSAdapt        adapt;
315   PetscInt       rejections     = 0;
316   PetscBool      accept         = PETSC_TRUE;
317   PetscReal      next_time_step = ts->time_step;
318 
319   PetscFunctionBegin;
320   if (!ts->steprollback) PetscCall(VecCopy(ts->vec_sol, irk->U0));
321   PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
322   for (i = 0; i < nstages; i++) PetscCall(VecStrideScatter(ts->vec_sol, i * bs, irk->Z, INSERT_VALUES));
323 
324   irk->status = TS_STEP_INCOMPLETE;
325   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
326     PetscCall(VecCopy(ts->vec_sol, irk->U));
327     PetscCall(TSGetSNES(ts, &snes));
328     PetscCall(SNESSolve(snes, NULL, irk->Z));
329     PetscCall(SNESGetIterationNumber(snes, &its));
330     PetscCall(SNESGetLinearSolveIterations(snes, &lits));
331     ts->snes_its += its;
332     ts->ksp_its += lits;
333     PetscCall(VecStrideGatherAll(irk->Z, irk->Y, INSERT_VALUES));
334     for (i = 0; i < nstages; i++) {
335       PetscCall(VecZeroEntries(irk->YdotI[i]));
336       for (j = 0; j < nstages; j++) PetscCall(VecAXPY(irk->YdotI[i], A_inv[i + j * nstages] / ts->time_step, irk->Y[j]));
337       PetscCall(VecAXPY(irk->YdotI[i], -A_inv_rowsum[i] / ts->time_step, irk->U));
338     }
339     irk->status = TS_STEP_INCOMPLETE;
340     PetscCall(TSEvaluateStep_IRK(ts, irk->order, ts->vec_sol, NULL));
341     irk->status = TS_STEP_PENDING;
342     PetscCall(TSGetAdapt(ts, &adapt));
343     PetscCall(TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept));
344     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
345     if (!accept) {
346       PetscCall(TSRollBack_IRK(ts));
347       ts->time_step = next_time_step;
348       goto reject_step;
349     }
350 
351     ts->ptime += ts->time_step;
352     ts->time_step = next_time_step;
353     break;
354   reject_step:
355     ts->reject++;
356     accept = PETSC_FALSE;
357     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
358       ts->reason = TS_DIVERGED_STEP_REJECTED;
359       PetscCall(PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections));
360     }
361   }
362   PetscFunctionReturn(PETSC_SUCCESS);
363 }
364 
365 static PetscErrorCode TSInterpolate_IRK(TS ts, PetscReal itime, Vec U)
366 {
367   TS_IRK          *irk     = (TS_IRK *)ts->data;
368   PetscInt         nstages = irk->nstages, pinterp = irk->pinterp, i, j;
369   PetscReal        h;
370   PetscReal        tt, t;
371   PetscScalar     *bt;
372   const PetscReal *B = irk->tableau->binterp;
373 
374   PetscFunctionBegin;
375   PetscCheck(B, PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not have an interpolation formula", irk->method_name);
376   switch (irk->status) {
377   case TS_STEP_INCOMPLETE:
378   case TS_STEP_PENDING:
379     h = ts->time_step;
380     t = (itime - ts->ptime) / h;
381     break;
382   case TS_STEP_COMPLETE:
383     h = ts->ptime - ts->ptime_prev;
384     t = (itime - ts->ptime) / h + 1; /* In the interval [0,1] */
385     break;
386   default:
387     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
388   }
389   PetscCall(PetscMalloc1(nstages, &bt));
390   for (i = 0; i < nstages; i++) bt[i] = 0;
391   for (j = 0, tt = t; j < pinterp; j++, tt *= t) {
392     for (i = 0; i < nstages; i++) bt[i] += h * B[i * pinterp + j] * tt;
393   }
394   PetscCall(VecMAXPY(U, nstages, bt, irk->YdotI));
395   PetscFunctionReturn(PETSC_SUCCESS);
396 }
397 
398 static PetscErrorCode TSIRKTableauReset(TS ts)
399 {
400   TS_IRK    *irk = (TS_IRK *)ts->data;
401   IRKTableau tab = irk->tableau;
402 
403   PetscFunctionBegin;
404   if (!tab) PetscFunctionReturn(PETSC_SUCCESS);
405   PetscCall(PetscFree3(tab->A, tab->A_inv, tab->I_s));
406   PetscCall(PetscFree4(tab->b, tab->c, tab->binterp, tab->A_inv_rowsum));
407   PetscFunctionReturn(PETSC_SUCCESS);
408 }
409 
410 static PetscErrorCode TSReset_IRK(TS ts)
411 {
412   TS_IRK *irk = (TS_IRK *)ts->data;
413 
414   PetscFunctionBegin;
415   PetscCall(TSIRKTableauReset(ts));
416   if (irk->tableau) PetscCall(PetscFree(irk->tableau));
417   if (irk->method_name) PetscCall(PetscFree(irk->method_name));
418   if (irk->work) PetscCall(PetscFree(irk->work));
419   PetscCall(VecDestroyVecs(irk->nstages, &irk->Y));
420   PetscCall(VecDestroyVecs(irk->nstages, &irk->YdotI));
421   PetscCall(VecDestroy(&irk->Ydot));
422   PetscCall(VecDestroy(&irk->Z));
423   PetscCall(VecDestroy(&irk->U));
424   PetscCall(VecDestroy(&irk->U0));
425   PetscCall(MatDestroy(&irk->TJ));
426   PetscFunctionReturn(PETSC_SUCCESS);
427 }
428 
429 static PetscErrorCode TSIRKGetVecs(TS ts, DM dm, Vec *U)
430 {
431   TS_IRK *irk = (TS_IRK *)ts->data;
432 
433   PetscFunctionBegin;
434   if (U) {
435     if (dm && dm != ts->dm) {
436       PetscCall(DMGetNamedGlobalVector(dm, "TSIRK_U", U));
437     } else *U = irk->U;
438   }
439   PetscFunctionReturn(PETSC_SUCCESS);
440 }
441 
442 static PetscErrorCode TSIRKRestoreVecs(TS ts, DM dm, Vec *U)
443 {
444   PetscFunctionBegin;
445   if (U) {
446     if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSIRK_U", U));
447   }
448   PetscFunctionReturn(PETSC_SUCCESS);
449 }
450 
451 /*
452   This defines the nonlinear equations that is to be solved with SNES
453     G[e\otimes t + C*dt, Z, Zdot] = 0
454     Zdot = (In \otimes S)*Z - (In \otimes Se) U
455   where S = 1/(dt*A)
456 */
457 static PetscErrorCode SNESTSFormFunction_IRK(SNES snes, Vec ZC, Vec FC, TS ts)
458 {
459   TS_IRK            *irk     = (TS_IRK *)ts->data;
460   IRKTableau         tab     = irk->tableau;
461   const PetscInt     nstages = irk->nstages;
462   const PetscReal   *c       = tab->c;
463   const PetscScalar *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
464   DM                 dm, dmsave;
465   Vec                U, *YdotI = irk->YdotI, Ydot = irk->Ydot, *Y = irk->Y;
466   PetscReal          h = ts->time_step;
467   PetscInt           i, j;
468 
469   PetscFunctionBegin;
470   PetscCall(SNESGetDM(snes, &dm));
471   PetscCall(TSIRKGetVecs(ts, dm, &U));
472   PetscCall(VecStrideGatherAll(ZC, Y, INSERT_VALUES));
473   dmsave = ts->dm;
474   ts->dm = dm;
475   for (i = 0; i < nstages; i++) {
476     PetscCall(VecZeroEntries(Ydot));
477     for (j = 0; j < nstages; j++) PetscCall(VecAXPY(Ydot, A_inv[j * nstages + i] / h, Y[j]));
478     PetscCall(VecAXPY(Ydot, -A_inv_rowsum[i] / h, U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
479     PetscCall(TSComputeIFunction(ts, ts->ptime + ts->time_step * c[i], Y[i], Ydot, YdotI[i], PETSC_FALSE));
480   }
481   PetscCall(VecStrideScatterAll(YdotI, FC, INSERT_VALUES));
482   ts->dm = dmsave;
483   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
484   PetscFunctionReturn(PETSC_SUCCESS);
485 }
486 
487 /*
488    For explicit ODE, the Jacobian is
489      JC = I_n \otimes S - J \otimes I_s
490    For DAE, the Jacobian is
491      JC = M_n \otimes S - J \otimes I_s
492 */
493 static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes, Vec ZC, Mat JC, Mat JCpre, TS ts)
494 {
495   TS_IRK          *irk     = (TS_IRK *)ts->data;
496   IRKTableau       tab     = irk->tableau;
497   const PetscInt   nstages = irk->nstages;
498   const PetscReal *c       = tab->c;
499   DM               dm, dmsave;
500   Vec             *Y = irk->Y, Ydot = irk->Ydot;
501   Mat              J;
502   PetscScalar     *S;
503   PetscInt         i, j, bs;
504 
505   PetscFunctionBegin;
506   PetscCall(SNESGetDM(snes, &dm));
507   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
508   dmsave = ts->dm;
509   ts->dm = dm;
510   PetscCall(VecGetBlockSize(Y[nstages - 1], &bs));
511   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
512     PetscCall(VecStrideGather(ZC, (nstages - 1) * bs, Y[nstages - 1], INSERT_VALUES));
513     PetscCall(MatKAIJGetAIJ(JC, &J));
514     PetscCall(TSComputeIJacobian(ts, ts->ptime + ts->time_step * c[nstages - 1], Y[nstages - 1], Ydot, 0, J, J, PETSC_FALSE));
515     PetscCall(MatKAIJGetS(JC, NULL, NULL, &S));
516     for (i = 0; i < nstages; i++)
517       for (j = 0; j < nstages; j++) S[i + nstages * j] = tab->A_inv[i + nstages * j] / ts->time_step;
518     PetscCall(MatKAIJRestoreS(JC, &S));
519   } else SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not support implicit formula", irk->method_name); /* TODO: need the mass matrix for DAE  */
520   ts->dm = dmsave;
521   PetscFunctionReturn(PETSC_SUCCESS);
522 }
523 
524 static PetscErrorCode DMCoarsenHook_TSIRK(DM fine, DM coarse, void *ctx)
525 {
526   PetscFunctionBegin;
527   PetscFunctionReturn(PETSC_SUCCESS);
528 }
529 
530 static PetscErrorCode DMRestrictHook_TSIRK(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, void *ctx)
531 {
532   TS  ts = (TS)ctx;
533   Vec U, U_c;
534 
535   PetscFunctionBegin;
536   PetscCall(TSIRKGetVecs(ts, fine, &U));
537   PetscCall(TSIRKGetVecs(ts, coarse, &U_c));
538   PetscCall(MatRestrict(restrct, U, U_c));
539   PetscCall(VecPointwiseMult(U_c, rscale, U_c));
540   PetscCall(TSIRKRestoreVecs(ts, fine, &U));
541   PetscCall(TSIRKRestoreVecs(ts, coarse, &U_c));
542   PetscFunctionReturn(PETSC_SUCCESS);
543 }
544 
545 static PetscErrorCode DMSubDomainHook_TSIRK(DM dm, DM subdm, void *ctx)
546 {
547   PetscFunctionBegin;
548   PetscFunctionReturn(PETSC_SUCCESS);
549 }
550 
551 static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, void *ctx)
552 {
553   TS  ts = (TS)ctx;
554   Vec U, U_c;
555 
556   PetscFunctionBegin;
557   PetscCall(TSIRKGetVecs(ts, dm, &U));
558   PetscCall(TSIRKGetVecs(ts, subdm, &U_c));
559 
560   PetscCall(VecScatterBegin(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
561   PetscCall(VecScatterEnd(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
562 
563   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
564   PetscCall(TSIRKRestoreVecs(ts, subdm, &U_c));
565   PetscFunctionReturn(PETSC_SUCCESS);
566 }
567 
568 static PetscErrorCode TSSetUp_IRK(TS ts)
569 {
570   TS_IRK        *irk = (TS_IRK *)ts->data;
571   IRKTableau     tab = irk->tableau;
572   DM             dm;
573   Mat            J;
574   Vec            R;
575   const PetscInt nstages = irk->nstages;
576   PetscInt       vsize, bs;
577 
578   PetscFunctionBegin;
579   if (!irk->work) PetscCall(PetscMalloc1(irk->nstages, &irk->work));
580   if (!irk->Y) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->Y));
581   if (!irk->YdotI) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->YdotI));
582   if (!irk->Ydot) PetscCall(VecDuplicate(ts->vec_sol, &irk->Ydot));
583   if (!irk->U) PetscCall(VecDuplicate(ts->vec_sol, &irk->U));
584   if (!irk->U0) PetscCall(VecDuplicate(ts->vec_sol, &irk->U0));
585   if (!irk->Z) {
586     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol), &irk->Z));
587     PetscCall(VecGetSize(ts->vec_sol, &vsize));
588     PetscCall(VecSetSizes(irk->Z, PETSC_DECIDE, vsize * irk->nstages));
589     PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
590     PetscCall(VecSetBlockSize(irk->Z, irk->nstages * bs));
591     PetscCall(VecSetFromOptions(irk->Z));
592   }
593   PetscCall(TSGetDM(ts, &dm));
594   PetscCall(DMCoarsenHookAdd(dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
595   PetscCall(DMSubDomainHookAdd(dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
596 
597   PetscCall(TSGetSNES(ts, &ts->snes));
598   PetscCall(VecDuplicate(irk->Z, &R));
599   PetscCall(SNESSetFunction(ts->snes, R, SNESTSFormFunction, ts));
600   PetscCall(TSGetIJacobian(ts, &J, NULL, NULL, NULL));
601   if (!irk->TJ) {
602     /* Create the KAIJ matrix for solving the stages */
603     PetscCall(MatCreateKAIJ(J, nstages, nstages, tab->A_inv, tab->I_s, &irk->TJ));
604   }
605   PetscCall(SNESSetJacobian(ts->snes, irk->TJ, irk->TJ, SNESTSFormJacobian, ts));
606   PetscCall(VecDestroy(&R));
607   PetscFunctionReturn(PETSC_SUCCESS);
608 }
609 
610 static PetscErrorCode TSSetFromOptions_IRK(TS ts, PetscOptionItems PetscOptionsObject)
611 {
612   TS_IRK *irk        = (TS_IRK *)ts->data;
613   char    tname[256] = TSIRKGAUSS;
614 
615   PetscFunctionBegin;
616   PetscOptionsHeadBegin(PetscOptionsObject, "IRK ODE solver options");
617   {
618     PetscBool flg1, flg2;
619     PetscCall(PetscOptionsInt("-ts_irk_nstages", "Stages of the IRK method", "TSIRKSetNumStages", irk->nstages, &irk->nstages, &flg1));
620     PetscCall(PetscOptionsFList("-ts_irk_type", "Type of IRK method", "TSIRKSetType", TSIRKList, irk->method_name[0] ? irk->method_name : tname, tname, sizeof(tname), &flg2));
621     if (flg1 || flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
622       PetscCall(TSIRKSetType(ts, tname));
623     }
624   }
625   PetscOptionsHeadEnd();
626   PetscFunctionReturn(PETSC_SUCCESS);
627 }
628 
629 static PetscErrorCode TSView_IRK(TS ts, PetscViewer viewer)
630 {
631   TS_IRK   *irk = (TS_IRK *)ts->data;
632   PetscBool iascii;
633 
634   PetscFunctionBegin;
635   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
636   if (iascii) {
637     IRKTableau tab = irk->tableau;
638     TSIRKType  irktype;
639     char       buf[512];
640 
641     PetscCall(TSIRKGetType(ts, &irktype));
642     PetscCall(PetscViewerASCIIPrintf(viewer, "  IRK type %s\n", irktype));
643     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", irk->nstages, tab->c));
644     PetscCall(PetscViewerASCIIPrintf(viewer, "  Abscissa       c = %s\n", buf));
645     PetscCall(PetscViewerASCIIPrintf(viewer, "Stiffly accurate: %s\n", irk->stiffly_accurate ? "yes" : "no"));
646     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", PetscSqr(irk->nstages), tab->A));
647     PetscCall(PetscViewerASCIIPrintf(viewer, "  A coefficients       A = %s\n", buf));
648   }
649   PetscFunctionReturn(PETSC_SUCCESS);
650 }
651 
652 static PetscErrorCode TSLoad_IRK(TS ts, PetscViewer viewer)
653 {
654   SNES    snes;
655   TSAdapt adapt;
656 
657   PetscFunctionBegin;
658   PetscCall(TSGetAdapt(ts, &adapt));
659   PetscCall(TSAdaptLoad(adapt, viewer));
660   PetscCall(TSGetSNES(ts, &snes));
661   PetscCall(SNESLoad(snes, viewer));
662   /* function and Jacobian context for SNES when used with TS is always ts object */
663   PetscCall(SNESSetFunction(snes, NULL, NULL, ts));
664   PetscCall(SNESSetJacobian(snes, NULL, NULL, NULL, ts));
665   PetscFunctionReturn(PETSC_SUCCESS);
666 }
667 
668 /*@
669   TSIRKSetType - Set the type of `TSIRK` scheme to use
670 
671   Logically Collective
672 
673   Input Parameters:
674 + ts      - timestepping context
675 - irktype - type of `TSIRK` scheme
676 
677   Options Database Key:
678 . -ts_irk_type <gauss> - set irk type
679 
680   Level: intermediate
681 
682 .seealso: [](ch_ts), `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
683 @*/
684 PetscErrorCode TSIRKSetType(TS ts, TSIRKType irktype)
685 {
686   PetscFunctionBegin;
687   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
688   PetscAssertPointer(irktype, 2);
689   PetscTryMethod(ts, "TSIRKSetType_C", (TS, TSIRKType), (ts, irktype));
690   PetscFunctionReturn(PETSC_SUCCESS);
691 }
692 
693 /*@
694   TSIRKGetType - Get the type of `TSIRK` IMEX scheme being used
695 
696   Logically Collective
697 
698   Input Parameter:
699 . ts - timestepping context
700 
701   Output Parameter:
702 . irktype - type of `TSIRK` IMEX scheme
703 
704   Level: intermediate
705 
706 .seealso: [](ch_ts), `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
707 @*/
708 PetscErrorCode TSIRKGetType(TS ts, TSIRKType *irktype)
709 {
710   PetscFunctionBegin;
711   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
712   PetscUseMethod(ts, "TSIRKGetType_C", (TS, TSIRKType *), (ts, irktype));
713   PetscFunctionReturn(PETSC_SUCCESS);
714 }
715 
716 /*@
717   TSIRKSetNumStages - Set the number of stages of `TSIRK` scheme to use
718 
719   Logically Collective
720 
721   Input Parameters:
722 + ts      - timestepping context
723 - nstages - number of stages of `TSIRK` scheme
724 
725   Options Database Key:
726 . -ts_irk_nstages <int> - set number of stages
727 
728   Level: intermediate
729 
730 .seealso: [](ch_ts), `TSIRKGetNumStages()`, `TSIRK`
731 @*/
732 PetscErrorCode TSIRKSetNumStages(TS ts, PetscInt nstages)
733 {
734   PetscFunctionBegin;
735   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
736   PetscTryMethod(ts, "TSIRKSetNumStages_C", (TS, PetscInt), (ts, nstages));
737   PetscFunctionReturn(PETSC_SUCCESS);
738 }
739 
740 /*@
741   TSIRKGetNumStages - Get the number of stages of `TSIRK` scheme
742 
743   Logically Collective
744 
745   Input Parameters:
746 + ts      - timestepping context
747 - nstages - number of stages of `TSIRK` scheme
748 
749   Level: intermediate
750 
751 .seealso: [](ch_ts), `TSIRKSetNumStages()`, `TSIRK`
752 @*/
753 PetscErrorCode TSIRKGetNumStages(TS ts, PetscInt *nstages)
754 {
755   PetscFunctionBegin;
756   PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
757   PetscAssertPointer(nstages, 2);
758   PetscTryMethod(ts, "TSIRKGetNumStages_C", (TS, PetscInt *), (ts, nstages));
759   PetscFunctionReturn(PETSC_SUCCESS);
760 }
761 
762 static PetscErrorCode TSIRKGetType_IRK(TS ts, TSIRKType *irktype)
763 {
764   TS_IRK *irk = (TS_IRK *)ts->data;
765 
766   PetscFunctionBegin;
767   *irktype = irk->method_name;
768   PetscFunctionReturn(PETSC_SUCCESS);
769 }
770 
771 static PetscErrorCode TSIRKSetType_IRK(TS ts, TSIRKType irktype)
772 {
773   TS_IRK *irk = (TS_IRK *)ts->data;
774   PetscErrorCode (*irkcreate)(TS);
775 
776   PetscFunctionBegin;
777   if (irk->method_name) {
778     PetscCall(PetscFree(irk->method_name));
779     PetscCall(TSIRKTableauReset(ts));
780   }
781   PetscCall(PetscFunctionListFind(TSIRKList, irktype, &irkcreate));
782   PetscCheck(irkcreate, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown TSIRK type \"%s\" given", irktype);
783   PetscCall((*irkcreate)(ts));
784   PetscCall(PetscStrallocpy(irktype, &irk->method_name));
785   PetscFunctionReturn(PETSC_SUCCESS);
786 }
787 
788 static PetscErrorCode TSIRKSetNumStages_IRK(TS ts, PetscInt nstages)
789 {
790   TS_IRK *irk = (TS_IRK *)ts->data;
791 
792   PetscFunctionBegin;
793   PetscCheck(nstages > 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "input argument, %" PetscInt_FMT ", out of range", nstages);
794   irk->nstages = nstages;
795   PetscFunctionReturn(PETSC_SUCCESS);
796 }
797 
798 static PetscErrorCode TSIRKGetNumStages_IRK(TS ts, PetscInt *nstages)
799 {
800   TS_IRK *irk = (TS_IRK *)ts->data;
801 
802   PetscFunctionBegin;
803   PetscAssertPointer(nstages, 2);
804   *nstages = irk->nstages;
805   PetscFunctionReturn(PETSC_SUCCESS);
806 }
807 
808 static PetscErrorCode TSDestroy_IRK(TS ts)
809 {
810   PetscFunctionBegin;
811   PetscCall(TSReset_IRK(ts));
812   if (ts->dm) {
813     PetscCall(DMCoarsenHookRemove(ts->dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
814     PetscCall(DMSubDomainHookRemove(ts->dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
815   }
816   PetscCall(PetscFree(ts->data));
817   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", NULL));
818   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", NULL));
819   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", NULL));
820   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", NULL));
821   PetscFunctionReturn(PETSC_SUCCESS);
822 }
823 
824 /*MC
825       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes
826 
827   Level: beginner
828 
829   Notes:
830   `TSIRK` uses the sparse Kronecker product matrix implementation of `MATKAIJ` to achieve good arithmetic intensity.
831 
832   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s
833   when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with
834   -ts_irk_nstages or `TSIRKSetNumStages()`.
835 
836 .seealso: [](ch_ts), `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`, `TSType`
837 M*/
838 PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
839 {
840   TS_IRK *irk;
841 
842   PetscFunctionBegin;
843   PetscCall(TSIRKInitializePackage());
844 
845   ts->ops->reset          = TSReset_IRK;
846   ts->ops->destroy        = TSDestroy_IRK;
847   ts->ops->view           = TSView_IRK;
848   ts->ops->load           = TSLoad_IRK;
849   ts->ops->setup          = TSSetUp_IRK;
850   ts->ops->step           = TSStep_IRK;
851   ts->ops->interpolate    = TSInterpolate_IRK;
852   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
853   ts->ops->rollback       = TSRollBack_IRK;
854   ts->ops->setfromoptions = TSSetFromOptions_IRK;
855   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
856   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;
857 
858   ts->usessnes = PETSC_TRUE;
859 
860   PetscCall(PetscNew(&irk));
861   ts->data = (void *)irk;
862 
863   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", TSIRKSetType_IRK));
864   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", TSIRKGetType_IRK));
865   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", TSIRKSetNumStages_IRK));
866   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", TSIRKGetNumStages_IRK));
867   /* 3-stage IRK_Gauss is the default */
868   PetscCall(PetscNew(&irk->tableau));
869   irk->nstages = 3;
870   PetscCall(TSIRKSetType(ts, TSIRKDefault));
871   PetscFunctionReturn(PETSC_SUCCESS);
872 }
873