Source code for firedrake.interpolation

import numpy
import os
import tempfile
import abc
import warnings
from collections.abc import Iterable
from typing import Literal
from functools import partial, singledispatch
from typing import Hashable

import FIAT
import ufl
import finat.ufl
from ufl.algorithms import extract_arguments, extract_coefficients, replace
from ufl.domain import as_domain, extract_unique_domain

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache

from finat.element_factory import create_element, as_fiat_cell
from tsfc import compile_expression_dual_evaluation
from tsfc.ufl_utils import extract_firedrake_constants, hash_expr

import gem
import finat

import firedrake
import firedrake.bcs
from firedrake import tsfc_interface, utils, functionspaceimpl
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype as get_dat_mpi_type
from firedrake.cofunction import Cofunction
from mpi4py import MPI

from pyadjoint import stop_annotating, no_annotations

__all__ = (
    "interpolate",
    "Interpolator",
    "Interpolate",
    "DofNotDefinedError",
    "CrossMeshInterpolator",
    "SameMeshInterpolator",
)


[docs] class Interpolate(ufl.Interpolate): def __init__(self, expr, v, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): """Symbolic representation of the interpolation operator. Parameters ---------- expr : ufl.core.expr.Expr or ufl.BaseForm The UFL expression to interpolate. v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. subset : pyop2.types.set.Subset An optional subset to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. access : pyop2.types.access.Access The pyop2 access descriptor for combining updates to shared DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is supported at present when interpolating across meshes. See note in :func:`.interpolate` if changing this from default. allow_missing_dofs : bool For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh that are not in the source mesh. When ``False`` this raises a ``ValueError`` should this occur. When ``True`` the corresponding values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. This does not affect adjoint interpolation. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). default_missing_val : float For interpolation across meshes: the optional value to assign to DoFs in the target mesh that are outside the source mesh. If this is not set then the values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. matfree : bool If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. """ # Check function space expr = ufl.as_ufl(expr) if isinstance(v, functionspaceimpl.WithGeometry): expr_args = extract_arguments(expr) is_adjoint = len(expr_args) and expr_args[0].number() == 0 v = Argument(v.dual(), 1 if is_adjoint else 0) V = v.arguments()[0].function_space() if len(expr.ufl_shape) != len(V.value_shape): raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}') if expr.ufl_shape != V.value_shape: raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}') super().__init__(expr, v) # -- Interpolate data (e.g. `subset` or `access`) -- # self.interp_data = {"subset": subset, "access": access, "allow_missing_dofs": allow_missing_dofs, "default_missing_val": default_missing_val, "matfree": matfree} function_space = ufl.Interpolate.ufl_function_space def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): interp_data = interp_data or self.interp_data.copy() return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data)
[docs] @PETSc.Log.EventDecorator() def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. :arg expr: a UFL expression. :arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`, or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form). If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation. :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. :kwarg access: The pyop2 access descriptor for combining updates to shared DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is supported at present when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. See note below. :kwarg allow_missing_dofs: For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh that are not in the source mesh. When ``False`` this raises a ``ValueError`` should this occur. When ``True`` the corresponding values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. In either case, if ``default_missing_val`` is specified, that value is used. This does not affect adjoint interpolation. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). :kwarg default_missing_val: For interpolation across meshes: the optional value to assign to DoFs in the target mesh that are outside the source mesh. If this is not set then the values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. :returns: A symbolic :class:`.Interpolate` object .. note:: If you use an access descriptor other than ``WRITE``, the behaviour of interpolation changes if interpolating into a function space, or an existing function. If the former, then the newly allocated function will be initialised with appropriate values (e.g. for MIN access, it will be initialised with MAX_FLOAT). On the other hand, if you provide a function, then it is assumed that its values should take part in the reduction (hence using MIN will compute the MIN between the existing values and any new values). """ if isinstance(V, (Cofunction, Coargument)): dual_arg = V elif isinstance(V, ufl.BaseForm): rank = len(V.arguments()) if rank == 1: dual_arg = V else: raise TypeError(f"Expected a one-form, provided form had {rank} arguments") elif isinstance(V, functionspaceimpl.WithGeometry): dual_arg = Coargument(V.dual(), 0) expr_args = extract_arguments(ufl.as_ufl(expr)) if expr_args and expr_args[0].number() == 0: warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. " "Use a TrialFunction in the expression.") v, = expr_args expr = replace(expr, {v: v.reconstruct(number=1)}) else: raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") interp = Interpolate(expr, dual_arg, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, default_missing_val=default_missing_val, matfree=matfree) return interp
[docs] class Interpolator(abc.ABC): """A reusable interpolation object. This object can be used to carry out the same interpolation multiple times (for example in a timestepping loop). Parameters ---------- expr The underlying ufl.Interpolate or the operand to the ufl.Interpolate. V The :class:`.FunctionSpace` or :class:`.Function` to interpolate into. subset An optional :class:`pyop2.types.set.Subset` to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. freeze_expr Set to True to prevent the expression being re-evaluated on each call. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. access The pyop2 access descriptor for combining updates to shared DoFs. Only ``op2.WRITE`` is supported at present when interpolating across meshes. Only ``op2.INC`` is supported for the matrix-free adjoint interpolation. See note in :func:`.interpolate` if changing this from default. bcs An optional list of boundary conditions to zero-out in the output function space. Interpolator rows or columns which are associated with boundary condition nodes are zeroed out when this is specified. allow_missing_dofs For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh that are not in the source mesh. When ``False`` this raises a ``ValueError`` should this occur. When ``True`` the corresponding values are either (a) unchanged if some ``output`` is given to the :meth:`interpolate` method or (b) set to zero. Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. This does not affect adjoint interpolation. Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). matfree If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. Notes ----- The :class:`Interpolator` holds a reference to the provided arguments (such that they won't be collected until the :class:`Interpolator` is also collected). """ def __new__(cls, expr, V, **kwargs): V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() if not isinstance(expr, ufl.Interpolate): expr = interpolate(expr, V_target) arguments = expr.arguments() has_mixed_arguments = any(len(a.function_space()) > 1 for a in arguments) if len(arguments) == 2 and has_mixed_arguments: return object.__new__(MixedInterpolator) operand, = expr.ufl_operands target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh submesh_interp_implemented = \ all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ target_mesh.topological_dimension() == source_mesh.topological_dimension() if target_mesh is source_mesh or submesh_interp_implemented: return object.__new__(SameMeshInterpolator) else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): return object.__new__(SameMeshInterpolator) elif has_mixed_arguments or len(V_target) > 1: return object.__new__(MixedInterpolator) else: return object.__new__(CrossMeshInterpolator) def __init__( self, expr: ufl.Interpolate | ufl.classes.Expr, V: ufl.FunctionSpace | firedrake.function.Function, subset: op2.Subset | None = None, freeze_expr: bool = False, access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, bcs: Iterable[firedrake.bcs.BCBase] | None = None, allow_missing_dofs: bool = False, matfree: bool = True ): if not isinstance(expr, ufl.Interpolate): expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand self.V = V self.subset = subset self.freeze_expr = freeze_expr self.bcs = bcs self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree self.callable = None # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of # self.ufl_interpolate (which carries the dual argument). # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) and isinstance(self, SameMeshInterpolator) and isinstance(source_mesh.topology, VertexOnlyMeshTopology) and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) if isinstance(self, CrossMeshInterpolator) or vom_onto_other_vom: # For bespoke interpolation, we currently rely on different assembly procedures: # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). # For case 2, we first redundantly assemble case 1 and then construct the transpose. # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, # and we separately compute the action against the dropped Cofunction within assemble(). if not isinstance(dual_arg, ufl.Coargument): # Drop the Cofunction expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) expr_args = extract_arguments(operand) if expr_args and expr_args[0].number() == 0: # Construct the symbolic forward Interpolate v0, v1 = expr.arguments() expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), v1: v1.reconstruct(number=v0.number())}) dual_arg, operand = expr.argument_slots() self.expr_renumbered = operand self.ufl_interpolate_renumbered = expr if not isinstance(dual_arg, ufl.Coargument): # Matrix-free assembly of 0-form or 1-form requires INC access if access and access != op2.INC: raise ValueError("Matfree adjoint interpolation requires INC access") access = op2.INC elif access is None: # Default access for forward 1-form or 2-form (forward and adjoint) access = op2.WRITE self.access = access
[docs] def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ .. warning:: This method has been removed. Use the function :func:`interpolate` to return a symbolic :class:`Interpolate` object. """ raise FutureWarning( "The 'interpolate' method on `Interpolator` objects has been " "removed. Use the `interpolate` function instead." )
@abc.abstractmethod def _interpolate(self, *args, **kwargs): """ Compute the interpolation operation of interest. .. note:: This method is called when an :class:`Interpolate` object is being assembled. """ pass
[docs] def assemble(self, tensor=None, default_missing_val=None): """Assemble the operator (or its action).""" from firedrake.assemble import assemble needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate arguments = self.ufl_interpolate.arguments() if len(arguments) == 2: # Assembling the operator res = tensor.petscmat if tensor else PETSc.Mat() # Get the interpolation matrix op2mat = self.callable() petsc_mat = op2mat.handle if needs_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) elif tensor: petsc_mat.copy(tensor.petscmat) else: res = petsc_mat return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) else: # Assembling the action cofunctions = () if needs_adjoint: # The renumbered Interpolate has dropped Cofunctions. # We need to explicitly operate on them. dual_arg, _ = self.ufl_interpolate.argument_slots() if not isinstance(dual_arg, ufl.Coargument): cofunctions = (dual_arg,) if needs_adjoint and len(arguments) == 0: Iu = self._interpolate(default_missing_val=default_missing_val) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, default_missing_val=default_missing_val)
[docs] class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the target function space contains degrees of freedom (i.e. nodes) which cannot be defined in the source function space. This typically occurs when the target mesh covers a larger domain than the source mesh. Attributes ---------- src_mesh : :func:`.Mesh` The source mesh. dest_mesh : :func:`.Mesh` The destination mesh. """ def __init__(self, src_mesh, dest_mesh): self.src_mesh = src_mesh self.dest_mesh = dest_mesh def __str__(self): return ( f"The given target function space on domain {repr(self.dest_mesh)} " "contains degrees of freedom which cannot cannot be defined in the " f"source function space on domain {repr(self.src_mesh)}. " "This may be because the target mesh covers a larger domain than the " "source mesh. To disable this error, set allow_missing_dofs=True." )
[docs] class CrossMeshInterpolator(Interpolator): """ Interpolate a function from one mesh and function space to another. For arguments, see :class:`.Interpolator`. """ @no_annotations def __init__( self, expr, V, subset=None, freeze_expr=False, access=None, bcs=None, allow_missing_dofs=False, matfree=True ): if subset: raise NotImplementedError("subset not implemented") if freeze_expr: # Probably just need to pass freeze_expr to the various # interpolators for this to work. raise NotImplementedError("freeze_expr not implemented") if bcs: raise NotImplementedError("bcs not implemented") if V.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points # of the target function space in the source mesh. raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) if self.access != op2.WRITE: raise NotImplementedError("access other than op2.WRITE not implemented") expr = self.expr_renumbered self.arguments = extract_arguments(expr) self.nargs = len(self.arguments) if self._allow_missing_dofs: missing_points_behaviour = MissingPointsBehaviour.IGNORE else: missing_points_behaviour = MissingPointsBehaviour.ERROR # setup V_dest = V.function_space() if isinstance(V, firedrake.Function) else V src_mesh = extract_unique_domain(expr) dest_mesh = as_domain(V_dest) src_mesh_gdim = src_mesh.geometric_dimension() dest_mesh_gdim = dest_mesh.geometric_dimension() if src_mesh_gdim != dest_mesh_gdim: raise ValueError( "geometric dimensions of source and destination meshes must match" ) self.src_mesh = src_mesh self.dest_mesh = dest_mesh # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo # node coordinates because interpolation doesn't usually include halos. # NOTE: it is very important to set redundant=False, otherwise the # input ordering VOM will only contain the points on rank 0! # QUESTION: Should any of the below have annotation turned off? ufl_scalar_element = V_dest.ufl_element() if isinstance(ufl_scalar_element, finat.ufl.MixedElement): if type(ufl_scalar_element) is finat.ufl.MixedElement: raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") # For a VectorElement or TensorElement the correct # VectorFunctionSpace equivalent is built from the scalar # sub-element. ufl_scalar_element, = set(ufl_scalar_element.sub_elements) if ufl_scalar_element.reference_value_shape != (): raise NotImplementedError( "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." ) from firedrake.assemble import assemble V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec) f_dest_node_coords = assemble(f_dest_node_coords) dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) try: self.vom_dest_node_coords_in_src_mesh = firedrake.VertexOnlyMesh( src_mesh, dest_node_coords, redundant=False, missing_points_behaviour=missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: raise DofNotDefinedError(src_mesh, dest_mesh) # vom_dest_node_coords_in_src_mesh uses the parallel decomposition of # the global node coordinates of V_dest in the SOURCE mesh (src_mesh). # I first point evaluate my expression at these locations, giving a # P0DG function on the VOM. As described in the manual, this is an # interpolation operation. shape = V_dest.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom) # The parallel decomposition of the nodes of V_dest in the DESTINATION # mesh (dest_mesh) is retrieved using the input_ordering attribute of the # VOM. This again is an interpolation operation, which, under the hood # is a PETSc SF reduce. P0DG_vom_i_o = fs_type( self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 ) self.to_input_ordering_interpolate = Interpolate( firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o ) # The P0DG function outputted by the above interpolation has the # correct parallel decomposition for the nodes of V_dest in dest_mesh so # we can safely assign the dat values. This is all done in the actual # interpolation method below. @PETSc.Log.EventDecorator() def _interpolate( self, *function, output=None, transpose=None, adjoint=False, default_missing_val=None, **kwargs, ): """Compute the interpolation. For arguments, see :class:`.Interpolator`. """ from firedrake.assemble import assemble if transpose is not None: warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) adjoint = transpose or adjoint if adjoint and not self.nargs: raise ValueError( "Can currently only apply adjoint interpolation with arguments." ) if self.nargs != len(function): raise ValueError( "Passed %d Functions to interpolate, expected %d" % (len(function), self.nargs) ) if self.nargs: (f_src,) = function if not hasattr(f_src, "dat"): raise ValueError( "The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" ) else: f_src = self.expr if adjoint: try: V_dest = self.expr.function_space().dual() except AttributeError: if self.nargs: V_dest = self.arguments[-1].function_space().dual() else: coeffs = extract_coefficients(self.expr) if len(coeffs): V_dest = coeffs[0].function_space().dual() else: raise ValueError( "Can't adjoint interpolate an expression with no coefficients or arguments." ) else: if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): V_dest = self.V.function_space() else: V_dest = self.V if output: if output.function_space() != V_dest: raise ValueError("Given output has the wrong function space!") else: if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): output = self.V else: output = firedrake.Function(V_dest) if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate assert not self.nargs f_src_at_dest_node_coords_src_mesh_decomp = ( assemble(self.point_eval_interpolate) ) else: f_src_at_dest_node_coords_src_mesh_decomp = ( assemble(action(self.point_eval_interpolate, f_src)) ) f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( self.to_input_ordering_interpolate.function_space() ) # We have to create the Function before interpolating so we can # set default missing values (if requested). if default_missing_val is not None: f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : ] = default_missing_val elif self._allow_missing_dofs: # If we have allowed missing points we know we might end up # with points in the target mesh that are not in the source # mesh. However, since we haven't specified a default missing # value we expect the interpolation to leave these points # unchanged. By setting the dat values to NaN we can later # identify these points and skip over them when assigning to # the output function. f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan interp = action(self.to_input_ordering_interpolate, f_src_at_dest_node_coords_src_mesh_decomp) assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) # we can now confidently assign this to a function on V_dest if self._allow_missing_dofs and default_missing_val is None: indices = numpy.where( ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) )[0] output.dat.data_wo[ indices ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[indices] else: output.dat.data_wo[ : ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[:] else: # adjoint interpolation # f_src is a cofunction on V_dest.dual as originally specified when # creating the interpolator. Our first adjoint operation is to # assign the dat values to a P0DG cofunction on our input ordering # VOM. This has the parallel decomposition V_dest on our orinally # specified dest_mesh. We can therefore safely create a P0DG # cofunction on the input-ordering VOM (which has this parallel # decomposition and ordering) and assign the dat values. f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( self.to_input_ordering_interpolate.function_space().dual() ) f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ : ] = f_src.dat.data_ro[:] # The rest of the adjoint interpolation is merely the composition # of the adjoint interpolators in the reverse direction. NOTE: I # don't have to worry about skipping over missing points here # because I'm going from the input ordering VOM to the original VOM # and all points from the input ordering VOM are in the original. interp = action(expr_adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp) f_src_at_src_node_coords = assemble(interp) # NOTE: if I wanted the default missing value to be applied to # adjoint interpolation I would have to do it here. However, # this would require me to implement default missing values for # adjoint interpolation from a point evaluation interpolator # which I haven't done. I wonder if it is necessary - perhaps the # adjoint operator always sets all the values of the resulting # cofunction? My initial attempt to insert setting the dat values # prior to performing the multHermitian operation in # SameMeshInterpolator.interpolate did not effect the result. For # now, I say in the docstring that it only applies to forward # interpolation. interp = action(expr_adjoint(self.point_eval_interpolate), f_src_at_src_node_coords) assemble(interp, tensor=output) return output
[docs] class SameMeshInterpolator(Interpolator): """ An interpolator for interpolation within the same mesh or onto a validly- defined :func:`.VertexOnlyMesh`. For arguments, see :class:`.Interpolator`. """ @no_annotations def __init__(self, expr, V, subset=None, freeze_expr=False, access=None, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): if subset is None: if isinstance(expr, ufl.Interpolate): operand, = expr.ufl_operands else: operand = expr target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology source = source_mesh.topology if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": raise AssertionError("Only cell-cell interpolation supported") indices_active = composed_map.indices_active_with_halo make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: if not allow_missing_dofs: raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: # Do not need subset as target <= source. pass super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) expr = self.ufl_interpolate_renumbered try: self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() @PETSc.Log.EventDecorator() def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): """Compute the interpolation. For arguments, see :class:`.Interpolator`. """ if transpose is not None: warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) adjoint = transpose or adjoint try: assembled_interpolator = self.frozen_assembled_interpolator copy_required = True except AttributeError: assembled_interpolator = self.callable() copy_required = False # Return the original if self.freeze_expr: if len(self.arguments) == 2: # Interpolation operator self.frozen_assembled_interpolator = assembled_interpolator else: # Interpolation action self.frozen_assembled_interpolator = assembled_interpolator.copy() if len(self.arguments) == 2 and len(function) > 0: function, = function if not hasattr(function, "dat"): raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") if adjoint: mul = assembled_interpolator.handle.multHermitian col, row = self.arguments else: mul = assembled_interpolator.handle.mult row, col = self.arguments V = row.function_space().dual() assert function.function_space() == col.function_space() result = output or firedrake.Function(V) with function.dat.vec_ro as x, result.dat.vec_wo as out: if x is not out: mul(x, out) else: out_ = out.duplicate() mul(x, out_) out_.copy(result=out) return result else: if output: output.assign(assembled_interpolator) return output if isinstance(self.V, firedrake.Function): if copy_required: self.V.assign(assembled_interpolator) return self.V else: if len(self.arguments) == 0: return assembled_interpolator.dat.data.item() elif copy_required: return assembled_interpolator.copy() else: return assembled_interpolator
@PETSc.Log.EventDecorator() def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): if not isinstance(expr, ufl.Interpolate): raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") dual_arg, operand = expr.argument_slots() target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) and isinstance(source_mesh.topology, VertexOnlyMeshTopology) and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) arguments = expr.arguments() rank = len(arguments) if rank <= 1: if rank == 0: R = firedrake.FunctionSpace(target_mesh, "Real", 0) f = firedrake.Function(R, dtype=utils.ScalarType) elif isinstance(V, firedrake.Function): f = V V = f.function_space() else: V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) if access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) if access == firedrake.MIN: val = firedrake.Constant(finfo.max) else: val = firedrake.Constant(finfo.min) f.assign(val) tensor = f.dat elif rank == 2: if isinstance(V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") if vom_onto_other_vom: # We make our own linear operator for this case using PETSc SFs tensor = None else: Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed name="%s_%s_sparsity" % (Vrow.name, Vcol.name), nest=False, block_sparse=True) tensor = op2.Mat(sparsity) f = tensor else: raise ValueError(f"Cannot interpolate an expression with {rank} arguments") if vom_onto_other_vom: wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the # data, including the correct data size and dimensional information # (so for vector function spaces in 2 dimensions we might need a # concatenation of 2 MPI.DOUBLE types when we are in real mode) if tensor is not None: # Callable will do interpolation into our pre-supplied function f # when it is called. assert f.dat is tensor wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) assert len(arguments) == 1 def callable(): wrapper.forward_operation(f.dat) return f else: assert len(arguments) == 2 assert tensor is None # we know we will be outputting either a function or a cofunction, # both of which will use a dat as a data carrier. At present, the # data type does not depend on function space dimension, so we can # safely use the argument function space. NOTE: If this changes # after cofunctions are fully implemented, this will need to be # reconsidered. temp_source_func = firedrake.Function(Vcol) wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) # Leave wrapper inside a callable so we can access the handle # property. If matfree is True, then the handle is a PETSc SF # pretending to be a PETSc Mat. If matfree is False, then this # will be a PETSc Mat representing the equivalent permutation # matrix def callable(): return wrapper return callable else: loops = [] # Initialise to zero if needed if access is op2.INC: loops.append(tensor.zero) # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels if len(arguments) == 2: # Matrix case assumes that the spaces are not mixed expressions = {(0,): expr} elif isinstance(dual_arg, Coargument): # Split in the coargument expressions = dict(firedrake.formmanipulation.split_form(expr)) else: # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian interp = expr._ufl_expr_reconstruct_(operand, V) # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) # Split the cofunction dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) # Combine the splits by taking their action expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): sub_tensor = tensor[indices[0]] if rank == 1 else tensor loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs)) # Apply bcs if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) def callable(loops, f): for l in loops: l() return f return partial(callable, loops, f) @utils.known_pyop2_safe def _interpolator(tensor, expr, subset, access, bcs=None): if isinstance(expr, ufl.ZeroBaseForm): # Zero simplification, avoid code-generation if access is op2.INC: return () elif access is op2.WRITE: return (partial(tensor.zero, subset=subset),) # Unclear how to avoid codegen for MIN and MAX # Reconstruct the expression as an Interpolate V = expr.arguments()[-1].function_space().dual() expr = interpolate(ufl.zero(V.value_shape), V) if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") arguments = expr.arguments() dual_arg, operand = expr.argument_slots() V = dual_arg.arguments()[0].function_space() try: to_element = create_element(V.ufl_element()) except KeyError: # FInAT only elements raise NotImplementedError("Don't know how to create FIAT element for %s" % V.ufl_element()) if access is op2.READ: raise ValueError("Can't have READ access for output function") # NOTE: The par_loop is always over the target mesh cells. target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") # For trans-mesh interpolation we use a FInAT QuadratureElement as the # (base) target element with runtime point set expressions as their # quadrature rule point set and weights from their dual basis. # NOTE: This setup is useful for thinking about future design - in the # future this `rebuild` function can be absorbed into FInAT as a # transformer that eats an element and gives you an equivalent (which # may or may not be a QuadratureElement) that lets you do run time # tabulation. Alternatively (and this all depends on future design # decision about FInAT how dual evaluation should work) the # to_element's dual basis (which look rather like quadrature rules) can # have their pointset(s) directly replaced with run-time tabulated # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) rt_var_name = 'rt_X' try: cell = operand.ufl_element().ufl_cell() except AttributeError: # expression must be pure function of spatial coordinates so # domain has correct ufl cell cell = source_mesh.ufl_cell() to_element = rebuild(to_element, cell, rt_var_name) cell_set = target_mesh.cell_set if subset is not None: assert subset.superset == cell_set cell_set = subset parameters = {} parameters['scalar_type'] = utils.ScalarType copyin = () copyout = () # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity. needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: # Create a buffer for the weighted Cofunction W = dual_arg.function_space() v = firedrake.Function(W) expr = expr._ufl_expr_reconstruct_(operand, v=v) copyin += (partial(dual_arg.dat.copy, v.dat),) # Compute the reciprocal of the DOF multiplicity wdat = W.make_dat() m_ = get_interp_node_map(source_mesh, target_mesh, W) wsize = W.finat_element.space_dimension() * W.block_size kernel_code = f""" void multiplicity(PetscScalar *restrict w) {{ for (PetscInt i=0; i<{wsize}; i++) w[i] += 1; }}""" kernel = op2.Kernel(kernel_code, "multiplicity") op2.par_loop(kernel, cell_set, wdat(op2.INC, m_)) with wdat.vec as w: w.reciprocal() # Create a callable to apply the weight with wdat.vec_ro as w, v.dat.vec as y: copyin += (partial(y.pointwiseMult, y, w),) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping # (e.g. L2 Piola, or tensor element with symmetries) # FIXME: for the runtime unknown point set (for cross-mesh # interpolation) we have to pass the finat element we construct # here. Ideally we would only pass the UFL element through. kernel = compile_expression(cell_set.comm, expr, to_element, V.ufl_element(), domain=source_mesh, parameters=parameters) ast = kernel.ast oriented = kernel.oriented needs_cell_sizes = kernel.needs_cell_sizes coefficient_numbers = kernel.coefficient_numbers needs_external_coords = kernel.needs_external_coords name = kernel.name kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is not op2.INC), flop_count=kernel.flop_count, events=(kernel.event,)) parloop_args = [kernel, cell_set] coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) if needs_external_coords: coefficients = [source_mesh.coordinates] + coefficients if any(c.dat == tensor for c in coefficients): output = tensor tensor = op2.Dat(tensor.dataset) if access is not op2.WRITE: copyin += (partial(output.copy, tensor), ) copyout += (partial(tensor.copy, output), ) if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): V_dest = arguments[-1].function_space() m_ = get_interp_node_map(source_mesh, target_mesh, V_dest) parloop_args.append(tensor(access, m_)) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) rows_map = get_interp_node_map(source_mesh, target_mesh, Vrow) columns_map = get_interp_node_map(source_mesh, target_mesh, Vcol) lgmaps = None if bcs: if ufl.duals.is_dual(Vrow): Vrow = Vrow.dual() if ufl.duals.is_dual(Vcol): Vcol = Vcol.dual() bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) if needs_cell_sizes: cs = source_mesh.cell_sizes parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) for coefficient in coefficients: m_ = get_interp_node_map(source_mesh, target_mesh, coefficient.function_space()) parloop_args.append(coefficient.dat(op2.READ, m_)) for const in extract_firedrake_constants(expr): parloop_args.append(const.dat(op2.READ)) # Finally, add the target mesh reference coordinates if they appear in the kernel if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: # NOTE: TSFC will sometimes drop run-time arguments in generated # kernels if they are deemed not-necessary. # FIXME: Checking for argument name in the inner kernel to decide # whether to add an extra coefficient is a stopgap until # compile_expression_dual_evaluation # (a) outputs a coefficient map to indicate argument ordering in # parloops as `compile_form` does and # (b) allows the dual evaluation related coefficients to be supplied to # them rather than having to be added post-hoc (likely by # replacing `to_element` with a CoFunction/CoArgument as the # target `dual` which would contain `dual` related # coefficient(s)) if any(arg.name == rt_var_name for arg in kernel.code[name].args): # Add the coordinates of the target mesh quadrature points in the # source mesh's reference cell as an extra argument for the inner # loop. (With a vertex only mesh this is a single point for each # vertex cell.) target_ref_coords = target_mesh.reference_coordinates m_ = target_ref_coords.cell_node_map() parloop_args.append(target_ref_coords.dat(op2.READ, m_)) parloop = op2.ParLoop(*parloop_args) if isinstance(tensor, op2.Mat): return parloop, tensor.assemble else: return copyin + (parloop, ) + copyout def get_interp_node_map(source_mesh, target_mesh, fs): """Return the map between cells of the target mesh and nodes of the function space. If the function space is defined on the source mesh then the node map is composed with a map between target and source cells. """ if isinstance(target_mesh.topology, VertexOnlyMeshTopology): coeff_mesh = fs.mesh() m_ = fs.cell_node_map() if coeff_mesh is target_mesh or not coeff_mesh: # NOTE: coeff_mesh is None is allowed e.g. when interpolating from # a Real space pass elif coeff_mesh is source_mesh: if m_: # Since the par_loop is over the target mesh cells we need to # compose a map that takes us from target mesh cells to the # function space nodes on the source mesh. if source_mesh.extruded: # ExtrudedSet cannot be a map target so we need to build # this ourselves m_ = vom_cell_parent_node_map_extruded(target_mesh, m_) else: m_ = compose_map_and_cache(target_mesh.cell_parent_cell_map, m_) else: # m_ is allowed to be None when interpolating from a Real space, # even in the trans-mesh case. pass else: raise ValueError("Have coefficient with unexpected mesh") else: m_ = fs.entity_node_map(target_mesh.topology, "cell", None, None) return m_ try: _expr_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] except KeyError: _expr_cachedir = os.path.join(tempfile.gettempdir(), f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}") def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" dual_arg, operand = expr.argument_slots() return (hash_expr(operand), type(dual_arg), hash(ufl_element), utils.tuplify(parameters)) @memory_and_disk_cache( hashkey=_compile_expression_key, cachedir=tsfc_interface._cachedir ) @PETSc.Log.EventDecorator() def compile_expression(comm, *args, **kwargs): return compile_expression_dual_evaluation(*args, **kwargs) @singledispatch def rebuild(element, expr_cell, rt_var_name): raise NotImplementedError(f"Cross mesh interpolation not implemented for a {element} element.") @rebuild.register(finat.fiat_elements.ScalarFiatElement) def rebuild_dg(element, expr_cell, rt_var_name): # To tabulate on the given element (which is on a different mesh to the # expression) we must do so at runtime. We therefore create a quadrature # element with runtime points to evaluate for each point in the element's # dual basis. This exists on the same reference cell as the input element # and we can interpolate onto it before mapping the result back onto the # target space. expr_tdim = expr_cell.topological_dimension() # Need point evaluations and matching weights from dual basis. # This could use FIAT's dual basis as below: # num_points = sum(len(dual.get_point_dict()) for dual in element.fiat_equivalent.dual_basis()) # weights = [] # for dual in element.fiat_equivalent.dual_basis(): # pts = dual.get_point_dict().keys() # for p in pts: # for w, _ in dual.get_point_dict()[p]: # weights.append(w) # assert len(weights) == num_points # but for now we just fix the values to what we know works: if element.degree != 0 or not isinstance(element.cell, FIAT.reference_element.Point): raise NotImplementedError("Cross mesh interpolation only implemented for P0DG on vertex cells.") num_points = 1 weights = [1.]*num_points # gem.Variable name starting with rt_ forces TSFC runtime tabulation assert rt_var_name.startswith("rt_") runtime_points_expr = gem.Variable(rt_var_name, (num_points, expr_tdim)) rule_pointset = finat.point_set.UnknownPointSet(runtime_points_expr) rule = finat.quadrature.QuadratureRule(rule_pointset, weights=weights) return finat.QuadratureElement(as_fiat_cell(expr_cell), rule) @rebuild.register(finat.TensorFiniteElement) def rebuild_te(element, expr_cell, rt_var_name): return finat.TensorFiniteElement(rebuild(element.base_element, expr_cell, rt_var_name), element._shape, transpose=element._transpose) def compose_map_and_cache(map1, map2): """ Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1 using map2 as the cache key. The composed map maps from the iterset of map1 to the toset of map2. Makes :class:`pyop2.ComposedMap` and caches the result on map1 if the composed map is not found. :arg map1: The map with the desired iterset from which the result is retrieved or cached :arg map2: The map with the desired toset :returns: The composed map """ cache_key = hash((map2, "composed")) try: cmap = map1._cache[cache_key] except KeyError: # Real function space case separately cmap = None if map2 is None else op2.ComposedMap(map2, map1) map1._cache[cache_key] = cmap return cmap def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): """Build a map from the cells of a vertex only mesh to the nodes of the nodes on the source mesh where the source mesh is extruded. Parameters ---------- vertex_only_mesh : :class:`mesh.MeshGeometry` The ``mesh.VertexOnlyMesh`` whose cells we iterate over. extruded_cell_node_map : :class:`pyop2.Map` The cell node map of the function space on the extruded mesh within which the ``mesh.VertexOnlyMesh`` is immersed. Returns ------- :class:`pyop2.Map` The map from the cells of the vertex only mesh to the nodes of the source mesh's cell node map. The map iterset is the ``vertex_only_mesh.cell_set`` and the map toset is the ``extruded_cell_node_map.toset``. Notes ----- For an extruded mesh the cell node map is a map from a :class:`pyop2.ExtrudedSet` (the cells of the extruded mesh) to a :class:`pyop2.Set` (the nodes of the extruded mesh). Take for example ``mx = ExtrudedMesh(UnitIntervalMesh(2), 3)`` with ``mx.layers = 4`` which looks like .. code-block:: text -------------------layer 4------------------- | parent_cell_num = 2 | parent_cell_num = 5 | | | | | extrusion_height = 2 | extrusion_height = 2 | -------------------layer 3------------------- | parent_cell_num = 1 | parent_cell_num = 4 | | | | | extrusion_height = 1 | extrusion_height = 1 | -------------------layer 2------------------- | parent_cell_num = 0 | parent_cell_num = 3 | | | | | extrusion_height = 0 | extrusion_height = 0 | -------------------layer 1------------------- base_cell_num = 0 base_cell_num = 1 If we declare ``FunctionSpace(mx, "CG", 2)`` then the node numbering (i.e. Degree of Freedom/DoF numbering) is .. code-block:: text 6 ---------13----------20---------27---------34 | | | 5 12 19 26 33 | | | 4 ---------11----------18---------25---------32 | | | 3 10 17 24 31 | | | 2 ---------9-----------16---------23---------30 | | | 1 8 15 22 29 | | | 0 ---------7-----------14---------21---------28 base_cell_num = 0 base_cell_num = 1 Cell node map values for an extruded mesh are indexed by the base cell number (rows) and the degree of freedom (DoF) index (columns). So ``extruded_cell_node_map.values[0] = [14, 15, 16, 0, 1, 2, 7, 8, 9]`` are all the DoF/node numbers for the ``base_cell_num = 0``. Similarly ``extruded_cell_node_map.values[1] = [28, 29, 30, 21, 22, 23, 14, 15, 16]`` contain all 9 of the DoFs for ``base_cell_num = 1``. To get the DoFs/nodes for the rest of the cells we need to include the ``extruded_cell_node_map.offset``, which tells us how far each cell's DoFs/nodes are translated up from the first layer to the second, and multiply these by the the given ``extrusion_height``. So in our example ``extruded_cell_node_map.offset = [2, 2, 2, 2, 2, 2, 2, 2, 2]`` (we index this with the DoF/node index - it's an array because each DoF/node in the extruded mesh cell, in principal, can be offset upwards by a different amount). For ``base_cell_num = 0`` with ``extrusion_height = 1`` (``parent_cell_num = 1``) we add ``1*2 = 2`` to each of the DoFs/nodes in ``extruded_cell_node_map.values[0]`` to get ``extruded_cell_node_map.values[0] + 1 * extruded_cell_node_map.offset[0] = [16, 17, 18, 2, 3, 4, 9, 10, 11]`` where ``0`` is the DoF/node index. For each cell (vertex) of a vertex only mesh immersed in a parent extruded mesh, we can can get the corresponding ``base_cell_num`` and ``extrusion_height`` of the parent extruded mesh. Armed with this information we use the above to work out the corresponding DoFs/nodes on the parent extruded mesh. """ if not isinstance(vertex_only_mesh.topology, VertexOnlyMeshTopology): raise TypeError("The input mesh must be a VertexOnlyMesh") cnm = extruded_cell_node_map vmx = vertex_only_mesh dofs_per_target_cell = cnm.arity base_cells = vmx.cell_parent_base_cell_list heights = vmx.cell_parent_extrusion_height_list assert cnm.values_with_halo.shape[1] == dofs_per_target_cell assert len(cnm.offset) == dofs_per_target_cell target_cell_parent_node_list = [ cnm.values_with_halo[base_cell, :] + height * cnm.offset[:] for base_cell, height in zip(base_cells, heights) ] return op2.Map( vmx.cell_set, cnm.toset, dofs_per_target_cell, target_cell_parent_node_list ) class GlobalWrapper(object): """Wrapper object that fakes a Global to behave like a Function.""" def __init__(self, glob): self.dat = glob self.cell_node_map = lambda *arguments: None self.ufl_domain = lambda: None class VomOntoVomWrapper(object): """Utility class for interpolating from one ``VertexOnlyMesh`` to it's intput ordering ``VertexOnlyMesh``, or vice versa. Parameters ---------- V : `.FunctionSpace` The P0DG function space (which may be vector or tensor valued) on the source vertex-only mesh. source_vom : `.VertexOnlyMesh` The vertex-only mesh we interpolate from. target_vom : `.VertexOnlyMesh` The vertex-only mesh we interpolate to. expr : `ufl.Expr` The expression to interpolate. If ``arguments`` is not empty, those arguments must be present within it. matfree : bool If ``False``, the matrix representating the permutation of the points is constructed and used to perform the interpolation. If ``True``, then the interpolation is performed using the broadcast and reduce operations on the PETSc Star Forest. """ def __init__(self, V, source_vom, target_vom, expr, matfree): arguments = extract_arguments(expr) reduce = False if source_vom.input_ordering is target_vom: reduce = True original_vom = source_vom elif target_vom.input_ordering is source_vom: original_vom = target_vom else: raise ValueError( "The target vom and source vom must be linked by input ordering!" ) self.V = V self.source_vom = source_vom self.expr = expr self.arguments = arguments self.reduce = reduce # note that interpolation doesn't include halo cells self.dummy_mat = VomOntoVomDummyMat( original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, expr, arguments ) if matfree: # If matfree, we use the SF to perform the interpolation self.handle = self.dummy_mat._wrap_dummy_mat() else: # Otherwise we create the permutation matrix self.handle = self.dummy_mat._create_permutation_mat() @property def mpi_type(self): """ The MPI type to use for the PETSc SF. Should correspond to the underlying data type of the PETSc Vec. """ return self.handle.mpi_type @mpi_type.setter def mpi_type(self, val): self.dummy_mat.mpi_type = val def forward_operation(self, target_dat): coeff = self.dummy_mat.expr_as_coeff() with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: self.handle.mult(coeff_vec, target_vec) class VomOntoVomDummyMat(object): """Dummy object to stand in for a PETSc ``Mat`` when we are interpolating between vertex-only meshes. Parameters ---------- sf: PETSc.sf The PETSc Star Forest (SF) to use for the operation forward_reduce : bool If ``True``, the action of the operator (accessed via the `mult` method) is to perform a SF reduce from the source vec to the target vec, whilst the adjoint action (accessed via the `multHermitian` method) is to perform a SF broadcast from the source vec to the target vec. If ``False``, the opposite is true. V : `.FunctionSpace` The P0DG function space (which may be vector or tensor valued) on the source vertex-only mesh. source_vom : `.VertexOnlyMesh` The vertex-only mesh we interpolate from. expr : `ufl.Expr` The expression to interpolate. If ``arguments`` is not empty, those arguments must be present within it. arguments : list of `ufl.Argument` The arguments in the expression. """ def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments): self.sf = sf self.forward_reduce = forward_reduce self.V = V self.source_vom = source_vom self.expr = expr self.arguments = arguments # Calculate correct local and global sizes for the matrix nroots, leaves, _ = sf.getGraph() self.nleaves = len(leaves) self._local_sizes = V.comm.allgather(nroots) self.source_size = (self.V.block_size * nroots, self.V.block_size * sum(self._local_sizes)) self.target_size = ( self.V.block_size * self.nleaves, self.V.block_size * V.comm.allreduce(self.nleaves, op=MPI.SUM), ) @property def mpi_type(self): """ The MPI type to use for the PETSc SF. Should correspond to the underlying data type of the PETSc Vec. """ return self._mpi_type @mpi_type.setter def mpi_type(self, val): self._mpi_type = val def expr_as_coeff(self, source_vec=None): """ Return a coefficient that corresponds to the expression used at construction, where the expression has been interpolated into the P0DG function space on the source vertex-only mesh. Will fail if there are no arguments. """ # Since we always output a coefficient when we don't have arguments in # the expression, we should evaluate the expression on the source mesh # so its dat can be sent to the target mesh. with stop_annotating(): element = self.V.ufl_element() # Could be vector/tensor valued P0DG = firedrake.FunctionSpace(self.source_vom, element) # if we have any arguments in the expression we need to replace # them with equivalent coefficients now coeff_expr = self.expr if len(self.arguments): if len(self.arguments) > 1: raise NotImplementedError( "Can only interpolate expressions with one argument!" ) if source_vec is None: raise ValueError("Need to provide a source dat for the argument!") arg = self.arguments[0] arg_coeff = firedrake.Function(arg.function_space()) arg_coeff.dat.data_wo[:] = source_vec.getArray(readonly=True).reshape( arg_coeff.dat.data_wo.shape ) coeff_expr = ufl.replace(self.expr, {arg: arg_coeff}) coeff = firedrake.Function(P0DG).interpolate(coeff_expr) return coeff def reduce(self, source_vec, target_vec): source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.reduceBegin( self.mpi_type, source_arr, target_arr, MPI.REPLACE, ) self.sf.reduceEnd( self.mpi_type, source_arr, target_arr, MPI.REPLACE, ) def broadcast(self, source_vec, target_vec): source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.bcastBegin( self.mpi_type, source_arr, target_arr, MPI.REPLACE, ) self.sf.bcastEnd( self.mpi_type, source_arr, target_arr, MPI.REPLACE, ) def mult(self, mat, source_vec, target_vec): # need to evaluate expression before doing mult coeff = self.expr_as_coeff(source_vec) with coeff.dat.vec_ro as coeff_vec: if self.forward_reduce: self.reduce(coeff_vec, target_vec) else: self.broadcast(coeff_vec, target_vec) def multHermitian(self, mat, source_vec, target_vec): self.multTranspose(mat, source_vec, target_vec) def multTranspose(self, mat, source_vec, target_vec): # can only do adjoint if our expression exclusively contains a # single argument, making the application of the adjoint operator # straightforward (haven't worked out how to do this otherwise!) if not len(self.arguments) == 1: raise NotImplementedError( "Can only apply adjoint to expressions with one argument!" ) if self.arguments[0] is not self.expr: raise NotImplementedError( "Can only apply adjoint to expressions consisting of a single argument at the moment." ) if self.forward_reduce: self.broadcast(source_vec, target_vec) else: # We need to ensure the target vec is zeroed for SF Reduce to # represent multHermitian in case the interpolation matrix is not # square (in which case it will have columns which are zero). This # happens when we interpolate from an input-ordering vertex-only # mesh to an immersed vertex-only mesh where the input ordering # contains points that are not in the immersed mesh. The resulting # interpolation matrix will have columns of zeros for the points # that are not in the immersed mesh. The adjoint interpolation # matrix will then have rows of zeros for those points. target_vec.zeroEntries() self.reduce(source_vec, target_vec) def _create_permutation_mat(self): """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to its input ordering vertex-only mesh""" mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) mat.setUp() start = sum(self._local_sizes[:self.V.comm.rank]) end = start + self.source_size[0] contiguous_indices = numpy.arange(start, end, dtype=utils.IntType) perm = numpy.zeros(self.nleaves, dtype=utils.IntType) self.sf.bcastBegin(MPI.INT, contiguous_indices, perm, MPI.REPLACE) self.sf.bcastEnd(MPI.INT, contiguous_indices, perm, MPI.REPLACE) rows = numpy.arange(self.target_size[0] + 1, dtype=utils.IntType) cols = (self.V.block_size * perm[:, None] + numpy.arange(self.V.block_size, dtype=utils.IntType)[None, :]).reshape(-1) mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=utils.IntType)) mat.assemble() if self.forward_reduce: mat.transpose() return mat def _wrap_dummy_mat(self): mat = PETSc.Mat().create(comm=self.V.comm) if self.forward_reduce: mat_size = (self.source_size, self.target_size) else: mat_size = (self.target_size, self.source_size) mat.setSizes(mat_size) mat.setType(mat.Type.PYTHON) mat.setPythonContext(self) mat.setUp() return mat def duplicate(self, mat=None, op=None): return self._wrap_dummy_mat() class MixedInterpolator(Interpolator): """A reusable interpolation object between MixedFunctionSpaces. Parameters ---------- expr The underlying ufl.Interpolate or the operand to the ufl.Interpolate. V The :class:`.FunctionSpace` or :class:`.Function` to interpolate into. bcs A list of boundary conditions. **kwargs Any extra kwargs are passed on to the sub Interpolators. For details see :class:`firedrake.interpolation.Interpolator`. """ def __init__(self, expr, V, bcs=None, **kwargs): super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) expr = self.ufl_interpolate self.arguments = expr.arguments() rank = len(self.arguments) # We need a Coargument in order to split the Interpolate needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 if needs_action: dual_arg, operand = expr.argument_slots() # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) # Create the Jacobian to be split into blocks expr = expr._ufl_expr_reconstruct_(operand, V) Isub = {} # Split in the arguments of the Interpolate for indices, form in firedrake.formmanipulation.split_form(expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity continue vi, _ = form.argument_slots() Vtarget = vi.function_space().dual() if bcs and rank != 0: args = form.arguments() Vsource = args[1-vi.number()].function_space() sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] else: sub_bcs = None if needs_action: # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) self._sub_interpolators = Isub self.callable = self._assemble_matnest def __getitem__(self, item): return self._sub_interpolators[item] def __iter__(self): return iter(self._sub_interpolators) def _assemble_matnest(self): """Assemble the operator.""" shape = tuple(len(a.function_space()) for a in self.arguments) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) # Assemble the sparse block matrix for i in self: blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) return tensor.M def _interpolate(self, *function, output=None, adjoint=False, **kwargs): """Assemble the action.""" rank = len(self.arguments) if rank == 0: result = sum(self[i].assemble(**kwargs) for i in self) return output.assign(result) if output else result if output is None: output = firedrake.Function(self.arguments[-1].function_space().dual()) if rank == 1: for k, sub_tensor in enumerate(output.subfunctions): sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) elif rank == 2: for k, sub_tensor in enumerate(output.subfunctions): sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) for i in self if i[0] == k)) return output