Source code for firedrake.preconditioners.patch

from firedrake.preconditioners.base import PCBase, SNESBase, PCSNESBase
from firedrake.preconditioners.asm import validate_overlap
from firedrake.petsc import PETSc
from firedrake.cython.patchimpl import set_patch_residual, set_patch_jacobian
from firedrake.solving_utils import _SNESContext
from firedrake.utils import complex_mode, IntType
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
from firedrake.interpolation import interpolate
from firedrake.ufl_expr import extract_domains

from collections import namedtuple
import operator
from itertools import chain
from functools import cached_property, partial
import numpy
from finat.ufl import VectorElement, MixedElement
from ufl.domain import extract_unique_domain
from tsfc.ufl_utils import extract_firedrake_constants
import weakref
import petsctools

import ctypes
from pyop2 import op2
import pyop2.types
from pyop2.compilation import load
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel
from pyop2.mpi import COMM_SELF

__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")


class DenseSparsity(object):
    def __init__(self, rset, cset):
        self.shape = (1, 1)
        self._nrows = rset.size
        self._ncols = cset.size
        self._dims = (((1, 1), ), )
        self.dims = self._dims
        self.dsets = rset, cset

    def __getitem__(self, *args):
        return self

    def __contains__(self, *args):
        return True


class LocalPack(Pack):
    def pick_loop_indices(self, loop_index, layer_index, entity_index):
        return (entity_index, layer_index)


class LocalMatPack(LocalPack, MatPack):
    insertion_names = {False: "MatSetValues",
                       True: "MatSetValues"}


class LocalMatKernelArg(op2.MatKernelArg):

    pack = LocalMatPack


class LocalMatLegacyArg(op2.MatLegacyArg):

    @property
    def global_kernel_arg(self):
        map_args = [m._global_kernel_arg for m in self.maps]
        return LocalMatKernelArg(self.data.dims, map_args)


class LocalMat(pyop2.types.AbstractMat):

    def __init__(self, dset):
        self._sparsity = DenseSparsity(dset, dset)
        self.dtype = numpy.dtype(PETSc.ScalarType)

    def __call__(self, access, maps):
        return LocalMatLegacyArg(self, maps, access)


class LocalDatPack(LocalPack, DatPack):
    def __init__(self, needs_mask, *args, **kwargs):
        self.needs_mask = needs_mask
        super().__init__(*args, **kwargs)

    def _mask(self, map_):
        if self.needs_mask:
            return Comparison(">=", map_, Literal(numpy.int32(0)))
        else:
            return None


class LocalDatKernelArg(op2.DatKernelArg):

    def __init__(self, *args, needs_mask, **kwargs):
        super().__init__(*args, **kwargs)
        self.needs_mask = needs_mask

    @property
    def pack(self):
        return partial(LocalDatPack, self.needs_mask)


class LocalDatLegacyArg(op2.DatLegacyArg):

    @property
    def global_kernel_arg(self):
        map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
        return LocalDatKernelArg(self.data.dataset.dim, map_arg,
                                 needs_mask=self.data.needs_mask)


class LocalDat(pyop2.types.AbstractDat):
    def __init__(self, dset, needs_mask=False):
        self._dataset = dset
        self.dtype = numpy.dtype(PETSc.ScalarType)
        self._shape = (dset.total_size,) + (() if dset.cdim == 1 else dset.dim)
        self.needs_mask = needs_mask

    @cached_property
    def _wrapper_cache_key_(self):
        return super()._wrapper_cache_key_ + (self.needs_mask, )

    def __call__(self, access, map_=None):
        return LocalDatLegacyArg(self, map_, access)

    def increment_dat_version(self):
        pass


register_petsc_function("MatSetValues")


CompiledKernel = namedtuple('CompiledKernel', ["funptr", "kinfo"])


def get_map(V, base_mesh, base_integral_type):
    return V.topological.entity_node_map(base_mesh.topology, base_integral_type, None, None)


def matrix_funptr(form, state):
    from firedrake.tsfc_interface import compile_form
    test, trial = map(operator.methodcaller("function_space"), form.arguments())
    if test != trial:
        raise NotImplementedError("Only for matching test and trial spaces")

    if state is not None:
        dont_split = (state, )
    else:
        dont_split = ()

    kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)

    all_meshes = extract_domains(form)
    cell_kernels = []
    int_facet_kernels = []
    ext_facet_kernels = []
    for kernel in kernels:
        kinfo = kernel.kinfo
        mesh = all_meshes[kinfo.domain_number]  # integration domain
        integral_type = kinfo.integral_type

        if kinfo.subdomain_id != ("otherwise",):
            raise NotImplementedError("Only for full domain integrals")
        if kinfo.integral_type not in {"cell", "interior_facet", "exterior_facet"}:
            raise NotImplementedError("Only for cell, interior facet, or exterior facet integrals")

        # OK, now we've validated the kernel, let's build the callback
        args = []

        if integral_type == "cell":
            kernels = cell_kernels
        elif integral_type == "interior_facet":
            kernels = int_facet_kernels
        elif integral_type == "exterior_facet":
            kernels = ext_facet_kernels

        toset = op2.Set(1, comm=test.comm)
        dofset = op2.DataSet(toset, 1)
        arity = sum(m.arity*s.cdim
                    for m, s in zip(get_map(test, mesh, integral_type),
                                    test.dof_dset))
        iterset = get_map(test, mesh, integral_type).iterset
        entity_node_map = op2.Map(iterset,
                                  toset, arity,
                                  values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        mat = LocalMat(dofset)

        arg = mat(op2.INC, (entity_node_map, entity_node_map))
        args.append(arg.global_kernel_arg)
        statedat = LocalDat(dofset)
        state_entity_node_map = op2.Map(iterset,
                                        toset, arity,
                                        values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        statearg = statedat(op2.READ, state_entity_node_map)

        for i in kinfo.active_domain_numbers.coordinates:
            c = all_meshes[i].coordinates
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for i in kinfo.active_domain_numbers.cell_orientations:
            c = all_meshes[i].cell_orientations()
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for i in kinfo.active_domain_numbers.cell_sizes:
            c = all_meshes[i].cell_sizes
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for n, indices in kinfo.coefficient_numbers:
            c = form.coefficients()[n]
            if c is state:
                if indices != (0, ):
                    raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
                args.append(statearg.global_kernel_arg)
                continue
            for ind in indices:
                c_ = c.subfunctions[ind]
                map_ = get_map(c_.function_space(), mesh, integral_type)
                if c_.function_space().ufl_element().family() == "Real":
                    # Interior facet integrals double Real coefficients for the
                    # two sides of the facet, matching the TSFC-generated kernel.
                    arg = op2.GlobalKernelArg(
                        (c_.function_space().block_size,), double=integral_type.startswith("interior_facet")
                    )
                else:
                    arg = c_.dat(op2.READ, map_).global_kernel_arg
                args.append(arg)

        all_constants = extract_firedrake_constants(form)
        for constant_index in kinfo.constant_numbers:
            args.append(all_constants[constant_index].dat(op2.READ).global_kernel_arg)

        if integral_type == "interior_facet":
            arg = mesh.interior_facets.local_facet_dat(op2.READ)
            args.append(arg.global_kernel_arg)
        elif integral_type == "exterior_facet":
            arg = mesh.exterior_facets.local_facet_dat(op2.READ)
            args.append(arg.global_kernel_arg)
        iterset = op2.Subset(iterset, [])

        mod = op2.GlobalKernel(kinfo.kernel, args, subset=True)
        kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
    return cell_kernels, int_facet_kernels, ext_facet_kernels


def residual_funptr(form, state):
    from firedrake.tsfc_interface import compile_form
    test, = map(operator.methodcaller("function_space"), form.arguments())

    if state.function_space() != test:
        raise NotImplementedError("State and test space must be dual to one-another")

    if state is not None:
        dont_split = (state, )
    else:
        dont_split = ()

    kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)

    all_meshes = extract_domains(form)
    cell_kernels = []
    int_facet_kernels = []
    ext_facet_kernels = []
    for kernel in kernels:
        kinfo = kernel.kinfo
        mesh = all_meshes[kinfo.domain_number]  # integration domain
        integral_type = kinfo.integral_type

        if kinfo.subdomain_id != ("otherwise",):
            raise NotImplementedError("Only for full domain integrals")
        if kinfo.integral_type not in {"cell", "interior_facet", "exterior_facet"}:
            raise NotImplementedError("Only for cell, interior facet, or exterior facet integrals")
        args = []

        if kinfo.integral_type == "cell":
            kernels = cell_kernels
        elif kinfo.integral_type == "interior_facet":
            kernels = int_facet_kernels
        elif kinfo.integral_type == "exterior_facet":
            kernels = ext_facet_kernels

        toset = op2.Set(1, comm=test.comm)
        dofset = op2.DataSet(toset, 1)
        arity = sum(m.arity*s.cdim
                    for m, s in zip(get_map(test, mesh, integral_type),
                                    test.dof_dset))
        iterset = get_map(test, mesh, integral_type).iterset
        entity_node_map = op2.Map(iterset,
                                  toset, arity,
                                  values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        dat = LocalDat(dofset, needs_mask=True)

        statedat = LocalDat(dofset)
        state_entity_node_map = op2.Map(iterset,
                                        toset, arity,
                                        values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        statearg = statedat(op2.READ, state_entity_node_map)

        arg = dat(op2.INC, entity_node_map)
        args.append(arg.global_kernel_arg)

        for i in kinfo.active_domain_numbers.coordinates:
            c = all_meshes[i].coordinates
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for i in kinfo.active_domain_numbers.cell_orientations:
            c = all_meshes[i].cell_orientations()
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for i in kinfo.active_domain_numbers.cell_sizes:
            c = all_meshes[i].cell_sizes
            arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
            args.append(arg.global_kernel_arg)
        for n, indices in kinfo.coefficient_numbers:
            c = form.coefficients()[n]
            if c is state:
                if indices != (0, ):
                    raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
                args.append(statearg.global_kernel_arg)
                continue
            for ind in indices:
                c_ = c.subfunctions[ind]
                map_ = get_map(c_.function_space(), mesh, integral_type)
                if c_.function_space().ufl_element().family() == "Real":
                    # Interior facet integrals double Real coefficients for the
                    # two sides of the facet, matching the TSFC-generated kernel.
                    arg = op2.GlobalKernelArg(
                        (c_.function_space().block_size,), double=integral_type.startswith("interior_facet")
                    )
                else:
                    arg = c_.dat(op2.READ, map_).global_kernel_arg
                args.append(arg)

        all_constants = extract_firedrake_constants(form)
        for constant_index in kinfo.constant_numbers:
            args.append(all_constants[constant_index].dat(op2.READ).global_kernel_arg)

        if kinfo.integral_type == "interior_facet":
            arg = extract_unique_domain(test).interior_facets.local_facet_dat(op2.READ)
            args.append(arg.global_kernel_arg)
        elif kinfo.integral_type == "exterior_facet":
            arg = extract_unique_domain(test).exterior_facets.local_facet_dat(op2.READ)
            args.append(arg.global_kernel_arg)
        iterset = op2.Subset(iterset, [])

        mod = op2.GlobalKernel(kinfo.kernel, args, subset=True)
        kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
    return cell_kernels, int_facet_kernels, ext_facet_kernels


# We need to set C function pointer callbacks for PCPatch to work.
# Although petsc4py provides a high-level Python wrapper for them,
# this is very costly when going back and forth from C to Python only
# to extract function pointers and send them straight back to C. Here,
# since we know what the calling convention of the C function is, we
# just wrap up everything as a C function pointer and use that
# directly.
def make_struct(op_coeffs, op_maps, jacobian=False):
    import ctypes
    coeffs = []
    maps = []
    for i, c in enumerate(op_coeffs):
        if c is None:
            coeffs.append("state")
        else:
            coeffs.append("c{}".format(i))
    for i, m in enumerate(op_maps):
        if m is None:
            maps.append("dofArrayWithAll")
        else:
            maps.append("m{}".format(i))
    coeff_struct = ";\n".join("  const PetscScalar *c{}".format(i) for i, c in enumerate(op_coeffs) if c is not None)
    map_struct = ";\n".join("  const PetscInt    *m{}".format(i) for i, m in enumerate(op_maps) if m is not None)
    coeff_decl = ", ".join("const PetscScalar *restrict {}".format(c) for c in coeffs)
    map_decl = ", ".join("const PetscInt *restrict {}".format(m) for m in maps)
    coeff_call = ", ".join(c if c == "state" else "ctx->{}".format(c) for c in coeffs)
    map_call = ", ".join(m if m == "dofArrayWithAll" else "ctx->{}".format(m) for m in maps)
    if jacobian:
        out = "Mat J"
    else:
        out = "PetscScalar * restrict F"
    function = "  void (*pyop2_call)(int start, int end, const PetscInt * restrict cells, {}, {}, const PetscInt *restrict dofArray, {})".format(out, coeff_decl, map_decl)

    fields = []
    for c in coeffs:
        if c != "state":
            fields.append((c, ctypes.c_voidp))
    for m in maps:
        if m != "dofArrayWithAll":
            fields.append((m, ctypes.c_voidp))
    fields.append(("point2facet", ctypes.c_voidp))
    fields.append(("pyop2_call", ctypes.c_voidp))

    class Struct(ctypes.Structure):
        _fields_ = fields
    struct = """
typedef struct {{
{};
{};
  const PetscInt    *point2facet;
{};
}} UserCtx;""".format(coeff_struct, map_struct, function)
    call = "pyop2_call(0, npoints, whichPoints, out, {}, activeDofsArray, {})".format(coeff_call, map_call)

    return struct, call, Struct


def make_residual_wrapper(coeffs, maps, flops):
    struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=False)

    return """
#include <petsc.h>
{}
static PetscInt pointbuf[128];
PetscErrorCode ComputeResidual(PC pc,
                               PetscInt point,
                               Vec x,
                               Vec F,
                               IS points,
                               PetscInt ndof,
                               const PetscInt *dofArray,
                               const PetscInt *dofArrayWithAll,
                               void *ctx_)
{{
   const PetscScalar *state          = NULL;
   const PetscInt    *whichPoints    = NULL;
   const PetscInt    *activeDofsArray = dofArray;
   PetscScalar       *out            = NULL;
   UserCtx           *ctx            = (UserCtx *)ctx_;
   PetscInt           npoints;
   PetscInt          *filtpoints     = NULL;
   PetscInt          *filtdofs       = NULL;
   PetscErrorCode     ierr;
   PetscFunctionBeginUser;
   ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
   if (!npoints) PetscFunctionReturn(0);
   ierr = VecSet(F, 0.0);CHKERRQ(ierr);
   if (x) {{
     ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
   }}
   ierr = VecGetArray(F, &out);CHKERRQ(ierr);
   ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
   if (ctx->point2facet) {{
     PetscInt nvalid = 0;
     PetscInt tDPP   = ndof / npoints;
     ierr = PetscMalloc1(npoints, &filtpoints);CHKERRQ(ierr);
     if (ndof > 0) {{ ierr = PetscMalloc1(ndof, &filtdofs);CHKERRQ(ierr); }}
     for (PetscInt i = 0; i < npoints; i++) {{
       PetscInt fi = ctx->point2facet[whichPoints[i]];
       if (fi >= 0) {{
         filtpoints[nvalid] = fi;
         for (PetscInt d = 0; d < tDPP; d++)
           filtdofs[nvalid * tDPP + d] = dofArray[i * tDPP + d];
         nvalid++;
       }}
     }}
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
     npoints        = nvalid;
     whichPoints    = filtpoints;
     activeDofsArray = filtdofs;
   }}
   if (npoints) ctx->{};
   if (ctx->point2facet) {{
     ierr = PetscFree(filtpoints);
     ierr = PetscFree(filtdofs);
   }} else {{
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
   }}
   ierr = VecRestoreArray(F, &out);CHKERRQ(ierr);
   if (x) {{
     ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
   }}
   PetscLogFlops({} * npoints);
   PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct


def make_jacobian_wrapper(coeffs, maps, flops):
    struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=True)

    return """
#include <petsc.h>
{}

static PetscInt pointbuf[128];
PetscErrorCode ComputeJacobian(PC pc,
                               PetscInt point,
                               Vec x,
                               Mat out,
                               IS points,
                               PetscInt ndof,
                               const PetscInt *dofArray,
                               const PetscInt *dofArrayWithAll,
                               void *ctx_)
{{
   const PetscScalar *state          = NULL;
   const PetscInt    *whichPoints    = NULL;
   const PetscInt    *activeDofsArray = dofArray;
   UserCtx           *ctx            = (UserCtx *)ctx_;
   PetscInt           npoints;
   PetscInt          *filtpoints     = NULL;
   PetscInt          *filtdofs       = NULL;
   PetscErrorCode     ierr;
   PetscFunctionBeginUser;
   ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
   if (!npoints) PetscFunctionReturn(0);
   if (x) {{
     ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
   }}
   ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
   if (ctx->point2facet) {{
     PetscInt nvalid = 0;
     PetscInt tDPP   = ndof / npoints;
     ierr = PetscMalloc1(npoints, &filtpoints);CHKERRQ(ierr);
     if (ndof > 0) {{ ierr = PetscMalloc1(ndof, &filtdofs);CHKERRQ(ierr); }}
     for (PetscInt i = 0; i < npoints; i++) {{
       PetscInt fi = ctx->point2facet[whichPoints[i]];
       if (fi >= 0) {{
         filtpoints[nvalid] = fi;
         for (PetscInt d = 0; d < tDPP; d++)
           filtdofs[nvalid * tDPP + d] = dofArray[i * tDPP + d];
         nvalid++;
       }}
     }}
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
     npoints        = nvalid;
     whichPoints    = filtpoints;
     activeDofsArray = filtdofs;
   }}
   if (npoints) ctx->{};
   if (ctx->point2facet) {{
     ierr = PetscFree(filtpoints);
     ierr = PetscFree(filtdofs);
   }} else {{
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
   }}
   if (x) {{
     ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
   }}
   PetscLogFlops({} * npoints);
   PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct


def load_c_function(code, name, comm):
    cppargs = petsctools.get_petsc_dirs(prefix="-I", subdir="include")
    ldargs = (
        *petsctools.get_petsc_dirs(prefix="-L", subdir="lib"),
        *petsctools.get_petsc_dirs(prefix="-Wl,-rpath,", subdir="lib"),
        "-lpetsc",
        "-lm",
    )
    dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm)
    fn = getattr(dll, name)
    fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp,
                   ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int,
                   ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp]
    fn.restype = ctypes.c_int
    return fn


def make_c_arguments(form, kernel, state, integral_type, require_state=False,
                     require_facet_number=False):
    all_meshes = extract_domains(form)
    mesh = all_meshes[kernel.kinfo.domain_number]
    coeffs = []
    coeffs.extend([all_meshes[i].coordinates for i in kernel.kinfo.active_domain_numbers.coordinates])
    coeffs.extend([all_meshes[i].cell_orientations() for i in kernel.kinfo.active_domain_numbers.cell_orientations])
    coeffs.extend([all_meshes[i].cell_sizes for i in kernel.kinfo.active_domain_numbers.cell_sizes])
    for n, indices in kernel.kinfo.coefficient_numbers:
        c = form.coefficients()[n]
        if c is state:
            if indices != (0, ):
                raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
            coeffs.append(c)
        else:
            coeffs.extend([c.subfunctions[ind] for ind in indices])
    if require_state:
        assert state in coeffs, "Couldn't find state vector in form coefficients"
    data_args = []
    map_args = []
    seen = set()
    for c in coeffs:
        if c is state:
            data_args.append(None)
            map_args.append(None)
        else:
            data_args.extend(c.dat._kernel_args_)
        map_ = get_map(c.function_space(), mesh, integral_type)
        if map_ is not None:
            for k in map_._kernel_args_:
                if k not in seen:
                    map_args.append(k)
                    seen.add(k)

    all_constants = extract_firedrake_constants(form)
    for constant_index in kernel.kinfo.constant_numbers:
        data_args.extend(all_constants[constant_index].dat._kernel_args_)

    if require_facet_number:
        if integral_type == "interior_facet":
            data_args.extend(mesh.interior_facets.local_facet_dat._kernel_args_)
        elif integral_type == "exterior_facet":
            data_args.extend(mesh.exterior_facets.local_facet_dat._kernel_args_)
    return data_args, map_args


def make_c_struct(data_args, map_args, function, struct, point2facet=None):
    args = [a for a in chain(data_args, map_args) if a is not None]
    if point2facet is None:
        args.append(0)
    else:
        args.append(point2facet)
    return struct(*args, ctypes.cast(function, ctypes.c_voidp).value)


def bcdofs(bc, ghost=True):
    # Return the global dofs fixed by a DirichletBC
    # in the numbering given by concatenation of all the
    # subspaces of a mixed function space
    Z = bc.function_space()
    while Z.parent is not None:
        Z = Z.parent

    indices = bc._indices
    offset = 0

    for (i, idx) in enumerate(indices):
        if isinstance(Z.ufl_element(), VectorElement):
            offset += idx
            assert i == len(indices)-1  # assert we're at the end of the chain
            assert Z.sub(idx).block_size == 1
        elif isinstance(Z.ufl_element(), MixedElement):
            if ghost:
                offset += sum(Z.sub(j).dof_count for j in range(idx))
            else:
                offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
        else:
            raise NotImplementedError("How are you taking a .sub?")

        Z = Z.sub(idx)

    if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
        bs = Z.parent.block_size
        start = 0
        stop = 1
    else:
        bs = Z.block_size
        start = 0
        stop = bs
    nodes = bc.nodes
    if not ghost:
        nodes = nodes[nodes < Z.dof_dset.size]

    return numpy.concatenate([nodes*bs + j for j in range(start, stop)]) + offset


def select_entity(p, dm=None, exclude=None):
    """Filter entities based on some label.

    :arg p: the entity.
    :arg dm: the DMPlex object to query for labels.
    :arg exclude: The label marking points to exclude."""
    if exclude is None:
        return True
    else:
        # If the exclude label marks this point (the value is not -1),
        # we don't want it.
        return dm.getLabelValue(exclude, p) == -1


[docs] class PlaneSmoother(object):
[docs] @staticmethod def coords(dm, p, coordinates): coordinatesV = coordinates.function_space() data = coordinates.dat.data_ro_with_halos coordinatesDM = coordinatesV.dm coordinatesSection = coordinatesDM.getDefaultSection() closure_of_p = [x for x in dm.getTransitiveClosure(p, useCone=True)[0] if coordinatesSection.getDof(x) > 0] gdim = data.shape[1] bary = numpy.zeros(gdim) ndof = 0 for p_ in closure_of_p: (dof, offset) = (coordinatesSection.getDof(p_), coordinatesSection.getOffset(p_)) bary += data[offset:offset + dof].reshape(dof, gdim).sum(axis=0) ndof += dof bary /= ndof return bary
[docs] def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None): # compute # [(pStart, (x, y, z)), (pEnd, (x, y, z))] from firedrake.assemble import assemble if ndiv is None and divisions is None: raise RuntimeError("Must either set ndiv or divisions for PlaneSmoother!") mesh = dm.getAttr("__firedrake_mesh__") if len(set(mesh)) == 1: mesh_unique = mesh.unique() else: raise NotImplementedError("Not implemented for general mixed meshes") coordinates = mesh_unique.coordinates V = coordinates.function_space() if V.finat_element.is_dg(): # We're using DG or DQ for our coordinates, so we got # a periodic mesh. We need to interpolate to CGk # with access descriptor MAX to define a consistent opinion # about where the vertices are. CGk = V.reconstruct(family="Lagrange") coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX)) select = partial(select_entity, dm=dm, exclude="pyop2_ghost") entities = [(p, self.coords(dm, p, coordinates)) for p in filter(select, range(*dm.getChart()))] if isinstance(axis, int): minx = min(entities, key=lambda z: z[1][axis])[1][axis] maxx = max(entities, key=lambda z: z[1][axis])[1][axis] def keyfunc(z): coords = tuple(z[1]) return (coords[axis], ) + tuple(coords[:axis] + coords[axis+1:]) else: minx = axis(min(entities, key=lambda z: axis(z[1]))[1]) maxx = axis(max(entities, key=lambda z: axis(z[1]))[1]) def keyfunc(z): coords = tuple(z[1]) return (axis(coords), ) + coords s = sorted(entities, key=keyfunc, reverse=(dir == -1)) (entities, coords) = zip(*s) if isinstance(axis, int): coords = [c[axis] for c in coords] else: coords = [axis(c) for c in coords] if divisions is None: divisions = numpy.linspace(minx, maxx, ndiv+1) if ndiv is None: ndiv = numpy.size(divisions)-1 indices = numpy.searchsorted(coords[::dir], divisions) out = [] for k in range(ndiv): out.append(entities[indices[k]:indices[k+1]]) out.append(entities[indices[-1]:]) return out
[docs] def __call__(self, pc): if complex_mode: raise NotImplementedError("Sorry, plane smoothers not yet implemented in complex mode") dm = pc.getDM() context = dm.getAttr("__firedrake_ctx__") prefix = pc.getOptionsPrefix() or "" sentinel = object() sweeps = PETSc.Options(prefix).getString("pc_patch_construct_ps_sweeps", default=sentinel) if sweeps == sentinel: raise ValueError("Must set %spc_patch_construct_ps_sweeps" % prefix) patches = [] import re for sweep in sweeps.split(':'): sweep_split = re.split(r'([+-])', sweep) try: axis = int(sweep_split[0]) except ValueError: try: axis = context.appctx[sweep_split[0]] except KeyError: raise KeyError("PlaneSmoother axis key %s not provided" % sweep_split[0]) dir = {'+': +1, '-': -1}[sweep_split[1]] # Either use equispaced bins for relaxation or get from appctx try: ndiv = int(sweep_split[2]) entities = self.sort_entities(dm, axis, dir, ndiv=ndiv) except ValueError: try: divisions = context.appctx[sweep_split[2]] entities = self.sort_entities(dm, axis, dir, divisions=divisions) except KeyError: raise KeyError("PlaneSmoother division key %s not provided" % sweep_split[2:]) for patch in entities: if not patch: continue else: iset = PETSc.IS().createGeneral(patch, comm=COMM_SELF) patches.append(iset) iterationSet = PETSc.IS().createStride(size=len(patches), first=0, step=1, comm=COMM_SELF) return (patches, iterationSet)
class PatchBase(PCSNESBase): def initialize(self, obj): ctx = get_appctx(obj.getDM()) if ctx is None: raise ValueError("No context found on form") if not isinstance(ctx, _SNESContext): raise ValueError("Don't know how to get form from %r" % ctx) J, bcs = self.form(obj) V = J.arguments()[0].function_space() mesh = V.mesh() if len(set(mesh)) == 1: mesh_unique = mesh.unique() else: raise NotImplementedError("Not implemented for general mixed meshes") self.plex = mesh_unique.topology_dm # We need to attach the mesh and appctx to the plex, so that # PlaneSmoothers (and any other user-customised patch # constructors) can use firedrake's opinion of what # the coordinates are, rather than plex's. self.plex.setAttr("__firedrake_mesh__", weakref.proxy(mesh)) self.ctx = ctx self.plex.setAttr("__firedrake_ctx__", weakref.proxy(ctx)) if mesh_unique.cell_set._extruded: raise NotImplementedError("Not implemented on extruded meshes") # Validate the mesh overlap prefix = (obj.getOptionsPrefix() or "") + "patch_" opts = PETSc.Options(prefix) petsc_prefix = self._petsc_prefix patch_type = opts.getString(f"{petsc_prefix}construct_type") patch_dim = opts.getInt(f"{petsc_prefix}construct_dim", -1) patch_codim = opts.getInt(f"{petsc_prefix}construct_codim", -1) if patch_dim != -1: assert patch_codim == -1, "Cannot set both dim and codim" elif patch_codim != -1: assert patch_dim == -1, "Cannot set both dim and codim" patch_dim = self.plex.getDimension() - patch_codim else: patch_dim = 0 validate_overlap(mesh_unique, patch_dim, patch_type) patch = obj.__class__().create(comm=mesh.comm) patch.setOptionsPrefix(prefix) self.configure_patch(patch, obj) patch.setType("patch") if isinstance(obj, PETSc.SNES): Jstate = ctx._problem.u is_snes = True else: Jstate = None is_snes = False if len(bcs) > 0: ghost_bc_nodes = numpy.unique( numpy.concatenate([bcdofs(bc, ghost=True) for bc in bcs], dtype=PETSc.IntType) ) global_bc_nodes = numpy.unique( numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs], dtype=PETSc.IntType)) else: ghost_bc_nodes = numpy.empty(0, dtype=PETSc.IntType) global_bc_nodes = numpy.empty(0, dtype=PETSc.IntType) Jcell_kernels, Jint_facet_kernels, Jext_facet_kernels = matrix_funptr(J, Jstate) Jhas_cell_kernel = len(Jcell_kernels) > 0 if Jhas_cell_kernel: Jcell_kernel, = Jcell_kernels Jcell_flops = Jcell_kernel.kinfo.kernel.num_flops Jop_data_args, Jop_map_args = make_c_arguments(J, Jcell_kernel, Jstate, "cell") code, Struct = make_jacobian_wrapper(Jop_data_args, Jop_map_args, Jcell_flops) Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) Jop_struct = make_c_struct(Jop_data_args, Jop_map_args, Jcell_kernel.funptr, Struct) Jhas_int_facet_kernel = False if len(Jint_facet_kernels) > 0: Jint_facet_kernel, = Jint_facet_kernels Jhas_int_facet_kernel = True Jint_facet_flops = Jint_facet_kernel.kinfo.kernel.num_flops facet_Jop_data_args, facet_Jop_map_args = make_c_arguments(J, Jint_facet_kernel, Jstate, "interior_facet", require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_flops) facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) point2facet = mesh_unique.interior_facets.point2facetnumber.ctypes.data facet_Jop_struct = make_c_struct(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_kernel.funptr, Struct, point2facet=point2facet) Jhas_ext_facet_kernel = False if len(Jext_facet_kernels) > 0: Jext_facet_kernel, = Jext_facet_kernels Jhas_ext_facet_kernel = True Jext_facet_flops = Jext_facet_kernel.kinfo.kernel.num_flops ext_facet_Jop_data_args, ext_facet_Jop_map_args = make_c_arguments(J, Jext_facet_kernel, Jstate, "exterior_facet", require_facet_number=True) code, Struct = make_jacobian_wrapper(ext_facet_Jop_data_args, ext_facet_Jop_map_args, Jext_facet_flops) ext_facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) ext_point2facet = mesh_unique.exterior_facets.point2facetnumber.ctypes.data ext_facet_Jop_struct = make_c_struct(ext_facet_Jop_data_args, ext_facet_Jop_map_args, Jext_facet_kernel.funptr, Struct, point2facet=ext_point2facet) set_residual = hasattr(ctx, "F") and isinstance(obj, PETSc.SNES) if set_residual: F = ctx.F Fstate = ctx._problem.u Fcell_kernels, Fint_facet_kernels, Fext_facet_kernels = residual_funptr(F, Fstate) Fhas_cell_kernel = len(Fcell_kernels) > 0 if Fhas_cell_kernel: Fcell_kernel, = Fcell_kernels Fcell_flops = Fcell_kernel.kinfo.kernel.num_flops Fop_data_args, Fop_map_args = make_c_arguments(F, Fcell_kernel, Fstate, "cell", require_state=True) code, Struct = make_residual_wrapper(Fop_data_args, Fop_map_args, Fcell_flops) Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) Fop_struct = make_c_struct(Fop_data_args, Fop_map_args, Fcell_kernel.funptr, Struct) Fhas_int_facet_kernel = False if len(Fint_facet_kernels) > 0: Fint_facet_kernel, = Fint_facet_kernels Fhas_int_facet_kernel = True Fint_facet_flops = Fint_facet_kernel.kinfo.kernel.num_flops facet_Fop_data_args, facet_Fop_map_args = make_c_arguments(F, Fint_facet_kernel, Fstate, "interior_facet", require_state=True, require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_flops) facet_Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) point2facet = extract_unique_domain(F).interior_facets.point2facetnumber.ctypes.data facet_Fop_struct = make_c_struct(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_kernel.funptr, Struct, point2facet=point2facet) Fhas_ext_facet_kernel = False if len(Fext_facet_kernels) > 0: Fext_facet_kernel, = Fext_facet_kernels Fhas_ext_facet_kernel = True Fext_facet_flops = Fext_facet_kernel.kinfo.kernel.num_flops ext_facet_Fop_data_args, ext_facet_Fop_map_args = make_c_arguments(F, Fext_facet_kernel, Fstate, "exterior_facet", require_state=True, require_facet_number=True) code, Struct = make_residual_wrapper(ext_facet_Fop_data_args, ext_facet_Fop_map_args, Fext_facet_flops) ext_facet_Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) ext_point2facet = extract_unique_domain(F).exterior_facets.point2facetnumber.ctypes.data ext_facet_Fop_struct = make_c_struct(ext_facet_Fop_data_args, ext_facet_Fop_map_args, Fext_facet_kernel.funptr, Struct, point2facet=ext_point2facet) patch.setDM(self.plex) patch.setPatchCellNumbering(mesh_unique._cell_numbering) offsets = numpy.append([0], numpy.cumsum([W.dof_count for W in V])).astype(PETSc.IntType) patch.setPatchDiscretisationInfo([W.dm for W in V], numpy.array([W.block_size for W in V], dtype=PETSc.IntType), [W.cell_node_list for W in V], offsets, ghost_bc_nodes, global_bc_nodes) if Jhas_cell_kernel: self.Jop_struct = Jop_struct set_patch_jacobian(patch, ctypes.cast(Jop_function, ctypes.c_voidp).value, ctypes.addressof(Jop_struct), is_snes=is_snes) if Jhas_int_facet_kernel: self.facet_Jop_struct = facet_Jop_struct set_patch_jacobian(patch, ctypes.cast(facet_Jop_function, ctypes.c_voidp).value, ctypes.addressof(facet_Jop_struct), is_snes=is_snes, interior_facets=True) if Jhas_ext_facet_kernel: self.ext_facet_Jop_struct = ext_facet_Jop_struct set_patch_jacobian(patch, ctypes.cast(ext_facet_Jop_function, ctypes.c_voidp).value, ctypes.addressof(ext_facet_Jop_struct), is_snes=is_snes, exterior_facets=True) if set_residual: if Fhas_cell_kernel: self.Fop_struct = Fop_struct set_patch_residual(patch, ctypes.cast(Fop_function, ctypes.c_voidp).value, ctypes.addressof(Fop_struct), is_snes=is_snes) if Fhas_int_facet_kernel: set_patch_residual(patch, ctypes.cast(facet_Fop_function, ctypes.c_voidp).value, ctypes.addressof(facet_Fop_struct), is_snes=is_snes, interior_facets=True) if Fhas_ext_facet_kernel: self.ext_facet_Fop_struct = ext_facet_Fop_struct set_patch_residual(patch, ctypes.cast(ext_facet_Fop_function, ctypes.c_voidp).value, ctypes.addressof(ext_facet_Fop_struct), is_snes=is_snes, exterior_facets=True) patch.setPatchConstructType(PETSc.PC.PatchConstructType.PYTHON, operator=self.user_construction_op) patch.setAttr("ctx", ctx) patch.incrementTabLevel(1, parent=obj) patch.setFromOptions() patch.setUp() self.patch = patch def destroy(self, obj): # In this destructor we clean up the __firedrake_mesh__ we set # on the plex and the context we set on the patch object. # We have to check if these attributes are available because # the destroy function will be called by petsc4py when # PCPythonSetContext is called (which occurs before # initialize). if hasattr(self, "plex"): d = self.plex.getDict() try: del d["__firedrake_mesh__"] except KeyError: pass if hasattr(self, "patch"): try: del self.patch.getDict()["ctx"] except KeyError: pass self.patch.destroy() def user_construction_op(self, obj, *args, **kwargs): prefix = obj.getOptionsPrefix() or "" sentinel = object() usercode = PETSc.Options(prefix).getString("%s_patch_construct_python_type" % self._objectname, default=sentinel) if usercode == sentinel: raise ValueError("Must set %s%s_patch_construct_python_type" % (prefix, self._objectname)) (modname, funname) = usercode.rsplit('.', 1) mod = __import__(modname) fun = getattr(mod, funname) if isinstance(fun, type): fun = fun() return fun(obj, *args, **kwargs) def update(self, pc): self.patch.setUp() def view(self, pc, viewer=None): self.patch.view(viewer=viewer)
[docs] class PatchPC(PCBase, PatchBase): _petsc_prefix = "pc_patch_"
[docs] def configure_patch(self, patch, pc): (A, P) = pc.getOperators() patch.setOperators(A, P)
[docs] def apply(self, pc, x, y): self.patch.apply(x, y)
[docs] def applyTranspose(self, pc, x, y): self.patch.applyTranspose(x, y)
[docs] class PatchSNES(SNESBase, PatchBase): _petsc_prefix = "snes_patch_"
[docs] def configure_patch(self, patch, snes): patch.setTolerances(max_it=1) patch.setConvergenceTest("skip") (f, residual) = snes.getFunction() assert residual is not None (fun, args, kargs) = residual patch.setFunction(fun, f.duplicate(), args=args, kargs=kargs) # Need an empty RHS for the solve, # PCApply can't deal with RHS = NULL, # and this goes through a call to PCApply at some point self.dummy = f.duplicate()
[docs] def step(self, snes, x, f, y): push_appctx(self.plex, self.ctx) x.copy(y) self.patch.solve(snes.vec_rhs or self.dummy, y) y.axpy(-1, x) y.scale(-1) snes.setConvergedReason(self.patch.getConvergedReason()) pop_appctx(self.plex)