"""
This file contains all the necessary functions to accurately count the
total number of floating point operations for a given script.
"""
import gem.gem as gem
import gem.impero as imp
from functools import singledispatch
import numpy
import math
[docs]
@singledispatch
def statement(tree, temporaries):
raise NotImplementedError
[docs]
@statement.register(imp.Block)
def statement_block(tree, temporaries):
flops = sum(statement(child, temporaries) for child in tree.children)
return flops
[docs]
@statement.register(imp.For)
def statement_for(tree, temporaries):
extent = tree.index.extent
assert extent is not None
child, = tree.children
flops = statement(child, temporaries)
return flops * extent
[docs]
@statement.register(imp.Initialise)
def statement_initialise(tree, temporaries):
return 0
[docs]
@statement.register(imp.Accumulate)
def statement_accumulate(tree, temporaries):
flops = expression_flops(tree.indexsum.children[0], temporaries)
return flops + 1
[docs]
@statement.register(imp.Return)
def statement_return(tree, temporaries):
flops = expression_flops(tree.expression, temporaries)
return flops + 1
[docs]
@statement.register(imp.ReturnAccumulate)
def statement_returnaccumulate(tree, temporaries):
flops = expression_flops(tree.indexsum.children[0], temporaries)
return flops + 1
[docs]
@statement.register(imp.Evaluate)
def statement_evaluate(tree, temporaries):
flops = expression_flops(tree.expression, temporaries, top=True)
return flops
[docs]
@singledispatch
def flops(expr, temporaries):
raise NotImplementedError(f"Don't know how to count flops of {type(expr)}")
[docs]
@flops.register(gem.Failure)
def flops_failure(expr, temporaries):
raise ValueError("Not expecting a Failure node")
[docs]
@flops.register(gem.Variable)
@flops.register(gem.Identity)
@flops.register(gem.Delta)
@flops.register(gem.Zero)
@flops.register(gem.Literal)
@flops.register(gem.Index)
@flops.register(gem.VariableIndex)
def flops_zero(expr, temporaries):
# Initial set up of these Gem nodes are of 0 floating point operations.
return 0
[docs]
@flops.register(gem.LogicalNot)
@flops.register(gem.LogicalAnd)
@flops.register(gem.LogicalOr)
@flops.register(gem.ListTensor)
def flops_zeroplus(expr, temporaries):
# These nodes contribute 0 floating point operations, but their children may not.
return 0 + sum(expression_flops(child, temporaries)
for child in expr.children)
[docs]
@flops.register(gem.Product)
def flops_product(expr, temporaries):
# Multiplication by -1 is not a flop.
a, b = expr.children
if isinstance(a, gem.Literal) and a.value == -1:
return expression_flops(b, temporaries)
elif isinstance(b, gem.Literal) and b.value == -1:
return expression_flops(a, temporaries)
else:
return 1 + sum(expression_flops(child, temporaries)
for child in expr.children)
[docs]
@flops.register(gem.Sum)
@flops.register(gem.Division)
@flops.register(gem.Comparison)
@flops.register(gem.MathFunction)
@flops.register(gem.MinValue)
@flops.register(gem.MaxValue)
def flops_oneplus(expr, temporaries):
return 1 + sum(expression_flops(child, temporaries)
for child in expr.children)
[docs]
@flops.register(gem.Power)
def flops_power(expr, temporaries):
base, exponent = expr.children
base_flops = expression_flops(base, temporaries)
if isinstance(exponent, gem.Literal):
exponent = exponent.value
if exponent > 0 and exponent == math.floor(exponent):
return base_flops + int(math.ceil(math.log2(exponent)))
else:
return base_flops + 5 # heuristic
else:
return base_flops + 5 # heuristic
[docs]
@flops.register(gem.Conditional)
def flops_conditional(expr, temporaries):
condition, then, else_ = (expression_flops(child, temporaries)
for child in expr.children)
return condition + max(then, else_)
[docs]
@flops.register(gem.Indexed)
@flops.register(gem.FlexiblyIndexed)
def flops_indexed(expr, temporaries):
aggregate = sum(expression_flops(child, temporaries)
for child in expr.children)
# Average flops per entry
return aggregate / numpy.prod(expr.children[0].shape, dtype=int)
[docs]
@flops.register(gem.IndexSum)
def flops_indexsum(expr, temporaries):
raise ValueError("Not expecting IndexSum")
[docs]
@flops.register(gem.Inverse)
def flops_inverse(expr, temporaries):
n, _ = expr.shape
# 2n^3 + child flop count
return 2*n**3 + sum(expression_flops(child, temporaries)
for child in expr.children)
[docs]
@flops.register(gem.Solve)
def flops_solve(expr, temporaries):
n, m = expr.shape
# 2mn + inversion cost of A + children flop count
return 2*n*m + 2*n**3 + sum(expression_flops(child, temporaries)
for child in expr.children)
[docs]
@flops.register(gem.ComponentTensor)
def flops_componenttensor(expr, temporaries):
raise ValueError("Not expecting ComponentTensor")
[docs]
def expression_flops(expression, temporaries, top=False):
"""An approximation to flops required for each expression.
:arg expression: GEM expression.
:arg temporaries: Expressions that are assigned to temporaries
:arg top: are we at the root?
:returns: flop count for the expression
"""
if not top and expression in temporaries:
return 0
else:
return flops(expression, temporaries)
[docs]
def count_flops(impero_c):
"""An approximation to flops required for a scheduled impero_c tree.
:arg impero_c: a :class:`~.Impero_C` object.
:returns: approximate flop count for the tree.
"""
try:
return statement(impero_c.tree, set(impero_c.temporaries))
except (ValueError, NotImplementedError):
return 0