Source code for firedrake.ensemble.ensemble_function

from functools import cached_property
from contextlib import contextmanager

from pyop2 import MixedDat
from firedrake.petsc import PETSc
from firedrake.ensemble.ensemble_functionspace import (
    EnsembleFunctionSpaceBase, EnsembleFunctionSpace, EnsembleDualSpace)
from firedrake.adjoint_utils import EnsembleFunctionMixin
from firedrake.function import Function
from firedrake.norms import norm

__all__ = ("EnsembleFunction", "EnsembleCofunction")


class EnsembleFunctionBase(EnsembleFunctionMixin):
    """
    A mixed (co)function defined on a :class:`~.ensemble.Ensemble`.
    The subcomponents are distributed over the ensemble members, and
    are specified locally in a :class:`~firedrake.EnsembleFunctionSpace`.

    Parameters
    ----------

    function_space : `~ensemble_functionspace.EnsembleFunctionSpace`.
        The function space of the (co)function.

    Notes
    -----
    Passing an `EnsembleDualSpace` to `EnsembleFunction`
    will return an instance of :class:`~firedrake.EnsembleCofunction`.

    This class does not carry UFL symbolic information, unlike a
    :class:`~firedrake.function.Function`. UFL expressions can only be defined
    locally on each ensemble member using a `~firedrake.function.Function`
    from `EnsembleFunction.subfunctions`.

    See Also
    --------
    - Primal ensemble objects: :class:`~ensemble_functionspace.EnsembleFunctionSpace` and :class:`~firedrake.EnsembleFunction`.
    - Dual ensemble objects: :class:`~firedrake.EnsembleDualSpace` and :class:`~firedrake.EnsembleCofunction`.
    """

    @PETSc.Log.EventDecorator()
    @EnsembleFunctionMixin._ad_annotate_init
    def __init__(self, function_space: EnsembleFunctionSpaceBase):
        self._fs = function_space

        # we hold all subcomponents on the local
        # ensemble member in one big mixed function.
        self._full_local_function = Function(function_space._full_local_space)

        # create a Vec containing the data for all subcomponents on all
        # ensemble members. Because we use the Vec of each local mixed
        # function as the storage, if the data in the Function Vec
        # is valid then the data in the EnsembleFunction Vec is valid.

        with self._full_local_function.dat.vec as fvec:
            n = function_space.nlocal_rank_dofs
            N = function_space.nglobal_dofs
            sizes = (n, N)
            self._vec = PETSc.Vec().createWithArray(
                fvec.array, size=sizes,
                comm=function_space.global_comm)
            self._vec.setFromOptions()

    def function_space(self):
        return self._fs

    @cached_property
    def subfunctions(self):
        """
        The (co)functions on the local ensemble member.
        """
        def local_function(i):
            V = self._fs.local_spaces[i]
            usubs = self._subcomponents(i)
            if len(usubs) == 1:
                dat = usubs[0].dat
            else:
                dat = MixedDat((u.dat for u in usubs))
            return Function(V, val=dat)

        return tuple(local_function(i)
                     for i in range(self._fs.nlocal_spaces))

    def _subcomponents(self, i):
        """
        Return the subfunctions of the local mixed function storage
        corresponding to the i-th local function.

        Firedrake doesn't support nested ``MixedFunctionSpace``, so internally
        :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleFunctionSpace` flattens all the
        local :class:`~firedrake.functionspaceimpl.FunctionSpace` into a
        single ``MixedFunctionSpace``. This method retrieves the components of
        the flattened MixedFunction corresponding to the i-th local
        :class:`~firedrake.function.Function`.
        """
        return tuple(self._full_local_function.subfunctions[j]
                     for j in self._fs._component_indices(i))

    @PETSc.Log.EventDecorator()
    def riesz_representation(self, **kwargs):
        """
        Return the Riesz representation of this :class:`EnsembleFunction`
        with respect to the given Riesz map.

        Internally delegates to the
        :meth:`firedrake.function.Function.riesz_representation()`
        of each component.

        Parameters
        ----------
        riesz_map
            The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable.

        kwargs
            other arguments to be passed to the firedrake.riesz_map.
        """
        riesz = EnsembleFunction(self.function_space().dual())
        for uself, uriesz in zip(self.subfunctions, riesz.subfunctions):
            uriesz.assign(
                uself.riesz_representation(**kwargs))
        return riesz

    @PETSc.Log.EventDecorator()
    def assign(self, other, subsets=None):
        r"""Set the :class:`EnsembleFunction` to the value of another
        :class:`EnsembleFunction` other.

        Parameters
        ----------

        other : :class:`EnsembleFunction`
            The value to assign from.

        subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]]
            One subset for each local :class:`firedrake.functionFunction`.
            None elements will be ignored. The values of each local function
            will only be assigned on the nodes on the corresponding subset.
        """
        if type(other) is not type(self):
            raise TypeError(
                f"Cannot assign {type(self).__name__} from {type(other).__name__}")
        for i in range(self._fs.nlocal_spaces):
            self.subfunctions[i].assign(
                other.subfunctions[i],
                subset=subsets[i] if subsets else None)
        return self

    @PETSc.Log.EventDecorator()
    def copy(self):
        """
        Return a deep copy of the :class:`EnsembleFunction`.
        """
        new = type(self)(self.function_space())
        new.assign(self)
        return new

    @PETSc.Log.EventDecorator()
    def zero(self, subsets=None):
        """
        Set values to zero.

        Parameters
        ----------

        subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]]
            One subset for each local :class:`firedrake.function.Function`.
            None elements will be ignored.  The values of each local function
            will only be zeroed on the nodes on the corresponding subset.
        """
        for i in range(self._fs.nlocal_spaces):
            self.subfunctions[i].zero(
                subset=subsets[i] if subsets else None)
        return self

    @PETSc.Log.EventDecorator()
    def __iadd__(self, other):
        for us, uo in zip(self.subfunctions, other.subfunctions):
            us.assign(us + uo)
        return self

    @PETSc.Log.EventDecorator()
    def __imul__(self, other):
        if type(other) is type(self):
            for us, uo in zip(self.subfunctions, other.subfunctions):
                us.assign(us*uo)
        else:
            for us in self.subfunctions:
                us *= other
        return self

    @PETSc.Log.EventDecorator()
    def __add__(self, other):
        new = self.copy()
        new += other
        return new

    @PETSc.Log.EventDecorator()
    def __mul__(self, other):
        new = self.copy()
        new *= other
        return new

    @PETSc.Log.EventDecorator()
    def __rmul__(self, other):
        if type(other) is type(self):
            for us, uo in zip(self.subfunctions, other.subfunctions):
                us.assign(us*uo)
        else:
            for us in self.subfunctions:
                us *= other
        return self

    @contextmanager
    def vec(self):
        """
        Context manager for the global :class:`petsc4py.PETSc.Vec` with
        read/write access.

        It is invalid to access the ``Vec`` outside of a context manager.
        """
        # The globally defined _vec views the _full_local_function.dat.vec.
        # The data in _full_local_function.dat.vec is only valid inside the
        # context manager, so we need to activate that context manager before
        # yielding our _vec otherwise the data will not be up to date.
        # However, because the copies in the _full_local_function.dat.vec
        # context manager are done without _vec knowing, we have to manually
        # increment the state to make sure its still in sync.
        with self._full_local_function.dat.vec:
            self._vec.stateIncrease()
            yield self._vec

    @contextmanager
    def vec_ro(self):
        """
        Context manager for the global :class:`petsc4py.PETSc.Vec` with
        read only access.

        It is invalid to access the ``Vec`` outside of a context manager.
        """
        # The globally defined _vec views the _full_local_function.dat.vec.
        # The data in _full_local_function.dat.vec is only valid inside the
        # context manager, so we need to activate that context manager before
        # yielding our _vec otherwise the data will not be up to date.
        with self._full_local_function.dat.vec_ro:
            self._vec.stateIncrease()
            yield self._vec

    @contextmanager
    def vec_wo(self):
        """
        Context manager for the global :class:`petsc4py.PETSc.Vec` with
        write only access.

        It is invalid to access the ``Vec`` outside of a context manager.
        """
        # The globally defined _vec views the _full_local_function.dat.vec.
        # The data in _full_local_function.dat.vec is only valid inside the
        # context manager, so we need to activate that context manager before
        # yielding our _vec otherwise the data will not be copied back into
        # the _full_local_function properly when exiting the context manager.
        # Because the _full_local_function.dat.vec_wo context manager doesn't
        # copy any data on entry, this time we don't have to manually increase
        # _vec's state. If the user modifies _vec inside out context manager then
        # _vec will know and will handle incrementing it's state itself.
        with self._full_local_function.dat.vec_wo:
            yield self._vec


[docs] class EnsembleFunction(EnsembleFunctionBase): """ A mixed Function defined on a :class:`~.ensemble.Ensemble`. The subcomponents are distributed over the ensemble members, and are specified locally in a :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleFunctionSpace`. Parameters ---------- function_space : :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleFunctionSpace`. The function space of the Function. Notes ----- Passing an :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleDualSpace` to ``EnsembleFunction`` will return an instance of :class:`EnsembleCofunction`. This class does not carry UFL symbolic information, unlike a :class:`~firedrake.function.Function`. UFL expressions can only be defined locally on each ensemble member using a :class:`~firedrake.function.Function` from ``EnsembleFunction.subfunctions``. See Also -------- :class:`~.ensemble_functionspace.EnsembleFunctionSpace` :class:`~.ensemble_function.EnsembleFunction` :class:`~.ensemble_functionspace.EnsembleDualSpace` :class:`~.ensemble_function.EnsembleCofunction` """ def __new__(cls, function_space: EnsembleFunctionSpaceBase): if isinstance(function_space, EnsembleDualSpace): return EnsembleCofunction(function_space) return super().__new__(cls) def __init__(self, function_space: EnsembleFunctionSpace): if not isinstance(function_space, EnsembleFunctionSpace): raise TypeError( "EnsembleFunction must be created using an EnsembleFunctionSpace") super().__init__(function_space)
[docs] def norm(self, *args, **kwargs): """Compute the norm of the function. Any arguments are forwarded to :func:`~firedrake.norms.norm`. """ return self._fs.ensemble_comm.allreduce( sum(norm(u, *args, **kwargs) for u in self.subfunctions))
[docs] class EnsembleCofunction(EnsembleFunctionBase): """ A mixed finite element Cofunction distributed over an ensemble. Parameters ---------- function_space : :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleDualSpace` The function space of the cofunction. """ """ A mixed Cofunction defined on a :class:`~firedrake.ensemble.ensemble.Ensemble`. The subcomponents are distributed over the ensemble members, and are specified locally in a :class:`~firedrake.ensemble.ensemble_functionspace.EnsembleDualSpace`. Parameters ---------- function_space : `~firedrake.ensemble.ensemble_functionspace.EnsembleDualSpace`. The dual function space of the Cofunction. Notes ----- This class does not carry UFL symbolic information, unlike a :class:`~firedrake.cofunction.Cofunction`. UFL expressions can only be defined locally on each ensemble member using a `~firedrake.cofunction.Cofunction` from :meth:`EnsembleCofunction.subfunctions`. See Also -------- :class:`~.ensemble_functionspace.EnsembleFunctionSpace` :class:`~.ensemble_function.EnsembleFunction` :class:`~.ensemble_functionspace.EnsembleDualSpace` :class:`~.ensemble_function.EnsembleCofunction` """ def __init__(self, function_space: EnsembleDualSpace): if not isinstance(function_space, EnsembleDualSpace): raise TypeError( "EnsembleCofunction must be created using an EnsembleDualSpace") super().__init__(function_space)