Source code for gem.impero_utils

"""Utilities for building an Impero AST from an ordered list of
terminal Impero operations, and for building any additional data
required for straightforward C code generation.

What this module does is independent of the generated code target.
"""

import collections
from functools import singledispatch
from itertools import chain, groupby

from gem.node import traversal, collect_refcount
from gem import gem, impero as imp, optimise, scheduling


# ImperoC is named tuple for C code generation.
#
# Attributes:
#     tree        - Impero AST describing the loop structure and operations
#     temporaries - List of GEM expressions which have assigned temporaries
#     declare     - Where to declare temporaries to get correct C code
#     indices     - Indices for declarations and referencing values
ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices'])


[docs] class NoopError(Exception): """No operations in the kernel.""" pass
[docs] def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True): """Lower GEM nodes that cannot be translated to C directly.""" if remove_componenttensors: expressions = optimise.remove_componenttensors(expressions) if replace_delta: expressions = optimise.replace_delta(expressions) return expressions
[docs] def compile_gem(assignments, prefix_ordering, remove_zeros=False, emit_return_accumulate=True): """Compiles GEM to Impero. :arg assignments: list of (return variable, expression DAG root) pairs :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables :arg emit_return_accumulate: emit ReturnAccumulate nodes (see :func:`~.scheduling.emit_operations`)? If False, split into Accumulate/Return pairs. Set to False if the output tensor of kernels is not guaranteed to be zero on entry. """ # Remove zeros if remove_zeros: def nonzero(assignment): variable, expression = assignment return not isinstance(expression, gem.Zero) assignments = list(filter(nonzero, assignments)) # Just the expressions expressions = [expression for variable, expression in assignments] # Collect indices in a deterministic order indices = list(collections.OrderedDict.fromkeys(chain.from_iterable( node.index_ordering() for node in traversal(expressions) if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)) ))) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(assignments, get_indices, emit_return_accumulate) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(tree) # Determine declarations declare, indices = place_declarations(tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
[docs] def make_prefix_ordering(indices, prefix_ordering): """Creates an ordering of ``indices`` which starts with those indices in ``prefix_ordering``.""" # Need to return deterministically ordered indices return tuple(prefix_ordering) + tuple(k for k in indices if k not in prefix_ordering)
[docs] def make_index_orderer(index_ordering): """Returns a function which given a set of indices returns those indices in the order as they appear in ``index_ordering``.""" idx2pos = {idx: pos for pos, idx in enumerate(index_ordering)} def apply_ordering(indices): return tuple(sorted(indices, key=lambda i: idx2pos[i])) return apply_ordering
[docs] def inline_temporaries(expressions, ops): """Inline temporaries which could be inlined without blowing up the code. :arg expressions: a multi-root GEM expression DAG, used for reference counting :arg ops: ordered list of Impero terminals :returns: a filtered ``ops``, without the unnecessary :class:`impero.Evaluate`s """ refcount = collect_refcount(expressions) candidates = set() # candidates for inlining for op in ops: if isinstance(op, imp.Evaluate): expr = op.expression if expr.shape == () and refcount[expr] == 1: candidates.add(expr) # Prevent inlining that pulls expressions into inner loops for node in traversal(expressions): for child in node.children: if child in candidates and set(child.free_indices) < set(node.free_indices): candidates.remove(child) # Filter out candidates return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]
[docs] def collect_temporaries(tree): """Collects GEM expressions to assign to temporaries from a list of Impero terminals.""" result = [] for node in traversal((tree,)): # IndexSum temporaries should be added either at Initialise or # at Accumulate. The difference is only in ordering # (numbering). We chose Accumulate here. if isinstance(node, imp.Accumulate): result.append(node.indexsum) elif isinstance(node, imp.Evaluate): result.append(node.expression) return result
[docs] def make_loop_tree(ops, get_indices, level=0): """Creates an Impero AST with loops from a list of operations and their respective free indices. :arg ops: a list of Impero terminal nodes :arg get_indices: callable mapping from GEM nodes to an ordering of free indices :arg level: depth of loop nesting :returns: Impero AST with loops, without declarations """ keyfunc = lambda op: op.loop_shape(get_indices)[level:level+1] statements = [] for first_index, op_group in groupby(ops, keyfunc): if first_index: inner_block = make_loop_tree(op_group, get_indices, level+1) statements.append(imp.For(first_index[0], inner_block)) else: statements.extend(op_group) # Remove no-op terminals from the tree statements = [s for s in statements if not isinstance(s, imp.Noop)] return imp.Block(statements)
[docs] def place_declarations(tree, temporaries, get_indices): """Determines where and how to declare temporaries for an Impero AST. :arg tree: Impero AST to determine the declarations for :arg temporaries: list of GEM expressions which are assigned to temporaries :arg get_indices: callable mapping from GEM nodes to an ordering of free indices """ numbering = {t: n for n, t in enumerate(temporaries)} assert len(numbering) == len(temporaries) # Collect the total number of temporary references total_refcount = collections.Counter() for node in traversal((tree,)): if isinstance(node, imp.Terminal): total_refcount.update(temp_refcount(numbering, node)) assert set(total_refcount) == set(temporaries) # Result declare = {} indices = {} @singledispatch def recurse(expr, loop_indices): """Visit an Impero AST to collect declarations. :arg expr: Impero tree node :arg loop_indices: loop indices (in order) from the outer loops surrounding ``expr`` :returns: :class:`collections.Counter` with the reference counts for each temporary in the subtree whose root is ``expr`` """ return AssertionError("unsupported expression type %s" % type(expr)) @recurse.register(imp.Terminal) def recurse_terminal(expr, loop_indices): return temp_refcount(numbering, expr) @recurse.register(imp.For) def recurse_for(expr, loop_indices): return recurse(expr.children[0], loop_indices + (expr.index,)) @recurse.register(imp.Block) def recurse_block(expr, loop_indices): # Temporaries declared at the beginning of the block are # collected here declare[expr] = [] # Collect reference counts for the block refcount = collections.Counter() for statement in expr.children: refcount.update(recurse(statement, loop_indices)) # Visit :class:`collections.Counter` in deterministic order for e in sorted(refcount.keys(), key=lambda t: numbering[t]): if refcount[e] == total_refcount[e]: # If all references are within this block, then this # block is the right place to declare the temporary. assert loop_indices == get_indices(e)[:len(loop_indices)] indices[e] = get_indices(e)[len(loop_indices):] if indices[e]: # Scalar-valued temporaries are not declared until # their value is assigned. This does not really # matter, but produces a more compact and nicer to # read C code. declare[expr].append(e) # Remove expression from the ``refcount`` so it will # not be declared again. del refcount[e] return refcount # Populate result remainder = recurse(tree, ()) assert not remainder # Set in ``declare`` for Impero terminals whether they should # declare the temporary that they are writing to. for node in traversal((tree,)): if isinstance(node, imp.Terminal): declare[node] = False if isinstance(node, imp.Evaluate): e = node.expression elif isinstance(node, imp.Initialise): e = node.indexsum else: continue if len(indices[e]) == 0: declare[node] = True return declare, indices
[docs] def temp_refcount(temporaries, op): """Collects the number of times temporaries are referenced when generating code for an Impero terminal. :arg temporaries: set of temporaries :arg op: Impero terminal :returns: :class:`collections.Counter` object mapping some of elements from ``temporaries`` to the number of times they will referenced from ``op`` """ counter = collections.Counter() def recurse(o): """Traverses expression until reaching temporaries, counting temporary references.""" if o in temporaries: counter[o] += 1 else: for c in o.children: recurse(c) def recurse_top(o): """Traverses expression until reaching temporaries, counting temporary references. Always descends into children at least once, even when the root is a temporary.""" if o in temporaries: counter[o] += 1 for c in o.children: recurse(c) if isinstance(op, imp.Initialise): counter[op.indexsum] += 1 elif isinstance(op, imp.Accumulate): recurse_top(op.indexsum) elif isinstance(op, imp.Evaluate): recurse_top(op.expression) elif isinstance(op, imp.Return): recurse(op.expression) elif isinstance(op, imp.ReturnAccumulate): recurse(op.indexsum.children[0]) elif isinstance(op, imp.Noop): pass else: raise AssertionError("unhandled operation: %s" % type(op)) return counter