Source code for firedrake.ensemble.ensemble

from functools import wraps
import weakref
from contextlib import contextmanager
from itertools import zip_longest
from types import SimpleNamespace

from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from pyop2.mpi import MPI, internal_comm


def _ensemble_mpi_dispatch(func):
    """
    This wrapper checks if any arg or kwarg of the wrapped
    ensemble method is a Function or Cofunction, and if so
    it calls the specialised Firedrake implementation.
    Otherwise the standard mpi4py implementation is called.
    """
    @wraps(func)
    def _mpi_dispatch(self, *args, **kwargs):
        if any(isinstance(arg, (Function, Cofunction))
               for arg in [*args, *kwargs.values()]):
            return func(self, *args, **kwargs)
        else:
            mpicall = getattr(self._ensemble_comm, func.__name__)
            return mpicall(*args, **kwargs)
    return _mpi_dispatch


[docs] class Ensemble: def __init__(self, comm: MPI.Comm, M: int, **kwargs): """ Create a set of space and ensemble subcommunicators. Wrapper methods around many MPI communication functions are provided for sending :class:`.Function` and :class:`.Cofunction` objects between spatial communicators. For non-Firedrake objects these wrappers will dispatch to the normal implementations on :class:`mpi4py.MPI.Comm`, which means that the same call site can be used for both Firedrake and non-Firedrake types. Parameters ---------- comm : The communicator to split. M : The size of the communicators used for spatial parallelism. Must be an integer divisor of the size of ``comm``. kwargs : Can include an ``ensemble_name`` string used as a communicator name prefix, for debugging. Raises ------ ValueError If ``M`` does not divide ``comm.size`` exactly. """ size = comm.size if (size // M)*M != size: raise ValueError("Invalid size of subcommunicators %d does not divide %d" % (M, size)) rank = comm.rank # Global comm self.global_comm = comm ensemble_name = kwargs.get("ensemble_name", "Ensemble") # User and internal communicator for spatial parallelism, contains a # contiguous chunk of M processes from `global_comm`. self.comm = self.global_comm.Split(color=(rank // M), key=rank) self.comm.name = f"{ensemble_name} spatial comm" weakref.finalize(self, self.comm.Free) # User and internal communicator for ensemble parallelism, contains all # processes in `global_comm` which have the same rank in `comm`. self.ensemble_comm = self.global_comm.Split(color=(rank % M), key=rank) self.ensemble_comm.name = f"{ensemble_name} ensemble comm" weakref.finalize(self, self.ensemble_comm.Free) # Keep a reference to the internal communicator because some methods return # non-blocking requests and we need to avoid cleaning up the communicator before # they complete. Note that this communicator should *never* be passed to PETSc, as # objects created with the communicator will never get cleaned up. self._ensemble_comm = internal_comm(self.ensemble_comm, self) if (self.comm.size != M) or (self.ensemble_comm.size != (size // M)): raise ValueError(f"{M=} does not exactly divide {comm.size=}") @property def ensemble_size(self) -> int: """The number of ensemble members. """ return self.ensemble_comm.size @property def ensemble_rank(self) -> int: """The rank of the local ensemble member. """ return self.ensemble_comm.rank def _check_function(self, f: Function | Cofunction, g: Function | Cofunction | None = None): """ Check if :class:`.Function` ``f`` (and possibly a second :class:`.Function` ``g``) is a valid argument for ensemble MPI routines Parameters ---------- f : The :class:`.Function` to check. g : Second :class:`.Function` to check. Raises ------ ValueError If ``Function`` communicators mismatch each other or the ensemble spatial communicator, or is the functions are in different spaces """ if MPI.Comm.Compare(f.comm, self.comm) not in {MPI.CONGRUENT, MPI.IDENT}: raise ValueError("Function communicator does not match space communicator") if g is not None: if MPI.Comm.Compare(f.comm, g.comm) not in {MPI.CONGRUENT, MPI.IDENT}: raise ValueError("Mismatching communicators for functions") if f.function_space() != g.function_space(): raise ValueError("Mismatching function spaces for functions")
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def allreduce(self, f: Function | Cofunction, f_reduced: Function | Cofunction | None = None, op: MPI.Op = MPI.SUM ) -> Function | Cofunction: """ Allreduce a :class:`.Function` ``f`` into ``f_reduced``. Parameters ---------- f : The :class:`.Function` to allreduce. f_reduced : The result of the reduction. Must be in the same :func:`~firedrake.functionspace.FunctionSpace` as ``f``. op : MPI reduction operator. Defaults to MPI.SUM. Returns ------- Function | Cofunction : The result of the reduction. Raises ------ ValueError If Function communicators mismatch each other or the ensemble spatial communicator, or if the Functions are in different spaces """ f_reduced = f_reduced or Function(f.function_space()) self._check_function(f, f_reduced) with f_reduced.dat.vec_wo as vout, f.dat.vec_ro as vin: self._ensemble_comm.Allreduce(vin.array_r, vout.array, op=op) return f_reduced
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def iallreduce(self, f: Function | Cofunction, f_reduced: Function | Cofunction | None = None, op: MPI.Op = MPI.SUM ) -> list[MPI.Request]: """ Allreduce (non-blocking) a :class:`.Function` ``f`` into ``f_reduced``. Parameters ---------- f : The a :class:`.Function` to allreduce. f_reduced : The result of the reduction. Must be in the same :func:`~firedrake.functionspace.FunctionSpace` as ``f``. op : MPI reduction operator. Defaults to MPI.SUM. Returns ------- list[mpi4py.MPI.Request] : Requests one for each of ``f.subfunctions``. Raises ------ ValueError If Function communicators mismatch each other or the ensemble spatial communicator, or if the Functions are in different spaces """ f_reduced = f_reduced or Function(f.function_space()) self._check_function(f, f_reduced) return [self._ensemble_comm.Iallreduce(fdat.data, rdat.data, op=op) for fdat, rdat in zip(f.dat, f_reduced.dat)]
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def reduce(self, f: Function | Cofunction, f_reduced: Function | Cofunction | None = None, op: MPI.Op = MPI.SUM, root: int = 0 ) -> Function | Cofunction: """ Reduce a :class:`.Function` ``f`` into ``f_reduced``. Parameters ---------- f : The :class:`.Function` to reduce. f_reduced : The result of the reduction. Must be in the same :func:`~firedrake.functionspace.FunctionSpace` as ``f``. op : MPI reduction operator. Defaults to MPI.SUM. root : The ensemble rank to reduce to. Returns ------- Function | Cofunction : The result of the reduction. Raises ------ ValueError If Function communicators mismatch each other or the ensemble spatial communicator, or if the Functions are in different spaces """ f_reduced = f_reduced or Function(f.function_space()) self._check_function(f, f_reduced) if self.ensemble_comm.rank == root: with f_reduced.dat.vec_wo as vout, f.dat.vec_ro as vin: self._ensemble_comm.Reduce(vin.array_r, vout.array, op=op, root=root) else: with f.dat.vec_ro as vin: self._ensemble_comm.Reduce(vin.array_r, None, op=op, root=root) return f_reduced
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def ireduce(self, f: Function | Cofunction, f_reduced: Function | Cofunction | None = None, op: MPI.Op = MPI.SUM, root: int = 0 ) -> list[MPI.Request]: """ Reduce (non-blocking) a :class:`.Function` ``f`` into ``f_reduced``. Parameters ---------- f : The a :class:`.Function` to reduce. f_reduced : The result of the reduction. Must be in the same :func:`~firedrake.functionspace.FunctionSpace` as ``f``. op : MPI reduction operator. Defaults to MPI.SUM. root : The ensemble rank to reduce to. Returns ------- list[mpi4py.MPI.Request] Requests one for each of ``f.subfunctions``. Raises ------ ValueError If Function communicators mismatch each other or the ensemble spatial communicator, or if the Functions are in different spaces """ f_reduced = f_reduced or Function(f.function_space()) self._check_function(f, f_reduced) return [self._ensemble_comm.Ireduce(fdat.data_ro, rdat.data, op=op, root=root) for fdat, rdat in zip(f.dat, f_reduced.dat)]
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def bcast(self, f: Function | Cofunction, root: int = 0 ) -> Function | Cofunction: """ Broadcast a :class:`.Function` ``f`` over ``ensemble_comm`` from :attr:`~.Ensemble.ensemble_rank` ``root``. Parameters ---------- f : The :class:`.Function` to broadcast. root : The rank to broadcast from. Returns ------- Function | Cofunction : The result of the broadcast. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. """ self._check_function(f) with f.dat.vec as vec: self._ensemble_comm.Bcast(vec.array, root=root) return f
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def ibcast(self, f: Function | Cofunction, root: int = 0 ) -> list[MPI.Request]: """ Broadcast (non-blocking) a :class:`.Function` ``f`` over ``ensemble_comm`` :attr:`~.Ensemble.ensemble_rank` ``root``. Parameters ---------- f : The :class:`.Function` to broadcast. root : The rank to broadcast from. Returns ------- list[mpi4py.MPI.Request] Requests one for each of ``f.subfunctions``. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. """ self._check_function(f) return [self._ensemble_comm.Ibcast(dat.data, root=root) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def send(self, f: Function | Cofunction, dest: int, tag: int = 0): """ Send (blocking) a :class:`.Function` ``f`` over ``ensemble_comm`` to another :attr:`~.Ensemble.ensemble_rank`. Parameters ---------- f : The a :class:`.Function` to send. dest : The :attr:`~.Ensemble.ensemble_rank` to send ``f`` to. tag : The tag of the message. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. """ self._check_function(f) for dat in f.dat: self._ensemble_comm.Send(dat.data_ro, dest=dest, tag=tag)
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def recv(self, f: Function | Cofunction, source: int = MPI.ANY_SOURCE, tag: int = MPI.ANY_TAG, statuses: list[MPI.Status] | MPI.Status = None, ) -> Function | Cofunction: """ Receive (blocking) a :class:`.Function` ``f`` over ``ensemble_comm`` from another :attr:`~.Ensemble.ensemble_rank`. Parameters ---------- f : The :class:`.Function` to receive into. source : The :attr:`~.Ensemble.ensemble_rank` to receive ``f`` from. tag : The tag of the message. statuses : The :class:`mpi4py.MPI.Status` of the internal recv calls (one for each of the ``subfunctions`` of ``f``). Returns ------- Function | Cofunction : ``f`` with the received data. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. ValueError If the number of ``statuses`` provided is not the number of subfunctions of ``f``. """ self._check_function(f) if statuses is not None and isinstance(statuses, MPI.Status): statuses = [statuses] if statuses is not None and len(statuses) != len(f.dat): raise ValueError("Need to provide enough status objects for all parts of the Function") for dat, status in zip_longest(f.dat, statuses or (), fillvalue=None): self._ensemble_comm.Recv(dat.data, source=source, tag=tag, status=status) return f
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def isend(self, f: Function | Cofunction, dest: int, tag: int = 0 ) -> list[MPI.Request]: """ Send (non-blocking) a :class:`.Function` ``f`` over ``ensemble_comm`` to another :attr:`~.Ensemble.ensemble_rank`. Parameters ---------- f : The a :class:`.Function` to send. dest : The :attr:`~.Ensemble.ensemble_rank` to send ``f`` to. tag : The tag of the message. Returns ------- list[mpi4py.MPI.Request] Requests one for each of ``f.subfunctions``. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. """ self._check_function(f) return [self._ensemble_comm.Isend(dat.data_ro, dest=dest, tag=tag) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def irecv(self, f: Function | Cofunction, source: int = MPI.ANY_SOURCE, tag: int = MPI.ANY_TAG ) -> list[MPI.Request]: """ Receive (non-blocking) a :class:`.Function` ``f`` over ``ensemble_comm`` from another :attr:`~.Ensemble.ensemble_rank`. Parameters ---------- f : The :class:`.Function` to receive into. source : The :attr:`~.Ensemble.ensemble_rank` to receive ``f`` from. tag : The tag of the message. Returns ------- list[mpi4py.MPI.Request] Requests one for each of ``f.subfunctions``. Raises ------ ValueError If the Function communicator mismatches the ``ensemble.comm``. """ self._check_function(f) return [self._ensemble_comm.Irecv(dat.data, source=source, tag=tag) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def sendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0, frecv: Function | Cofunction | None = None, source: int = MPI.ANY_SOURCE, recvtag: int = MPI.ANY_TAG, statuses: list[MPI.Status] | MPI.Status = None ) -> Function | Cofunction: """ Send (blocking) a :class:`.Function` ``fsend`` and receive a :class:`.Function` ``frecv`` over ``ensemble_comm`` to/from other :attr:`~.Ensemble.ensemble_rank`. ``fsend`` and ``frecv`` do not need to be in the same function space but do need to have the same number of subfunctions. Parameters ---------- fsend : The a :class:`.Function` to send. dest : The :attr:`~.Ensemble.ensemble_rank` to send ``fsend`` to. sendtag : The tag of the send message. frecv : The :class:`.Function` to receive into. source : The :attr:`~.Ensemble.ensemble_rank` to receive ``frecv`` from. recvtag : The tag of the receive message. statuses : The :class:`mpi4py.MPI.Status` of the internal recv calls (one for each of the ``subfunctions`` of ``frecv``). Returns ------- Function | Cofunction ``frecv`` with the received data. Raises ------ ValueError If the Function communicators mismatches each other or the ``ensemble.comm``. ValueError If the number of ``statuses`` provided is not the number of subfunctions of ``f``. """ frecv = frecv or Function(fsend.function_space()) # functions don't necessarily have to match self._check_function(fsend) self._check_function(frecv) if statuses is not None and isinstance(statuses, MPI.Status): statuses = [statuses] if statuses is not None and len(statuses) != len(frecv.dat): raise ValueError("Need to provide enough status objects for all parts of the Function") with fsend.dat.vec_ro as sendvec, frecv.dat.vec_wo as recvvec: self._ensemble_comm.Sendrecv(sendvec, dest, sendtag=sendtag, recvbuf=recvvec, source=source, recvtag=recvtag, status=statuses) return frecv
[docs] @PETSc.Log.EventDecorator() @_ensemble_mpi_dispatch def isendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0, frecv: Function | Cofunction | None = None, source: int = MPI.ANY_SOURCE, recvtag: int = MPI.ANY_TAG ) -> list[MPI.Request]: """ Send (non-blocking) a :class:`.Function` ``fsend`` and receive a :class:`.Function` ``frecv`` over ``ensemble_comm`` to/from other :attr:`~.Ensemble.ensemble_rank`. ``fsend`` and ``frecv`` do not need to be in the same function space. Parameters ---------- fsend : The a :class:`.Function` to send. dest : The :attr:`~.Ensemble.ensemble_rank` to send ``fsend`` to. sendtag : The tag of the send message. frecv : The :class:`.Function` to receive into. source : The :attr:`~.Ensemble.ensemble_rank` to receive ``frecv`` from. recvtag : The tag of the receive message. Returns ------- list[mpi4py.MPI.Request] Requests one for each of ``f.subfunctions``. Raises ------ ValueError If the Function communicators mismatches each other or the ``ensemble.comm``. """ frecv = frecv or Function(fsend.function_space()) # functions don't necessarily have to match self._check_function(fsend) self._check_function(frecv) requests = [] requests.extend([self._ensemble_comm.Isend(dat.data_ro, dest=dest, tag=sendtag) for dat in fsend.dat]) requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests
[docs] @contextmanager def sequential(self, *, synchronise: bool = False, reverse: bool = False, **kwargs): """ Context manager for executing code on each ensemble member consecutively (ordered by increasing :attr:`~.Ensemble.ensemble_rank`). Any data in ``kwargs`` will be made available in the returned context and will be communicated forward after each ensemble member exits. :class:`.Function` or :class:`.Cofunction` ``kwargs`` will be sent with the corresponding Ensemble methods. For example: .. code-block:: python3 with ensemble.sequential(index=0) as ctx: print(ensemble.ensemble_rank, ctx.index) ctx.index += 2 Would print: .. code-block:: 0 0 1 2 2 4 3 6 ... If ``reverse is True`` then the ensemble ranks will be looped through in decreasing order i.e. ``ensemble_rank == (ensemble_size - 1)`` will run first, then ``ensemble_rank == (ensemble_size - 2)`` etc. Parameters ---------- synchronise : If True then MPI_Barrier will be called on the ``global_comm`` at the beginning and end of this method. reverse : If True then will iterate through spatial comms in order of decreasing ``ensemble_rank``. kwargs : Data to be passed forward by each rank and made available in the returned ``ctx``. """ rank = self.ensemble_rank if reverse: # send backwards src = rank + 1 dst = rank - 1 first_rank = (rank == self.ensemble_size - 1) last_rank = (rank == 0) else: # send forwards src = rank - 1 dst = rank + 1 first_rank = (rank == 0) last_rank = (rank == self.ensemble_size - 1) if synchronise: self.global_comm.Barrier() if not first_rank: for i, (k, v) in enumerate(kwargs.items()): if isinstance(v, (Function, Cofunction)): # Functions are sent in-place, everything else is pickled recv_args = [kwargs[k]] else: recv_args = [] kwargs[k] = self.recv(*recv_args, source=src, tag=rank+i*100) ctx = SimpleNamespace(**kwargs) yield ctx if not last_rank: for i, v in enumerate((getattr(ctx, k) for k in kwargs.keys())): try: self.send(v, dest=dst, tag=dst+i*100) except Exception as error: raise TypeError( "Failed to send object of type {type(v)__name__}. kwargs for" " Ensemble.sequential must be Functions, Cofunctions," " or acceptable arguments to mpi4py.MPI.Comm.send." ) from error if synchronise: self.global_comm.Barrier()