Source code for firedrake.adjoint.transformed_functional

from collections.abc import Sequence
from contextlib import contextmanager
from numbers import Real
from operator import itemgetter
from typing import Optional, Union

import firedrake as fd
from firedrake.adjoint import Control, ReducedFunctional, Tape
from firedrake.functionspaceimpl import WithGeometry
import finat
import pyadjoint
from pyadjoint import no_annotations
from pyadjoint.enlisting import Enlist
from pyadjoint.reduced_functional import AbstractReducedFunctional
import ufl

__all__ = \
    [
        "L2RieszMap",
        "L2TransformedFunctional"
    ]


@contextmanager
def local_vector(u, *, readonly=False):
    u_local = u.createLocalVector()
    u.getLocalVector(u_local, readonly=readonly)
    yield u_local
    u.restoreLocalVector(u_local, readonly=readonly)


class L2Cholesky:
    """Mass matrix Cholesky factorization for a (real) DG space.

    Parameters
    ----------

    space
        DG space.
    constant_jacobian
        Whether the mass matrix is constant.
    """

    def __init__(self, space: WithGeometry, *, constant_jacobian: Optional[bool] = True):
        if fd.utils.complex_mode:
            raise NotImplementedError("complex not supported")

        self._space = space
        self._constant_jacobian = constant_jacobian
        self._cached_pc = None

    @property
    def space(self) -> fd.functionspaceimpl.WithGeometry:
        """Function space.
        """

        return self._space

    def _pc(self):
        import petsc4py.PETSc as PETSc

        if self._cached_pc is None:
            M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx,
                            mat_type="aij")
            M_local = M.petscmat.getDiagonalBlock()

            pc = PETSc.PC().create(M_local.comm)
            pc.setType(PETSc.PC.Type.CHOLESKY)
            pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC)
            pc.setOperators(M_local)
            pc.setUp()

            if self._constant_jacobian:
                self._cached_pc = M, M_local, pc
        else:
            _, _, pc = self._cached_pc

        return pc

    def C_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Cofunction:
        r"""For the Cholesky factorization

        ... math :

            M = C C^T,

        compute the action of :math:`C^{-1}`.

        Parameters
        ----------

        u
            Compute :math:`C^{-1} \tilde{u}` where :math:`\tilde{u}` is the
            vector of degrees of freedom for :math:`u`.

        Returns
        -------

        firedrake.cofunction.Cofunction
            Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`.
        """

        pc = self._pc()
        v = fd.Cofunction(self.space.dual())
        with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v:
            with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s:
                pc.applySymmetricLeft(u_v_s, v_v_s)
        return v

    def C_T_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Function:
        r"""For the Cholesky factorization

        ... math :

            M = C C^T,

        compute the action of :math:`C^{-T}`.

        Parameters
        ----------

        u
            Compute :math:`C^{-T} \tilde{u}` where :math:`\tilde{u}` is the
            vector of degrees of freedom for :math:`u`.

        Returns
        -------

        firedrake.function.Function
            Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`.
        """

        pc = self._pc()
        v = fd.Function(self.space)
        with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v:
            with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s:
                pc.applySymmetricRight(u_v_s, v_v_s)
        return v


[docs] class L2RieszMap(fd.RieszMap): """An :math:`L^2` Riesz map. Parameters ---------- target Function space. kwargs Keyword arguments are passed to the base class constructor. """ def __init__(self, target: WithGeometry, **kwargs): if not isinstance(target, fd.functionspaceimpl.WithGeometry): raise TypeError("Target must be a WithGeometry") super().__init__(target, ufl.L2, **kwargs)
def is_dg_space(space: WithGeometry) -> bool: """Return whether a function space is DG. Parameters ---------- space The function space. Returns ------- bool Whether the function space is DG. """ e, _ = finat.element_factory.convert(space.ufl_element()) return e.is_dg()
[docs] class L2TransformedFunctional(AbstractReducedFunctional): r"""Represents the functional .. math:: J \circ \Pi \circ \Xi where - :math:`J` is the functional definining an optimization problem. - :math:`\Pi` is the :math:`L^2` projection from a DG space containing the control space as a subspace. - :math:`\Xi` represents a change of basis from an :math:`L^2` orthonormal basis to the finite element basis for the DG space. The optimization is therefore transformed into an optimization problem using an :math:`L^2` orthonormal basis for a DG finite element space. The transformation is related to the factorization in section 4.1 of https://doi.org/10.1137/18M1175239 -- specifically the factorization in their equation (4.2) can be related to :math:`\Pi \circ \Xi`. Parameters ---------- functional Functional defining the optimization problem, :math:`J`. controls Controls. space_D DG space containing the control space. riesz_map Used for projecting from the DG space onto the control space. Ignored for DG controls. alpha Modifies the functional, equivalent to adding an extra term to :math:`J \circ \Pi` .. math:: \frac{1}{2} \alpha \left\| m_D - \Pi ( m_D ) \right\|_{L^2}^2. e.g. in a minimization problem this adds a penalty term which can be used to avoid ill-posedness due to the use of a larger DG space. tape Tape used in evaluations involving :math:`J`. """ @no_annotations def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control, Sequence[Control]], *, space_D: Optional[Union[None, WithGeometry, Sequence[Union[None, WithGeometry]]]] = None, riesz_map: Optional[Union[L2RieszMap, Sequence[L2RieszMap]]] = None, alpha: Optional[Real] = 0, tape: Optional[Tape] = None): if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): raise TypeError("controls must be Function objects") super().__init__() self._J = ReducedFunctional(functional, controls, tape=tape) self._space = tuple(control.control.function_space() for control in self._J.controls) if space_D is None: space_D = tuple(None for _ in self._space) self._space_D = Enlist(space_D) if len(self._space_D) != len(self._space): raise ValueError("Invalid length") self._space_D = tuple((space if is_dg_space(space) else space.broken_space()) if space_D is None else space_D for space, space_D in zip(self._space, self._space_D)) self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") for space_D in self._space_D) self._controls = Enlist(Enlist(controls).delist(self._controls)) if riesz_map is None: riesz_map = tuple(map(L2RieszMap, self._space)) self._riesz_map = Enlist(riesz_map) if len(self._riesz_map) != len(self._controls): raise ValueError("Invalid length") self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian) for space_D, riesz_map in zip(self._space_D, self._riesz_map)) self._alpha = alpha self._m_k = None # Map the initial guess controls_t = self._dual_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) for control, control_t in zip(self._controls, controls_t): control.control.assign(control_t) @property def controls(self) -> Enlist[Control]: return Enlist(self._controls.delist()) def _dual_transform(self, u, u_D=None, *, apply_riesz=False): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") if u_D is None: u_D = tuple(None for _ in u) else: u_D = Enlist(u_D) if len(u_D) != len(self.controls): raise ValueError("Invalid length") def transform(C, u, u_D, space, space_D, riesz_map): if apply_riesz: if space is space_D: v = u else: v = fd.assemble(fd.inner(riesz_map(u), fd.TestFunction(space_D)) * fd.dx) else: v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) if u_D is not None: v.dat.axpy(1, u_D.dat) v = C.C_inv_action(v) return v.riesz_representation("l2") v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map)) return u.delist(v) def _primal_transform(self, u): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") def transform(C, u, space, space_D, riesz_map): if fd.utils.complex_mode: # Would need to be adjoint raise NotImplementedError("complex not supported") v = C.C_T_inv_action(u) if space is space_D: w = v else: w = riesz_map(fd.assemble(fd.inner(v, fd.TestFunction(space)) * fd.dx)) return v, w vw = tuple(map(transform, self._C, u, self._space, self._space_D, self._riesz_map)) return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw)))
[docs] @no_annotations def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, Sequence[fd.Function]]: """Map the result of an optimization. Parameters ---------- m The result of the optimization. Represents an expansion in the :math:`L^2` orthonormal basis for the DG space. Returns ------- firedrake.function.Function or Sequence[firedrake.function.Function] The mapped result in the original control space. """ _, m_J = self._primal_transform(m) return m_J
[docs] @no_annotations def __call__(self, values: Union[fd.Function, Sequence[fd.Function]]) -> pyadjoint.AdjFloat: """Evaluate the functional. Parameters --------- value Control values. Returns ------- pyadjoint.AdjFloat The functional value. """ values = Enlist(values) m_D, m_J = self._primal_transform(values) J = self._J(m_J) if self._alpha != 0: for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J): if space is not space_D: J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, m_D_i - m_J_i) * fd.dx) self._m_k = m_D, m_J return J
[docs] @no_annotations def derivative(self, adj_input: Optional[Real] = 1.0, apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]: """Evaluate the derivative. Parameters --------- adj_value Not supported. apply_riesz Whether to apply the Riesz map to the result. Returns ------- firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction] The derivative. """ if not isinstance(adj_input, Real) or adj_input != 1: raise NotImplementedError("adj_input != 1 not supported") u = Enlist(self._J.derivative()) if self._alpha == 0: v_alpha = None else: v_alpha = [] for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k): if space is space_D: v_alpha.append(None) else: if fd.utils.complex_mode: raise RuntimeError("Not complex differentiable") v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx)) v = self._dual_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) return u.delist(v)
[docs] @no_annotations def hessian(self, m_dot: Union[fd.Function, Sequence[fd.Function]], hessian_input: Optional[None] = None, evaluate_tlm: Optional[bool] = True, apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]: """Evaluate the Hessian action. Parameters ---------- m_dot Action direction. hessian_input Not supported. evaluate_tlm Whether to re-evaluate the tangent-linear. apply_riesz Whether to apply the Riesz map to the result. Returns ------- firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction] The Hessian action. """ if hessian_input is not None: raise NotImplementedError("hessian_input not None not supported") m_dot = Enlist(m_dot) m_dot_D, m_dot_J = self._primal_transform(m_dot) u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm)) if self._alpha == 0: v_alpha = None else: v_alpha = [] for space, space_D, m_dot_D_i, m_dot_J_i in zip(self._space, self._space_D, m_dot_D, m_dot_J): if space is space_D: v_alpha.append(None) else: if fd.utils.complex_mode: raise RuntimeError("Not complex differentiable") v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx)) v = self._dual_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) return u.delist(v)
[docs] @no_annotations def tlm(self, m_dot: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, list[fd.Function]]: """Evaluate a Jacobian action. Parameters ---------- m_dot Action direction. Returns ------- firedrake.function.Function or list[firedrake.function.Function] The Jacobian action. """ m_dot = Enlist(m_dot) m_dot_D, m_dot_J = self._primal_transform(m_dot) tau_J = self._J.tlm(m_dot.delist(m_dot_J)) if self._alpha != 0: for space, space_D, m_dot_D_i, m_D, m_J in zip(self._space, self._space_D, m_dot_D, *self._m_k): if space is not space_D: if fd.utils.complex_mode: raise RuntimeError("Not complex differentiable") tau_J += fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_dot_D_i) * fd.dx) return tau_J