Source code for gem.coffee

"""This module contains an implementation of the COFFEE optimisation
algorithm operating on a GEM representation.

This file is NOT for code generation as a COFFEE AST.
"""

from collections import OrderedDict
import itertools
import logging

import numpy

from gem.gem import IndexSum, one
from gem.optimise import make_sum, make_product
from gem.refactorise import Monomial
from gem.utils import groupby


__all__ = ['optimise_monomial_sum']


def monomial_sum_to_expression(monomial_sum):
    """Convert a monomial sum to a GEM expression.

    :arg monomial_sum: an iterable of :class:`Monomial`s

    :returns: GEM expression
    """
    indexsums = []  # The result is summation of indexsums
    # Group monomials according to their sum indices
    groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices))
    # Create IndexSum's from each monomial group
    for _, monomials in groups:
        sum_indices = monomials[0].sum_indices
        products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials]
        indexsums.append(IndexSum(make_sum(products), sum_indices))
    return make_sum(indexsums)


def index_extent(factor, linear_indices):
    """Compute the product of the extents of linear indices of a GEM expression

    :arg factor: GEM expression
    :arg linear_indices: set of linear indices

    :returns: product of extents of linear indices
    """
    return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices])


def find_optimal_atomics(monomials, linear_indices):
    """Find optimal atomic common subexpressions, which produce least number of
    terms in the resultant IndexSum when factorised.

    :arg monomials: A list of :class:`Monomial`s, all of which should have
                    the same sum indices
    :arg linear_indices: tuple of linear indices

    :returns: list of atomic GEM expressions
    """
    atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials))))

    def cost(solution):
        extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution))
        # Prefer shorter solutions, but larger extents
        return (len(solution), -extent)

    optimal_solution = set(atomics)  # pessimal but feasible solution
    solution = set()

    max_it = 1 << 12
    it = iter(range(max_it))

    def solve(idx):
        while idx < len(monomials) and solution.intersection(monomials[idx].atomics):
            idx += 1

        if idx < len(monomials):
            if len(solution) < len(optimal_solution):
                for atomic in monomials[idx].atomics:
                    solution.add(atomic)
                    solve(idx + 1)
                    solution.remove(atomic)
        else:
            if cost(solution) < cost(optimal_solution):
                optimal_solution.clear()
                optimal_solution.update(solution)
            next(it)

    try:
        solve(0)
    except StopIteration:
        logger = logging.getLogger('tsfc')
        logger.warning("Solution to ILP problem may not be optimal: search "
                       "interrupted after examining %d solutions.", max_it)

    return tuple(atomic for atomic in atomics if atomic in optimal_solution)


def factorise_atomics(monomials, optimal_atomics, linear_indices):
    """Group and factorise monomials using a list of atomics as common
    subexpressions. Create new monomials for each group and optimise them recursively.

    :arg monomials: an iterable of :class:`Monomial`s, all of which should have
                    the same sum indices
    :arg optimal_atomics: list of tuples of atomics to be used as common subexpression
    :arg linear_indices: tuple of linear indices

    :returns: an iterable of :class:`Monomials`s after factorisation
    """
    if not optimal_atomics or len(monomials) <= 1:
        return monomials

    # Group monomials with respect to each optimal atomic
    def group_key(monomial):
        for oa in optimal_atomics:
            if oa in monomial.atomics:
                return oa
        assert False, "Expect at least one optimal atomic per monomial."
    factor_group = groupby(monomials, key=group_key)

    # We should not drop monomials
    assert sum(len(ms) for _, ms in factor_group) == len(monomials)

    sum_indices = next(iter(monomials)).sum_indices
    new_monomials = []
    for oa, monomials in factor_group:
        # Create new MonomialSum for the factorised out terms
        sub_monomials = []
        for monomial in monomials:
            atomics = list(monomial.atomics)
            atomics.remove(oa)  # remove common factor
            sub_monomials.append(Monomial((), tuple(atomics), monomial.rest))
        # Continue to factorise the remaining expression
        sub_monomials = optimise_monomials(sub_monomials, linear_indices)
        if len(sub_monomials) == 1:
            # Factorised part is a product, we add back the common atomics then
            # add to new MonomialSum directly rather than forming a product node
            # Retaining the monomial structure enables applying associativity
            # when forming GEM nodes later.
            sub_monomial, = sub_monomials
            new_monomials.append(
                Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest))
        else:
            # Factorised part is a summation, we need to create a new GEM node
            # and multiply with the common factor
            node = monomial_sum_to_expression(sub_monomials)
            # If the free indices of the new node intersect with linear indices,
            # add to the new monomial as `atomic`, otherwise add as `rest`.
            # Note: we might want to continue to factorise with the new atomics
            # by running optimise_monoials twice.
            if set(linear_indices) & set(node.free_indices):
                new_monomials.append(Monomial(sum_indices, (oa, node), one))
            else:
                new_monomials.append(Monomial(sum_indices, (oa, ), node))
    return new_monomials


[docs] def optimise_monomial_sum(monomial_sum, linear_indices): """Choose optimal common atomic subexpressions and factorise a :class:`MonomialSum` object to create a GEM expression. :arg monomial_sum: a :class:`MonomialSum` object :arg linear_indices: tuple of linear indices :returns: factorised GEM expression """ groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) new_monomials = [] for _, monomials in groups: new_monomials.extend(optimise_monomials(monomials, linear_indices)) return monomial_sum_to_expression(new_monomials)
def optimise_monomials(monomials, linear_indices): """Choose optimal common atomic subexpressions and factorise an iterable of monomials. :arg monomials: a list of :class:`Monomial`s, all of which should have the same sum indices :arg linear_indices: tuple of linear indices :returns: an iterable of factorised :class:`Monomials`s """ assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1, \ "All monomials required to have same sum indices for factorisation" result = [m for m in monomials if not m.atomics] # skipped monomials active_monomials = [m for m in monomials if m.atomics] optimal_atomics = find_optimal_atomics(active_monomials, linear_indices) result += factorise_atomics(active_monomials, optimal_atomics, linear_indices) return result