Source code for firedrake.adjoint_utils.blocks.function

import ufl
from ufl.corealg.traversal import traverse_unique_terminals
from ufl.formatting.ufl2unicode import ufl2unicode
from pyadjoint import Block, OverloadedType, AdjFloat
import firedrake
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint, \
    DelegatedFunctionCheckpoint
from .block_utils import isconstant


[docs] class FunctionAssignBlock(Block): def __init__(self, func, other, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) self.other = None self.expr = None if isinstance(other, OverloadedType): self.add_dependency(other, no_duplicates=True) elif isinstance(other, float) or isinstance(other, int): other = AdjFloat(other) self.add_dependency(other, no_duplicates=True) elif not (isinstance(other, float) or isinstance(other, int)): # Assume that this is a point-wise evaluated UFL expression # (firedrake only) for op in traverse_unique_terminals(other): if isinstance(op, OverloadedType): self.add_dependency(op, no_duplicates=True) self.expr = other def _replace_with_saved_output(self): if self.expr is None: return None replace_map = {} for dep in self.get_dependencies(): replace_map[dep.output] = dep.saved_output return ufl.replace(self.expr, replace_map)
[docs] def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): adj_input_func, = adj_inputs if isinstance(adj_input_func, firedrake.Cofunction): adj_input_func = adj_input_func.riesz_representation(riesz_map="l2") if self.expr is None: return adj_input_func expr = self._replace_with_saved_output() return expr, adj_input_func
[docs] def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if self.expr is None: if isinstance(block_variable.output, AdjFloat): try: # Adjoint of a broadcast is just a sum return adj_inputs[0].dat.data_ro.sum() except AttributeError: # Catch the case where adj_inputs[0] is just a float return adj_inputs[0] elif isconstant(block_variable.output): R = block_variable.output._ad_function_space( prepared.function_space().mesh() ) return self._adj_assign_constant(prepared, R) else: adj_output = firedrake.Function( block_variable.output.function_space()) adj_output.assign(prepared) adj_output = adj_output.riesz_representation(riesz_map="l2") return adj_output else: # Linear combination expr, adj_input_func = prepared adj_output = firedrake.Function(adj_input_func.function_space()) if not isconstant(block_variable.output): diff_expr = ufl.algorithms.expand_derivatives( ufl.derivative( expr, block_variable.saved_output, adj_input_func ) ) # Firedrake does not support assignment of conjugate functions adj_output.interpolate(ufl.conj(diff_expr)) adj_output = adj_output.riesz_representation(riesz_map="l2") else: mesh = adj_output.function_space().mesh() diff_expr = ufl.algorithms.expand_derivatives( ufl.derivative( expr, block_variable.saved_output, firedrake.Constant(1., domain=mesh) ) ) adj_output.assign(diff_expr) return adj_output.dat.inner(adj_input_func.dat) if isconstant(block_variable.output): R = block_variable.output._ad_function_space( adj_output.function_space().mesh() ) return self._adj_assign_constant(adj_output, R) else: return adj_output
def _adj_assign_constant(self, adj_output, constant_fs): r = firedrake.Function(constant_fs) shape = r.ufl_shape if shape == () or shape[0] == 1: # Scalar Constant r.dat.data[:] = adj_output.dat.data_ro.sum() else: # We assume the shape of the constant == shape of the output # function if not scalar. This assumption is due to FEniCS not # supporting products with non-scalar constants in assign. values = [] for i in range(shape[0]): values.append(adj_output.sub(i, deepcopy=True).dat.data_ro.sum()) r.assign(firedrake.Constant(values)) return r
[docs] def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): if self.expr is None: return None return self._replace_with_saved_output()
[docs] def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): if self.expr is None: return tlm_inputs[0] expr = prepared dudm = firedrake.Function(block_variable.output.function_space()) dudmi = firedrake.Function(block_variable.output.function_space()) for dep in self.get_dependencies(): if dep.tlm_value: dudmi.assign(ufl.algorithms.expand_derivatives( ufl.derivative(expr, dep.saved_output, dep.tlm_value))) dudm.dat += 1.0 * dudmi.dat return dudm
[docs] def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, hessian_inputs, relevant_dependencies)
[docs] def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): # Current implementation assumes lincom in hessian, # otherwise we need second-order derivatives here. return self.evaluate_adj_component(inputs, hessian_inputs, block_variable, idx, prepared)
[docs] def prepare_recompute_component(self, inputs, relevant_outputs): if self.expr is None: return None return self._replace_with_saved_output()
[docs] def recompute_component(self, inputs, block_variable, idx, prepared=None): """Recompute the assignment. Parameters ---------- inputs : list of Function or Constant The variables in the RHS of the assignment. block_variable : pyadjoint.block_variable.BlockVariable The output block variable. idx : int Index associated to the inputs list. prepared : The precomputed RHS value. Notes ----- Recomputes the block_variable only if the checkpoint was not delegated to another :class:`~firedrake.function.Function`. Returns ------- Function Return either the firedrake function or `BlockVariable` checkpoint to which was delegated the checkpointing. """ if isinstance(block_variable.checkpoint, DelegatedFunctionCheckpoint): return block_variable.checkpoint else: if self.expr is None: prepared = inputs[0] output = firedrake.Function( block_variable.output.function_space() ) output.assign(prepared) return maybe_disk_checkpoint(output)
def __str__(self): rhs = self.expr or self.other or self.get_dependencies()[0].output if isinstance(rhs, ufl.core.expr.Expr): rhs_str = ufl2unicode(rhs) else: rhs_str = str(rhs) return f"assign({rhs_str})"
[docs] class SubfunctionBlock(Block): def __init__(self, func, idx, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) self.add_dependency(func) self.idx = idx
[docs] def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): eval_adj = firedrake.Cofunction(block_variable.output.function_space().dual()) if type(adj_inputs[0]) is firedrake.Cofunction: eval_adj.sub(self.idx).assign(adj_inputs[0]) else: eval_adj.sub(self.idx).assign(adj_inputs[0].function) return eval_adj
[docs] def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): return firedrake.Function.sub(tlm_inputs[0], self.idx)
[docs] def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): eval_hessian = firedrake.Cofunction(block_variable.output.function_space().dual()) eval_hessian.sub(self.idx).assign(hessian_inputs[0]) return eval_hessian
[docs] def recompute_component(self, inputs, block_variable, idx, prepared): return maybe_disk_checkpoint( firedrake.Function.sub(inputs[0], self.idx) )
def __str__(self): return f"{self.get_dependencies()[0]}[{self.idx}]"
[docs] class FunctionMergeBlock(Block): def __init__(self, func, idx, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) self.add_dependency(func) self.idx = idx for output in func._ad_outputs: self.add_dependency(output)
[docs] def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if idx == 0: return adj_inputs[0].subfunctions[self.idx] else: return adj_inputs[0]
[docs] def evaluate_tlm(self): tlm_input = self.get_dependencies()[0].tlm_value if tlm_input is None: return output = self.get_outputs()[0] fs = output.output.function_space() f = type(output.output)(fs) output.add_tlm_output( type(output.output).assign(f.sub(self.idx), tlm_input) )
[docs] def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): return hessian_inputs[0]
[docs] def recompute_component(self, inputs, block_variable, idx, prepared): sub_func = inputs[0] parent_in = inputs[1] parent_out = type(parent_in)(parent_in) parent_out.sub(self.idx).assign(sub_func) return maybe_disk_checkpoint(parent_out)
def __str__(self): deps = self.get_dependencies() return f"{deps[1]}[{self.idx}].assign({deps[0]})"