from .scheme import ContinuousPetrovGalerkinScheme, DiscontinuousGalerkinScheme
from .dirk_stepper import DIRKTimeStepper
from .explicit_stepper import ExplicitTimeStepper
from .discontinuous_galerkin_stepper import DiscontinuousGalerkinTimeStepper
from .galerkin_stepper import ContinuousPetrovGalerkinTimeStepper
from .imex import RadauIIAIMEXMethod, DIRKIMEXMethod
from .labeling import split_explicit
from .stage_derivative import StageDerivativeTimeStepper, AdaptiveTimeStepper
from .stage_value import StageValueTimeStepper
from .tools import AI
from .multistep import MultistepTimeStepper
from .tableaux.multistep_tableaux import MultistepTableau
valid_base_kwargs = ("bcs", "form_compiler_parameters",
"is_linear", "constant_jacobian",
"restrict", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"appctx", "options_prefix", "pre_apply_bcs")
valid_kwargs_per_stage_type = {
"deriv": ["Fp", "stage_type", "bc_type", "splitting", "adaptive_parameters", "aux_indices", "sample_points"],
"value": ["Fp", "stage_type", "basis_type",
"update_solver_parameters", "splitting", "bounds", "use_collocation_update", "sample_points"],
"dirk": ["Fp", "stage_type"],
"explicit": ["Fp", "stage_type"],
"imex": ["Fexp", "stage_type", "it_solver_parameters", "prop_solver_parameters",
"splitting", "num_its_initial", "num_its_per_step"],
"dirkimex": ["Fexp", "stage_type", "mass_parameters"],
"dg": ["Fp", "sample_points"],
"cpg": ["Fp", "bc_type", "aux_indices", "sample_points"]}
valid_adapt_parameters = ["tol", "dtmin", "dtmax", "KI", "KP",
"max_reject", "onscale_factor",
"safety_factor", "gamma0_params"]
valid_multistep_kwargs = ("Fp", "bounds", "startup_parameters")
[docs]
def imex_separation(F, Fexp_kwarg, label):
Fimp, Fexp_label = split_explicit(F)
if Fexp_kwarg is None:
if Fexp_label is None:
raise ValueError(f"Calling an {label} scheme with no explicit form. Did you really mean to do this?")
else:
Fexp = Fexp_label
else:
Fexp = Fexp_kwarg
if Fexp_label is not None:
raise ValueError("You specified an explicit part in two ways!")
return Fimp, Fexp
[docs]
def TimeStepper(F, method, t, dt, u0, **kwargs):
"""Helper function to dispatch between various back-end classes
for doing time stepping. Returns an instance of the
appropriate class.
:arg F: A :class:`ufl.Form` instance describing the semi-discrete problem
``F(t, u; v) == 0``, where ``u`` is the unknown
:class:`firedrake.Function` and ``v`` is the
:class:`firedrake.TestFunction`. To specify a linear problem,
``F`` must be of the form ``a(t; w, v) - L(t; v)``, where
``w`` is a :class:`firedrake.TrialFunction`.
:arg method: A :class:`ButcherTableau` instance (for RK methods) or
a :class:`GalerkinScheme` instance (for CPG or DG) methods
to be used in time marching.
:arg t: a :class:`firedrake.Constant` or :class:`firedrake.Function`
on the Real space over the same mesh as ``u0``. This serves as
a variable referring to the current time.
:arg dt: a :class:`firedrake.Constant` or :class:`firedrake.Function`
on the Real space over the same mesh as ``u0``. This serves as
a variable referring to the current time step size.
The user may adjust this value between time steps.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:kwarg bcs: An iterable of :class:`firedrake.DirichletBC` or
:class: `firedrake.EquationBC` containing
the strongly-enforced boundary conditions. Irksome will
manipulate these to obtain boundary conditions for each
stage of the RK method.
:kwarg constant_jacobian: A boolean flag indicating whether the Jacobian
does not change between time steps. If ``dt`` is updated, the Jacobian
may be flagged for an update via :func:`invalidate_jacobian`.
:kwarg nullspace: A :class:`firedrake.VectorSpaceBasis`
or :class:`firedrake.MixedVectorSpaceBasis` specifying a nullspace
over the space of ``u0``.
:kwarg stage_type: Whether to formulate in terms of a stage
derivatives or stage values. Support for :class:`firedrake.EquationBC`
in ``bcs`` is limited to the stage derivative formulation.
:kwarg splitting: A callable used to factor the Butcher matrix
:kwarg bc_type: For stage derivative formulation, how to manipulate
the strongly-enforced boundary conditions.
Support for :class:`firedrake.EquationBC` in ``bcs`` is limited
to DAE style BCs.
:kwarg solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with each time step.
:kwarg update_solver_parameters: A :class:`dict` of parameters for
inverting the mass matrix at each step (only used if
stage_type is "value")
:kwarg adaptive_parameters: A :class:`dict` of parameters for use with
adaptive time stepping (only used if stage_type is "deriv")
:kwarg use_collocation_update: An optional kwarg indicating whether to use
the terminal value of the collocation polynomial as the solution
update. This is needed to bypass the mass matrix inversion when
enforcing bounds constraints with an RK method that is not stiffly
accurate. Currently, only constant-in-time boundary conditions are
supported.
:kwarg aux_indices: Only valid for continuous Petrov Galerkin time scheme. It
specifies that some of the variables in `u0` are to be treated as
auxiliary, that is, discretized in the lower-order DG test space.
:startup_parameters: An optional :class:`dict` containing parameters used to automatically
find starting values for multistep methods.
:kwarg sample_points: An optional kwarg used to evaluate collocation methods
at additional points in time.
"""
# first pluck out the cases for Galerkin in time...
if isinstance(method, DiscontinuousGalerkinScheme):
assert set(kwargs.keys()).issubset(list(valid_base_kwargs) + valid_kwargs_per_stage_type["dg"])
return DiscontinuousGalerkinTimeStepper(F, method, t, dt, u0, **kwargs)
elif isinstance(method, ContinuousPetrovGalerkinScheme):
assert set(kwargs.keys()).issubset(list(valid_base_kwargs) + valid_kwargs_per_stage_type["cpg"])
return ContinuousPetrovGalerkinTimeStepper(F, method, t, dt, u0, **kwargs)
# then, pluck out the case for multistep methods...
if isinstance(method, MultistepTableau):
base_kwargs = {}
for k in valid_base_kwargs:
if k in kwargs:
base_kwargs[k] = kwargs.pop(k)
bcs = base_kwargs.pop("bcs", None)
for cur_kwarg in kwargs.keys():
if cur_kwarg not in valid_multistep_kwargs:
raise ValueError(f"kwarg {cur_kwarg} is not allowable for MultistepTimeStepper")
bounds = kwargs.pop('bounds', None)
Fp = kwargs.pop('Fp', None)
startup_parameters = kwargs.pop('startup_parameters', None)
return MultistepTimeStepper(F, method, t, dt, u0, bcs=bcs, Fp=Fp, startup_parameters=startup_parameters, bounds=bounds, **base_kwargs)
stage_type = kwargs.pop("stage_type", "deriv")
adapt_params = kwargs.pop("adaptive_parameters", None)
if adapt_params is not None:
assert stage_type == "deriv", "Adaptive time stepping is only implemented for derivative stage type"
base_kwargs = {}
for k in valid_base_kwargs:
if k in kwargs:
base_kwargs[k] = kwargs.pop(k)
bcs = base_kwargs.pop("bcs", None)
for cur_kwarg in kwargs.keys():
if cur_kwarg not in valid_kwargs_per_stage_type[stage_type]:
raise ValueError(f"kwarg {cur_kwarg} is not allowable for stage_type {stage_type}")
if stage_type == "deriv":
Fp = kwargs.get("Fp", None)
bc_type = kwargs.get("bc_type", "DAE")
splitting = kwargs.get("splitting", AI)
aux_indices = kwargs.get("aux_indices", None)
sample_points = kwargs.get("sample_points", None)
if adapt_params is None:
return StageDerivativeTimeStepper(
F, method, t, dt, u0, bcs, Fp=Fp,
bc_type=bc_type, splitting=splitting, aux_indices=aux_indices, sample_points=sample_points, **base_kwargs)
else:
for param in adapt_params:
assert param in valid_adapt_parameters
tol = adapt_params.get("tol", 1e-3)
dtmin = adapt_params.get("dtmin", 1.e-15)
dtmax = adapt_params.get("dtmax", 1.0)
KI = adapt_params.get("KI", 1/15)
KP = adapt_params.get("KP", 0.13)
max_reject = adapt_params.get("max_reject", 10)
onscale_factor = adapt_params.get("onscale_factor", 1.2)
safety_factor = adapt_params.get("safety_factor", 0.9)
gamma0_params = adapt_params.get("gamma0_params")
return AdaptiveTimeStepper(
F, method, t, dt, u0, bcs,
bc_type=bc_type, splitting=splitting,
tol=tol, dtmin=dtmin, dtmax=dtmax, KI=KI, KP=KP,
max_reject=max_reject, onscale_factor=onscale_factor,
safety_factor=safety_factor, gamma0_params=gamma0_params,
**base_kwargs)
elif stage_type == "value":
Fp = kwargs.get("Fp", None)
splitting = kwargs.get("splitting", AI)
basis_type = kwargs.get("basis_type")
update_solver_parameters = kwargs.get("update_solver_parameters")
bounds = kwargs.get("bounds")
use_collocation_update = kwargs.get("use_collocation_update", False)
sample_points = kwargs.get("sample_points", None)
return StageValueTimeStepper(
F, method, t, dt, u0, bcs=bcs, Fp=Fp,
splitting=splitting, basis_type=basis_type,
update_solver_parameters=update_solver_parameters,
bounds=bounds, use_collocation_update=use_collocation_update,
sample_points=sample_points,
**base_kwargs)
elif stage_type == "dirk":
Fp = kwargs.get("Fp", None)
return DIRKTimeStepper(
F, method, t, dt, u0, bcs, Fp=Fp, **base_kwargs)
elif stage_type == "explicit":
Fp = kwargs.get("Fp", None)
return ExplicitTimeStepper(
F, method, t, dt, u0, bcs, Fp=Fp, **base_kwargs)
elif stage_type == "imex":
Fimp, Fexp = imex_separation(F, kwargs.get("Fexp"), stage_type)
appctx = base_kwargs.pop("appctx", None)
nullspace = base_kwargs.pop("nullspace", None)
splitting = kwargs.get("splitting", AI)
it_solver_parameters = kwargs.get("it_solver_parameters")
prop_solver_parameters = kwargs.get("prop_solver_parameters")
num_its_initial = kwargs.get("num_its_initial", 0)
num_its_per_step = kwargs.get("num_its_per_step", 0)
return RadauIIAIMEXMethod(
Fimp, Fexp, method, t, dt, u0, bcs,
it_solver_parameters, prop_solver_parameters,
splitting, appctx, nullspace,
num_its_initial, num_its_per_step, **base_kwargs)
elif stage_type == "dirkimex":
Fimp, Fexp = imex_separation(F, kwargs.get("Fexp"), stage_type)
appctx = base_kwargs.pop("appctx", None)
nullspace = base_kwargs.pop("nullspace", None)
solver_parameters = base_kwargs.pop("solver_parameters", None)
mass_parameters = kwargs.get("mass_parameters")
return DIRKIMEXMethod(
Fimp, Fexp, method, t, dt, u0, bcs,
solver_parameters, mass_parameters, appctx, nullspace, **base_kwargs)