Source code for gem.scheduling

"""Schedules operations to evaluate a multi-root expression DAG,
forming an ordered list of Impero terminals."""

import collections
import functools
import itertools

from gem import gem, impero
from gem.node import collect_refcount


[docs] class OrderedDefaultDict(collections.OrderedDict): """A dictionary that provides a default value and ordered iteration. :arg factory: The callable used to create the default value. See :class:`collections.OrderedDict` for description of the remaining arguments. """ def __init__(self, factory, *args, **kwargs): self.factory = factory super(OrderedDefaultDict, self).__init__(*args, **kwargs) def __missing__(self, key): val = self[key] = self.factory() return val
[docs] class ReferenceStager(object): """Provides staging for nodes in reference counted expression DAGs. A callback function is called once the reference count is exhausted.""" def __init__(self, reference_count, callback): """Initialises a ReferenceStager. :arg reference_count: initial reference counts for all expected nodes :arg callback: function to call on each node when reference count is exhausted """ self.waiting = reference_count.copy() self.callback = callback
[docs] def decref(self, o): """Decreases the reference count of a node, and possibly triggering a callback (when the reference count drops to zero).""" assert 1 <= self.waiting[o] self.waiting[o] -= 1 if self.waiting[o] == 0: self.callback(o)
[docs] def empty(self): """All reference counts exhausted?""" return not any(self.waiting.values())
[docs] class Queue(object): """Special queue for operation scheduling. GEM / Impero nodes are inserted when they are ready to be scheduled, i.e. any operation which depends on the operation to be inserted must have been scheduled already. This class implements a heuristic for ordering operations within the constraints in a way which aims to achieve maximum loop fusion to minimise the size of temporaries which need to be introduced. """ def __init__(self, callback): """Initialises a Queue. :arg callback: function called on each element "popped" from the queue """ # Must have deterministic iteration over the queue self.queue = OrderedDefaultDict(list) self.callback = callback
[docs] def insert(self, indices, elem): """Insert element into queue. :arg indices: loop indices used by the scheduling heuristic :arg elem: element to be scheduled """ self.queue[indices].append(elem)
[docs] def process(self): """Pops elements from the queue and calls the callback function on them until the queue is empty. The callback function can insert further elements into the queue. """ indices = () while self.queue: # Find innermost non-empty outer loop while indices not in (i[:len(indices)] for i in self.queue.keys()): indices = indices[:-1] # Pick a loop for i in self.queue.keys(): if i[:len(indices)] == indices: indices = i break while self.queue[indices]: self.callback(self.queue[indices].pop()) del self.queue[indices]
[docs] def handle(ops, push, decref, node): """Helper function for scheduling""" if isinstance(node, gem.Variable): # Declared in the kernel header pass elif isinstance(node, gem.Constant): # Constant literals inlined, unless tensor-valued if node.shape: ops.append(impero.Evaluate(node)) elif isinstance(node, gem.Zero): # should rarely happen assert not node.shape elif isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)): if node.indirect_children: # Do not inline; # Index expression can be involved if it contains VariableIndex. ops.append(impero.Evaluate(node)) for child in itertools.chain(node.children, node.indirect_children): decref(child) elif isinstance(node, gem.IndexSum): ops.append(impero.Noop(node)) push(impero.Accumulate(node)) elif isinstance(node, gem.Node): ops.append(impero.Evaluate(node)) for child in node.children: decref(child) elif isinstance(node, impero.Initialise): ops.append(node) elif isinstance(node, impero.Accumulate): ops.append(node) push(impero.Initialise(node.indexsum)) decref(node.indexsum.children[0]) elif isinstance(node, impero.Return): ops.append(node) decref(node.expression) elif isinstance(node, impero.ReturnAccumulate): ops.append(node) decref(node.indexsum.children[0]) else: raise AssertionError("no handler for node type %s" % type(node))
[docs] def emit_operations(assignments, get_indices, emit_return_accumulate=True): """Makes an ordering of operations to evaluate a multi-root expression DAG. :arg assignments: Iterable of (variable, expression) pairs. The value of expression is written into variable upon execution. :arg get_indices: mapping from GEM nodes to an ordering of free indices :arg emit_return_accumulate: emit ReturnAccumulate nodes? Set to False if the output variables are not guaranteed zero on entry to the kernel. :returns: list of Impero terminals correctly ordered to evaluate the assignments """ # Prepare reference counts refcount = collect_refcount([e for v, e in assignments]) # Stage return operations staging = [] for variable, expression in assignments: if emit_return_accumulate and \ refcount[expression] == 1 and isinstance(expression, gem.IndexSum) \ and set(variable.free_indices) == set(expression.free_indices): staging.append(impero.ReturnAccumulate(variable, expression)) refcount[expression] -= 1 else: staging.append(impero.Return(variable, expression)) # Prepare data structures def push_node(node): queue.insert(get_indices(node), node) def push_op(op): queue.insert(op.loop_shape(get_indices), op) ops = [] stager = ReferenceStager(refcount, push_node) queue = Queue(functools.partial(handle, ops, push_op, stager.decref)) # Enqueue return operations for op in staging: push_op(op) # Schedule operations queue.process() # Assert that nothing left unprocessed assert stager.empty() # Return ops.reverse() return ops