Source code for firedrake.ensemble

import weakref

from firedrake.petsc import PETSc
from pyop2.mpi import MPI, internal_comm
from itertools import zip_longest

__all__ = ("Ensemble", )


[docs] class Ensemble(object): def __init__(self, comm, M, **kwargs): """ Create a set of space and ensemble subcommunicators. :arg comm: The communicator to split. :arg M: the size of the communicators used for spatial parallelism. :kwarg ensemble_name: string used as communicator name prefix, for debugging. :raises ValueError: if ``M`` does not divide ``comm.size`` exactly. """ size = comm.size if (size // M)*M != size: raise ValueError("Invalid size of subcommunicators %d does not divide %d" % (M, size)) rank = comm.rank # Global comm self.global_comm = comm # Internal global comm self._comm = internal_comm(comm, self) ensemble_name = kwargs.get("ensemble_name", "Ensemble") # User and internal communicator for spatial parallelism, contains a # contiguous chunk of M processes from `global_comm`. self.comm = self.global_comm.Split(color=(rank // M), key=rank) self.comm.name = f"{ensemble_name} spatial comm" weakref.finalize(self, self.comm.Free) self._spatial_comm = internal_comm(self.comm, self) # User and internal communicator for ensemble parallelism, contains all # processes in `global_comm` which have the same rank in `comm`. self.ensemble_comm = self.global_comm.Split(color=(rank % M), key=rank) self.ensemble_comm.name = f"{ensemble_name} ensemble comm" weakref.finalize(self, self.ensemble_comm.Free) self._ensemble_comm = internal_comm(self.ensemble_comm, self) assert self.comm.size == M assert self.ensemble_comm.size == (size // M) def _check_function(self, f, g=None): """ Check if function f (and possibly a second function g) is a valid argument for ensemble mpi routines :arg f: The function to check :arg g: Second function to check :raises ValueError: if function communicators mismatch each other or the ensemble spatial communicator, or is the functions are in different spaces """ if MPI.Comm.Compare(f._comm, self._spatial_comm) not in {MPI.CONGRUENT, MPI.IDENT}: raise ValueError("Function communicator does not match space communicator") if g is not None: if MPI.Comm.Compare(f._comm, g._comm) not in {MPI.CONGRUENT, MPI.IDENT}: raise ValueError("Mismatching communicators for functions") if f.function_space() != g.function_space(): raise ValueError("Mismatching function spaces for functions")
[docs] @PETSc.Log.EventDecorator() def allreduce(self, f, f_reduced, op=MPI.SUM): """ Allreduce a function f into f_reduced over ``ensemble_comm`` . :arg f: The a :class:`.Function` to allreduce. :arg f_reduced: the result of the reduction. :arg op: MPI reduction operator. Defaults to MPI.SUM. :raises ValueError: if function communicators mismatch each other or the ensemble spatial communicator, or if the functions are in different spaces """ self._check_function(f, f_reduced) with f_reduced.dat.vec_wo as vout, f.dat.vec_ro as vin: self._ensemble_comm.Allreduce(vin.array_r, vout.array, op=op) return f_reduced
[docs] @PETSc.Log.EventDecorator() def iallreduce(self, f, f_reduced, op=MPI.SUM): """ Allreduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` . :arg f: The a :class:`.Function` to allreduce. :arg f_reduced: the result of the reduction. :arg op: MPI reduction operator. Defaults to MPI.SUM. :returns: list of MPI.Request objects (one for each of f.subfunctions). :raises ValueError: if function communicators mismatch each other or the ensemble spatial communicator, or if the functions are in different spaces """ self._check_function(f, f_reduced) return [self._ensemble_comm.Iallreduce(fdat.data, rdat.data, op=op) for fdat, rdat in zip(f.dat, f_reduced.dat)]
[docs] @PETSc.Log.EventDecorator() def reduce(self, f, f_reduced, op=MPI.SUM, root=0): """ Reduce a function f into f_reduced over ``ensemble_comm`` to rank root :arg f: The a :class:`.Function` to reduce. :arg f_reduced: the result of the reduction on rank root. :arg op: MPI reduction operator. Defaults to MPI.SUM. :arg root: rank to reduce to. Defaults to 0. :raises ValueError: if function communicators mismatch each other or the ensemble spatial communicator, or is the functions are in different spaces """ self._check_function(f, f_reduced) if self.ensemble_comm.rank == root: with f_reduced.dat.vec_wo as vout, f.dat.vec_ro as vin: self._ensemble_comm.Reduce(vin.array_r, vout.array, op=op, root=root) else: with f.dat.vec_ro as vin: self._ensemble_comm.Reduce(vin.array_r, None, op=op, root=root) return f_reduced
[docs] @PETSc.Log.EventDecorator() def ireduce(self, f, f_reduced, op=MPI.SUM, root=0): """ Reduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` to rank root :arg f: The a :class:`.Function` to reduce. :arg f_reduced: the result of the reduction on rank root. :arg op: MPI reduction operator. Defaults to MPI.SUM. :arg root: rank to reduce to. Defaults to 0. :returns: list of MPI.Request objects (one for each of f.subfunctions). :raises ValueError: if function communicators mismatch each other or the ensemble spatial communicator, or is the functions are in different spaces """ self._check_function(f, f_reduced) return [self._ensemble_comm.Ireduce(fdat.data_ro, rdat.data, op=op, root=root) for fdat, rdat in zip(f.dat, f_reduced.dat)]
[docs] @PETSc.Log.EventDecorator() def bcast(self, f, root=0): """ Broadcast a function f over ``ensemble_comm`` from rank root :arg f: The :class:`.Function` to broadcast. :arg root: rank to broadcast from. Defaults to 0. :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) with f.dat.vec as vec: self._ensemble_comm.Bcast(vec.array, root=root) return f
[docs] @PETSc.Log.EventDecorator() def ibcast(self, f, root=0): """ Broadcast (non-blocking) a function f over ``ensemble_comm`` from rank root :arg f: The :class:`.Function` to broadcast. :arg root: rank to broadcast from. Defaults to 0. :returns: list of MPI.Request objects (one for each of f.subfunctions). :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) return [self._ensemble_comm.Ibcast(dat.data, root=root) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() def send(self, f, dest, tag=0): """ Send (blocking) a function f over ``ensemble_comm`` to another ensemble rank. :arg f: The a :class:`.Function` to send :arg dest: the rank to send to :arg tag: the tag of the message. Defaults to 0 :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) for dat in f.dat: self._ensemble_comm.Send(dat.data_ro, dest=dest, tag=tag)
[docs] @PETSc.Log.EventDecorator() def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None): """ Receive (blocking) a function f over ``ensemble_comm`` from another ensemble rank. :arg f: The a :class:`.Function` to receive into :arg source: the rank to receive from. Defaults to MPI.ANY_SOURCE. :arg tag: the tag of the message. Defaults to MPI.ANY_TAG. :arg statuses: MPI.Status objects (one for each of f.subfunctions or None). :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) if statuses is not None and len(statuses) != len(f.dat): raise ValueError("Need to provide enough status objects for all parts of the Function") for dat, status in zip_longest(f.dat, statuses or (), fillvalue=None): self._ensemble_comm.Recv(dat.data, source=source, tag=tag, status=status)
[docs] @PETSc.Log.EventDecorator() def isend(self, f, dest, tag=0): """ Send (non-blocking) a function f over ``ensemble_comm`` to another ensemble rank. :arg f: The a :class:`.Function` to send :arg dest: the rank to send to :arg tag: the tag of the message. Defaults to 0. :returns: list of MPI.Request objects (one for each of f.subfunctions). :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) return [self._ensemble_comm.Isend(dat.data_ro, dest=dest, tag=tag) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG): """ Receive (non-blocking) a function f over ``ensemble_comm`` from another ensemble rank. :arg f: The a :class:`.Function` to receive into :arg source: the rank to receive from. Defaults to MPI.ANY_SOURCE. :arg tag: the tag of the message. Defaults to MPI.ANY_TAG. :returns: list of MPI.Request objects (one for each of f.subfunctions). :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ self._check_function(f) return [self._ensemble_comm.Irecv(dat.data, source=source, tag=tag) for dat in f.dat]
[docs] @PETSc.Log.EventDecorator() def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG, status=None): """ Send (blocking) a function fsend and receive a function frecv over ``ensemble_comm`` to another ensemble rank. :arg fsend: The a :class:`.Function` to send. :arg dest: the rank to send to. :arg sendtag: the tag of the send message. Defaults to 0. :arg frecv: The a :class:`.Function` to receive into. :arg source: the rank to receive from. Defaults to MPI.ANY_SOURCE. :arg recvtag: the tag of the received message. Defaults to MPI.ANY_TAG. :arg status: MPI.Status object or None. :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ # functions don't necessarily have to match self._check_function(fsend) self._check_function(frecv) with fsend.dat.vec_ro as sendvec, frecv.dat.vec_wo as recvvec: self._ensemble_comm.Sendrecv(sendvec, dest, sendtag=sendtag, recvbuf=recvvec, source=source, recvtag=recvtag, status=status)
[docs] @PETSc.Log.EventDecorator() def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG): """ Send a function fsend and receive a function frecv over ``ensemble_comm`` to another ensemble rank. :arg fsend: The a :class:`.Function` to send. :arg dest: the rank to send to. :arg sendtag: the tag of the send message. Defaults to 0. :arg frecv: The a :class:`.Function` to receive into. :arg source: the rank to receive from. Defaults to MPI.ANY_SOURCE. :arg recvtag: the tag of the received message. Defaults to MPI.ANY_TAG. :returns: list of MPI.Request objects (one for each of fsend.subfunctions and frecv.subfunctions). :raises ValueError: if function communicator mismatches the ensemble spatial communicator. """ # functions don't necessarily have to match self._check_function(fsend) self._check_function(frecv) requests = [] requests.extend([self._ensemble_comm.Isend(dat.data_ro, dest=dest, tag=sendtag) for dat in fsend.dat]) requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests