Source code for irksome.stage_value
# formulate RK methods to solve for stage values rather than the stage derivatives.
import numpy
from FIAT import Bernstein, ufc_simplex
from FIAT.barycentric_interpolation import LagrangePolynomialSet
from ufl import Form, as_tensor, as_ufl
from .tableaux.ButcherTableaux import CollocationButcherTableau
from .ufl.deriv import expand_time_derivatives
from .ufl.manipulation import (has_nonlinear_time_derivative,
split_time_derivative_terms,
remove_time_derivatives)
from .tools import AI, extract_timedep_arguments, dot, reshape, replace
from .constant import vecconst
from .base_time_stepper import StageCoupledTimeStepper
from .backend import get_backend
[docs]
def to_value(u0, stages, vandermonde):
"""convert from Bernstein to Lagrange representation
the Bernstein coefficients are [u0; ZZ], and the Lagrange
are [u0; UU] since the value at the left-endpoint is unchanged.
Since u0 is not part of the unknown vector of stages, we disassemble
the Vandermonde matrix (first row is [1, 0, ...]).
"""
ZZ_np = reshape(stages, (-1, *u0.ufl_shape))
if vandermonde is None:
return ZZ_np
u0_np = reshape(u0, (-1, *u0.ufl_shape))
u_np = numpy.concatenate((u0_np, ZZ_np))
return dot(vandermonde[1:], u_np)
[docs]
def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=AI, vandermonde=None, aux_indices=None, backend: str = "firedrake"):
"""Given a time-dependent variational form and a
:class:`ButcherTableau`, produce UFL for the s-stage RK method.
:arg F: a :class:`ufl.Form` instance describing the semi-discrete problem.
:arg butch: the :class:`ButcherTableau` for the RK method being used to
advance in time.
:arg t: a :class:`Constant` or :class:`Function`
on the Real space over the same mesh as `u0`. This serves as
a variable referring to the current time.
:arg dt: a :class:`Constant` or :class:`Function`
on the Real space over the same mesh as `u0`. This serves as
a variable referring to the current time step size.
The user may adjust this value between time steps.
:arg u0: a :class:`Function` referring to the state of
the PDE system at time `t`
:arg stages: a :class:`Function` representing the stages to be solved for.
It lives in a :class:`FunctionSpace` corresponding to the
s-way tensor product of the space on which the semidiscrete
form lives.
:kwarg bcs: optionally, a :class:`DirichletBC` object (or iterable thereof)
containing (possibly time-dependent) boundary conditions imposed
on the system.
:kwarg splitting: a callable that maps the (floating point) Butcher matrix
a to a pair of matrices `A1, A2` such that `butch.A = A1 A2`. This is used
to vary between the classical RK formulation and Butcher's reformulation
that leads to a denser mass matrix with block-diagonal stiffness.
Only `AI` and `IA` are currently supported.
:kwarg vandermonde: a numpy array encoding a change of basis to the Lagrange
polynomials associated with the collocation nodes from some other
(e.g. Bernstein or Chebyshev) basis. This allows us to solve for the
coefficients in some basis rather than the values at particular stages,
which can be useful for satisfying bounds constraints.
If none is provided, we assume it is the identity, working in the
Lagrange basis.
:kwarg aux_indices: a list of field indices, currently ignored.
:kwarg sample_points: An optional kwarg used to evaluate collocation methods
at additional points in time.
:returns: a 2-tuple of
- `Fnew`, the :class:`Form`
- `bcnew`, a list of :class:`DirichletBC` objects to be posed
on the stages
"""
v, u = extract_timedep_arguments(F, u0)
backend_cls = get_backend(backend)
V = backend_cls.get_function_space(v)
assert V == backend_cls.get_function_space(u0)
c = vecconst(butch.c, backend=backend)
bA1, bA2 = splitting(butch.A)
try:
bA2inv = numpy.linalg.inv(bA2)
except numpy.linalg.LinAlgError:
raise NotImplementedError("We require A = A1 A2 with A2 invertible")
A1 = vecconst(bA1, backend=backend)
A2inv = vecconst(bA2inv, backend=backend)
# s-way product space for the stage variables
num_stages = butch.num_stages
Vbig = stages.function_space()
test = backend_cls.TestFunction(Vbig)
# set up the pieces we need to work with to do our substitutions
v_np = reshape(test, (num_stages, *v.ufl_shape))
w_np = to_value(u0, stages, vandermonde)
A1Tv = dot(A1.T, v_np)
A2invTv = dot(A2inv.T, v_np)
# first, process terms with a time derivative. I'm
# assuming we have something of the form inner(Dt(g(u0)), v)*dx
# For each stage i, this gets replaced with
# inner((g(stages[i]) - g(u0))/dt, v)*dx
split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,))
F_dtless = remove_time_derivatives(split_form.time)
F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=())
Fnew = Form([])
# Terms with time derivatives: use two evaluations so that
# Dt(g(u)) is discretised as g(U_i) - g(u0), not g(U_i - u0).
# These are identical for linear g but differ for nonlinear g,
# and the two-evaluation form is what gives mass conservation.
for i in range(num_stages):
repl_new = {t: t + c[i] * dt,
v: A2invTv[i],
u: w_np[i]}
# Evaluate g at the old solution u0 (not substituted) and
# old time t (not substituted).
repl_old = {v: A2invTv[i], u: u0}
Fnew += replace(F_dtless, repl_new) - replace(F_dtless, repl_old)
# Handle the rest of the terms
for i in range(num_stages):
# replace the solution with stage values
repl = {t: t + c[i] * dt,
v: A1Tv[i] * dt,
u: w_np[i]}
Fnew += replace(F_remainder, repl)
if bcs is None:
bcs = []
bcsnew = []
if vandermonde is not None:
Vander_inv = vecconst(numpy.linalg.inv(vandermonde.astype(float)), backend=backend)
# For each BC, we need a new BC for each stage
# so we need to figure out how the function is indexed (mixed + vec)
# and then set it to have the value of the original argument at
# time t+C[i]*dt.
for bc in bcs:
bcarg = as_ufl(bc._original_arg)
g_np = numpy.array([replace(bcarg, {t: t + ci * dt}) for ci in c])
if vandermonde is not None:
g_np -= vandermonde[1:, 0] * bcarg
g_np = Vander_inv[1:, 1:] @ g_np
for i in range(num_stages):
Vbigi = backend_cls.stage2spaces4bc(bc, V, Vbig, i)
bcsnew.extend(bc.reconstruct(V=Vbigi, g=as_tensor(g_np[i])))
return Fnew, bcsnew
[docs]
class StageValueTimeStepper(StageCoupledTimeStepper):
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None,
update_solver_parameters=None,
splitting=AI, basis_type=None,
appctx=None, bounds=None,
use_collocation_update=False,
sample_points=None,
backend: str = "firedrake",
**kwargs):
self.butcher_tableau = butcher_tableau
self.basis_type = basis_type
if basis_type is None or basis_type == 'Lagrange':
vandermonde = None
else:
nodes = numpy.insert(butcher_tableau.c, 0, 0.0)
pts = numpy.reshape(nodes, (-1, 1))
vandermonde = self.tabulate_poly(pts).T
self.vandermonde = vandermonde
super().__init__(F, t, dt, u0, butcher_tableau.num_stages, bcs=bcs,
solver_parameters=solver_parameters,
appctx=appctx,
splitting=splitting, butcher_tableau=butcher_tableau, bounds=bounds,
sample_points=sample_points, backend=backend,
**kwargs)
self.num_fields = len(self._backend.get_function_space(u0))
self.set_initial_guess()
if use_collocation_update:
# Use the terminal value of the collocation polynomial to update the solution.
# Note: collocation update is only implemented for constant-in-time boundary conditions.
# TODO: create an assertion to check for constant-in-time boundary conditions.
self.collocation_vander = self.tabulate_poly((1.0,))
self._update = self._update_collocation
elif (not butcher_tableau.is_stiffly_accurate) and (vandermonde is None):
# Conservative variational update is needed only when Dt's
# argument is nonlinear in u0; for any g linear in u0
# (g = c*u, g = c(x)*u, g = M*u; affine g = u + f(t,x) too,
# with the f(t,x) piece handled by the remainder via the
# Dt-split) the bAinv shortcut commutes with g and is exact.
# It is also the only correct path under DAE structure: the
# conservative variational head reduces to 0 on algebraic
# blocks where Dt is absent, so it does not determine u_new
# there.
if has_nonlinear_time_derivative(F, u0):
self.unew, self.update_solver = self.get_update_solver(update_solver_parameters)
self._update = self._update_general
else:
try:
A = butcher_tableau.A
b = butcher_tableau.b
self.bAinv = vecconst(numpy.linalg.solve(A.T, b), backend=backend)
self.update_scale = 1-numpy.sum(self.bAinv)
self._update = self._update_Ainv
except numpy.linalg.LinAlgError:
self.unew, self.update_solver = self.get_update_solver(update_solver_parameters)
self._update = self._update_general
else:
self._update = self._update_stiff_acc
def _update_Ainv(self):
nf = self.num_fields
ns = self.num_stages
scale = self.update_scale
bAinv = self.bAinv
for i, u0bit in enumerate(self.u0.subfunctions):
u0bit *= scale
u0bit += sum(self.stages.subfunctions[nf * s + i] * bAinv[s] for s in range(ns))
def _update_stiff_acc(self):
for i, u0bit in enumerate(self.u0.subfunctions):
u0bit.assign(self.stages.subfunctions[self.num_fields*(self.num_stages-1)+i])
[docs]
def get_update_solver(self, update_solver_parameters):
"""Build a conservative variational update solve for u_new.
For a mass term ``inner(Dt(g(u)), v) * dx`` the update head is
inner(g(u_new) - g(u_0), v) * dx
evaluated at the stage-solve test function ``v``. For
``g = identity`` it reduces to ``inner(u_new - u_0, v) * dx``,
so the discrete update equation is unchanged in the linear
case. The remaining (non-time-derivative) part of the form is
contributed by the standard RK quadrature
``sum_i b_i * F_remainder(stage_i)``.
``update_solver_parameters`` does not inherit from
``solver_parameters``. The update solve is a different
problem from the stage solve -- it is posed on ``V`` rather
than ``V^s = V x ... x V``, and its Jacobian is a (nonlinear)
weighted mass matrix rather than the stage operator. Stage-
tuned options such as fieldsplit indices, ``snes_type='ksponly'``,
lagged Jacobians, or custom multigrid transfers generally do
not apply. If ``update_solver_parameters`` is None, Firedrake's
default solver parameters are used (typically a sparse direct
solve). Pass an explicit dict to override.
"""
# only form update stuff if we need it
# which means neither stiffly accurate nor Vandermonde
backend_cls = self._backend
C = vecconst(self.butcher_tableau.c)
B = vecconst(self.butcher_tableau.b)
F = self.F
t = self.t
dt = self.dt
u0 = self.u0
v, u = extract_timedep_arguments(F, u0)
unew = backend_cls.Function(backend_cls.get_function_space(u))
split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,))
F_dtless = remove_time_derivatives(split_form.time)
F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=())
Fupdate = replace(F_dtless, {u: unew}) - replace(F_dtless, {u: u0})
u_np = to_value(u0, self.stages, self.vandermonde)
for i in range(self.num_stages):
repl = {t: t + C[i] * dt,
u: u_np[i]}
Fupdate += dt * B[i] * replace(F_remainder, repl)
# And the BC's for the update -- just the original BC at t+dt
update_bcs = []
for bc in self.orig_bcs:
bcarg = as_ufl(bc._original_arg)
gcur = replace(bcarg, {t: t + dt})
update_bcs.append(bc.reconstruct(g=gcur))
update_problem = backend_cls.create_variational_problem(Fupdate, unew, update_bcs)
update_solver = backend_cls.create_variational_solver(update_problem, solver_parameters=update_solver_parameters)
return unew, update_solver
def _update_general(self):
# Constant-in-time initial guess to prevent singular Jacobian
self.unew.assign(self.u0)
self.update_solver.solve()
self.u0.assign(self.unew)
def _update_collocation(self):
stage_vals = numpy.array(self.u0.subfunctions + self.stages.subfunctions, dtype=object)
for i, u0bit in enumerate(self.u0.subfunctions):
u0bit.assign(stage_vals[i::self.num_fields] @ self.collocation_vander)
[docs]
def get_form_and_bcs(self, stages, F=None, bcs=None, tableau=None):
if bcs is None:
bcs = self.orig_bcs
return getFormStage(F or self.F,
tableau or self.butcher_tableau,
self.t, self.dt, self.u0,
stages, bcs=bcs,
splitting=self.splitting,
vandermonde=self.vandermonde)
[docs]
def set_initial_guess(self):
"""Set a constant-in-time initial guess"""
for k in range(self.num_stages):
for i, u0bit in enumerate(self.u0.subfunctions):
sbit = self.stages.subfunctions[self.num_fields * k + i]
sbit.assign(u0bit)
[docs]
def tabulate_poly(self, sample_points):
if not isinstance(self.butcher_tableau, CollocationButcherTableau):
raise ValueError("Need a collocation method to evaluate the collocation polynomial")
nodes = numpy.insert(self.butcher_tableau.c, 0, 0.0)
if len(set(nodes)) != len(nodes):
raise ValueError("Need non-confluent collocation method for polynomial evaluation")
ref_el = ufc_simplex(1)
if self.basis_type is None or self.basis_type == "Lagrange":
lag_basis = LagrangePolynomialSet(ref_el, nodes)
vander = vecconst(lag_basis.tabulate(sample_points, 0)[(0,)])
elif self.basis_type == "Bernstein":
bern_element = Bernstein(ref_el, self.butcher_tableau.num_stages)
vander = vecconst(bern_element.tabulate(0, sample_points)[(0,)])
else:
raise ValueError(f"Unknown or unimplemented basis transformation type {self.basis_type}")
return vander