from functools import singledispatchmethod
from importlib import metadata
import sys
from ufl.constantvalue import as_ufl
from ufl.core.ufl_type import ufl_type
from ufl.corealg.dag_traverser import DAGTraverser
from ufl.algorithms.map_integrands import map_integrands
from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.form import BaseForm
from ufl.classes import (Coefficient, Conj, Curl, ConstantValue, Derivative,
Div, Expr, Grad, Indexed, ReferenceGrad,
ReferenceValue, SpatialCoordinate, Variable)
from ufl.corealg.multifunction import MultiFunction
[docs]
class IrksomeImportOrderException(Exception):
pass
[docs]
def check_irksome_import_order():
"""Check that irksome has been imported early enough.
Due to the inadequacies of the UFL type system, it is not possible to
define a new UFL type once any ufl MultiFunction has been used.
This restriction can be removed once all MultiFunctions have been
transitioned to ufl.corealg.dag_traverser.DAGTraverser."""
try:
firedrake_version = metadata.version("firedrake")
except metadata.PackageNotFoundError:
# Firedrake not yet imported, the handler cache should be clean.
expected_cache_size = 0
else:
if firedrake_version < "2026.04":
expected_cache_size = 2
else:
expected_cache_size = 1
if len(MultiFunction._handlers_cache) > expected_cache_size:
raise IrksomeImportOrderException(
"""A UFL multifunction has already run.
Irksome needs to be imported earlier.
"""
)
if expected_cache_size:
# In the cases where Firedrake/UFL has already instantiated
# multifunctions, clear the cache and reinstantiate them now that
# TimeDerivative has been added to the UFL typecode list.
MultiFunction._handlers_cache = {}
if "ufl.formatting.ufl2unicode" in sys.modules:
from ufl.formatting import ufl2unicode
ufl2unicode._precrules = ufl2unicode.PrecedenceRules()
if expected_cache_size > 1:
from firedrake.formmanipulation import ExtractSubBlock
ExtractSubBlock.index_inliner = ExtractSubBlock.IndexInliner()
[docs]
@ufl_type(num_ops=1,
inherit_shape_from_operand=0,
inherit_indices_from_operand=0)
class TimeDerivative(Derivative):
"""UFL node representing a time derivative of some quantity/field.
Note: Currently form compilers do not understand how to process
these nodes. Instead, Irksome pre-processes forms containing
`TimeDerivative` nodes."""
__slots__ = ()
def __new__(cls, f):
return Derivative.__new__(cls)
def __init__(self, f):
Derivative.__init__(self, (f,))
def __str__(self):
return "d{%s}/dt" % (self.ufl_operands[0],)
[docs]
def Dt(f, order=1):
"""Short-hand function to produce a :class:`TimeDerivative` of a given order."""
for k in range(order):
f = TimeDerivative(f)
return f
[docs]
class TimeDerivativeRuleset(GenericDerivativeRuleset):
"""Apply AD rules to time derivative expressions."""
def __init__(self, t=None, timedep_coeffs=None):
GenericDerivativeRuleset.__init__(self, ())
self.t = t
self._Id = as_ufl(1.0)
self.timedep_coeffs = timedep_coeffs
# Work around singledispatchmethod inheritance issue;
# see https://bugs.python.org/issue36457.
[docs]
@singledispatchmethod
def process(self, o):
return super().process(o)
[docs]
@process.register(ConstantValue)
def constant(self, o):
if self.t is not None and o is self.t:
return self._Id
else:
return self.independent_terminal(o)
[docs]
@process.register(Coefficient)
@process.register(SpatialCoordinate)
def terminal(self, o):
if self.t is not None and o is self.t:
return self._Id
elif self.timedep_coeffs is None or o in self.timedep_coeffs:
return TimeDerivative(o)
else:
return self.independent_terminal(o)
[docs]
@process.register(TimeDerivative)
@DAGTraverser.postorder
def time_derivative(self, o, f):
if isinstance(f, TimeDerivative):
return TimeDerivative(f)
else:
return self(f)
[docs]
@process.register(Conj)
@process.register(Curl)
@process.register(Derivative)
@process.register(Div)
@process.register(Grad)
@process.register(Indexed)
@process.register(ReferenceGrad)
@process.register(ReferenceValue)
@process.register(Variable)
@DAGTraverser.postorder
def terminal_modifier(self, o, *operands):
return o._ufl_expr_reconstruct_(*operands)
[docs]
class TimeDerivativeRuleDispatcher(DAGTraverser):
'''
Mapping rules to splat out time derivatives so that replacement should
work on more complex problems.
'''
def __init__(self, t=None, timedep_coeffs=None, **kwargs):
super().__init__(**kwargs)
self.rules = TimeDerivativeRuleset(t=t, timedep_coeffs=timedep_coeffs)
# Work around singledispatchmethod inheritance issue;
# see https://bugs.python.org/issue36457.
[docs]
@singledispatchmethod
def process(self, o):
return super().process(o)
[docs]
@process.register(TimeDerivative)
def time_derivative(self, o):
f, = o.ufl_operands
return self.rules(f)
@process.register(Expr)
@process.register(BaseForm)
def _generic(self, o):
return self.reuse_if_untouched(o)
[docs]
def apply_time_derivatives(expression, t=None, timedep_coeffs=None):
rules = TimeDerivativeRuleDispatcher(t=t, timedep_coeffs=timedep_coeffs)
return map_integrands(rules, expression)
[docs]
def expand_time_derivatives(expression, t=None, timedep_coeffs=None):
expression = apply_algebra_lowering(expression)
expression = apply_time_derivatives(expression, t=t, timedep_coeffs=timedep_coeffs)
return expression