Source code for firedrake.preconditioners.patch

from firedrake.preconditioners.base import PCBase, SNESBase, PCSNESBase
from firedrake.petsc import PETSc
from firedrake.cython.patchimpl import set_patch_residual, set_patch_jacobian
from firedrake.solving_utils import _SNESContext
from firedrake.utils import cached_property, complex_mode, IntType
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
from firedrake.interpolation import Interpolate

from collections import namedtuple
import operator
from itertools import chain
from functools import partial
import numpy
from finat.ufl import VectorElement, MixedElement
from ufl.domain import extract_unique_domain
from tsfc.ufl_utils import extract_firedrake_constants
import weakref

import ctypes
from pyop2 import op2
import pyop2.types
from pyop2.compilation import load
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel
from pyop2.mpi import COMM_SELF
from pyop2.utils import get_petsc_dir

__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")


class DenseSparsity(object):
    def __init__(self, rset, cset):
        self.shape = (1, 1)
        self._nrows = rset.size
        self._ncols = cset.size
        self._dims = (((1, 1), ), )
        self.dims = self._dims
        self.dsets = rset, cset

    def __getitem__(self, *args):
        return self

    def __contains__(self, *args):
        return True


class LocalPack(Pack):
    def pick_loop_indices(self, loop_index, layer_index, entity_index):
        return (entity_index, layer_index)


class LocalMatPack(LocalPack, MatPack):
    insertion_names = {False: "MatSetValues",
                       True: "MatSetValues"}


class LocalMatKernelArg(op2.MatKernelArg):

    pack = LocalMatPack


class LocalMatLegacyArg(op2.MatLegacyArg):

    @property
    def global_kernel_arg(self):
        map_args = [m._global_kernel_arg for m in self.maps]
        return LocalMatKernelArg(self.data.dims, map_args)


class LocalMat(pyop2.types.AbstractMat):

    def __init__(self, dset):
        self._sparsity = DenseSparsity(dset, dset)
        self.dtype = numpy.dtype(PETSc.ScalarType)

    def __call__(self, access, maps):
        return LocalMatLegacyArg(self, maps, access)


class LocalDatPack(LocalPack, DatPack):
    def __init__(self, needs_mask, *args, **kwargs):
        self.needs_mask = needs_mask
        super().__init__(*args, **kwargs)

    def _mask(self, map_):
        if self.needs_mask:
            return Comparison(">=", map_, Literal(numpy.int32(0)))
        else:
            return None


class LocalDatKernelArg(op2.DatKernelArg):

    def __init__(self, *args, needs_mask, **kwargs):
        super().__init__(*args, **kwargs)
        self.needs_mask = needs_mask

    @property
    def pack(self):
        return partial(LocalDatPack, self.needs_mask)


class LocalDatLegacyArg(op2.DatLegacyArg):

    @property
    def global_kernel_arg(self):
        map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
        return LocalDatKernelArg(self.data.dataset.dim, map_arg,
                                 needs_mask=self.data.needs_mask)


class LocalDat(pyop2.types.AbstractDat):
    def __init__(self, dset, needs_mask=False):
        self._dataset = dset
        self.dtype = numpy.dtype(PETSc.ScalarType)
        self._shape = (dset.total_size,) + (() if dset.cdim == 1 else dset.dim)
        self.needs_mask = needs_mask

    @cached_property
    def _wrapper_cache_key_(self):
        return super()._wrapper_cache_key_ + (self.needs_mask, )

    def __call__(self, access, map_=None):
        return LocalDatLegacyArg(self, map_, access)

    def increment_dat_version(self):
        pass


register_petsc_function("MatSetValues")


CompiledKernel = namedtuple('CompiledKernel', ["funptr", "kinfo"])


def matrix_funptr(form, state):
    from firedrake.tsfc_interface import compile_form
    test, trial = map(operator.methodcaller("function_space"), form.arguments())
    if test != trial:
        raise NotImplementedError("Only for matching test and trial spaces")

    if state is not None:
        dont_split = (state, )
    else:
        dont_split = ()

    kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)

    cell_kernels = []
    int_facet_kernels = []
    for kernel in kernels:
        kinfo = kernel.kinfo

        if kinfo.subdomain_id != ("otherwise",):
            raise NotImplementedError("Only for full domain integrals")
        if kinfo.integral_type not in {"cell", "interior_facet"}:
            raise NotImplementedError("Only for cell or interior facet integrals")

        # OK, now we've validated the kernel, let's build the callback
        args = []

        if kinfo.integral_type == "cell":
            get_map = operator.methodcaller("cell_node_map")
            kernels = cell_kernels
        elif kinfo.integral_type == "interior_facet":
            get_map = operator.methodcaller("interior_facet_node_map")
            kernels = int_facet_kernels
        else:
            get_map = None

        toset = op2.Set(1, comm=test.comm)
        dofset = op2.DataSet(toset, 1)
        arity = sum(m.arity*s.cdim
                    for m, s in zip(get_map(test),
                                    test.dof_dset))
        iterset = get_map(test).iterset
        entity_node_map = op2.Map(iterset,
                                  toset, arity,
                                  values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        mat = LocalMat(dofset)

        arg = mat(op2.INC, (entity_node_map, entity_node_map))
        args.append(arg)
        statedat = LocalDat(dofset)
        state_entity_node_map = op2.Map(iterset,
                                        toset, arity,
                                        values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        statearg = statedat(op2.READ, state_entity_node_map)

        mesh = form.ufl_domains()[kinfo.domain_number]
        arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates))
        args.append(arg)
        if kinfo.oriented:
            c = mesh.cell_orientations()
            arg = c.dat(op2.READ, get_map(c))
            args.append(arg)
        if kinfo.needs_cell_sizes:
            c = mesh.cell_sizes
            arg = c.dat(op2.READ, get_map(c))
            args.append(arg)
        for n, indices in kinfo.coefficient_numbers:
            c = form.coefficients()[n]
            if c is state:
                if indices != (0, ):
                    raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
                args.append(statearg)
                continue
            for ind in indices:
                c_ = c.subfunctions[ind]
                map_ = get_map(c_)
                arg = c_.dat(op2.READ, map_)
                args.append(arg)

        all_constants = extract_firedrake_constants(form)
        for constant_index in kinfo.constant_numbers:
            args.append(all_constants[constant_index].dat(op2.READ))

        if kinfo.integral_type == "interior_facet":
            arg = mesh.interior_facets.local_facet_dat(op2.READ)
            args.append(arg)
        iterset = op2.Subset(iterset, [])

        wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
        mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
        kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
    return cell_kernels, int_facet_kernels


def residual_funptr(form, state):
    from firedrake.tsfc_interface import compile_form
    test, = map(operator.methodcaller("function_space"), form.arguments())

    if state.function_space() != test:
        raise NotImplementedError("State and test space must be dual to one-another")

    if state is not None:
        dont_split = (state, )
    else:
        dont_split = ()

    kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)

    cell_kernels = []
    int_facet_kernels = []
    for kernel in kernels:
        kinfo = kernel.kinfo

        if kinfo.subdomain_id != ("otherwise",):
            raise NotImplementedError("Only for full domain integrals")
        if kinfo.integral_type not in {"cell", "interior_facet"}:
            raise NotImplementedError("Only for cell integrals or interior_facet integrals")
        args = []

        if kinfo.integral_type == "cell":
            get_map = operator.methodcaller("cell_node_map")
            kernels = cell_kernels
        elif kinfo.integral_type == "interior_facet":
            get_map = operator.methodcaller("interior_facet_node_map")
            kernels = int_facet_kernels
        else:
            get_map = None

        toset = op2.Set(1, comm=test.comm)
        dofset = op2.DataSet(toset, 1)
        arity = sum(m.arity*s.cdim
                    for m, s in zip(get_map(test),
                                    test.dof_dset))
        iterset = get_map(test).iterset
        entity_node_map = op2.Map(iterset,
                                  toset, arity,
                                  values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        dat = LocalDat(dofset, needs_mask=True)

        statedat = LocalDat(dofset)
        state_entity_node_map = op2.Map(iterset,
                                        toset, arity,
                                        values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
        statearg = statedat(op2.READ, state_entity_node_map)

        arg = dat(op2.INC, entity_node_map)
        args.append(arg)

        mesh = form.ufl_domains()[kinfo.domain_number]
        arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates))
        args.append(arg)

        if kinfo.oriented:
            c = mesh.cell_orientations()
            arg = c.dat(op2.READ, get_map(c))
            args.append(arg)
        if kinfo.needs_cell_sizes:
            c = mesh.cell_sizes
            arg = c.dat(op2.READ, get_map(c))
            args.append(arg)
        for n, indices in kinfo.coefficient_numbers:
            c = form.coefficients()[n]
            if c is state:
                if indices != (0, ):
                    raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
                args.append(statearg)
                continue
            for ind in indices:
                c_ = c.subfunctions[ind]
                map_ = get_map(c_)
                arg = c_.dat(op2.READ, map_)
                args.append(arg)

        all_constants = extract_firedrake_constants(form)
        for constant_index in kinfo.constant_numbers:
            args.append(all_constants[constant_index].dat(op2.READ))

        if kinfo.integral_type == "interior_facet":
            arg = extract_unique_domain(test).interior_facets.local_facet_dat(op2.READ)
            args.append(arg)
        iterset = op2.Subset(iterset, [])

        wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
        mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
        kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
    return cell_kernels, int_facet_kernels


# We need to set C function pointer callbacks for PCPatch to work.
# Although petsc4py provides a high-level Python wrapper for them,
# this is very costly when going back and forth from C to Python only
# to extract function pointers and send them straight back to C. Here,
# since we know what the calling convention of the C function is, we
# just wrap up everything as a C function pointer and use that
# directly.
def make_struct(op_coeffs, op_maps, jacobian=False):
    import ctypes
    coeffs = []
    maps = []
    for i, c in enumerate(op_coeffs):
        if c is None:
            coeffs.append("state")
        else:
            coeffs.append("c{}".format(i))
    for i, m in enumerate(op_maps):
        if m is None:
            maps.append("dofArrayWithAll")
        else:
            maps.append("m{}".format(i))
    coeff_struct = ";\n".join("  const PetscScalar *c{}".format(i) for i, c in enumerate(op_coeffs) if c is not None)
    map_struct = ";\n".join("  const PetscInt    *m{}".format(i) for i, m in enumerate(op_maps) if m is not None)
    coeff_decl = ", ".join("const PetscScalar *restrict {}".format(c) for c in coeffs)
    map_decl = ", ".join("const PetscInt *restrict {}".format(m) for m in maps)
    coeff_call = ", ".join(c if c == "state" else "ctx->{}".format(c) for c in coeffs)
    map_call = ", ".join(m if m == "dofArrayWithAll" else "ctx->{}".format(m) for m in maps)
    if jacobian:
        out = "Mat J"
    else:
        out = "PetscScalar * restrict F"
    function = "  void (*pyop2_call)(int start, int end, const PetscInt * restrict cells, {}, {}, const PetscInt *restrict dofArray, {})".format(out, coeff_decl, map_decl)

    fields = []
    for c in coeffs:
        if c != "state":
            fields.append((c, ctypes.c_voidp))
    for m in maps:
        if m != "dofArrayWithAll":
            fields.append((m, ctypes.c_voidp))
    fields.append(("point2facet", ctypes.c_voidp))
    fields.append(("pyop2_call", ctypes.c_voidp))

    class Struct(ctypes.Structure):
        _fields_ = fields
    struct = """
typedef struct {{
{};
{};
  const PetscInt    *point2facet;
{};
}} UserCtx;""".format(coeff_struct, map_struct, function)
    call = "pyop2_call(0, npoints, whichPoints, out, {}, dofArray, {})".format(coeff_call, map_call)

    return struct, call, Struct


def make_residual_wrapper(coeffs, maps, flops):
    struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=False)

    return """
#include <petsc.h>
{}
static PetscInt pointbuf[128];
PetscErrorCode ComputeResidual(PC pc,
                               PetscInt point,
                               Vec x,
                               Vec F,
                               IS points,
                               PetscInt ndof,
                               const PetscInt *dofArray,
                               const PetscInt *dofArrayWithAll,
                               void *ctx_)
{{
   const PetscScalar *state       = NULL;
   const PetscInt    *whichPoints = NULL;
   PetscScalar       *out         = NULL;
   UserCtx           *ctx         = (UserCtx *)ctx_;
   PetscInt           npoints;
   PetscErrorCode     ierr;
   PetscFunctionBeginUser;
   ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
   if (!npoints) PetscFunctionReturn(0);
   ierr = VecSet(F, 0.0);CHKERRQ(ierr);
   if (x) {{
     ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
   }}
   ierr = VecGetArray(F, &out);CHKERRQ(ierr);
   ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
   if (ctx->point2facet) {{
     PetscInt *pointsArray = NULL;
     if (npoints > 128) {{
       ierr = PetscMalloc1(npoints, &pointsArray);CHKERRQ(ierr);
     }} else {{
       pointsArray = pointbuf;
     }}
     for (PetscInt i = 0; i < npoints; i++) {{
       pointsArray[i] = ctx->point2facet[whichPoints[i]];
     }}
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
     whichPoints = pointsArray;
   }}
   ctx->{};
   if (ctx->point2facet) {{
     if (npoints > 128) {{
       ierr = PetscFree(whichPoints);
     }}
   }} else {{
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
   }}
   ierr = VecRestoreArray(F, &out);CHKERRQ(ierr);
   if (x) {{
     ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
   }}
   PetscLogFlops({} * npoints);
   PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct


def make_jacobian_wrapper(coeffs, maps, flops):
    struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=True)

    return """
#include <petsc.h>
{}

static PetscInt pointbuf[128];
PetscErrorCode ComputeJacobian(PC pc,
                               PetscInt point,
                               Vec x,
                               Mat out,
                               IS points,
                               PetscInt ndof,
                               const PetscInt *dofArray,
                               const PetscInt *dofArrayWithAll,
                               void *ctx_)
{{
   const PetscScalar *state       = NULL;
   const PetscInt    *whichPoints = NULL;
   UserCtx           *ctx         = (UserCtx *)ctx_;
   PetscInt           npoints;
   PetscErrorCode     ierr;
   PetscFunctionBeginUser;
   ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
   if (!npoints) PetscFunctionReturn(0);
   if (x) {{
     ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
   }}
   ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
   if (ctx->point2facet) {{
     PetscInt *pointsArray = NULL;
     if (npoints > 128) {{
       ierr = PetscMalloc1(npoints, &pointsArray);CHKERRQ(ierr);
     }} else {{
       pointsArray = pointbuf;
     }}
     for (PetscInt i = 0; i < npoints; i++) {{
       pointsArray[i] = ctx->point2facet[whichPoints[i]];
     }}
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
     whichPoints = pointsArray;
   }}
   ctx->{};
   if (ctx->point2facet) {{
     if (npoints > 128) {{
       ierr = PetscFree(whichPoints);
     }}
   }} else {{
     ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
   }}
   if (x) {{
     ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
   }}
   PetscLogFlops({} * npoints);
   PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct


def load_c_function(code, name, comm):
    cppargs = ["-I%s/include" % d for d in get_petsc_dir()]
    ldargs = (["-L%s/lib" % d for d in get_petsc_dir()]
              + ["-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()]
              + ["-lpetsc", "-lm"])
    dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm)
    fn = getattr(dll, name)
    fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp,
                   ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int,
                   ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp]
    fn.restype = ctypes.c_int
    return fn


def make_c_arguments(form, kernel, state, get_map, require_state=False,
                     require_facet_number=False):
    mesh = form.ufl_domains()[kernel.kinfo.domain_number]
    coeffs = [mesh.coordinates]
    if kernel.kinfo.oriented:
        coeffs.append(mesh.cell_orientations())
    if kernel.kinfo.needs_cell_sizes:
        coeffs.append(mesh.cell_sizes)
    for n, indices in kernel.kinfo.coefficient_numbers:
        c = form.coefficients()[n]
        if c is state:
            if indices != (0, ):
                raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
            coeffs.append(c)
        else:
            coeffs.extend([c.subfunctions[ind] for ind in indices])
    if require_state:
        assert state in coeffs, "Couldn't find state vector in form coefficients"
    data_args = []
    map_args = []
    seen = set()
    for c in coeffs:
        if c is state:
            data_args.append(None)
            map_args.append(None)
        else:
            data_args.extend(c.dat._kernel_args_)
        map_ = get_map(c)
        if map_ is not None:
            for k in map_._kernel_args_:
                if k not in seen:
                    map_args.append(k)
                    seen.add(k)

    all_constants = extract_firedrake_constants(form)
    for constant_index in kernel.kinfo.constant_numbers:
        data_args.extend(all_constants[constant_index].dat._kernel_args_)

    if require_facet_number:
        data_args.extend(mesh.interior_facets.local_facet_dat._kernel_args_)
    return data_args, map_args


def make_c_struct(data_args, map_args, function, struct, point2facet=None):
    args = [a for a in chain(data_args, map_args) if a is not None]
    if point2facet is None:
        args.append(0)
    else:
        args.append(point2facet)
    return struct(*args, ctypes.cast(function, ctypes.c_voidp).value)


def bcdofs(bc, ghost=True):
    # Return the global dofs fixed by a DirichletBC
    # in the numbering given by concatenation of all the
    # subspaces of a mixed function space
    Z = bc.function_space()
    while Z.parent is not None:
        Z = Z.parent

    indices = bc._indices
    offset = 0

    for (i, idx) in enumerate(indices):
        if isinstance(Z.ufl_element(), VectorElement):
            offset += idx
            assert i == len(indices)-1  # assert we're at the end of the chain
            assert Z.sub(idx).block_size == 1
        elif isinstance(Z.ufl_element(), MixedElement):
            if ghost:
                offset += sum(Z.sub(j).dof_count for j in range(idx))
            else:
                offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
        else:
            raise NotImplementedError("How are you taking a .sub?")

        Z = Z.sub(idx)

    if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
        bs = Z.parent.block_size
        start = 0
        stop = 1
    else:
        bs = Z.block_size
        start = 0
        stop = bs
    nodes = bc.nodes
    if not ghost:
        nodes = nodes[nodes < Z.dof_dset.size]

    return numpy.concatenate([nodes*bs + j for j in range(start, stop)]) + offset


def select_entity(p, dm=None, exclude=None):
    """Filter entities based on some label.

    :arg p: the entity.
    :arg dm: the DMPlex object to query for labels.
    :arg exclude: The label marking points to exclude."""
    if exclude is None:
        return True
    else:
        # If the exclude label marks this point (the value is not -1),
        # we don't want it.
        return dm.getLabelValue(exclude, p) == -1


[docs] class PlaneSmoother(object):
[docs] @staticmethod def coords(dm, p, coordinates): coordinatesV = coordinates.function_space() data = coordinates.dat.data_ro_with_halos coordinatesDM = coordinatesV.dm coordinatesSection = coordinatesDM.getDefaultSection() closure_of_p = [x for x in dm.getTransitiveClosure(p, useCone=True)[0] if coordinatesSection.getDof(x) > 0] gdim = data.shape[1] bary = numpy.zeros(gdim) ndof = 0 for p_ in closure_of_p: (dof, offset) = (coordinatesSection.getDof(p_), coordinatesSection.getOffset(p_)) bary += data[offset:offset + dof].reshape(dof, gdim).sum(axis=0) ndof += dof bary /= ndof return bary
[docs] def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None): # compute # [(pStart, (x, y, z)), (pEnd, (x, y, z))] from firedrake.assemble import assemble if ndiv is None and divisions is None: raise RuntimeError("Must either set ndiv or divisions for PlaneSmoother!") mesh = dm.getAttr("__firedrake_mesh__") coordinates = mesh.coordinates V = coordinates.function_space() if V.finat_element.is_dg(): # We're using DG or DQ for our coordinates, so we got # a periodic mesh. We need to interpolate to CGk # with access descriptor MAX to define a consistent opinion # about where the vertices are. CGk = V.reconstruct(family="Lagrange") coordinates = assemble(Interpolate(coordinates, CGk, access=op2.MAX)) select = partial(select_entity, dm=dm, exclude="pyop2_ghost") entities = [(p, self.coords(dm, p, coordinates)) for p in filter(select, range(*dm.getChart()))] if isinstance(axis, int): minx = min(entities, key=lambda z: z[1][axis])[1][axis] maxx = max(entities, key=lambda z: z[1][axis])[1][axis] def keyfunc(z): coords = tuple(z[1]) return (coords[axis], ) + tuple(coords[:axis] + coords[axis+1:]) else: minx = axis(min(entities, key=lambda z: axis(z[1]))[1]) maxx = axis(max(entities, key=lambda z: axis(z[1]))[1]) def keyfunc(z): coords = tuple(z[1]) return (axis(coords), ) + coords s = sorted(entities, key=keyfunc, reverse=(dir == -1)) (entities, coords) = zip(*s) if isinstance(axis, int): coords = [c[axis] for c in coords] else: coords = [axis(c) for c in coords] if divisions is None: divisions = numpy.linspace(minx, maxx, ndiv+1) if ndiv is None: ndiv = numpy.size(divisions)-1 indices = numpy.searchsorted(coords[::dir], divisions) out = [] for k in range(ndiv): out.append(entities[indices[k]:indices[k+1]]) out.append(entities[indices[-1]:]) return out
[docs] def __call__(self, pc): if complex_mode: raise NotImplementedError("Sorry, plane smoothers not yet implemented in complex mode") dm = pc.getDM() context = dm.getAttr("__firedrake_ctx__") prefix = pc.getOptionsPrefix() or "" sentinel = object() sweeps = PETSc.Options(prefix).getString("pc_patch_construct_ps_sweeps", default=sentinel) if sweeps == sentinel: raise ValueError("Must set %spc_patch_construct_ps_sweeps" % prefix) patches = [] import re for sweep in sweeps.split(':'): sweep_split = re.split(r'([+-])', sweep) try: axis = int(sweep_split[0]) except ValueError: try: axis = context.appctx[sweep_split[0]] except KeyError: raise KeyError("PlaneSmoother axis key %s not provided" % sweep_split[0]) dir = {'+': +1, '-': -1}[sweep_split[1]] # Either use equispaced bins for relaxation or get from appctx try: ndiv = int(sweep_split[2]) entities = self.sort_entities(dm, axis, dir, ndiv=ndiv) except ValueError: try: divisions = context.appctx[sweep_split[2]] entities = self.sort_entities(dm, axis, dir, divisions=divisions) except KeyError: raise KeyError("PlaneSmoother division key %s not provided" % sweep_split[2:]) for patch in entities: if not patch: continue else: iset = PETSc.IS().createGeneral(patch, comm=COMM_SELF) patches.append(iset) iterationSet = PETSc.IS().createStride(size=len(patches), first=0, step=1, comm=COMM_SELF) return (patches, iterationSet)
class PatchBase(PCSNESBase): def initialize(self, obj): ctx = get_appctx(obj.getDM()) if ctx is None: raise ValueError("No context found on form") if not isinstance(ctx, _SNESContext): raise ValueError("Don't know how to get form from %r" % ctx) J, bcs = self.form(obj) V = J.arguments()[0].function_space() mesh = V.mesh() self.plex = mesh.topology_dm # We need to attach the mesh and appctx to the plex, so that # PlaneSmoothers (and any other user-customised patch # constructors) can use firedrake's opinion of what # the coordinates are, rather than plex's. self.plex.setAttr("__firedrake_mesh__", weakref.proxy(mesh)) self.ctx = ctx self.plex.setAttr("__firedrake_ctx__", weakref.proxy(ctx)) if mesh.cell_set._extruded: raise NotImplementedError("Not implemented on extruded meshes") if "overlap_type" not in mesh._distribution_parameters: if mesh.comm.size > 1: # Want to do # warnings.warn("You almost surely want to set an overlap_type in your mesh's distribution_parameters.") # but doesn't warn! PETSc.Sys.Print("Warning: you almost surely want to set an overlap_type in your mesh's distribution_parameters.") patch = obj.__class__().create(comm=mesh.comm) patch.setOptionsPrefix((obj.getOptionsPrefix() or "") + "patch_") self.configure_patch(patch, obj) patch.setType("patch") if isinstance(obj, PETSc.SNES): Jstate = ctx._problem.u is_snes = True else: Jstate = None is_snes = False if len(bcs) > 0: ghost_bc_nodes = numpy.unique( numpy.concatenate([bcdofs(bc, ghost=True) for bc in bcs], dtype=PETSc.IntType) ) global_bc_nodes = numpy.unique( numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs], dtype=PETSc.IntType)) else: ghost_bc_nodes = numpy.empty(0, dtype=PETSc.IntType) global_bc_nodes = numpy.empty(0, dtype=PETSc.IntType) Jcell_kernels, Jint_facet_kernels = matrix_funptr(J, Jstate) Jcell_kernel, = Jcell_kernels Jcell_flops = Jcell_kernel.kinfo.kernel.num_flops Jop_data_args, Jop_map_args = make_c_arguments(J, Jcell_kernel, Jstate, operator.methodcaller("cell_node_map")) code, Struct = make_jacobian_wrapper(Jop_data_args, Jop_map_args, Jcell_flops) Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) Jop_struct = make_c_struct(Jop_data_args, Jop_map_args, Jcell_kernel.funptr, Struct) Jhas_int_facet_kernel = False if len(Jint_facet_kernels) > 0: Jint_facet_kernel, = Jint_facet_kernels Jhas_int_facet_kernel = True Jint_facet_flops = Jint_facet_kernel.kinfo.kernel.num_flops facet_Jop_data_args, facet_Jop_map_args = make_c_arguments(J, Jint_facet_kernel, Jstate, operator.methodcaller("interior_facet_node_map"), require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_flops) facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm) point2facet = mesh.interior_facets.point2facetnumber.ctypes.data facet_Jop_struct = make_c_struct(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_kernel.funptr, Struct, point2facet=point2facet) set_residual = hasattr(ctx, "F") and isinstance(obj, PETSc.SNES) if set_residual: F = ctx.F Fstate = ctx._problem.u Fcell_kernels, Fint_facet_kernels = residual_funptr(F, Fstate) Fcell_kernel, = Fcell_kernels Fcell_flops = Fcell_kernel.kinfo.kernel.num_flops Fop_data_args, Fop_map_args = make_c_arguments(F, Fcell_kernel, Fstate, operator.methodcaller("cell_node_map"), require_state=True) code, Struct = make_residual_wrapper(Fop_data_args, Fop_map_args, Fcell_flops) Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) Fop_struct = make_c_struct(Fop_data_args, Fop_map_args, Fcell_kernel.funptr, Struct) Fhas_int_facet_kernel = False if len(Fint_facet_kernels) > 0: Fint_facet_kernel, = Fint_facet_kernels Fhas_int_facet_kernel = True Fint_facet_flops = Fint_facet_kernel.kinfo.kernel.num_flops facet_Fop_data_args, facet_Fop_map_args = make_c_arguments(F, Fint_facet_kernel, Fstate, operator.methodcaller("interior_facet_node_map"), require_state=True, require_facet_number=True) code, Struct = make_jacobian_wrapper(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_flops) facet_Fop_function = load_c_function(code, "ComputeResidual", mesh.comm) point2facet = extract_unique_domain(F).interior_facets.point2facetnumber.ctypes.data facet_Fop_struct = make_c_struct(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_kernel.funptr, Struct, point2facet=point2facet) patch.setDM(self.plex) patch.setPatchCellNumbering(mesh._cell_numbering) offsets = numpy.append([0], numpy.cumsum([W.dof_count for W in V])).astype(PETSc.IntType) patch.setPatchDiscretisationInfo([W.dm for W in V], numpy.array([W.block_size for W in V], dtype=PETSc.IntType), [W.cell_node_list for W in V], offsets, ghost_bc_nodes, global_bc_nodes) self.Jop_struct = Jop_struct set_patch_jacobian(patch, ctypes.cast(Jop_function, ctypes.c_voidp).value, ctypes.addressof(Jop_struct), is_snes=is_snes) if Jhas_int_facet_kernel: self.facet_Jop_struct = facet_Jop_struct set_patch_jacobian(patch, ctypes.cast(facet_Jop_function, ctypes.c_voidp).value, ctypes.addressof(facet_Jop_struct), is_snes=is_snes, interior_facets=True) if set_residual: self.Fop_struct = Fop_struct set_patch_residual(patch, ctypes.cast(Fop_function, ctypes.c_voidp).value, ctypes.addressof(Fop_struct), is_snes=is_snes) if Fhas_int_facet_kernel: set_patch_residual(patch, ctypes.cast(facet_Fop_function, ctypes.c_voidp).value, ctypes.addressof(facet_Fop_struct), is_snes=is_snes, interior_facets=True) patch.setPatchConstructType(PETSc.PC.PatchConstructType.PYTHON, operator=self.user_construction_op) patch.setAttr("ctx", ctx) patch.incrementTabLevel(1, parent=obj) patch.setFromOptions() patch.setUp() self.patch = patch def destroy(self, obj): # In this destructor we clean up the __firedrake_mesh__ we set # on the plex and the context we set on the patch object. # We have to check if these attributes are available because # the destroy function will be called by petsc4py when # PCPythonSetContext is called (which occurs before # initialize). if hasattr(self, "plex"): d = self.plex.getDict() try: del d["__firedrake_mesh__"] except KeyError: pass if hasattr(self, "patch"): try: del self.patch.getDict()["ctx"] except KeyError: pass self.patch.destroy() def user_construction_op(self, obj, *args, **kwargs): prefix = obj.getOptionsPrefix() or "" sentinel = object() usercode = PETSc.Options(prefix).getString("%s_patch_construct_python_type" % self._objectname, default=sentinel) if usercode == sentinel: raise ValueError("Must set %s%s_patch_construct_python_type" % (prefix, self._objectname)) (modname, funname) = usercode.rsplit('.', 1) mod = __import__(modname) fun = getattr(mod, funname) if isinstance(fun, type): fun = fun() return fun(obj, *args, **kwargs) def update(self, pc): self.patch.setUp() def view(self, pc, viewer=None): self.patch.view(viewer=viewer)
[docs] class PatchPC(PCBase, PatchBase):
[docs] def configure_patch(self, patch, pc): (A, P) = pc.getOperators() patch.setOperators(A, P)
[docs] def apply(self, pc, x, y): self.patch.apply(x, y)
[docs] def applyTranspose(self, pc, x, y): self.patch.applyTranspose(x, y)
[docs] class PatchSNES(SNESBase, PatchBase):
[docs] def configure_patch(self, patch, snes): patch.setTolerances(max_it=1) patch.setConvergenceTest("skip") (f, residual) = snes.getFunction() assert residual is not None (fun, args, kargs) = residual patch.setFunction(fun, f.duplicate(), args=args, kargs=kargs) # Need an empty RHS for the solve, # PCApply can't deal with RHS = NULL, # and this goes through a call to PCApply at some point self.dummy = f.duplicate()
[docs] def step(self, snes, x, f, y): push_appctx(self.plex, self.ctx) x.copy(y) self.patch.solve(snes.vec_rhs or self.dummy, y) y.axpy(-1, x) y.scale(-1) snes.setConvergedReason(self.patch.getConvergedReason()) pop_appctx(self.plex)