"""GEM is the intermediate language of TSFC for describing
tensor-valued mathematical expressions and tensor operations.
It is similar to Einstein's notation.
Its design was heavily inspired by UFL, with some major differences:
- GEM has got nothing FEM-specific.
- In UFL free indices are just unrolled shape, thus UFL is very
restrictive about operations on expressions with different sets of
free indices. GEM is much more relaxed about free indices.
Similarly to UFL, all GEM nodes have 'shape' and 'free_indices'
attributes / properties. Unlike UFL, however, index extents live on
the Index objects in GEM, not on all the nodes that have those free
indices.
"""
from abc import ABCMeta
from itertools import chain
from operator import attrgetter
from numbers import Integral, Number
import numpy
from numpy import asarray
from gem.node import Node as NodeBase, traversal
from FIAT.orientation_utils import Orientation as FIATOrientation
__all__ = ['Node', 'Identity', 'Literal', 'Zero', 'Failure',
'Variable', 'Sum', 'Product', 'Division', 'FloorDiv', 'Remainder', 'Power',
'MathFunction', 'MinValue', 'MaxValue', 'Comparison',
'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional',
'Index', 'VariableIndex', 'Indexed', 'ComponentTensor',
'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex',
'index_sum', 'partial_indexed', 'reshape', 'view',
'indices', 'as_gem', 'FlexiblyIndexed',
'Inverse', 'Solve', 'extract_type', 'uint_type']
uint_type = numpy.dtype(numpy.uintc)
class NodeMeta(type):
"""Metaclass of GEM nodes.
When a GEM node is constructed, this metaclass automatically
collects its free indices if 'free_indices' has not been set yet.
"""
def __call__(self, *args, **kwargs):
# Create and initialise object
obj = super(NodeMeta, self).__call__(*args, **kwargs)
# Set free_indices if not set already
if not hasattr(obj, 'free_indices'):
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)
return obj
[docs]
class Node(NodeBase, metaclass=NodeMeta):
"""Abstract GEM node class."""
__slots__ = ('free_indices', 'dtype')
[docs]
def is_equal(self, other):
"""Common subexpression eliminating equality predicate.
When two (sub)expressions are equal, the children of one
object are reassigned to the children of the other, so some
duplicated subexpressions are eliminated.
"""
result = NodeBase.is_equal(self, other)
if result:
self.children = other.children
return result
def __getitem__(self, indices):
try:
indices = tuple(indices)
except TypeError:
indices = (indices, )
return Indexed(self, indices)
def __add__(self, other):
return componentwise(Sum, self, as_gem(other))
def __radd__(self, other):
return as_gem(other).__add__(self)
def __sub__(self, other):
return componentwise(
Sum, self,
componentwise(Product, Literal(-1), as_gem(other)))
def __rsub__(self, other):
return as_gem(other).__sub__(self)
def __mul__(self, other):
return componentwise(Product, self, as_gem(other))
def __rmul__(self, other):
return as_gem(other).__mul__(self)
def __matmul__(self, other):
other = as_gem(other)
if not self.shape and not other.shape:
return Product(self, other)
elif not (self.shape and other.shape):
raise ValueError("Both objects must have shape for matmul")
elif self.shape[-1] != other.shape[0]:
raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul")
*i, k = indices(len(self.shape))
_, *j = indices(len(other.shape))
expr = Product(Indexed(self, tuple(i) + (k, )),
Indexed(other, (k, ) + tuple(j)))
return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j))
def __rmatmul__(self, other):
return as_gem(other).__matmul__(self)
@property
def T(self):
i = indices(len(self.shape))
return ComponentTensor(Indexed(self, i), tuple(reversed(i)))
def __truediv__(self, other):
other = as_gem(other)
if other.shape:
raise ValueError("Denominator must be scalar")
return componentwise(Division, self, other)
def __rtruediv__(self, other):
return as_gem(other).__truediv__(self)
def __floordiv__(self, other):
other = as_gem_uint(other)
if other.shape:
raise ValueError("Denominator must be scalar")
return componentwise(FloorDiv, self, other)
def __rfloordiv__(self, other):
return as_gem_uint(other).__floordiv__(self)
def __mod__(self, other):
other = as_gem_uint(other)
if other.shape:
raise ValueError("Denominator must be scalar")
return componentwise(Remainder, self, other)
def __rmod__(self, other):
return as_gem_uint(other).__mod__(self)
[docs]
@staticmethod
def inherit_dtype_from_children(children):
if any(c.dtype is None for c in children):
# Set dtype = None will let _assign_dtype()
# assign the default dtype for this node later.
return
else:
return numpy.result_type(*(c.dtype for c in children))
class Terminal(Node):
"""Abstract class for terminal GEM nodes."""
__slots__ = ('_dtype',)
children = ()
is_equal = NodeBase.is_equal
@property
def dtype(self):
"""Data type of the node.
We only need to set dtype (or _dtype) on terminal nodes, and
other nodes inherit dtype from their children.
Currently dtype is significant only for nodes under index DAGs
(DAGs underneath `VariableIndex`s representing indices), and
`VariableIndex` checks if the dtype of the node that it wraps is
of uint_type. _assign_dtype() will then assign uint_type to those nodes.
dtype can be `None` otherwise, and _assign_dtype() will assign
the default dtype to those nodes.
"""
if hasattr(self, '_dtype'):
return self._dtype
else:
raise AttributeError(f"Must set _dtype on terminal node, {type(self)}")
class Scalar(Node):
"""Abstract class for scalar-valued GEM nodes."""
__slots__ = ()
shape = ()
[docs]
class Failure(Terminal):
"""Abstract class for failure GEM nodes."""
__slots__ = ('shape', 'exception')
__front__ = ('shape', 'exception')
def __init__(self, shape, exception):
self.shape = shape
self.exception = exception
self._dtype = None
class Constant(Terminal):
"""Abstract base class for constant types.
Convention:
- array: numpy array of values
- value: float or complex value (scalars only)
"""
pass
[docs]
class Zero(Constant):
"""Symbolic zero tensor"""
__slots__ = ('shape',)
__front__ = ('shape',)
__back__ = ('dtype',)
def __init__(self, shape=(), dtype=None):
self.shape = shape
self._dtype = dtype
@property
def value(self):
assert not self.shape
return numpy.array(0, dtype=self.dtype or float).item()
[docs]
class Identity(Constant):
"""Identity matrix"""
__slots__ = ('dim',)
__front__ = ('dim',)
__back__ = ('dtype',)
def __init__(self, dim, dtype=None):
self.dim = dim
self._dtype = dtype
@property
def shape(self):
return (self.dim, self.dim)
@property
def array(self):
return numpy.eye(self.dim, dtype=self.dtype)
[docs]
class Literal(Constant):
"""Tensor-valued constant"""
__slots__ = ('array',)
__front__ = ('array',)
__back__ = ('dtype',)
def __new__(cls, array, dtype=None):
array = asarray(array)
return super(Literal, cls).__new__(cls)
def __init__(self, array, dtype=None):
array = asarray(array)
if dtype is None:
# Assume float or complex.
try:
self.array = array.astype(float, casting="safe")
except TypeError:
self.array = array.astype(complex)
else:
# Can be int, etc.
self.array = array.astype(dtype)
self._dtype = self.array.dtype
[docs]
def is_equal(self, other):
if type(self) is not type(other):
return False
if self.shape != other.shape:
return False
return tuple(self.array.flat) == tuple(other.array.flat)
[docs]
def get_hash(self):
return hash((type(self), self.shape, tuple(self.array.flat)))
@property
def value(self):
assert self.shape == ()
return self.array.dtype.type(self.array)
@property
def shape(self):
return self.array.shape
[docs]
class Variable(Terminal):
"""Symbolic variable tensor"""
__slots__ = ('name', 'shape')
__front__ = ('name', 'shape')
__back__ = ('dtype',)
def __init__(self, name, shape, dtype=None):
self.name = name
self.shape = shape
self._dtype = dtype
[docs]
class Sum(Scalar):
__slots__ = ('children',)
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# Constant folding
if isinstance(a, Zero):
return b
elif isinstance(b, Zero):
return a
if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))
self = super(Sum, cls).__new__(cls)
self.children = a, b
return self
[docs]
class Product(Scalar):
__slots__ = ('children',)
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# Constant folding
if isinstance(a, Zero) or isinstance(b, Zero):
return Zero()
if a == one:
return b
if b == one:
return a
if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))
self = super(Product, cls).__new__(cls)
self.children = a, b
return self
[docs]
class Division(Scalar):
__slots__ = ('children',)
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero()
if b == one:
return a
if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))
self = super(Division, cls).__new__(cls)
self.children = a, b
return self
[docs]
class FloorDiv(Scalar):
__slots__ = ('children',)
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children([a, b])
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return a
if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value // b.value, dtype=dtype)
self = super(FloorDiv, cls).__new__(cls)
self.children = a, b
return self
[docs]
class Remainder(Scalar):
__slots__ = ('children',)
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children([a, b])
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return Zero(dtype=dtype)
if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value % b.value, dtype=dtype)
self = super(Remainder, cls).__new__(cls)
self.children = a, b
return self
[docs]
class Power(Scalar):
__slots__ = ('children',)
def __new__(cls, base, exponent):
assert not base.shape
assert not exponent.shape
dtype = Node.inherit_dtype_from_children([base, exponent])
# Constant folding
if isinstance(base, Zero):
if isinstance(exponent, Zero):
raise ValueError("cannot solve 0^0")
return Zero(dtype=dtype)
elif isinstance(exponent, Zero):
return Literal(1, dtype=dtype)
elif isinstance(base, Constant) and isinstance(exponent, Constant):
return Literal(base.value ** exponent.value, dtype=dtype)
self = super(Power, cls).__new__(cls)
self.children = base, exponent
return self
[docs]
class MathFunction(Scalar):
__slots__ = ('name', 'children')
__front__ = ('name',)
def __new__(cls, name, *args):
assert isinstance(name, str)
assert all(arg.shape == () for arg in args)
if name in {'conj', 'real', 'imag'}:
arg, = args
if isinstance(arg, Zero):
return arg
self = super(MathFunction, cls).__new__(cls)
self.name = name
self.children = args
return self
[docs]
class MinValue(Scalar):
__slots__ = ('children',)
def __init__(self, a, b):
assert not a.shape
assert not b.shape
self.children = a, b
[docs]
class MaxValue(Scalar):
__slots__ = ('children',)
def __init__(self, a, b):
assert not a.shape
assert not b.shape
self.children = a, b
[docs]
class Comparison(Scalar):
__slots__ = ('operator', 'children')
__front__ = ('operator',)
def __init__(self, op, a, b):
assert not a.shape
assert not b.shape
if op not in [">", ">=", "==", "!=", "<", "<="]:
raise ValueError("invalid operator")
self.operator = op
self.children = a, b
self.dtype = None # Do not inherit dtype from children.
[docs]
class LogicalNot(Scalar):
__slots__ = ('children',)
def __init__(self, expression):
assert not expression.shape
self.children = expression,
[docs]
class LogicalAnd(Scalar):
__slots__ = ('children',)
def __init__(self, a, b):
assert not a.shape
assert not b.shape
self.children = a, b
[docs]
class LogicalOr(Scalar):
__slots__ = ('children',)
def __init__(self, a, b):
assert not a.shape
assert not b.shape
self.children = a, b
[docs]
class Conditional(Node):
__slots__ = ('children', 'shape')
def __new__(cls, condition, then, else_):
assert not condition.shape
assert then.shape == else_.shape == ()
# If both branches are the same, just return one of them. In
# particular, this will help constant-fold zeros.
if then == else_:
return then
self = super(Conditional, cls).__new__(cls)
self.children = condition, then, else_
self.shape = then.shape
self.dtype = Node.inherit_dtype_from_children([then, else_])
return self
class IndexBase(metaclass=ABCMeta):
"""Abstract base class for indices."""
pass
IndexBase.register(int)
[docs]
class Index(IndexBase):
"""Free index"""
# Not true object count, just for naming purposes
_count = 0
__slots__ = ('name', 'extent', 'count')
def __init__(self, name=None, extent=None):
self.name = name
Index._count += 1
self.count = Index._count
self.extent = extent
[docs]
def set_extent(self, value):
# Set extent, check for consistency
if self.extent is None:
self.extent = value
elif self.extent != value:
raise ValueError("Inconsistent index extents!")
def __str__(self):
if self.name is None:
return "i_%d" % self.count
return self.name
def __repr__(self):
if self.name is None:
return "Index(%r)" % self.count
return "Index(%r)" % self.name
def __lt__(self, other):
# Allow sorting of free indices in Python 3
return id(self) < id(other)
def __getstate__(self):
return self.name, self.extent, self.count
def __setstate__(self, state):
self.name, self.extent, self.count = state
[docs]
class VariableIndex(IndexBase):
"""An index that is constant during a single execution of the
kernel, but whose value is not known at compile time."""
__slots__ = ('expression',)
def __init__(self, expression):
assert isinstance(expression, Node)
assert not expression.shape
if expression.dtype != uint_type:
raise ValueError(f"expression.dtype ({expression.dtype}) != uint_type ({uint_type})")
self.expression = expression
def __eq__(self, other):
if self is other:
return True
if type(self) is not type(other):
return False
return self.expression == other.expression
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((type(self), self.expression))
def __str__(self):
return str(self.expression)
def __repr__(self):
return "%r(%r)" % (type(self), self.expression,)
def __reduce__(self):
return type(self), (self.expression,)
[docs]
class Indexed(Scalar):
__slots__ = ('children', 'multiindex', 'indirect_children')
__back__ = ('multiindex',)
def __new__(cls, aggregate, multiindex):
# Accept numpy or any integer, but cast to int.
multiindex = tuple(int(i) if isinstance(i, Integral) else i
for i in multiindex)
# Set index extents from shape
assert len(aggregate.shape) == len(multiindex)
for index, extent in zip(multiindex, aggregate.shape):
assert isinstance(index, IndexBase)
if isinstance(index, Index):
index.set_extent(extent)
elif isinstance(index, int) and not (0 <= index < extent):
raise IndexError("Invalid literal index")
# Empty multiindex
if not multiindex:
return aggregate
# Zero folding
if isinstance(aggregate, Zero):
return Zero(dtype=aggregate.dtype)
# All indices fixed
if all(isinstance(i, int) for i in multiindex):
if isinstance(aggregate, Constant):
return Literal(aggregate.array[multiindex], dtype=aggregate.dtype)
elif isinstance(aggregate, ListTensor):
return aggregate.array[multiindex]
self = super(Indexed, cls).__new__(cls)
self.children = (aggregate,)
self.multiindex = multiindex
self.indirect_children = tuple(i.expression for i in self.multiindex if isinstance(i, VariableIndex))
new_indices = []
for i in multiindex:
if isinstance(i, Index):
new_indices.append(i)
elif isinstance(i, VariableIndex):
new_indices.extend(i.expression.free_indices)
self.free_indices = unique(aggregate.free_indices + tuple(new_indices))
return self
[docs]
def index_ordering(self):
"""Running indices in the order of indexing in this node."""
free_indices = []
for i in self.multiindex:
if isinstance(i, Index):
free_indices.append(i)
elif isinstance(i, VariableIndex):
free_indices.extend(i.expression.free_indices)
return tuple(free_indices)
[docs]
class FlexiblyIndexed(Scalar):
"""Flexible indexing of :py:class:`Variable`s to implement views and
reshapes (splitting dimensions only)."""
__slots__ = ('children', 'dim2idxs', 'indirect_children')
__back__ = ('dim2idxs',)
def __init__(self, variable, dim2idxs):
"""Construct a flexibly indexed node.
Parameters
----------
variable : Node
`Node` that has a shape.
dim2idxs : tuple
Tuple of (offset, ((index, stride), (...), ...)) mapping indices,
where offset is {Node, int}, index is {Index, VariableIndex, int}, and
stride is {Node, int}.
For example, if ``variable`` is rank two, and ``dim2idxs`` is
((1, ((i, 12), (j, 4), (k, 1))), (0, ()))
then this corresponds to the indexing:
variable[1 + i*12 + j*4 + k][0]
"""
assert variable.shape
assert len(variable.shape) == len(dim2idxs)
dim2idxs_ = []
free_indices = []
for dim, (offset, idxs) in zip(variable.shape, dim2idxs):
offset_ = offset
idxs_ = []
last = 0
if isinstance(offset, Node):
free_indices.extend(offset.free_indices)
for index, stride in idxs:
if isinstance(index, Index):
assert index.extent is not None
free_indices.append(index)
idxs_.append((index, stride))
last += (index.extent - 1) * stride
elif isinstance(index, VariableIndex):
base_indices = index.expression.free_indices
assert all(base_index.extent is not None for base_index in base_indices)
free_indices.extend(base_indices)
idxs_.append((index, stride))
# last += (unknown_extent - 1) * stride
elif isinstance(index, int):
# TODO: Attach dtype to each Node.
# Here, we should simply be able to do:
# >>> offset_ += index * stride
# but "+" and "*" are not currently correctly overloaded
# for indices (integers); they assume floats.
if not isinstance(offset, Integral):
raise NotImplementedError(f"Found non-Integral offset : {offset}")
if isinstance(stride, Constant):
offset_ += index * stride.value
else:
offset_ += index * stride
else:
raise ValueError("Unexpected index type for flexible indexing")
if isinstance(stride, Node):
free_indices.extend(stride.free_indices)
if dim is not None and isinstance(offset_ + last, Integral) and offset_ + last >= dim:
raise ValueError("Offset {0} and indices {1} exceed dimension {2}".format(offset, idxs, dim))
dim2idxs_.append((offset_, tuple(idxs_)))
self.children = (variable,)
self.dim2idxs = tuple(dim2idxs_)
self.free_indices = unique(free_indices)
indirect_children = []
for offset, idxs in self.dim2idxs:
if isinstance(offset, Node):
indirect_children.append(offset)
for idx, stride in idxs:
if isinstance(idx, VariableIndex):
indirect_children.append(idx.expression)
if isinstance(stride, Node):
indirect_children.append(stride)
self.indirect_children = tuple(indirect_children)
[docs]
def index_ordering(self):
"""Running indices in the order of indexing in this node."""
free_indices = []
for offset, idxs in self.dim2idxs:
if isinstance(offset, Node):
free_indices.extend(offset.free_indices)
for index, stride in idxs:
if isinstance(index, Index):
free_indices.append(index)
elif isinstance(index, VariableIndex):
free_indices.extend(index.expression.free_indices)
if isinstance(stride, Node):
free_indices.extend(stride.free_indices)
return tuple(free_indices)
[docs]
class ComponentTensor(Node):
__slots__ = ('children', 'multiindex', 'shape')
__back__ = ('multiindex',)
def __new__(cls, expression, multiindex):
assert not expression.shape
# Empty multiindex
if not multiindex:
return expression
# Collect shape
shape = tuple(index.extent for index in multiindex)
assert all(s >= 0 for s in shape)
# Zero folding
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)
self = super(ComponentTensor, cls).__new__(cls)
self.children = (expression,)
self.multiindex = multiindex
self.shape = shape
# Collect free indices
assert set(multiindex) <= set(expression.free_indices)
self.free_indices = unique(set(expression.free_indices) - set(multiindex))
return self
[docs]
class IndexSum(Scalar):
__slots__ = ('children', 'multiindex')
__back__ = ('multiindex',)
def __new__(cls, summand, multiindex):
# Sum zeros
assert not summand.shape
if isinstance(summand, Zero):
return summand
# Unroll singleton sums
unroll = tuple(index for index in multiindex if index.extent <= 1)
if unroll:
assert numpy.prod([index.extent for index in unroll]) == 1
summand = Indexed(ComponentTensor(summand, unroll),
(0,) * len(unroll))
multiindex = tuple(index for index in multiindex
if index not in unroll)
# No indices case
multiindex = tuple(multiindex)
if not multiindex:
return summand
self = super(IndexSum, cls).__new__(cls)
self.children = (summand,)
self.multiindex = multiindex
# Collect shape and free indices
assert set(multiindex) <= set(summand.free_indices)
self.free_indices = unique(set(summand.free_indices) - set(multiindex))
return self
[docs]
class ListTensor(Node):
__slots__ = ('array',)
def __new__(cls, array):
array = asarray(array)
assert numpy.prod(array.shape)
dtype = Node.inherit_dtype_from_children(tuple(array.flat))
# Handle children with shape
child_shape = array.flat[0].shape
assert all(elem.shape == child_shape for elem in array.flat)
if child_shape:
# Destroy structure
direct_array = numpy.empty(array.shape + child_shape, dtype=object)
for alpha in numpy.ndindex(array.shape):
for beta in numpy.ndindex(child_shape):
direct_array[alpha + beta] = Indexed(array[alpha], beta)
array = direct_array
# Constant folding
if all(isinstance(elem, Constant) for elem in array.flat):
return Literal(numpy.vectorize(attrgetter('value'))(array), dtype=dtype)
self = super(ListTensor, cls).__new__(cls)
self.array = array
return self
@property
def children(self):
return tuple(self.array.flat)
@property
def shape(self):
return self.array.shape
def __reduce__(self):
return type(self), (self.array,)
[docs]
def reconstruct(self, *args):
return ListTensor(asarray(args).reshape(self.array.shape))
def __repr__(self):
return "ListTensor(%r)" % self.array.tolist()
[docs]
def is_equal(self, other):
"""Common subexpression eliminating equality predicate."""
if type(self) is not type(other):
return False
if (self.array == other.array).all():
self.array = other.array
return True
return False
[docs]
def get_hash(self):
return hash((type(self), self.shape, self.children))
[docs]
class Concatenate(Node):
"""Flattens and concatenates GEM expressions by shape.
Similar to what UFL MixedElement does to value shape. For
example, if children have shapes (2, 2), (), and (3,) then the
concatenated expression has shape (8,).
"""
__slots__ = ('children',)
def __new__(cls, *children):
dtype = Node.inherit_dtype_from_children(children)
if all(isinstance(child, Zero) for child in children):
size = int(sum(numpy.prod(child.shape, dtype=int) for child in children))
return Zero((size,), dtype=dtype)
self = super(Concatenate, cls).__new__(cls)
self.children = children
return self
@property
def shape(self):
return (int(sum(numpy.prod(child.shape, dtype=int) for child in self.children)),)
[docs]
class Delta(Scalar, Terminal):
__slots__ = ('i', 'j')
__front__ = ('i', 'j')
__back__ = ('dtype',)
def __new__(cls, i, j, dtype=None):
assert isinstance(i, IndexBase)
assert isinstance(j, IndexBase)
# \delta_{i,i} = 1
if i == j:
return one
# Fixed indices
if isinstance(i, int) and isinstance(j, int):
return one if i == j else Zero()
self = super(Delta, cls).__new__(cls)
self.i = i
self.j = j
# Set up free indices
free_indices = []
for index in (i, j):
if isinstance(index, Index):
free_indices.append(index)
elif isinstance(index, VariableIndex):
raise NotImplementedError("Can not make Delta with VariableIndex")
self.free_indices = tuple(unique(free_indices))
self._dtype = dtype
return self
[docs]
class Inverse(Node):
"""The inverse of a square matrix."""
__slots__ = ('children', 'shape')
def __new__(cls, tensor):
assert len(tensor.shape) == 2
assert tensor.shape[0] == tensor.shape[1]
# Invert 1x1 matrix
if tensor.shape == (1, 1):
multiindex = (Index(), Index())
return ComponentTensor(Division(one, Indexed(tensor, multiindex)), multiindex)
self = super(Inverse, cls).__new__(cls)
self.children = (tensor,)
self.shape = tensor.shape
return self
[docs]
class Solve(Node):
"""Solution of a square matrix equation with (potentially) multiple right hand sides.
Represents the X obtained by solving AX = B.
"""
__slots__ = ('children', 'shape')
def __init__(self, A, B):
# Shape requirements
assert B.shape
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
assert A.shape[0] == B.shape[0]
self.children = (A, B)
self.shape = A.shape[1:] + B.shape[1:]
[docs]
class OrientationVariableIndex(VariableIndex, FIATOrientation):
"""VariableIndex representing a fiat orientation.
Notes
-----
In the current implementation, we need to extract
`VariableIndex.expression` as index arithmetic
is not supported (indices are not `Node`).
"""
def __floordiv__(self, other):
other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other)
return type(self)(FloorDiv(self.expression, other))
def __rfloordiv__(self, other):
other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other)
return type(self)(FloorDiv(other, self.expression))
def __mod__(self, other):
other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other)
return type(self)(Remainder(self.expression, other))
def __rmod__(self, other):
other = other.expression if isinstance(other, VariableIndex) else as_gem_uint(other)
return type(self)(Remainder(other, self.expression))
def unique(indices):
"""Sorts free indices and eliminates duplicates.
:arg indices: iterable of indices
:returns: sorted tuple of unique free indices
"""
return tuple(sorted(set(indices), key=id))
[docs]
def index_sum(expression, indices):
"""Eliminates indices from the free indices of an expression by
summing over them. Skips any index that is not a free index of
the expression."""
multiindex = tuple(index for index in indices
if index in expression.free_indices)
return IndexSum(expression, multiindex)
[docs]
def partial_indexed(tensor, indices):
"""Generalised indexing into a tensor by eating shape off the front.
The number of indices may be less than or equal to the rank of the tensor,
so the result may have a non-empty shape.
:arg tensor: tensor-valued GEM expression
:arg indices: indices, at most as many as the rank of the tensor
:returns: a potentially tensor-valued expression
"""
if len(indices) == 0:
return tensor
elif len(indices) < len(tensor.shape):
rank = len(tensor.shape) - len(indices)
shape_indices = tuple(Index() for i in range(rank))
return ComponentTensor(
Indexed(tensor, indices + shape_indices),
shape_indices)
elif len(indices) == len(tensor.shape):
return Indexed(tensor, indices)
else:
raise ValueError("More indices than rank!")
def strides_of(shape):
"""Calculate cumulative strides from per-dimension capacities.
For example:
[2, 3, 4] ==> [12, 4, 1]
"""
temp = numpy.flipud(numpy.cumprod(numpy.flipud(list(shape)[1:])))
return list(temp) + [1]
def decompose_variable_view(expression):
"""Extract information from a shaped node.
Decompose ComponentTensor + FlexiblyIndexed."""
if (isinstance(expression, (Variable, Inverse, Solve))):
variable = expression
indexes = tuple(Index(extent=extent) for extent in expression.shape)
dim2idxs = tuple((0, ((index, 1),)) for index in indexes)
elif (isinstance(expression, ComponentTensor) and
not isinstance(expression.children[0], FlexiblyIndexed)):
variable = expression
indexes = expression.multiindex
dim2idxs = tuple((0, ((index, 1),)) for index in indexes)
elif isinstance(expression, ComponentTensor) and isinstance(expression.children[0], FlexiblyIndexed):
variable = expression.children[0].children[0]
indexes = expression.multiindex
dim2idxs = expression.children[0].dim2idxs
else:
raise ValueError("Cannot handle {} objects.".format(type(expression).__name__))
return variable, dim2idxs, indexes
[docs]
def reshape(expression, *shapes):
"""Reshape a variable (splitting indices only).
:arg expression: view of a :py:class:`Variable`
:arg shapes: one shape tuple for each dimension of the variable.
"""
variable, dim2idxs, indexes = decompose_variable_view(expression)
assert len(indexes) == len(shapes)
shape_of = dict(zip(indexes, shapes))
dim2idxs_ = []
indices = [[] for _ in range(len(indexes))]
for offset, idxs in dim2idxs:
idxs_ = []
for idx in idxs:
index, stride = idx
assert isinstance(index, Index)
dim = index.extent
shape = shape_of[index]
if dim is not None and numpy.prod(shape) != dim:
raise ValueError("Shape {} does not match extent {}.".format(shape, dim))
strides = strides_of(shape)
for extent, stride_ in zip(shape, strides):
index_ = Index(extent=extent)
idxs_.append((index_, stride_ * stride))
indices[indexes.index(index)].append(index_)
dim2idxs_.append((offset, tuple(idxs_)))
expr = FlexiblyIndexed(variable, tuple(dim2idxs_))
return ComponentTensor(expr, tuple(chain.from_iterable(indices)))
[docs]
def view(expression, *slices):
"""View a part of a shaped object.
:arg expression: a node that has a shape
:arg slices: one slice object for each dimension of the expression.
"""
variable, dim2idxs, indexes = decompose_variable_view(expression)
assert len(indexes) == len(slices)
slice_of = dict(zip(indexes, slices))
dim2idxs_ = []
indices = [None] * len(slices)
for offset, idxs in dim2idxs:
offset_ = offset
idxs_ = []
for idx in idxs:
index, stride = idx
assert isinstance(index, Index)
dim = index.extent
s = slice_of[index]
start = s.start or 0
stop = s.stop or dim
if stop is None:
raise ValueError("Unknown extent!")
if dim is not None and stop > dim:
raise ValueError("Slice exceeds dimension extent!")
step = s.step or 1
offset_ += start * stride
extent = 1 + (stop - start - 1) // step
index_ = Index(extent=extent)
indices[indexes.index(index)] = index_
idxs_.append((index_, step * stride))
dim2idxs_.append((offset_, tuple(idxs_)))
expr = FlexiblyIndexed(variable, tuple(dim2idxs_))
return ComponentTensor(expr, tuple(indices))
# Static one object for quicker constant folding
one = Literal(1)
# Syntax sugar
[docs]
def indices(n):
"""Make some :class:`Index` objects.
:arg n: The number of indices to make.
:returns: A tuple of `n` :class:`Index` objects.
"""
return tuple(Index() for _ in range(n))
def componentwise(op, *exprs):
"""Apply gem op to exprs component-wise and wrap up in a ComponentTensor.
:arg op: function that returns a gem Node.
:arg exprs: expressions to apply op to.
:raises ValueError: if the expressions have mismatching shapes.
:returns: New gem Node constructed from op.
Each expression must either have the same shape, or else be
scalar. Shaped expressions are indexed, the op is applied to the
scalar expressions and the result is wrapped up in a ComponentTensor.
"""
shapes = set(e.shape for e in exprs)
if len(shapes - {()}) > 1:
raise ValueError("expressions must have matching shape (or else be scalar)")
shape = max(shapes)
i = indices(len(shape))
exprs = tuple(Indexed(e, i) if e.shape else e for e in exprs)
return ComponentTensor(op(*exprs), i)
[docs]
def as_gem(expr):
"""Attempt to convert an expression into GEM of scalar type.
Parameters
----------
expr : Node or Number
The expression.
Returns
-------
Node
A GEM representation of the expression.
Raises
------
ValueError
If conversion was not possible.
"""
if isinstance(expr, Node):
return expr
elif isinstance(expr, Number):
return Literal(expr)
else:
raise ValueError("Do not know how to convert %r to GEM" % expr)
def as_gem_uint(expr):
"""Attempt to convert an expression into GEM of uint type.
Parameters
----------
expr : Node or Integral
The expression.
Returns
-------
Node
A GEM representation of the expression.
Raises
------
ValueError
If conversion was not possible.
"""
if isinstance(expr, Node):
return expr
elif isinstance(expr, Integral):
return Literal(expr, dtype=uint_type)
else:
raise ValueError("Do not know how to convert %r to GEM" % expr)