Source code for firedrake.adjoint_utils.solving
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape
from firedrake.adjoint_utils.blocks import SolveVarFormBlock, SolveLinearSystemBlock, GenericSolveBlock, ProjectBlock
import ufl
[docs]
def annotate_solve(solve):
    """This solve routine wraps the Firedrake :func:`.solve` call. Its purpose is to annotate the model,
    recording what solves occur and what forms are involved, so that the adjoint and tangent linear models may be
    constructed automatically by pyadjoint.
    To disable the annotation, just pass ``annotate=False`` to this routine, and it acts exactly like the
    Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
    for the purposes of the adjoint computation (such as projecting fields to other function spaces
    for the purposes of visualisation).
    The overloaded solve takes optional callback functions to extract adjoint solutions.
    All of the callback functions follow the same signature, taking a single argument of type Function.
    Keyword Args:
        adj_cb (:obj:`firedrake.function`, optional):
            callback function supplying the adjoint solution in the interior. The boundary values are zero.
        adj_bdy_cb (:obj:`firedrake.function`, optional):
            callback function supplying the adjoint solution on the boundary.
            The interior values are not guaranteed to be zero.
        adj2_cb (:obj:`firedrake.function`, optional):
            callback function supplying the second-order adjoint solution in the interior.
            The boundary values are zero.
        adj2_bdy_cb (:obj:`firedrake.function`, optional):
            callback function supplying the second-order adjoint solution on
            the boundary. The interior values are not guaranteed to be zero.
        ad_block_tag (:obj:`string`, optional):
            tag used to label the resulting block on the Pyadjoint tape. This
            is useful for identifying which block is associated with which equation in the forward model.
    """
    @wraps(solve)
    def wrapper(*args, **kwargs):
        ad_block_tag = kwargs.pop("ad_block_tag", None)
        annotate = annotate_tape(kwargs)
        if annotate:
            tape = get_working_tape()
            solve_block_type = SolveVarFormBlock
            if not isinstance(args[0], ufl.equation.Equation):
                solve_block_type = SolveLinearSystemBlock
            sb_kwargs = solve_block_type.pop_kwargs(kwargs)
            sb_kwargs.update(kwargs)
            block = solve_block_type(*args, ad_block_tag=ad_block_tag, **sb_kwargs)
            tape.add_block(block)
        with stop_annotating():
            output = solve(*args, **kwargs)
        if annotate:
            if hasattr(args[1], "create_block_variable"):
                block_variable = args[1].create_block_variable()
            else:
                block_variable = args[1].function.create_block_variable()
            block.add_output(block_variable)
        return output
    return wrapper 
[docs]
def get_solve_blocks():
    """
    Extract all blocks of the tape which correspond
    to PDE solves, except for those which correspond
    to calls of the ``project`` operator.
    """
    return [
        block
        for block in get_working_tape().get_blocks()
        if issubclass(type(block), GenericSolveBlock)
        and not issubclass(type(block), ProjectBlock)
    ]