Source code for gem.unconcatenate

"""Utility functions for decomposing Concatenate nodes.

The exported functions are flatten and unconcatenate.
- flatten: destroys the structure preserved within Concatenate nodes,
           essentially reducing FInAT provided tabulations to what
           FIAT could have provided, so old code can continue to work.
- unconcatenate: split up (variable, expression) pairs along
                 Concatenate nodes, thus recovering the structure
                 within them, yet eliminating the Concatenate nodes.

Let us see an example on unconcatenate.  Let us consider the form

    div(v) * dx

where v is an RTCF7 test function.  This means that the assembled
local vector has 8 * 7 + 7 * 8 = 112 entries.  So the compilation of
the form starts with a single assignment pair [(v, e)].  v is now the
indexed return variable, something equivalent to

    Indexed(Variable('A', (112,)), (j,))

where j is the basis function index of the argument.  e is just a GEM
quadrature expression with j as its only free index.  This will
contain the tabulation of the RTCF7 element, which will cause
something like

    C_j := Indexed(Concatenate(A, B), (j,))

to appear as a subexpression in e.  unconcatenate splits e along C_j
into e_1 and e_2 such that

    e_1 := e /. C_j -> A_{ja1,ja2}, and
    e_2 := e /. C_j -> B_{jb1,jb2}.

The split indices ja1, ja2, jb1, and jb2 have extents 8, 7, 7, and 8
respectively (see the RTCF7 element construction above).  So the
result of unconcatenate will be the list of pairs

     [(v_1, e_2), (v_2, e_2)]

where v_1 is the first 56 entries of v, reshaped as an 8 x 7 matrix,
indexed with (ja1, ja2), and similarly, v_2 is the second 56 entries
of v, reshaped as a 7 x 8 matrix, indexed with (jb1, jb2).

The unconcatenated form allows for sum factorisation of tensor product
elements as usual.  This pair splitting is also applicable to
coefficient evaluation: take the local basis function coefficients as
the variable, the FInAT tabulation of the element as the expression,
and apply "matrix-vector multifunction" for each pair after
unconcatenation, and then add up the results.
"""

from functools import singledispatch
from itertools import chain

import numpy

from gem.node import Memoizer, reuse_if_untouched
from gem.gem import (ComponentTensor, Concatenate, FlexiblyIndexed,
                     Index, Indexed, Literal, Node, partial_indexed,
                     reshape, view)
from gem.optimise import remove_componenttensors
from gem.interpreter import evaluate


__all__ = ['flatten', 'unconcatenate']


def find_group(expressions):
    """Finds a full set of indexed Concatenate nodes with the same
    free index, if any such node exists.

    Pre-condition: ComponentTensor nodes surrounding Concatenate nodes
    must be removed.

    :arg expressions: a multi-root GEM expression DAG
    :returns: a list of GEM nodes, or None
    """
    free_indices = set().union(chain(*[e.free_indices for e in expressions]))

    # Result variables
    index = None
    nodes = []

    # Sui generis pre-order traversal so that we can avoid going
    # unnecessarily deep in the DAG.
    seen = set()
    lifo = []
    for root in expressions:
        if root not in seen:
            seen.add(root)
            lifo.append(root)

    while lifo:
        node = lifo.pop()
        if not free_indices.intersection(node.free_indices):
            continue

        if isinstance(node, Indexed):
            child, = node.children
            if isinstance(child, Concatenate):
                i, = node.multiindex
                assert i in free_indices
                if (index or i) == i:
                    index = i
                    nodes.append(node)
                    # Skip adding children
                    continue

        for child in reversed(node.children):
            if child not in seen:
                seen.add(child)
                lifo.append(child)

    return index and nodes


def split_variable(variable_ref, index, multiindices):
    """Splits a flexibly indexed variable along a concatenation index.

    :param variable_ref: flexibly indexed variable to split
    :param index: :py:class:`Concatenate` index to split along
    :param multiindices: one multiindex for each split variable

    :returns: generator of split indexed variables
    """
    assert isinstance(variable_ref, FlexiblyIndexed)
    other_indices = list(variable_ref.index_ordering())
    other_indices.remove(index)
    other_indices = tuple(other_indices)
    data = ComponentTensor(variable_ref, (index,) + other_indices)
    slices = [slice(None)] * len(other_indices)
    shapes = [(other_index.extent,) for other_index in other_indices]

    offset = 0
    for multiindex in multiindices:
        shape = tuple(index.extent for index in multiindex)
        size = numpy.prod(shape, dtype=int)
        slice_ = slice(offset, offset + size)
        offset += size

        sub_ref = Indexed(reshape(view(data, slice_, *slices),
                                  shape, *shapes),
                          multiindex + other_indices)
        sub_ref, = remove_componenttensors((sub_ref,))
        yield sub_ref


def _replace_node(node, self):
    """Replace subexpressions using a given mapping.

    :param node: root of expression
    :param self: function for recursive calls
    """
    assert isinstance(node, Node)
    if self.cut(node):
        return node
    try:
        return self.mapping[node]
    except KeyError:
        return reuse_if_untouched(node, self)


def replace_node(expression, mapping, cut=None):
    """Replace subexpressions using a given mapping.

    :param expression: a GEM expression
    :param mapping: a :py:class:`dict` containing the substitutions
    :param cut: cutting predicate; if returns true, it is assumed that
                no replacements would take place in the subexpression.
    """
    mapper = Memoizer(_replace_node)
    mapper.mapping = mapping
    mapper.cut = cut or (lambda node: False)
    return mapper(expression)


def _unconcatenate(cache, pairs):
    # Tail-call recursive core of unconcatenate.
    # Assumes that input has already been sanitised.
    concat_group = find_group([e for v, e in pairs])
    if concat_group is None:
        return pairs

    # Get the index split
    concat_ref = next(iter(concat_group))
    assert isinstance(concat_ref, Indexed)
    concat_expr, = concat_ref.children
    index, = concat_ref.multiindex
    assert isinstance(concat_expr, Concatenate)
    try:
        multiindices = cache[index]
    except KeyError:
        multiindices = tuple(tuple(Index(extent=d) for d in child.shape)
                             for child in concat_expr.children)
        cache[index] = multiindices

    def cut(node):
        """No need to rebuild expression of independent of the
        relevant concatenation index."""
        return index not in node.free_indices

    # Build Concatenate node replacement mappings
    mappings = [{} for i in range(len(multiindices))]
    for concat_ref in concat_group:
        concat_expr, = concat_ref.children
        for i in range(len(multiindices)):
            sub_ref = Indexed(concat_expr.children[i], multiindices[i])
            sub_ref, = remove_componenttensors((sub_ref,))
            mappings[i][concat_ref] = sub_ref

    # Finally, split assignment pairs
    split_pairs = []
    for var, expr in pairs:
        if index not in var.free_indices:
            split_pairs.append((var, expr))
        else:
            for v, m in zip(split_variable(var, index, multiindices), mappings):
                split_pairs.append((v, replace_node(expr, m, cut)))

    # Run again, there may be other Concatenate groups
    return _unconcatenate(cache, split_pairs)


[docs] def unconcatenate(pairs, cache=None): """Splits a list of (indexed variable, expression) pairs along :py:class:`Concatenate` nodes embedded in the expressions. :param pairs: list of (indexed variable, expression) pairs :param cache: index splitting cache :py:class:`dict` (optional) :returns: list of (indexed variable, expression) pairs """ # Set up cache if cache is None: cache = {} # Eliminate index renaming due to ComponentTensor nodes exprs = remove_componenttensors([e for v, e in pairs]) pairs = [(v, e) for (v, _), e in zip(pairs, exprs)] return _unconcatenate(cache, pairs)
@singledispatch def _flatten(node, self): """Replace Concatenate nodes with Literal nodes. :arg node: root of the expression :arg self: function for recursive calls """ raise AssertionError("cannot handle type %s" % type(node)) _flatten.register(Node)(reuse_if_untouched) @_flatten.register(Concatenate) def _flatten_concatenate(node, self): result, = evaluate([node]) return partial_indexed(Literal(result.arr), result.fids)
[docs] def flatten(expressions): """Flatten Concatenate nodes, and destroy the structure they express. :arg expressions: a multi-root expression DAG """ mapper = Memoizer(_flatten) return list(map(mapper, expressions))