Source code for firedrake.adjoint.transformed_functional
from collections.abc import Sequence
from contextlib import contextmanager
from numbers import Real
from operator import itemgetter
from typing import Optional, Union
import firedrake as fd
from firedrake.adjoint import Control, ReducedFunctional, Tape
from firedrake.functionspaceimpl import WithGeometry
import finat
import pyadjoint
from pyadjoint import no_annotations
from pyadjoint.enlisting import Enlist
from pyadjoint.reduced_functional import AbstractReducedFunctional
import ufl
__all__ = \
[
"L2RieszMap",
"L2TransformedFunctional"
]
@contextmanager
def local_vector(u, *, readonly=False):
u_local = u.createLocalVector()
u.getLocalVector(u_local, readonly=readonly)
yield u_local
u.restoreLocalVector(u_local, readonly=readonly)
class L2Cholesky:
"""Mass matrix Cholesky factorization for a (real) DG space.
Parameters
----------
space
DG space.
constant_jacobian
Whether the mass matrix is constant.
"""
def __init__(self, space: WithGeometry, *, constant_jacobian: Optional[bool] = True):
if fd.utils.complex_mode:
raise NotImplementedError("complex not supported")
self._space = space
self._constant_jacobian = constant_jacobian
self._cached_pc = None
@property
def space(self) -> fd.functionspaceimpl.WithGeometry:
"""Function space.
"""
return self._space
def _pc(self):
import petsc4py.PETSc as PETSc
if self._cached_pc is None:
M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx,
mat_type="aij")
M_local = M.petscmat.getDiagonalBlock()
pc = PETSc.PC().create(M_local.comm)
pc.setType(PETSc.PC.Type.CHOLESKY)
pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC)
pc.setOperators(M_local)
pc.setUp()
if self._constant_jacobian:
self._cached_pc = M, M_local, pc
else:
_, _, pc = self._cached_pc
return pc
def C_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Cofunction:
r"""For the Cholesky factorization
... math :
M = C C^T,
compute the action of :math:`C^{-1}`.
Parameters
----------
u
Compute :math:`C^{-1} \tilde{u}` where :math:`\tilde{u}` is the
vector of degrees of freedom for :math:`u`.
Returns
-------
firedrake.cofunction.Cofunction
Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`.
"""
pc = self._pc()
v = fd.Cofunction(self.space.dual())
with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v:
with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s:
pc.applySymmetricLeft(u_v_s, v_v_s)
return v
def C_T_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Function:
r"""For the Cholesky factorization
... math :
M = C C^T,
compute the action of :math:`C^{-T}`.
Parameters
----------
u
Compute :math:`C^{-T} \tilde{u}` where :math:`\tilde{u}` is the
vector of degrees of freedom for :math:`u`.
Returns
-------
firedrake.function.Function
Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`.
"""
pc = self._pc()
v = fd.Function(self.space)
with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v:
with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s:
pc.applySymmetricRight(u_v_s, v_v_s)
return v
[docs]
class L2RieszMap(fd.RieszMap):
"""An :math:`L^2` Riesz map.
Parameters
----------
target
Function space.
kwargs
Keyword arguments are passed to the base class constructor.
"""
def __init__(self, target: WithGeometry, **kwargs):
if not isinstance(target, fd.functionspaceimpl.WithGeometry):
raise TypeError("Target must be a WithGeometry")
super().__init__(target, ufl.L2, **kwargs)
def is_dg_space(space: WithGeometry) -> bool:
"""Return whether a function space is DG.
Parameters
----------
space
The function space.
Returns
-------
bool
Whether the function space is DG.
"""
e, _ = finat.element_factory.convert(space.ufl_element())
return e.is_dg()
[docs]
class L2TransformedFunctional(AbstractReducedFunctional):
r"""Represents the functional
.. math::
J \circ \Pi \circ \Xi
where
- :math:`J` is the functional definining an optimization problem.
- :math:`\Pi` is the :math:`L^2` projection from a DG space containing
the control space as a subspace.
- :math:`\Xi` represents a change of basis from an :math:`L^2`
orthonormal basis to the finite element basis for the DG space.
The optimization is therefore transformed into an optimization problem
using an :math:`L^2` orthonormal basis for a DG finite element space.
The transformation is related to the factorization in section 4.1 of
https://doi.org/10.1137/18M1175239 -- specifically the factorization
in their equation (4.2) can be related to :math:`\Pi \circ \Xi`.
Parameters
----------
functional
Functional defining the optimization problem, :math:`J`.
controls
Controls.
space_D
DG space containing the control space.
riesz_map
Used for projecting from the DG space onto the control space. Ignored
for DG controls.
alpha
Modifies the functional, equivalent to adding an extra term to
:math:`J \circ \Pi`
.. math::
\frac{1}{2} \alpha \left\| m_D - \Pi ( m_D ) \right\|_{L^2}^2.
e.g. in a minimization problem this adds a penalty term which can
be used to avoid ill-posedness due to the use of a larger DG space.
tape
Tape used in evaluations involving :math:`J`.
"""
@no_annotations
def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control, Sequence[Control]], *,
space_D: Optional[Union[None, WithGeometry, Sequence[Union[None, WithGeometry]]]] = None,
riesz_map: Optional[Union[L2RieszMap, Sequence[L2RieszMap]]] = None,
alpha: Optional[Real] = 0,
tape: Optional[Tape] = None):
if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)):
raise TypeError("controls must be Function objects")
super().__init__()
self._J = ReducedFunctional(functional, controls, tape=tape)
self._space = tuple(control.control.function_space()
for control in self._J.controls)
if space_D is None:
space_D = tuple(None for _ in self._space)
self._space_D = Enlist(space_D)
if len(self._space_D) != len(self._space):
raise ValueError("Invalid length")
self._space_D = tuple((space if is_dg_space(space) else space.broken_space())
if space_D is None else space_D
for space, space_D in zip(self._space, self._space_D))
self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2")
for space_D in self._space_D)
self._controls = Enlist(Enlist(controls).delist(self._controls))
if riesz_map is None:
riesz_map = tuple(map(L2RieszMap, self._space))
self._riesz_map = Enlist(riesz_map)
if len(self._riesz_map) != len(self._controls):
raise ValueError("Invalid length")
self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian)
for space_D, riesz_map in zip(self._space_D, self._riesz_map))
self._alpha = alpha
self._m_k = None
# Map the initial guess
controls_t = self._dual_transform(tuple(control.control for control in self._J.controls), apply_riesz=False)
for control, control_t in zip(self._controls, controls_t):
control.control.assign(control_t)
@property
def controls(self) -> Enlist[Control]:
return Enlist(self._controls.delist())
def _dual_transform(self, u, u_D=None, *, apply_riesz=False):
u = Enlist(u)
if len(u) != len(self.controls):
raise ValueError("Invalid length")
if u_D is None:
u_D = tuple(None for _ in u)
else:
u_D = Enlist(u_D)
if len(u_D) != len(self.controls):
raise ValueError("Invalid length")
def transform(C, u, u_D, space, space_D, riesz_map):
if apply_riesz:
if space is space_D:
v = u
else:
v = fd.assemble(fd.inner(riesz_map(u), fd.TestFunction(space_D)) * fd.dx)
else:
v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx)
if u_D is not None:
v.dat.axpy(1, u_D.dat)
v = C.C_inv_action(v)
return v.riesz_representation("l2")
v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map))
return u.delist(v)
def _primal_transform(self, u):
u = Enlist(u)
if len(u) != len(self.controls):
raise ValueError("Invalid length")
def transform(C, u, space, space_D, riesz_map):
if fd.utils.complex_mode:
# Would need to be adjoint
raise NotImplementedError("complex not supported")
v = C.C_T_inv_action(u)
if space is space_D:
w = v
else:
w = riesz_map(fd.assemble(fd.inner(v, fd.TestFunction(space)) * fd.dx))
return v, w
vw = tuple(map(transform, self._C, u, self._space, self._space_D, self._riesz_map))
return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw)))
[docs]
@no_annotations
def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, Sequence[fd.Function]]:
"""Map the result of an optimization.
Parameters
----------
m
The result of the optimization. Represents an expansion in the
:math:`L^2` orthonormal basis for the DG space.
Returns
-------
firedrake.function.Function or Sequence[firedrake.function.Function]
The mapped result in the original control space.
"""
_, m_J = self._primal_transform(m)
return m_J
[docs]
@no_annotations
def __call__(self, values: Union[fd.Function, Sequence[fd.Function]]) -> pyadjoint.AdjFloat:
"""Evaluate the functional.
Parameters
---------
value
Control values.
Returns
-------
pyadjoint.AdjFloat
The functional value.
"""
values = Enlist(values)
m_D, m_J = self._primal_transform(values)
J = self._J(m_J)
if self._alpha != 0:
for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J):
if space is not space_D:
J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, m_D_i - m_J_i) * fd.dx)
self._m_k = m_D, m_J
return J
[docs]
@no_annotations
def derivative(self, adj_input: Optional[Real] = 1.0,
apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]:
"""Evaluate the derivative.
Parameters
---------
adj_value
Not supported.
apply_riesz
Whether to apply the Riesz map to the result.
Returns
-------
firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction]
The derivative.
"""
if not isinstance(adj_input, Real) or adj_input != 1:
raise NotImplementedError("adj_input != 1 not supported")
u = Enlist(self._J.derivative())
if self._alpha == 0:
v_alpha = None
else:
v_alpha = []
for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k):
if space is space_D:
v_alpha.append(None)
else:
if fd.utils.complex_mode:
raise RuntimeError("Not complex differentiable")
v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx))
v = self._dual_transform(u, v_alpha, apply_riesz=True)
if apply_riesz:
v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map)
for v_i, control in zip(v, self.controls))
return u.delist(v)
[docs]
@no_annotations
def hessian(self, m_dot: Union[fd.Function, Sequence[fd.Function]],
hessian_input: Optional[None] = None, evaluate_tlm: Optional[bool] = True,
apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]:
"""Evaluate the Hessian action.
Parameters
----------
m_dot
Action direction.
hessian_input
Not supported.
evaluate_tlm
Whether to re-evaluate the tangent-linear.
apply_riesz
Whether to apply the Riesz map to the result.
Returns
-------
firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction]
The Hessian action.
"""
if hessian_input is not None:
raise NotImplementedError("hessian_input not None not supported")
m_dot = Enlist(m_dot)
m_dot_D, m_dot_J = self._primal_transform(m_dot)
u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm))
if self._alpha == 0:
v_alpha = None
else:
v_alpha = []
for space, space_D, m_dot_D_i, m_dot_J_i in zip(self._space, self._space_D, m_dot_D, m_dot_J):
if space is space_D:
v_alpha.append(None)
else:
if fd.utils.complex_mode:
raise RuntimeError("Not complex differentiable")
v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx))
v = self._dual_transform(u, v_alpha, apply_riesz=True)
if apply_riesz:
v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map)
for v_i, control in zip(v, self.controls))
return u.delist(v)
[docs]
@no_annotations
def tlm(self, m_dot: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, list[fd.Function]]:
"""Evaluate a Jacobian action.
Parameters
----------
m_dot
Action direction.
Returns
-------
firedrake.function.Function or list[firedrake.function.Function]
The Jacobian action.
"""
m_dot = Enlist(m_dot)
m_dot_D, m_dot_J = self._primal_transform(m_dot)
tau_J = self._J.tlm(m_dot.delist(m_dot_J))
if self._alpha != 0:
for space, space_D, m_dot_D_i, m_D, m_J in zip(self._space, self._space_D, m_dot_D, *self._m_k):
if space is not space_D:
if fd.utils.complex_mode:
raise RuntimeError("Not complex differentiable")
tau_J += fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_dot_D_i) * fd.dx)
return tau_J
