Source code for firedrake.adjoint_utils.ensemble_function
from pyadjoint.overloaded_type import OverloadedType
from firedrake.petsc import PETSc
from .checkpointing import disk_checkpointing
from functools import wraps
[docs]
class EnsembleFunctionMixin(OverloadedType):
    """
    Basic functionality for EnsembleFunction to be OverloadedTypes.
    Note that currently no EnsembleFunction operations are taped.
    Enables EnsembleFunction to do the following:
    - Be a Control for a NumpyReducedFunctional (_ad_to_list and _ad_assign_numpy)
    - Be used with pyadjoint TAO solver (_ad_{to,from}_petsc)
    - Be used as a Control for Taylor tests (_ad_dot)
    """
    @staticmethod
    def _ad_annotate_init(init):
        @wraps(init)
        def wrapper(self, *args, **kwargs):
            OverloadedType.__init__(self)
            init(self, *args, **kwargs)
            self._ad_add = self.__add__
            self._ad_mul = self.__mul__
            self._ad_iadd = self.__iadd__
            self._ad_imul = self.__imul__
            self._ad_copy = self.copy
        return wrapper
    @staticmethod
    def _ad_to_list(m):
        with m.vec_ro() as gvec:
            lvec = PETSc.Vec().createSeq(gvec.size,
                                         comm=PETSc.COMM_SELF)
            PETSc.Scatter().toAll(gvec).scatter(
                gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES)
        return lvec.array_r.tolist()
    @staticmethod
    def _ad_assign_numpy(dst, src, offset):
        with dst.vec_wo() as vec:
            begin, end = vec.owner_range
            vec.array[:] = src[offset + begin: offset + end]
            offset += vec.size
        return dst, offset
    def _ad_dot(self, other, options=None):
        local_dot = sum(uself._ad_dot(uother, options=options)
                        for uself, uother in zip(self.subfunctions,
                                                 other.subfunctions))
        return self.ensemble.ensemble_comm.allreduce(local_dot)
    def _ad_convert_riesz(self, value, options=None):
        raise NotImplementedError
    def _ad_init_zero(self, dual=False):
        from firedrake import EnsembleFunction, EnsembleCofunction
        if dual:
            return EnsembleCofunction(self.function_space().dual())
        else:
            return EnsembleFunction(self.function_space())
    def _ad_create_checkpoint(self):
        if disk_checkpointing():
            raise NotImplementedError(
                "Disk checkpointing not implemented for EnsembleFunctions")
        else:
            return self.copy()
    def _ad_restore_at_checkpoint(self, checkpoint):
        if type(checkpoint) is type(self):
            return checkpoint
        raise NotImplementedError(
            "Disk checkpointing not implemented for EnsembleFunctions")
    def _ad_from_petsc(self, vec):
        with self.vec_wo as self_v:
            vec.copy(self_v)
    def _ad_to_petsc(self, vec=None):
        with self.vec_ro as self_v:
            return self_v.copy(vec or self._vec.duplicate())