"""
An interpreter for GEM trees.
"""
import numpy
import operator
from collections import OrderedDict
from functools import singledispatch
import itertools
from gem import gem, node
from gem.optimise import replace_delta
__all__ = ("evaluate", )
class Result(object):
"""An array object that tracks which axes of the array correspond to
gem free indices (and what those free indices are).
:arg arr: The array.
:arg fids: The free indices.
The first ``len(fids)`` axes of the provided array correspond to
the free indices, the remaining axes are the shape of each entry.
"""
def __init__(self, arr, fids=None):
self.arr = arr
self.fids = fids if fids is not None else ()
def broadcast(self, fids):
"""Given some free indices, return a broadcasted array which
contains extra dimensions that correspond to indices in fids
that are not in ``self.fids``.
Note that inserted dimensions will have length one.
:arg fids: The free indices for broadcasting.
"""
# Select free indices
axes = tuple(self.fids.index(fi) for fi in fids if fi in self.fids)
assert len(axes) == len(self.fids)
# Add shape
axes += tuple(range(len(self.fids), self.arr.ndim))
# Move axes, insert extra axes
arr = numpy.transpose(self.arr, axes)
for i, fi in enumerate(fids):
if fi not in self.fids:
arr = numpy.expand_dims(arr, axis=i)
return arr
def filter(self, idx, fids):
"""Given an index tuple and some free indices, return a
"filtered" index tuple which removes entries that correspond
to indices in fids that are not in ``self.fids``.
:arg idx: The index tuple to filter.
:arg fids: The free indices for the index tuple.
"""
return tuple(idx[fids.index(i)] for i in self.fids) + idx[len(fids):]
def __getitem__(self, idx):
return self.arr[tuple(idx)]
def __setitem__(self, idx, val):
self.arr[idx] = val
@property
def tshape(self):
"""The total shape of the result array."""
return self.arr.shape
@property
def fshape(self):
"""The shape of the free index part of the result array."""
return self.tshape[:len(self.fids)]
@property
def shape(self):
"""The shape of the shape part of the result array."""
return self.tshape[len(self.fids):]
def __repr__(self):
return "Result(%r, %r)" % (self.arr, self.fids)
def __str__(self):
return repr(self)
@classmethod
def empty(cls, *children, **kwargs):
"""Build an empty Result object.
:arg children: The children used to determine the shape and
free indices.
:kwarg dtype: The data type of the result array.
"""
dtype = kwargs.get("dtype", float)
assert all(children[0].shape == c.shape for c in children)
fids = []
for f in itertools.chain(*(c.fids for c in children)):
if f not in fids:
fids.append(f)
shape = tuple(i.extent for i in fids) + children[0].shape
return cls(numpy.empty(shape, dtype=dtype), tuple(fids))
@singledispatch
def _evaluate(expression, self):
"""Evaluate an expression using a provided callback handler.
:arg expression: The expression to evaluation.
:arg self: The callback handler (should provide bindings).
"""
raise ValueError("Unhandled node type %s" % type(expression))
@_evaluate.register(gem.Zero)
def _evaluate_zero(e, self):
"""Zeros produce an array of zeros."""
return Result(numpy.zeros(e.shape, dtype=float))
@_evaluate.register(gem.Failure)
def _evaluate_failure(e, self):
"""Failure nodes produce NaNs."""
return Result(numpy.full(e.shape, numpy.nan, dtype=float))
@_evaluate.register(gem.Constant)
def _evaluate_constant(e, self):
"""Constants return their array."""
return Result(e.array)
@_evaluate.register(gem.Delta)
def _evaluate_delta(e, self):
"""Lower delta and evaluate."""
e, = replace_delta((e,))
return self(e)
@_evaluate.register(gem.Variable)
def _evaluate_variable(e, self):
"""Look up variables in the provided bindings."""
try:
val = self.bindings[e]
except KeyError:
raise ValueError("Binding for %s not found" % e)
if val.shape != e.shape:
raise ValueError("Binding for %s has wrong shape. %s, not %s." %
(e, val.shape, e.shape))
return Result(val)
@_evaluate.register(gem.Power)
@_evaluate.register(gem.Division)
@_evaluate.register(gem.Product)
@_evaluate.register(gem.Sum)
def _evaluate_operator(e, self):
op = {gem.Product: operator.mul,
gem.Division: operator.truediv,
gem.Sum: operator.add,
gem.Power: operator.pow}[type(e)]
a, b = [self(o) for o in e.children]
result = Result.empty(a, b)
fids = result.fids
result.arr = op(a.broadcast(fids), b.broadcast(fids))
return result
@_evaluate.register(gem.MathFunction)
def _evaluate_mathfunction(e, self):
ops = [self(o) for o in e.children]
result = Result.empty(*ops)
names = {
"abs": abs,
"log": numpy.log,
"real": operator.attrgetter("real"),
"imag": operator.attrgetter("imag"),
"conj": operator.methodcaller("conjugate"),
}
op = names[e.name]
for idx in numpy.ndindex(result.tshape):
result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops))
return result
@_evaluate.register(gem.MaxValue)
@_evaluate.register(gem.MinValue)
def _evaluate_minmaxvalue(e, self):
ops = [self(o) for o in e.children]
result = Result.empty(*ops)
op = {gem.MinValue: min,
gem.MaxValue: max}[type(e)]
for idx in numpy.ndindex(result.tshape):
result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops))
return result
@_evaluate.register(gem.Comparison)
def _evaluate_comparison(e, self):
ops = [self(o) for o in e.children]
op = {">": operator.gt,
">=": operator.ge,
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le}[e.operator]
result = Result.empty(*ops, dtype=bool)
for idx in numpy.ndindex(result.tshape):
result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops))
return result
@_evaluate.register(gem.LogicalNot)
def _evaluate_logicalnot(e, self):
val = self(e.children[0])
assert val.arr.dtype == numpy.dtype("bool")
result = Result.empty(val, bool)
for idx in numpy.ndindex(result.tshape):
result[idx] = not val[val.filter(idx, result.fids)]
return result
@_evaluate.register(gem.LogicalAnd)
def _evaluate_logicaland(e, self):
a, b = [self(o) for o in e.children]
assert a.arr.dtype == numpy.dtype("bool")
assert b.arr.dtype == numpy.dtype("bool")
result = Result.empty(a, b, bool)
for idx in numpy.ndindex(result.tshape):
result[idx] = a[a.filter(idx, result.fids)] and \
b[b.filter(idx, result.fids)]
return result
@_evaluate.register(gem.LogicalOr)
def _evaluate_logicalor(e, self):
a, b = [self(o) for o in e.children]
assert a.arr.dtype == numpy.dtype("bool")
assert b.arr.dtype == numpy.dtype("bool")
result = Result.empty(a, b, dtype=bool)
for idx in numpy.ndindex(result.tshape):
result[idx] = a[a.filter(idx, result.fids)] or \
b[b.filter(idx, result.fids)]
return result
@_evaluate.register(gem.Conditional)
def _evaluate_conditional(e, self):
cond, then, else_ = [self(o) for o in e.children]
assert cond.arr.dtype == numpy.dtype("bool")
result = Result.empty(cond, then, else_)
for idx in numpy.ndindex(result.tshape):
if cond[cond.filter(idx, result.fids)]:
result[idx] = then[then.filter(idx, result.fids)]
else:
result[idx] = else_[else_.filter(idx, result.fids)]
return result
@_evaluate.register(gem.Indexed)
def _evaluate_indexed(e, self):
"""Indexing maps shape to free indices"""
val = self(e.children[0])
fids = tuple(i for i in e.multiindex if isinstance(i, gem.Index))
idx = []
# First pick up all the existing free indices
for _ in val.fids:
idx.append(slice(None))
# Now grab the shape axes
for i in e.multiindex:
if isinstance(i, gem.Index):
# Free index, want entire extent
idx.append(slice(None))
elif isinstance(i, gem.VariableIndex):
# Variable index, evaluate inner expression
result, = self(i.expression)
assert not result.tshape
idx.append(result[()])
else:
# Fixed index, just pick that value
idx.append(i)
assert len(idx) == len(val.tshape)
return Result(val[idx], val.fids + fids)
@_evaluate.register(gem.ComponentTensor)
def _evaluate_componenttensor(e, self):
"""Component tensors map free indices to shape."""
val = self(e.children[0])
axes = []
fids = []
# First grab the free indices that aren't bound
for a, f in enumerate(val.fids):
if f not in e.multiindex:
axes.append(a)
fids.append(f)
# Now the bound free indices
for i in e.multiindex:
axes.append(val.fids.index(i))
# Now the existing shape
axes.extend(range(len(val.fshape), len(val.tshape)))
return Result(numpy.transpose(val.arr, axes=axes),
tuple(fids))
@_evaluate.register(gem.IndexSum)
def _evaluate_indexsum(e, self):
"""Index sums reduce over the given axis."""
val = self(e.children[0])
idx = tuple(map(val.fids.index, e.multiindex))
rfids = tuple(fi for fi in val.fids if fi not in e.multiindex)
return Result(val.arr.sum(axis=idx), rfids)
@_evaluate.register(gem.ListTensor)
def _evaluate_listtensor(e, self):
"""List tensors just turn into arrays."""
ops = [self(o) for o in e.children]
tmp = Result.empty(*ops)
arrs = [numpy.broadcast_to(o.broadcast(tmp.fids), tmp.fshape) for o in ops]
arrs = numpy.moveaxis(numpy.asarray(arrs), 0, -1).reshape(tmp.fshape + e.shape)
return Result(arrs, tmp.fids)
@_evaluate.register(gem.Concatenate)
def _evaluate_concatenate(e, self):
"""Concatenate nodes flatten and concatenate shapes."""
ops = [self(o) for o in e.children]
fids = tuple(OrderedDict.fromkeys(itertools.chain(*(o.fids for o in ops))))
fshape = tuple(i.extent for i in fids)
arrs = []
for o in ops:
# Create temporary with correct shape
arr = numpy.empty(fshape + o.shape)
# Broadcast for extra free indices
arr[:] = o.broadcast(fids)
# Flatten shape
arr = arr.reshape(arr.shape[:arr.ndim-len(o.shape)] + (-1,))
arrs.append(arr)
arrs = numpy.concatenate(arrs, axis=-1)
return Result(arrs, fids)
[docs]
def evaluate(expressions, bindings=None):
"""Evaluate some GEM expressions given variable bindings.
:arg expressions: A single GEM expression, or iterable of
expressions to evaluate.
:kwarg bindings: An optional dict mapping GEM :class:`gem.Variable`
nodes to data.
:returns: a list of the evaluated expressions.
"""
try:
exprs = tuple(expressions)
except TypeError:
exprs = (expressions, )
mapper = node.Memoizer(_evaluate)
mapper.bindings = bindings if bindings is not None else {}
return list(map(mapper, exprs))