Source code for firedrake.fml.form_manipulation_language

"""A language for manipulating forms using labels."""

import ufl
import functools
import operator
from firedrake import Constant, Function
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union


__all__ = ["Label", "Term", "LabelledForm", "identity", "drop", "all_terms",
           "keep", "subject", "name_label"]

# ---------------------------------------------------------------------------- #
# Core routines for filtering terms
# ---------------------------------------------------------------------------- #


[docs] def identity(t: "Term") -> "Term": """ The identity map. Parameters ---------- t A term. Returns ------- Term The same term. """ return t
[docs] def drop(t: "Term") -> None: """Map all terms to ``None``. Parameters ---------- t A term. Returns ------- None None. """ return None
[docs] def keep(t: "Term") -> "Term": """Keep all terms. Functionally equivalent to identity. Parameters ---------- t A term. Returns ------- Term The same term. """ return t
[docs] def all_terms(t: "Term") -> bool: """Map all terms to ``True``. Parameters ---------- t A term. Returns ------- bool True. """ return True
# ---------------------------------------------------------------------------- # # Term class # ---------------------------------------------------------------------------- #
[docs] class Term(object): """A Term object contains a form and its labels.""" __slots__ = ["form", "labels"] def __init__(self, form: ufl.Form, label_dict: Mapping = None): """ Parameters ---------- form The form for this terms. label_dict Dictionary of key-value pairs corresponding to current form labels. Defaults to None. """ self.form = form self.labels = label_dict or {}
[docs] def get(self, label: "Label") -> Any: """Return the value of a label. Parameters ---------- label The label to return the value of. Returns ------- Any The value of a label. """ return self.labels.get(label.label)
[docs] def has_label( self, *labels: "Label", return_tuple: bool = False ) -> Union[Tuple[bool], bool]: """Return whether the specified labels are attached to this term. Parameters ---------- *labels A label or series of labels. A tuple is automatically returned if multiple labels are provided as arguments. return_tuple If True, forces a tuple to be returned even if only one label is provided as an argument. Defaults to False. Returns ------- bool Booleans corresponding to whether the term has the specified labels. """ if len(labels) == 1 and not return_tuple: return labels[0].label in self.labels else: return tuple(self.has_label(l) for l in labels)
def __add__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm": """Add a term or labelled form to this term. Parameters ---------- other The term or labelled form to add to this term. Returns ------- LabelledForm A labelled form containing the terms. """ if self is NullTerm: return other if other is None or other is NullTerm: return self elif isinstance(other, Term): return LabelledForm(self, other) elif isinstance(other, LabelledForm): return LabelledForm(self, *other.terms) else: return NotImplemented __radd__ = __add__ def __sub__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm": """Subtract a term or labelled form from this term. Parameters ---------- other The term or labelled form to subtract from this term. Returns ------- LabelledForm A labelled form containing the terms. """ other = other * Constant(-1.0) return self + other def __mul__( self, other: Union[float, Constant, ufl.algebra.Product] ) -> "Term": """Multiply this term by another quantity. Parameters ---------- other The quantity to multiply this term by. Returns ------- Term The product of the term with the quantity. """ return Term(other*self.form, self.labels) __rmul__ = __mul__ def __truediv__( self, other: Union[float, Constant, ufl.algebra.Product] ) -> "Term": """Divide this term by another quantity. Parameters ---------- other The quantity to divide this term by. Returns ------- Term The quotient of the term divided by the quantity. """ return self * (Constant(1.0) / other)
# This is necessary to be the initialiser for functools.reduce NullTerm = Term(None) # ---------------------------------------------------------------------------- # # Labelled form class # ---------------------------------------------------------------------------- #
[docs] class LabelledForm(object): """ A form, broken down into terms that pair individual forms with labels. The LabelledForm object holds a list of terms, which pair :class:`ufl.Form` objects with :class:`Label` s. The label_map routine allows the terms to be manipulated or selected based on particular filters. """ __slots__ = ["terms"] def __init__(self, *terms: Sequence[Term]): """ Parameters ---------- *terms : Term Terms to combine to make the LabelledForm. Raises ------ TypeError: If any argument is not a term. """ if len(terms) == 1 and isinstance(terms[0], LabelledForm): self.terms = terms[0].terms else: if any([type(term) is not Term for term in list(terms)]): raise TypeError('Can only pass terms or a LabelledForm to LabelledForm') self.terms = list(terms) def __add__( self, other: Union[ufl.Form, Term, "LabelledForm"] ) -> "LabelledForm": """Add a form, term or labelled form to this labelled form. Parameters ---------- other The form, term or labelled form to add to this labelled form. Returns ------- LabelledForm A labelled form containing the terms. """ if isinstance(other, ufl.Form): return LabelledForm(*self, Term(other)) elif type(other) is Term: return LabelledForm(*self, other) elif type(other) is LabelledForm: return LabelledForm(*self, *other) elif other is None: return self else: return NotImplemented __radd__ = __add__ def __sub__( self, other: Union[ufl.Form, Term, "LabelledForm"] ) -> "LabelledForm": """Subtract a form, term or labelled form from this labelled form. Parameters ---------- other The form, term or labelled form to subtract from this labelled form. Returns ------- LabelledForm A labelled form containing the terms. """ if type(other) is Term: return LabelledForm(*self, Constant(-1.)*other) elif type(other) is LabelledForm: return LabelledForm(*self, *[Constant(-1.)*t for t in other]) elif other is None: return self else: # Make new Term for other and subtract it return LabelledForm(*self, Term(Constant(-1.)*other)) def __mul__( self, other: Union[float, Constant, ufl.algebra.Product] ) -> "LabelledForm": """Multiply this labelled form by another quantity. Parameters ---------- other The quantity to multiply this labelled form by. All terms in the form are multiplied. Returns ------- LabelledForm The product of all terms with the quantity. """ return self.label_map(all_terms, lambda t: Term(other*t.form, t.labels)) def __truediv__( self, other: Union[float, Constant, ufl.algebra.Product] ) -> "LabelledForm": """Divide this labelled form by another quantity. Parameters ---------- other The quantity to divide this labelled form by. All terms in the form are divided. Returns ------- LabelledForm The quotient of all terms with the quantity. """ return self * (Constant(1.0) / other) __rmul__ = __mul__ def __iter__(self) -> Sequence: """Iterable of the terms in the labelled form.""" return iter(self.terms) def __len__(self) -> int: """Number of terms in the labelled form.""" return len(self.terms)
[docs] def label_map( self, term_filter: Callable[[Term], bool], map_if_true: Callable[[Term], Optional[Term]] = identity, map_if_false: Callable[[Term], Optional[Term]] = identity ) -> "LabelledForm": """Map selected terms in the labelled form, returning a new labelled form. Parameters ---------- term_filter A function to filter the labelled form's terms. map_if_true How to map the terms for which the term_filter returns True. Defaults to identity. map_if_false How to map the terms for which the term_filter returns False. Defaults to identity. Returns ------- LabelledForm A new labelled form with the terms mapped. """ # FIXME: The rendered docstring for this method is a mess, the lambda # hackery at the top goes some way to fix this, but this is probably a # bug in napoleon. new_labelled_form = LabelledForm( functools.reduce(operator.add, filter(lambda t: t is not None, (map_if_true(t) if term_filter(t) else map_if_false(t) for t in self.terms)), # Need to set an initialiser, otherwise the label_map # won't work if the term_filter is False for everything # None does not work, as then we add Terms to None # and the addition operation is defined from None # rather than the Term. NullTerm solves this. NullTerm)) # Drop the NullTerm new_labelled_form.terms = list(filter(lambda t: t is not NullTerm, new_labelled_form.terms)) return new_labelled_form
@property def form(self) -> ufl.Form: """Provide the whole form from the labelled form. Raises ------ TypeError If the labelled form has no terms. Returns ------- ufl.Form The whole form corresponding to all the terms. """ # Throw an error if there is no form if len(self.terms) == 0: raise TypeError('The labelled form cannot return a form as it has no terms') else: return functools.reduce(operator.add, (t.form for t in self.terms))
[docs] class Label(object): """Object for tagging forms, allowing them to be manipulated.""" __slots__ = ["label", "default_value", "value", "validator"] def __init__( self, label, *, value: Any = True, validator: Optional[Callable] = None ): """ Parameters ---------- label The name of the label. value The value for the label to take. Can be any type (subject to the validator). Defaults to True. validator Function to check the validity of any value later passed to the label. Defaults to None. """ self.label = label self.default_value = value self.validator = validator
[docs] def __call__( self, target: Union[ufl.Form, Term, LabelledForm], value: Any = None ) -> Union[Term, LabelledForm]: """Apply the label to a form or term. Parameters ---------- target The form, term or labelled form to be labelled. value The value to attach to this label. Defaults to None. Raises ------ ValueError If the `target` is not a ufl.Form, Term or LabelledForm. Returns ------- Union[Term, LabelledForm] A Term is returned if the target is a Term, otherwise a LabelledForm is returned. """ # if value is provided, check that we have a validator function # and validate the value, otherwise use default value if value is not None: assert self.validator, f'Label {self.label} requires a validator' assert self.validator(value), f'Value {value} for label {self.label} does not satisfy validator' self.value = value else: self.value = self.default_value if isinstance(target, LabelledForm): return LabelledForm(*(self(t, value) for t in target.terms)) elif isinstance(target, ufl.Form): return LabelledForm(Term(target, {self.label: self.value})) elif isinstance(target, Term): new_labels = target.labels.copy() new_labels.update({self.label: self.value}) return Term(target.form, new_labels) else: raise ValueError("Unable to label %s" % target)
[docs] def remove(self, target: Union[Term, LabelledForm]): """Remove a label from a term or labelled form. This removes any Label with this ``label`` from ``target``. If called on an LabelledForm, it acts term-wise. Parameters ---------- target Term or labelled form to have this label removed from. Raises ------ ValueError If the `target` is not a Term or a LabelledForm. """ if isinstance(target, LabelledForm): return LabelledForm(*(self.remove(t) for t in target.terms)) elif isinstance(target, Term): try: d = target.labels.copy() d.pop(self.label) return Term(target.form, d) except KeyError: return target else: raise ValueError("Unable to unlabel %s" % target)
[docs] def update_value(self, target: Union[Term, LabelledForm], new: Any): """Update the label of a term or labelled form. This updates the value of any Label with this ``label`` from ``target``. If called on an LabelledForm, it acts term-wise. Parameters ---------- target Term or labelled form to have this label updated. new The new value for this label to take. The type is subject to the label's validator (if it has one). Raises ------ ValueError If the `target` is not a Term or a LabelledForm. """ if isinstance(target, LabelledForm): return LabelledForm(*(self.update_value(t, new) for t in target.terms)) elif isinstance(target, Term): try: d = target.labels.copy() d[self.label] = new return Term(target.form, d) except KeyError: return target else: raise ValueError("Unable to relabel %s" % target)
# ---------------------------------------------------------------------------- # # Some common labels # ---------------------------------------------------------------------------- # subject = Label("subject", validator=lambda value: type(value) == Function) name_label = Label("name", validator=lambda value: type(value) == str)