Source code for firedrake.adjoint_utils.ensemble_function
from pyadjoint.overloaded_type import OverloadedType
from firedrake.petsc import PETSc
from .checkpointing import disk_checkpointing
from functools import wraps
[docs]
class EnsembleFunctionMixin(OverloadedType):
"""
Basic functionality for EnsembleFunction to be OverloadedTypes.
Note that currently no EnsembleFunction operations are taped.
Enables EnsembleFunction to do the following:
- Be a Control for a NumpyReducedFunctional (_ad_to_list and _ad_assign_numpy)
- Be used with pyadjoint TAO solver (_ad_{to,from}_petsc)
- Be used as a Control for Taylor tests (_ad_dot)
"""
@staticmethod
def _ad_annotate_init(init):
@wraps(init)
def wrapper(self, *args, **kwargs):
OverloadedType.__init__(self)
init(self, *args, **kwargs)
self._ad_add = self.__add__
self._ad_mul = self.__mul__
self._ad_iadd = self.__iadd__
self._ad_imul = self.__imul__
self._ad_copy = self.copy
return wrapper
@staticmethod
def _ad_to_list(m):
with m.vec_ro() as gvec:
lvec = PETSc.Vec().createSeq(gvec.size,
comm=PETSc.COMM_SELF)
PETSc.Scatter().toAll(gvec).scatter(
gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES)
return lvec.array_r.tolist()
@staticmethod
def _ad_assign_numpy(dst, src, offset):
with dst.vec_wo() as vec:
begin, end = vec.owner_range
vec.array[:] = src[offset + begin: offset + end]
offset += vec.size
return dst, offset
def _ad_dot(self, other, options=None):
local_dot = sum(uself._ad_dot(uother, options=options)
for uself, uother in zip(self.subfunctions,
other.subfunctions))
return self.ensemble.ensemble_comm.allreduce(local_dot)
def _ad_convert_riesz(self, value, options=None):
raise NotImplementedError
def _ad_init_zero(self, dual=False):
from firedrake import EnsembleFunction, EnsembleCofunction
if dual:
return EnsembleCofunction(self.function_space().dual())
else:
return EnsembleFunction(self.function_space())
def _ad_create_checkpoint(self):
if disk_checkpointing():
raise NotImplementedError(
"Disk checkpointing not implemented for EnsembleFunctions")
else:
return self.copy()
def _ad_restore_at_checkpoint(self, checkpoint):
if type(checkpoint) is type(self):
return checkpoint
raise NotImplementedError(
"Disk checkpointing not implemented for EnsembleFunctions")
def _ad_from_petsc(self, vec):
with self.vec_wo as self_v:
vec.copy(self_v)
def _ad_to_petsc(self, vec=None):
with self.vec_ro as self_v:
return self_v.copy(vec or self._vec.duplicate())