from .tools import replace
from .constant import vecconst
from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives
from .ufl.deriv import expand_time_derivatives
from .base_time_stepper import BaseTimeStepper
from .tableaux.multistep_tableaux import MultistepTableau
from .bcs import stage2spaces4bc
from ufl import Form
from ufl.constantvalue import as_ufl
from firedrake import NonlinearVariationalProblem, NonlinearVariationalSolver, derivative, Constant
[docs]
class MultistepTimeStepper(BaseTimeStepper):
"""front-end class for advancing time-dependent PDE via a general multistep method
:arg F: A :class:`ufl.Form` instance describing the semi-discrete problem
F(t, u; v) == 0, where `u` is the unknown
:class:`firedrake.Function and `v` is the
:class:firedrake.TestFunction`.
:arg method: A :class:`MultistepMethod` corresponding to the desired multistep method.
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: A :class:`firedrake.Function` containing the current
state of the problem to be solved.
:arg bcs: An iterable of :class:`firedrake.DirichletBC` containing
the strongly-enforced boundary conditions.
:arg solver_parameters: A :class:`dict` of solver parameters that
will be used in solving the algebraic problem associated
with each time step.
:arg appctx: An optional :class:`dict` containing application context.
This gets included with particular things that Irksome will
pass into the nonlinear solver so that, say, user-defined preconditioners
have access to it.
:arg startup_parameters: An optional :class:`dict` used to construct a single-step TimeStepper to be used
to find the required starting values.
"""
def __init__(self, F, method, t, dt, u0, bcs=None, Fp=None, solver_parameters=None, bounds=None, appctx=None, nullspace=None,
transpose_nullspace=None, near_nullspace=None, startup_parameters=None, **kwargs):
assert isinstance(method, MultistepTableau)
super().__init__(F, t, dt, u0,
bcs=bcs, appctx=appctx, nullspace=nullspace)
self.num_prev_steps = len(method.b) - 1
self.a = vecconst(method.a)
self.b = vecconst(method.b)
self.us = [u0.copy(deepcopy=True) for coeff in self.a[:-1]]
self.us.append(u0)
Fnew, bcsnew = self.get_form_and_bcs(F, t, dt, u0, self.a, self.b, bcs=bcs)
if Fp is not None:
Fpnew, _ = self.get_form_and_bcs(Fp, t, dt, u0, self.a, self.b, bcs=bcs)
Jp = derivative(Fpnew, self.us[-1])
else:
Jp = None
self.problem = NonlinearVariationalProblem(Fnew, self.us[-1], J=Jp, bcs=bcsnew, form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
is_linear=kwargs.pop("is_linear", False),
restrict=kwargs.pop("restrict", False))
self.solver = NonlinearVariationalSolver(self.problem, appctx=self.appctx,
nullspace=nullspace,
transpose_nullspace=transpose_nullspace,
near_nullspace=near_nullspace,
solver_parameters=solver_parameters,
**kwargs
)
self.num_steps = 0
self.num_nonlinear_iterations = 0
self.num_linear_iterations = 0
self.startup_parameters = startup_parameters
self.bounds = bounds
# optional method to mechanically find the required starting values via a single step method
[docs]
def startup(self):
if self.startup_parameters is None:
return ValueError('No startup parameters provided')
else:
if self.num_prev_steps == 1: # No startup required
self.startup_t = Constant(self.t) if isinstance(self.t, Constant) else self.t.copy(deepcopy=True)
return
butcher_tableau = self.startup_parameters.get('tableau', None)
if isinstance(butcher_tableau, MultistepTableau):
assert butcher_tableau.num_total_steps == 2, "Cannot use a multistep method to start a multistep method"
stepper_kwargs = self.startup_parameters.get('stepper_kwargs', {})
num_startup_steps = self.startup_parameters.get('num_startup_steps', 1)
assert isinstance(num_startup_steps, int) and num_startup_steps > 0
# delayed import to avoid a circular import
from .stepper import TimeStepper
if isinstance(self.dt, Constant):
startup_dt = Constant(self.dt / num_startup_steps)
else:
startup_dt = self.dt.copy(deepcopy=True)
startup_dt.assign(startup_dt / num_startup_steps)
self.startup_t = Constant(self.t) if isinstance(self.t, Constant) else self.t.copy(deepcopy=True)
self.us[0].assign(self.u0)
F_startup = replace(self.F, {self.t: self.startup_t})
v, = F_startup.arguments()
V = v.function_space()
# grab a copy of the boundary conditions w.r.t. startup_t
startup_bcs = []
if self.orig_bcs is None:
pass
else:
for bc in self.orig_bcs:
bcarg = as_ufl(bc._original_arg)
bcarg_startup = replace(bcarg, {self.t: self.startup_t})
bc_space = stage2spaces4bc(bc, V, V, 0)
startup_bcs.extend(bc.reconstruct(V=bc_space, g=bcarg_startup))
self.startup_TS = TimeStepper(F_startup, butcher_tableau, self.startup_t, startup_dt, self.u0, bcs=startup_bcs, **stepper_kwargs)
# advance the system and assign values to previous steps
for i in range(self.num_prev_steps - 1):
for substep in range(num_startup_steps):
self.startup_TS.advance()
self.startup_t.assign(self.startup_t + startup_dt)
self.us[i + 1].assign(self.u0)
[docs]
def advance(self):
self.solver.solve(bounds=self.bounds)
# update previous steps
for i in range(len(self.us) - 1):
self.us[i].assign(self.us[i + 1])
# update solver statistics
self.num_steps += 1
self.num_nonlinear_iterations += self.solver.snes.getIterationNumber()
self.num_linear_iterations += self.solver.snes.getLinearSolveIterations()
[docs]
def solver_stats(self):
return (self.num_steps, self.num_nonlinear_iterations, self.num_linear_iterations)
valid_multistep_kwargs = ("Fp", "bounds", "startup_parameters")