"""Generic abstract node class and utility functions for creating
expression DAG languages."""
import collections
import gem
[docs]
class Node(object):
"""Abstract node class.
Nodes are not meant to be modified.
A node can reference other nodes; they are called children. A node
might contain data, or reference other objects which are not
themselves nodes; they are not called children.
Both the children (if any) and non-child data (if any) are
required to create a node, or determine the equality of two
nodes. For reconstruction, however, only the new children are
necessary.
"""
__slots__ = ('hash_value',)
# Non-child data as the first arguments of the constructor.
# To be (potentially) overridden by derived node classes.
__front__ = ()
# Non-child data as the last arguments of the constructor.
# To be (potentially) overridden by derived node classes.
__back__ = ()
def _cons_args(self, children):
"""Constructs an argument list for the constructor with
non-child data from 'self' and children from 'children'.
Internally used utility function.
"""
front_args = [getattr(self, name) for name in self.__front__]
back_args = [getattr(self, name) for name in self.__back__]
return tuple(front_args) + tuple(children) + tuple(back_args)
def __reduce__(self):
# Gold version:
return type(self), self._cons_args(self.children)
[docs]
def reconstruct(self, *args):
"""Reconstructs the node with new children from
'args'. Non-child data are copied from 'self'.
Returns a new object.
"""
return type(self)(*self._cons_args(args))
def __repr__(self):
cons_args = self._cons_args(self.children)
return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args)))
def __eq__(self, other):
"""Provides equality testing with quick positive and negative
paths based on :func:`id` and :meth:`__hash__`.
"""
if self is other:
return True
elif hash(self) != hash(other):
return False
else:
return self.is_equal(other)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
"""Provides caching for hash values."""
try:
return self.hash_value
except AttributeError:
self.hash_value = self.get_hash()
return self.hash_value
[docs]
def is_equal(self, other):
"""Equality predicate.
This is the method to potentially override in derived classes,
not :meth:`__eq__` or :meth:`__ne__`.
"""
if type(self) is not type(other):
return False
self_consargs = self._cons_args(self.children)
other_consargs = other._cons_args(other.children)
return self_consargs == other_consargs
[docs]
def get_hash(self):
"""Hash function.
This is the method to potentially override in derived classes,
not :meth:`__hash__`.
"""
return hash((type(self),) + self._cons_args(self.children))
def _make_traversal_children(node):
if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)):
# Include child nodes hidden in index expressions.
return node.children + node.indirect_children
else:
return node.children
[docs]
def pre_traversal(expression_dags):
"""Pre-order traversal of the nodes of expression DAGs.
Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.
"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
# Keep the original ordering of roots, for deterministic code
# generation.
for root in expression_dags:
if root not in seen:
seen.add(root)
lifo.append(root)
while lifo:
node = lifo.pop()
yield node
children = _make_traversal_children(node)
for child in reversed(children):
if child not in seen:
seen.add(child)
lifo.append(child)
[docs]
def post_traversal(expression_dags):
"""Post-order traversal of the nodes of expression DAGs.
Notes
-----
This function also walks through nodes in index expressions
(e.g., `VariableIndex`s); see ``_make_traversal_children()``.
"""
seen = set()
lifo = []
# Some roots might be same, but they must be visited only once.
# Keep the original ordering of roots, for deterministic code
# generation.
for root in expression_dags:
if root not in seen:
seen.add(root)
lifo.append((root, list(_make_traversal_children(root))))
while lifo:
node, deps = lifo[-1]
for i, dep in enumerate(deps):
if dep is not None and dep not in seen:
lifo.append((dep, list(_make_traversal_children(dep))))
deps[i] = None
break
else:
yield node
seen.add(node)
lifo.pop()
# Default to the more efficient pre-order traversal
traversal = pre_traversal
[docs]
def collect_refcount(expression_dags):
"""Collects reference counts for a multi-root expression DAG.
Notes
-----
This function also collects reference counts of nodes
in index expressions (e.g., `VariableIndex`s); see
``_make_traversal_children()``.
"""
result = collections.Counter(expression_dags)
for node in traversal(expression_dags):
result.update(_make_traversal_children(node))
return result
[docs]
def noop_recursive(function):
"""No-op wrapper for functions with overridable recursive calls.
:arg function: a function with parameters (value, rec), where
``rec`` is expected to be a function used for
recursive calls.
:returns: a function with working recursion and nothing fancy
"""
def recursive(node):
return function(node, recursive)
return recursive
[docs]
def noop_recursive_arg(function):
"""No-op wrapper for functions with overridable recursive calls
and an argument.
:arg function: a function with parameters (value, rec, arg), where
``rec`` is expected to be a function used for
recursive calls.
:returns: a function with working recursion and nothing fancy
"""
def recursive(node, arg):
return function(node, recursive, arg)
return recursive
[docs]
class Memoizer(object):
"""Caching wrapper for functions with overridable recursive calls.
The lifetime of the cache is the lifetime of the object instance.
:arg function: a function with parameters (value, rec), where
``rec`` is expected to be a function used for
recursive calls.
:returns: a function with working recursion and caching
"""
def __init__(self, function):
self.cache = {}
self.function = function
def __call__(self, node):
try:
return self.cache[node]
except KeyError:
result = self.function(node, self)
self.cache[node] = result
return result
[docs]
class MemoizerArg(object):
"""Caching wrapper for functions with overridable recursive calls
and an argument. The lifetime of the cache is the lifetime of the
object instance.
:arg function: a function with parameters (value, rec, arg), where
``rec`` is expected to be a function used for
recursive calls.
:returns: a function with working recursion and caching
"""
def __init__(self, function):
self.cache = {}
self.function = function
def __call__(self, node, arg):
cache_key = (node, arg)
try:
return self.cache[cache_key]
except KeyError:
result = self.function(node, self, arg)
self.cache[cache_key] = result
return result
[docs]
def reuse_if_untouched(node, self):
"""Reuse if untouched recipe"""
new_children = list(map(self, node.children))
if all(nc == c for nc, c in zip(new_children, node.children)):
return node
else:
return node.reconstruct(*new_children)
[docs]
def reuse_if_untouched_arg(node, self, arg):
"""Reuse if touched recipe propagating an extra argument"""
new_children = [self(child, arg) for child in node.children]
if all(nc == c for nc, c in zip(new_children, node.children)):
return node
else:
return node.reconstruct(*new_children)