Source code for gem.optimise

"""A set of routines implementing various transformations on GEM
expressions."""

from collections import OrderedDict, defaultdict
from functools import singledispatch, partial, reduce
from itertools import combinations, permutations, zip_longest
from numbers import Integral

import numpy

from gem.utils import groupby
from gem.node import (Memoizer, MemoizerArg, reuse_if_untouched,
                      reuse_if_untouched_arg, traversal)
from gem.gem import (Node, Failure, Identity, Literal, Zero,
                     Product, Sum, Comparison, Conditional, Division,
                     Index, VariableIndex, Indexed, FlexiblyIndexed,
                     IndexSum, ComponentTensor, ListTensor, Delta,
                     partial_indexed, one)


[docs] @singledispatch def literal_rounding(node, self): """Perform FFC rounding of FIAT tabulation matrices on the literals of a GEM expression. :arg node: root of the expression :arg self: function for recursive calls """ raise AssertionError("cannot handle type %s" % type(node))
literal_rounding.register(Node)(reuse_if_untouched)
[docs] @literal_rounding.register(Literal) def literal_rounding_literal(node, self): table = node.array epsilon = self.epsilon # Mimic the rounding applied at COFFEE formatting, which in turn # mimics FFC formatting. one_decimal = numpy.asarray(numpy.round(table, 1)) one_decimal[numpy.logical_not(one_decimal)] = 0 # no minus zeros return Literal(numpy.where(abs(table - one_decimal) < epsilon, one_decimal, table))
[docs] def ffc_rounding(expression, epsilon): """Perform FFC rounding of FIAT tabulation matrices on the literals of a GEM expression. :arg expression: GEM expression :arg epsilon: tolerance limit for rounding """ mapper = Memoizer(literal_rounding) mapper.epsilon = epsilon return mapper(expression)
@singledispatch def _replace_division(node, self): """Replace division with multiplication :param node: root of expression :param self: function for recursive calls """ raise AssertionError("cannot handle type %s" % type(node)) _replace_division.register(Node)(reuse_if_untouched) @_replace_division.register(Division) def _replace_division_division(node, self): a, b = node.children return Product(self(a), Division(one, self(b)))
[docs] def replace_division(expressions): """Replace divisions with multiplications in expressions""" mapper = Memoizer(_replace_division) return list(map(mapper, expressions))
[docs] @singledispatch def replace_indices(node, self, subst): """Replace free indices in a GEM expression. :arg node: root of the expression :arg self: function for recursive calls :arg subst: tuple of pairs; each pair is a substitution rule with a free index to replace and an index to replace with. """ raise AssertionError("cannot handle type %s" % type(node))
replace_indices.register(Node)(reuse_if_untouched_arg) def _replace_indices_atomic(i, self, subst): if isinstance(i, VariableIndex): new_expr = self(i.expression, subst) return i if new_expr == i.expression else VariableIndex(new_expr) else: substitute = dict(subst) return substitute.get(i, i)
[docs] @replace_indices.register(Delta) def replace_indices_delta(node, self, subst): i = _replace_indices_atomic(node.i, self, subst) j = _replace_indices_atomic(node.j, self, subst) if i == node.i and j == node.j: return node else: return Delta(i, j)
[docs] @replace_indices.register(Indexed) def replace_indices_indexed(node, self, subst): child, = node.children substitute = dict(subst) multiindex = [] for i in node.multiindex: multiindex.append(_replace_indices_atomic(i, self, subst)) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) if new_child == child and multiindex == node.multiindex: return node else: return Indexed(new_child, multiindex)
[docs] @replace_indices.register(FlexiblyIndexed) def replace_indices_flexiblyindexed(node, self, subst): child, = node.children assert not child.free_indices dim2idxs = tuple( ( offset if isinstance(offset, Integral) else _replace_indices_atomic(offset, self, subst), tuple((_replace_indices_atomic(i, self, subst), s if isinstance(s, Integral) else self(s, subst)) for i, s in idxs) ) for offset, idxs in node.dim2idxs ) if dim2idxs == node.dim2idxs: return node else: return FlexiblyIndexed(child, dim2idxs)
[docs] def filtered_replace_indices(node, self, subst): """Wrapper for :func:`replace_indices`. At each call removes substitution rules that do not apply.""" if any(isinstance(k, VariableIndex) for k, _ in subst): raise NotImplementedError("Can not replace VariableIndex (will need inverse)") filtered_subst = tuple((k, v) for k, v in subst if k in node.free_indices) return replace_indices(node, self, filtered_subst)
[docs] def remove_componenttensors(expressions): """Removes all ComponentTensors in multi-root expression DAG.""" mapper = MemoizerArg(filtered_replace_indices) return [mapper(expression, ()) for expression in expressions]
@singledispatch def _constant_fold_zero(node, self): raise AssertionError("cannot handle type %s" % type(node)) _constant_fold_zero.register(Node)(reuse_if_untouched) @_constant_fold_zero.register(Literal) def _constant_fold_zero_literal(node, self): if (node.array == 0).all(): # All zeros, make symbolic zero return Zero(node.shape) else: return node @_constant_fold_zero.register(ListTensor) def _constant_fold_zero_listtensor(node, self): new_children = list(map(self, node.children)) if all(isinstance(nc, Zero) for nc in new_children): return Zero(node.shape) elif all(nc == c for nc, c in zip(new_children, node.children)): return node else: return node.reconstruct(*new_children)
[docs] def constant_fold_zero(exprs): """Produce symbolic zeros from Literals :arg exprs: An iterable of gem expressions. :returns: A list of gem expressions where any Literal containing only zeros is replaced by symbolic Zero of the appropriate shape. We need a separate path for ListTensor so that its `reconstruct` method will not be called when the new children are `Zero()`s; otherwise Literal `0`s would be reintroduced. """ mapper = Memoizer(_constant_fold_zero) return [mapper(e) for e in exprs]
def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed(_select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index,)) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors([Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError("How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[_select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions])]) raise NotImplementedError("No rule for factorising expressions of this kind.")
[docs] def select_expression(expressions, index): """Select an expression from a list of expressions with an index. Semantically equivalent to partial_indexed(ListTensor(expressions), (index,)) but has a much more optimised implementation. :arg expressions: a list of expressions of the same shape :arg index: an index (free, fixed or variable) :returns: an expression of the same shape as the given expressions """ # Check arguments shape = expressions[0].shape assert all(e.shape == shape for e in expressions) # Sanitise input expressions alpha = tuple(Index() for s in shape) exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions]) # Factor the expressions recursively and convert result selected = _select_expression(exprs, index) return ComponentTensor(selected, alpha)
[docs] def delta_elimination(sum_indices, factors): """IndexSum-Delta cancellation. :arg sum_indices: free indices for contractions :arg factors: product factors :returns: optimised (sum_indices, factors) """ sum_indices = list(sum_indices) # copy for modification def substitute(expression, from_, to_): if from_ not in expression.free_indices: return expression elif isinstance(expression, Delta): mapper = MemoizerArg(filtered_replace_indices) return mapper(expression, ((from_, to_),)) else: return Indexed(ComponentTensor(expression, (from_,)), (to_,)) delta_queue = [(f, index) for f in factors if isinstance(f, Delta) for index in (f.i, f.j) if index in sum_indices] while delta_queue: delta, from_ = delta_queue[0] to_, = list({delta.i, delta.j} - {from_}) sum_indices.remove(from_) factors = [substitute(f, from_, to_) for f in factors] delta_queue = [(f, index) for f in factors if isinstance(f, Delta) for index in (f.i, f.j) if index in sum_indices] return sum_indices, factors
[docs] def associate(operator, operands): """Apply associativity rules to construct an operation-minimal expression tree. For best performance give factors that have different set of free indices. :arg operator: associative binary operator :arg operands: list of operands :returns: (reduced expression, # of floating-point operations) """ if len(operands) > 32: # O(N^3) algorithm raise NotImplementedError("Not expected such a complicated expression!") def count(pair): """Operation count to reduce a pair of GEM expressions""" a, b = pair extents = [i.extent for i in set().union(a.free_indices, b.free_indices)] return numpy.prod(extents, dtype=int) flops = 0 while len(operands) > 1: # Greedy algorithm: choose a pair of operands that are the # cheapest to reduce. a, b = min(combinations(operands, 2), key=count) flops += count((a, b)) # Remove chosen factors, append their product operands.remove(a) operands.remove(b) operands.append(operator(a, b)) result, = operands return result, flops
[docs] def sum_factorise(sum_indices, factors): """Optimise a tensor product through sum factorisation. :arg sum_indices: free indices for contractions :arg factors: product factors :returns: optimised GEM expression """ if len(factors) == 0 and len(sum_indices) == 0: # Empty product return one if len(sum_indices) > 6: raise NotImplementedError("Too many indices for sum factorisation!") # Form groups by free indices groups = groupby(factors, key=lambda f: f.free_indices) groups = [reduce(Product, terms) for _, terms in groups] # Sum factorisation expression = None best_flops = numpy.inf # Consider all orderings of contraction indices for ordering in permutations(sum_indices): terms = groups[:] flops = 0 # Apply contraction index by index for sum_index in ordering: # Select terms that need to be part of the contraction contract = [t for t in terms if sum_index in t.free_indices] deferred = [t for t in terms if sum_index not in t.free_indices] # Optimise associativity product, flops_ = associate(Product, contract) term = IndexSum(product, (sum_index,)) flops += flops_ + numpy.prod([i.extent for i in product.free_indices], dtype=int) # Replace the contracted terms with the result of the # contraction. terms = deferred + [term] # If some contraction indices were independent, then we may # still have several terms at this point. expr, flops_ = associate(Product, terms) flops += flops_ if flops < best_flops: expression = expr best_flops = flops return expression
[docs] def make_sum(summands): """Constructs an operation-minimal sum of GEM expressions.""" groups = groupby(summands, key=lambda f: f.free_indices) summands = [reduce(Sum, terms) for _, terms in groups] result, flops = associate(Sum, summands) return result
[docs] def make_product(factors, sum_indices=()): """Constructs an operation-minimal (tensor) product of GEM expressions.""" return sum_factorise(sum_indices, factors)
[docs] def make_rename_map(): """Creates an rename map for reusing the same index renames.""" return defaultdict(Index)
[docs] def make_renamer(rename_map): r"""Creates a function for renaming indices when expanding products of IndexSums, i.e. applying to following rule: (\sum_i a_i)*(\sum_i b_i) ===> \sum_{i,i'} a_i*b_{i'} :arg rename_map: An rename map for renaming indices the same way as functions returned by other calls of this function. :returns: A function that takes an iterable of indices to rename, and returns (renamed indices, applier), where applier is a function that remap the free indices of GEM expressions from the old to the new indices. """ def _renamer(rename_map, current_set, incoming): renamed = [] renames = [] for i in incoming: j = i while j in current_set: j = rename_map[j] current_set.add(j) renamed.append(j) if i != j: renames.append((i, j)) if renames: def applier(expr): pairs = [(i, j) for i, j in renames if i in expr.free_indices] if pairs: current, renamed = zip(*pairs) return Indexed(ComponentTensor(expr, current), renamed) else: return expr else: applier = lambda expr: expr return tuple(renamed), applier return partial(_renamer, rename_map, set())
[docs] def traverse_product(expression, stop_at=None, rename_map=None): """Traverses a product tree and collects factors, also descending into tensor contractions (IndexSum). The nominators of divisions are also broken up, but not the denominators. :arg expression: a GEM expression :arg stop_at: Optional predicate on GEM expressions. If specified and returns true for some subexpression, that subexpression is not broken into further factors even if it is a product-like expression. :arg rename_map: an rename map for consistent index renaming :returns: (sum_indices, terms) - sum_indices: list of indices to sum over - terms: list of product terms """ if rename_map is None: rename_map = make_rename_map() renamer = make_renamer(rename_map) sum_indices = [] terms = [] stack = [expression] while stack: expr = stack.pop() if stop_at is not None and stop_at(expr): terms.append(expr) elif isinstance(expr, IndexSum): indices, applier = renamer(expr.multiindex) sum_indices.extend(indices) stack.extend(remove_componenttensors(map(applier, expr.children))) elif isinstance(expr, Product): stack.extend(reversed(expr.children)) elif isinstance(expr, Division): # Break up products in the dividend, but not in divisor. dividend, divisor = expr.children if dividend == one: terms.append(expr) else: stack.append(Division(one, divisor)) stack.append(dividend) else: terms.append(expr) return sum_indices, terms
[docs] def traverse_sum(expression, stop_at=None): """Traverses a summation tree and collects summands. :arg expression: a GEM expression :arg stop_at: Optional predicate on GEM expressions. If specified and returns true for some subexpression, that subexpression is not broken into further summands even if it is an addition. :returns: list of summand expressions """ stack = [expression] result = [] while stack: expr = stack.pop() if stop_at is not None and stop_at(expr): result.append(expr) elif isinstance(expr, Sum): stack.extend(reversed(expr.children)) else: result.append(expr) return result
[docs] def contraction(expression, ignore=None): """Optimise the contractions of the tensor product at the root of the expression, including: - IndexSum-Delta cancellation - Sum factorisation :arg ignore: Optional set of indices to ignore when applying sum factorisation (otherwise all summation indices will be considered). Use this if your expression has many contraction indices. This routine was designed with finite element coefficient evaluation in mind. """ # Eliminate annoying ComponentTensors expression, = remove_componenttensors([expression]) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): sum_indices, factors = delta_elimination(*traverse_product(expression)) factors = remove_componenttensors(factors) if ignore is not None: # TODO: This is a really blunt instrument and one might # plausibly want the ignored indices to be contracted on # the inside rather than the outside. extra = tuple(i for i in sum_indices if i in ignore) to_factor = tuple(i for i in sum_indices if i not in ignore) return IndexSum(sum_factorise(to_factor, factors), extra) else: return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, # this is the case for H(div) and H(curl) conforming tensor # product elements. So if ListTensors are used, they are pulled # out to be outermost, so we can straightforwardly factorise each # of its entries. lt_fis = OrderedDict() # ListTensor free indices for node in traversal((expression,)): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor): lt_fis.update(zip_longest(node.multiindex, ())) lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) if lt_fis: # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape)] entries = remove_componenttensors(entries) return Indexed(ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape) ), lt_fis) else: # Rebuild whole expression at once return rebuild(expression)
@singledispatch def _replace_delta(node, self): raise AssertionError("cannot handle type %s" % type(node)) _replace_delta.register(Node)(reuse_if_untouched) @_replace_delta.register(Delta) def _replace_delta_delta(node, self): i, j = node.i, node.j if isinstance(i, Index) or isinstance(j, Index): if isinstance(i, Index) and isinstance(j, Index): assert i.extent == j.extent if isinstance(i, Index): assert i.extent is not None size = i.extent if isinstance(j, Index): assert j.extent is not None size = j.extent return Indexed(Identity(size), (i, j)) else: def expression(index): if isinstance(index, int): return Literal(index) elif isinstance(index, VariableIndex): return index.expression else: raise ValueError("Cannot convert running index to expression.") e_i = expression(i) e_j = expression(j) return Conditional(Comparison("==", e_i, e_j), one, Zero())
[docs] def replace_delta(expressions): """Lowers all Deltas in a multi-root expression DAG.""" mapper = Memoizer(_replace_delta) return list(map(mapper, expressions))
@singledispatch def _unroll_indexsum(node, self): """Unrolls IndexSums below a certain extent. :arg node: root of the expression :arg self: function for recursive calls """ raise AssertionError("cannot handle type %s" % type(node)) _unroll_indexsum.register(Node)(reuse_if_untouched) @_unroll_indexsum.register(IndexSum) # noqa def _(node, self): unroll = tuple(filter(self.predicate, node.multiindex)) if unroll: # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) unrolled = reduce(Sum, (Indexed(ComponentTensor(summand, unroll), alpha) for alpha in numpy.ndindex(shape)), Zero()) return IndexSum(unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: return reuse_if_untouched(node, self)
[docs] def unroll_indexsum(expressions, predicate): """Unrolls IndexSums below a specified extent. :arg expressions: list of expression DAGs :arg predicate: a predicate function on :py:class:`Index` objects that tells whether to unroll a particular index :returns: list of expression DAGs with some unrolled IndexSums """ mapper = Memoizer(_unroll_indexsum) mapper.predicate = predicate return list(map(mapper, expressions))
[docs] def aggressive_unroll(expression): """Aggressively unrolls all loop structures.""" # Unroll expression shape if expression.shape: tensor = numpy.empty(expression.shape, dtype=object) for alpha in numpy.ndindex(expression.shape): tensor[alpha] = Indexed(expression, alpha) expression, = remove_componenttensors((ListTensor(tensor),)) # Unroll summation expression, = unroll_indexsum((expression,), predicate=lambda index: True) expression, = remove_componenttensors((expression,)) return expression