from functools import wraps
import ufl
from ufl.domain import extract_unique_domain
from pyadjoint.overloaded_type import create_overloaded_object, FloatingType
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape
from firedrake.adjoint_utils.blocks import FunctionAssignBlock, ProjectBlock, SubfunctionBlock, FunctionMergeBlock, SupermeshProjectBlock
import firedrake
from .checkpointing import disk_checkpointing, CheckpointFunction, \
CheckpointBase, checkpoint_init_data, DelegatedFunctionCheckpoint
[docs]
class FunctionMixin(FloatingType):
@staticmethod
def _ad_annotate_init(init):
@wraps(init)
def wrapper(self, *args, **kwargs):
FloatingType.__init__(self, *args,
block_class=kwargs.pop("block_class", None),
_ad_floating_active=kwargs.pop("_ad_floating_active", False),
_ad_args=kwargs.pop("_ad_args", None),
output_block_class=kwargs.pop("output_block_class", None),
_ad_output_args=kwargs.pop("_ad_output_args", None),
_ad_outputs=kwargs.pop("_ad_outputs", None),
ad_block_tag=kwargs.pop("ad_block_tag", None), **kwargs)
init(self, *args, **kwargs)
return wrapper
@staticmethod
def _ad_annotate_project(project):
@wraps(project)
def wrapper(self, b, *args, **kwargs):
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
bcs = kwargs.get("bcs", [])
if isinstance(b, firedrake.Function) and extract_unique_domain(b) != self.function_space().mesh():
block = SupermeshProjectBlock(b, self.function_space(), self, bcs, ad_block_tag=ad_block_tag)
else:
block = ProjectBlock(b, self.function_space(), self, bcs, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
with stop_annotating():
output = project(self, b, *args, **kwargs)
if annotate:
block.add_output(output.create_block_variable())
return output
return wrapper
@staticmethod
def _ad_annotate_subfunctions(subfunctions):
@wraps(subfunctions)
def wrapper(self, *args, **kwargs):
ad_block_tag = kwargs.pop("ad_block_tag", None)
with stop_annotating():
output = subfunctions(self, *args, **kwargs)
output = tuple(type(self)(output[i].function_space(),
output[i],
block_class=SubfunctionBlock,
_ad_floating_active=True,
_ad_args=[self, i],
_ad_output_args=[i],
output_block_class=FunctionMergeBlock,
_ad_outputs=[self],
ad_block_tag=ad_block_tag)
for i in range(len(output)))
return output
return wrapper
@staticmethod
def _ad_annotate_copy(copy):
@wraps(copy)
def wrapper(self, *args, **kwargs):
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
func = copy(self, *args, **kwargs)
if annotate:
if kwargs.pop("deepcopy", False):
block = FunctionAssignBlock(func, self, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
block.add_output(func.create_block_variable())
else:
# TODO: Implement. Here we would need to use floating types.
raise NotImplementedError("Currently kwargs['deepcopy'] must be set True")
return func
return wrapper
@staticmethod
def _ad_annotate_assign(assign):
@wraps(assign)
def wrapper(self, other, *args, **kwargs):
"""To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
Firedrake assign call."""
ad_block_tag = kwargs.pop("ad_block_tag", None)
# do not annotate in case of self assignment
annotate = annotate_tape(kwargs) and self != other
if annotate:
if not isinstance(other, ufl.core.operator.Operator):
other = create_overloaded_object(other)
block = FunctionAssignBlock(self, other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
with stop_annotating():
ret = assign(self, other, *args, **kwargs)
if annotate:
block_var = self.create_block_variable()
block.add_output(block_var)
if isinstance(other, type(self)):
if self.function_space().mesh() == other.function_space().mesh():
block_var._checkpoint = DelegatedFunctionCheckpoint(other.block_variable)
return ret
return wrapper
@staticmethod
def _ad_not_implemented(func):
@wraps(func)
def wrapper(*args, **kwargs):
if annotate_tape(kwargs):
raise NotImplementedError("Automatic differentiation is not supported for this operation.")
return func(*args, **kwargs)
return wrapper
@staticmethod
def _ad_annotate_iadd(__iadd__):
@wraps(__iadd__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __iadd__(self, other, **kwargs)
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
block = FunctionAssignBlock(func, self + other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
block.add_output(func.create_block_variable())
return func
return wrapper
@staticmethod
def _ad_annotate_isub(__isub__):
@wraps(__isub__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __isub__(self, other, **kwargs)
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
block = FunctionAssignBlock(func, self - other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
block.add_output(func.create_block_variable())
return func
return wrapper
@staticmethod
def _ad_annotate_imul(__imul__):
@wraps(__imul__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __imul__(self, other, **kwargs)
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
block = FunctionAssignBlock(func, self*other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
block.add_output(func.create_block_variable())
return func
return wrapper
@staticmethod
def _ad_annotate_itruediv(__itruediv__):
@wraps(__itruediv__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __itruediv__(self, other, **kwargs)
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
if annotate:
block = FunctionAssignBlock(func, self/other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
tape.add_block(block)
block.add_output(func.create_block_variable())
return func
return wrapper
def _ad_create_checkpoint(self):
if disk_checkpointing():
return CheckpointFunction(self)
else:
return self.copy(deepcopy=True)
def _ad_convert_riesz(self, value, riesz_map=None):
return value.riesz_representation(riesz_map=riesz_map or "L2")
def _ad_init_zero(self, dual=False):
from firedrake import Function, Cofunction
if dual:
return Cofunction(self.function_space().dual())
else:
return Function(self.function_space())
def _ad_restore_at_checkpoint(self, checkpoint):
if isinstance(checkpoint, CheckpointBase):
return checkpoint.restore()
else:
return checkpoint
def _ad_will_add_as_dependency(self):
"""Method called when the object is added as a Block dependency."""
with checkpoint_init_data():
super()._ad_will_add_as_dependency()
def _ad_mul(self, other):
from firedrake import Function
r = Function(self.function_space())
# `self` can be a Cofunction in which case only left multiplication
# with a scalar is allowed.
r.assign(other * self)
return r
def _ad_add(self, other):
from firedrake import Function
r = Function(self.function_space())
Function.assign(r, self + other)
return r
def _ad_dot(self, other, options=None):
from firedrake import assemble, action, Cofunction
if isinstance(other, Cofunction):
return assemble(action(other, self))
options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
if riesz_representation == "l2":
return self.dat.inner(other.dat)
elif riesz_representation == "L2":
return assemble(firedrake.inner(self, other)*firedrake.dx)
elif riesz_representation == "H1":
return assemble((firedrake.inner(self, other)
+ firedrake.inner(firedrake.grad(self), other))*firedrake.dx)
else:
raise NotImplementedError(
"Unknown Riesz representation %s" % riesz_representation)
@staticmethod
def _ad_assign_numpy(dst, src, offset):
range_begin, range_end = dst.dat.dataset.layout_vec.getOwnershipRange()
m_a_local = src[offset + range_begin:offset + range_end]
dst.dat.data_wo[...] = m_a_local.reshape(dst.dat.data_wo.shape)
offset += dst.dat.dataset.layout_vec.size
return dst, offset
@staticmethod
def _ad_to_list(m):
return m.dat.global_data.tolist()
def _ad_copy(self):
from firedrake import Function
r = Function(self.function_space())
r.assign(self)
return r
def _ad_dim(self):
return self.function_space().dim()
def _ad_imul(self, other):
self *= other
return self
def _ad_iadd(self, other):
self += other
return self
def _ad_function_space(self, mesh):
return self.ufl_function_space()
def _reduce(self, r, r0):
vec = self.dat.data_ro
for i in range(len(vec)):
r0 = r(vec[i], r0)
return r0
def _applyUnary(self, f):
for i in range(len(self.dat.data_ro)):
self.dat.data_wo[i] = f(self.dat.data_ro[i])
def _applyBinary(self, f, y):
for i in range(len(self.dat.data_ro)):
self.dat.data_wo[i] = f(self.dat.data_ro[i], y.dat.data_ro[i])
def _ad_from_petsc(self, vec):
with self.dat.vec_wo as self_v:
vec.copy(result=self_v)
def _ad_to_petsc(self, vec=None):
with self.dat.vec_ro as self_v:
if vec:
self_v.copy(result=vec)
else:
vec = self_v.copy()
return vec
def __deepcopy__(self, memodict={}):
return self.copy(deepcopy=True)
[docs]
class CofunctionMixin(FunctionMixin):
def _ad_dot(self, other):
return firedrake.assemble(firedrake.action(self, other))