from firedrake.preconditioners.base import PCBase, SNESBase, PCSNESBase
from firedrake.preconditioners.asm import validate_overlap
from firedrake.petsc import PETSc
from firedrake.cython.patchimpl import set_patch_residual, set_patch_jacobian
from firedrake.solving_utils import _SNESContext
from firedrake.utils import complex_mode, IntType
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
from firedrake.interpolation import interpolate
from firedrake.ufl_expr import extract_domains
from collections import namedtuple
import operator
from itertools import chain
from functools import cached_property, partial
import numpy
from finat.ufl import VectorElement, MixedElement
from ufl.domain import extract_unique_domain
from tsfc.ufl_utils import extract_firedrake_constants
import weakref
import petsctools
import ctypes
from pyop2 import op2
import pyop2.types
from pyop2.compilation import load
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel
from pyop2.mpi import COMM_SELF
__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")
class DenseSparsity(object):
def __init__(self, rset, cset):
self.shape = (1, 1)
self._nrows = rset.size
self._ncols = cset.size
self._dims = (((1, 1), ), )
self.dims = self._dims
self.dsets = rset, cset
def __getitem__(self, *args):
return self
def __contains__(self, *args):
return True
class LocalPack(Pack):
def pick_loop_indices(self, loop_index, layer_index, entity_index):
return (entity_index, layer_index)
class LocalMatPack(LocalPack, MatPack):
insertion_names = {False: "MatSetValues",
True: "MatSetValues"}
class LocalMatKernelArg(op2.MatKernelArg):
pack = LocalMatPack
class LocalMatLegacyArg(op2.MatLegacyArg):
@property
def global_kernel_arg(self):
map_args = [m._global_kernel_arg for m in self.maps]
return LocalMatKernelArg(self.data.dims, map_args)
class LocalMat(pyop2.types.AbstractMat):
def __init__(self, dset):
self._sparsity = DenseSparsity(dset, dset)
self.dtype = numpy.dtype(PETSc.ScalarType)
def __call__(self, access, maps):
return LocalMatLegacyArg(self, maps, access)
class LocalDatPack(LocalPack, DatPack):
def __init__(self, needs_mask, *args, **kwargs):
self.needs_mask = needs_mask
super().__init__(*args, **kwargs)
def _mask(self, map_):
if self.needs_mask:
return Comparison(">=", map_, Literal(numpy.int32(0)))
else:
return None
class LocalDatKernelArg(op2.DatKernelArg):
def __init__(self, *args, needs_mask, **kwargs):
super().__init__(*args, **kwargs)
self.needs_mask = needs_mask
@property
def pack(self):
return partial(LocalDatPack, self.needs_mask)
class LocalDatLegacyArg(op2.DatLegacyArg):
@property
def global_kernel_arg(self):
map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
return LocalDatKernelArg(self.data.dataset.dim, map_arg,
needs_mask=self.data.needs_mask)
class LocalDat(pyop2.types.AbstractDat):
def __init__(self, dset, needs_mask=False):
self._dataset = dset
self.dtype = numpy.dtype(PETSc.ScalarType)
self._shape = (dset.total_size,) + (() if dset.cdim == 1 else dset.dim)
self.needs_mask = needs_mask
@cached_property
def _wrapper_cache_key_(self):
return super()._wrapper_cache_key_ + (self.needs_mask, )
def __call__(self, access, map_=None):
return LocalDatLegacyArg(self, map_, access)
def increment_dat_version(self):
pass
register_petsc_function("MatSetValues")
CompiledKernel = namedtuple('CompiledKernel', ["funptr", "kinfo"])
def get_map(V, base_mesh, base_integral_type):
return V.topological.entity_node_map(base_mesh.topology, base_integral_type, None, None)
def matrix_funptr(form, state):
from firedrake.tsfc_interface import compile_form
test, trial = map(operator.methodcaller("function_space"), form.arguments())
if test != trial:
raise NotImplementedError("Only for matching test and trial spaces")
if state is not None:
dont_split = (state, )
else:
dont_split = ()
kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)
all_meshes = extract_domains(form)
cell_kernels = []
int_facet_kernels = []
ext_facet_kernels = []
for kernel in kernels:
kinfo = kernel.kinfo
mesh = all_meshes[kinfo.domain_number] # integration domain
integral_type = kinfo.integral_type
if kinfo.subdomain_id != ("otherwise",):
raise NotImplementedError("Only for full domain integrals")
if kinfo.integral_type not in {"cell", "interior_facet", "exterior_facet"}:
raise NotImplementedError("Only for cell, interior facet, or exterior facet integrals")
# OK, now we've validated the kernel, let's build the callback
args = []
if integral_type == "cell":
kernels = cell_kernels
elif integral_type == "interior_facet":
kernels = int_facet_kernels
elif integral_type == "exterior_facet":
kernels = ext_facet_kernels
toset = op2.Set(1, comm=test.comm)
dofset = op2.DataSet(toset, 1)
arity = sum(m.arity*s.cdim
for m, s in zip(get_map(test, mesh, integral_type),
test.dof_dset))
iterset = get_map(test, mesh, integral_type).iterset
entity_node_map = op2.Map(iterset,
toset, arity,
values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
mat = LocalMat(dofset)
arg = mat(op2.INC, (entity_node_map, entity_node_map))
args.append(arg.global_kernel_arg)
statedat = LocalDat(dofset)
state_entity_node_map = op2.Map(iterset,
toset, arity,
values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
statearg = statedat(op2.READ, state_entity_node_map)
for i in kinfo.active_domain_numbers.coordinates:
c = all_meshes[i].coordinates
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for i in kinfo.active_domain_numbers.cell_orientations:
c = all_meshes[i].cell_orientations()
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for i in kinfo.active_domain_numbers.cell_sizes:
c = all_meshes[i].cell_sizes
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for n, indices in kinfo.coefficient_numbers:
c = form.coefficients()[n]
if c is state:
if indices != (0, ):
raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
args.append(statearg.global_kernel_arg)
continue
for ind in indices:
c_ = c.subfunctions[ind]
map_ = get_map(c_.function_space(), mesh, integral_type)
if c_.function_space().ufl_element().family() == "Real":
# Interior facet integrals double Real coefficients for the
# two sides of the facet, matching the TSFC-generated kernel.
arg = op2.GlobalKernelArg(
(c_.function_space().block_size,), double=integral_type.startswith("interior_facet")
)
else:
arg = c_.dat(op2.READ, map_).global_kernel_arg
args.append(arg)
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
args.append(all_constants[constant_index].dat(op2.READ).global_kernel_arg)
if integral_type == "interior_facet":
arg = mesh.interior_facets.local_facet_dat(op2.READ)
args.append(arg.global_kernel_arg)
elif integral_type == "exterior_facet":
arg = mesh.exterior_facets.local_facet_dat(op2.READ)
args.append(arg.global_kernel_arg)
iterset = op2.Subset(iterset, [])
mod = op2.GlobalKernel(kinfo.kernel, args, subset=True)
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels, ext_facet_kernels
def residual_funptr(form, state):
from firedrake.tsfc_interface import compile_form
test, = map(operator.methodcaller("function_space"), form.arguments())
if state.function_space() != test:
raise NotImplementedError("State and test space must be dual to one-another")
if state is not None:
dont_split = (state, )
else:
dont_split = ()
kernels = compile_form(form, "subspace_form", split=False, dont_split=dont_split)
all_meshes = extract_domains(form)
cell_kernels = []
int_facet_kernels = []
ext_facet_kernels = []
for kernel in kernels:
kinfo = kernel.kinfo
mesh = all_meshes[kinfo.domain_number] # integration domain
integral_type = kinfo.integral_type
if kinfo.subdomain_id != ("otherwise",):
raise NotImplementedError("Only for full domain integrals")
if kinfo.integral_type not in {"cell", "interior_facet", "exterior_facet"}:
raise NotImplementedError("Only for cell, interior facet, or exterior facet integrals")
args = []
if kinfo.integral_type == "cell":
kernels = cell_kernels
elif kinfo.integral_type == "interior_facet":
kernels = int_facet_kernels
elif kinfo.integral_type == "exterior_facet":
kernels = ext_facet_kernels
toset = op2.Set(1, comm=test.comm)
dofset = op2.DataSet(toset, 1)
arity = sum(m.arity*s.cdim
for m, s in zip(get_map(test, mesh, integral_type),
test.dof_dset))
iterset = get_map(test, mesh, integral_type).iterset
entity_node_map = op2.Map(iterset,
toset, arity,
values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
dat = LocalDat(dofset, needs_mask=True)
statedat = LocalDat(dofset)
state_entity_node_map = op2.Map(iterset,
toset, arity,
values=numpy.zeros(iterset.total_size*arity, dtype=IntType))
statearg = statedat(op2.READ, state_entity_node_map)
arg = dat(op2.INC, entity_node_map)
args.append(arg.global_kernel_arg)
for i in kinfo.active_domain_numbers.coordinates:
c = all_meshes[i].coordinates
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for i in kinfo.active_domain_numbers.cell_orientations:
c = all_meshes[i].cell_orientations()
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for i in kinfo.active_domain_numbers.cell_sizes:
c = all_meshes[i].cell_sizes
arg = c.dat(op2.READ, get_map(c.function_space(), mesh, integral_type))
args.append(arg.global_kernel_arg)
for n, indices in kinfo.coefficient_numbers:
c = form.coefficients()[n]
if c is state:
if indices != (0, ):
raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
args.append(statearg.global_kernel_arg)
continue
for ind in indices:
c_ = c.subfunctions[ind]
map_ = get_map(c_.function_space(), mesh, integral_type)
if c_.function_space().ufl_element().family() == "Real":
# Interior facet integrals double Real coefficients for the
# two sides of the facet, matching the TSFC-generated kernel.
arg = op2.GlobalKernelArg(
(c_.function_space().block_size,), double=integral_type.startswith("interior_facet")
)
else:
arg = c_.dat(op2.READ, map_).global_kernel_arg
args.append(arg)
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
args.append(all_constants[constant_index].dat(op2.READ).global_kernel_arg)
if kinfo.integral_type == "interior_facet":
arg = extract_unique_domain(test).interior_facets.local_facet_dat(op2.READ)
args.append(arg.global_kernel_arg)
elif kinfo.integral_type == "exterior_facet":
arg = extract_unique_domain(test).exterior_facets.local_facet_dat(op2.READ)
args.append(arg.global_kernel_arg)
iterset = op2.Subset(iterset, [])
mod = op2.GlobalKernel(kinfo.kernel, args, subset=True)
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels, ext_facet_kernels
# We need to set C function pointer callbacks for PCPatch to work.
# Although petsc4py provides a high-level Python wrapper for them,
# this is very costly when going back and forth from C to Python only
# to extract function pointers and send them straight back to C. Here,
# since we know what the calling convention of the C function is, we
# just wrap up everything as a C function pointer and use that
# directly.
def make_struct(op_coeffs, op_maps, jacobian=False):
import ctypes
coeffs = []
maps = []
for i, c in enumerate(op_coeffs):
if c is None:
coeffs.append("state")
else:
coeffs.append("c{}".format(i))
for i, m in enumerate(op_maps):
if m is None:
maps.append("dofArrayWithAll")
else:
maps.append("m{}".format(i))
coeff_struct = ";\n".join(" const PetscScalar *c{}".format(i) for i, c in enumerate(op_coeffs) if c is not None)
map_struct = ";\n".join(" const PetscInt *m{}".format(i) for i, m in enumerate(op_maps) if m is not None)
coeff_decl = ", ".join("const PetscScalar *restrict {}".format(c) for c in coeffs)
map_decl = ", ".join("const PetscInt *restrict {}".format(m) for m in maps)
coeff_call = ", ".join(c if c == "state" else "ctx->{}".format(c) for c in coeffs)
map_call = ", ".join(m if m == "dofArrayWithAll" else "ctx->{}".format(m) for m in maps)
if jacobian:
out = "Mat J"
else:
out = "PetscScalar * restrict F"
function = " void (*pyop2_call)(int start, int end, const PetscInt * restrict cells, {}, {}, const PetscInt *restrict dofArray, {})".format(out, coeff_decl, map_decl)
fields = []
for c in coeffs:
if c != "state":
fields.append((c, ctypes.c_voidp))
for m in maps:
if m != "dofArrayWithAll":
fields.append((m, ctypes.c_voidp))
fields.append(("point2facet", ctypes.c_voidp))
fields.append(("pyop2_call", ctypes.c_voidp))
class Struct(ctypes.Structure):
_fields_ = fields
struct = """
typedef struct {{
{};
{};
const PetscInt *point2facet;
{};
}} UserCtx;""".format(coeff_struct, map_struct, function)
call = "pyop2_call(0, npoints, whichPoints, out, {}, activeDofsArray, {})".format(coeff_call, map_call)
return struct, call, Struct
def make_residual_wrapper(coeffs, maps, flops):
struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=False)
return """
#include <petsc.h>
{}
static PetscInt pointbuf[128];
PetscErrorCode ComputeResidual(PC pc,
PetscInt point,
Vec x,
Vec F,
IS points,
PetscInt ndof,
const PetscInt *dofArray,
const PetscInt *dofArrayWithAll,
void *ctx_)
{{
const PetscScalar *state = NULL;
const PetscInt *whichPoints = NULL;
const PetscInt *activeDofsArray = dofArray;
PetscScalar *out = NULL;
UserCtx *ctx = (UserCtx *)ctx_;
PetscInt npoints;
PetscInt *filtpoints = NULL;
PetscInt *filtdofs = NULL;
PetscErrorCode ierr;
PetscFunctionBeginUser;
ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
if (!npoints) PetscFunctionReturn(0);
ierr = VecSet(F, 0.0);CHKERRQ(ierr);
if (x) {{
ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
}}
ierr = VecGetArray(F, &out);CHKERRQ(ierr);
ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
if (ctx->point2facet) {{
PetscInt nvalid = 0;
PetscInt tDPP = ndof / npoints;
ierr = PetscMalloc1(npoints, &filtpoints);CHKERRQ(ierr);
if (ndof > 0) {{ ierr = PetscMalloc1(ndof, &filtdofs);CHKERRQ(ierr); }}
for (PetscInt i = 0; i < npoints; i++) {{
PetscInt fi = ctx->point2facet[whichPoints[i]];
if (fi >= 0) {{
filtpoints[nvalid] = fi;
for (PetscInt d = 0; d < tDPP; d++)
filtdofs[nvalid * tDPP + d] = dofArray[i * tDPP + d];
nvalid++;
}}
}}
ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
npoints = nvalid;
whichPoints = filtpoints;
activeDofsArray = filtdofs;
}}
if (npoints) ctx->{};
if (ctx->point2facet) {{
ierr = PetscFree(filtpoints);
ierr = PetscFree(filtdofs);
}} else {{
ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
}}
ierr = VecRestoreArray(F, &out);CHKERRQ(ierr);
if (x) {{
ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
}}
PetscLogFlops({} * npoints);
PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct
def make_jacobian_wrapper(coeffs, maps, flops):
struct_decl, pyop2_call, struct = make_struct(coeffs, maps, jacobian=True)
return """
#include <petsc.h>
{}
static PetscInt pointbuf[128];
PetscErrorCode ComputeJacobian(PC pc,
PetscInt point,
Vec x,
Mat out,
IS points,
PetscInt ndof,
const PetscInt *dofArray,
const PetscInt *dofArrayWithAll,
void *ctx_)
{{
const PetscScalar *state = NULL;
const PetscInt *whichPoints = NULL;
const PetscInt *activeDofsArray = dofArray;
UserCtx *ctx = (UserCtx *)ctx_;
PetscInt npoints;
PetscInt *filtpoints = NULL;
PetscInt *filtdofs = NULL;
PetscErrorCode ierr;
PetscFunctionBeginUser;
ierr = ISGetSize(points, &npoints);CHKERRQ(ierr);
if (!npoints) PetscFunctionReturn(0);
if (x) {{
ierr = VecGetArrayRead(x, &state);CHKERRQ(ierr);
}}
ierr = ISGetIndices(points, &whichPoints);CHKERRQ(ierr);
if (ctx->point2facet) {{
PetscInt nvalid = 0;
PetscInt tDPP = ndof / npoints;
ierr = PetscMalloc1(npoints, &filtpoints);CHKERRQ(ierr);
if (ndof > 0) {{ ierr = PetscMalloc1(ndof, &filtdofs);CHKERRQ(ierr); }}
for (PetscInt i = 0; i < npoints; i++) {{
PetscInt fi = ctx->point2facet[whichPoints[i]];
if (fi >= 0) {{
filtpoints[nvalid] = fi;
for (PetscInt d = 0; d < tDPP; d++)
filtdofs[nvalid * tDPP + d] = dofArray[i * tDPP + d];
nvalid++;
}}
}}
ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
npoints = nvalid;
whichPoints = filtpoints;
activeDofsArray = filtdofs;
}}
if (npoints) ctx->{};
if (ctx->point2facet) {{
ierr = PetscFree(filtpoints);
ierr = PetscFree(filtdofs);
}} else {{
ierr = ISRestoreIndices(points, &whichPoints);CHKERRQ(ierr);
}}
if (x) {{
ierr = VecRestoreArrayRead(x, &state);CHKERRQ(ierr);
}}
PetscLogFlops({} * npoints);
PetscFunctionReturn(0);
}}
""".format(struct_decl, pyop2_call, flops), struct
def load_c_function(code, name, comm):
cppargs = petsctools.get_petsc_dirs(prefix="-I", subdir="include")
ldargs = (
*petsctools.get_petsc_dirs(prefix="-L", subdir="lib"),
*petsctools.get_petsc_dirs(prefix="-Wl,-rpath,", subdir="lib"),
"-lpetsc",
"-lm",
)
dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm)
fn = getattr(dll, name)
fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp]
fn.restype = ctypes.c_int
return fn
def make_c_arguments(form, kernel, state, integral_type, require_state=False,
require_facet_number=False):
all_meshes = extract_domains(form)
mesh = all_meshes[kernel.kinfo.domain_number]
coeffs = []
coeffs.extend([all_meshes[i].coordinates for i in kernel.kinfo.active_domain_numbers.coordinates])
coeffs.extend([all_meshes[i].cell_orientations() for i in kernel.kinfo.active_domain_numbers.cell_orientations])
coeffs.extend([all_meshes[i].cell_sizes for i in kernel.kinfo.active_domain_numbers.cell_sizes])
for n, indices in kernel.kinfo.coefficient_numbers:
c = form.coefficients()[n]
if c is state:
if indices != (0, ):
raise ValueError(f"Active indices of state (dont_split) function must be (0, ), not {indices}")
coeffs.append(c)
else:
coeffs.extend([c.subfunctions[ind] for ind in indices])
if require_state:
assert state in coeffs, "Couldn't find state vector in form coefficients"
data_args = []
map_args = []
seen = set()
for c in coeffs:
if c is state:
data_args.append(None)
map_args.append(None)
else:
data_args.extend(c.dat._kernel_args_)
map_ = get_map(c.function_space(), mesh, integral_type)
if map_ is not None:
for k in map_._kernel_args_:
if k not in seen:
map_args.append(k)
seen.add(k)
all_constants = extract_firedrake_constants(form)
for constant_index in kernel.kinfo.constant_numbers:
data_args.extend(all_constants[constant_index].dat._kernel_args_)
if require_facet_number:
if integral_type == "interior_facet":
data_args.extend(mesh.interior_facets.local_facet_dat._kernel_args_)
elif integral_type == "exterior_facet":
data_args.extend(mesh.exterior_facets.local_facet_dat._kernel_args_)
return data_args, map_args
def make_c_struct(data_args, map_args, function, struct, point2facet=None):
args = [a for a in chain(data_args, map_args) if a is not None]
if point2facet is None:
args.append(0)
else:
args.append(point2facet)
return struct(*args, ctypes.cast(function, ctypes.c_voidp).value)
def bcdofs(bc, ghost=True):
# Return the global dofs fixed by a DirichletBC
# in the numbering given by concatenation of all the
# subspaces of a mixed function space
Z = bc.function_space()
while Z.parent is not None:
Z = Z.parent
indices = bc._indices
offset = 0
for (i, idx) in enumerate(indices):
if isinstance(Z.ufl_element(), VectorElement):
offset += idx
assert i == len(indices)-1 # assert we're at the end of the chain
assert Z.sub(idx).block_size == 1
elif isinstance(Z.ufl_element(), MixedElement):
if ghost:
offset += sum(Z.sub(j).dof_count for j in range(idx))
else:
offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
else:
raise NotImplementedError("How are you taking a .sub?")
Z = Z.sub(idx)
if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
bs = Z.parent.block_size
start = 0
stop = 1
else:
bs = Z.block_size
start = 0
stop = bs
nodes = bc.nodes
if not ghost:
nodes = nodes[nodes < Z.dof_dset.size]
return numpy.concatenate([nodes*bs + j for j in range(start, stop)]) + offset
def select_entity(p, dm=None, exclude=None):
"""Filter entities based on some label.
:arg p: the entity.
:arg dm: the DMPlex object to query for labels.
:arg exclude: The label marking points to exclude."""
if exclude is None:
return True
else:
# If the exclude label marks this point (the value is not -1),
# we don't want it.
return dm.getLabelValue(exclude, p) == -1
[docs]
class PlaneSmoother(object):
[docs]
@staticmethod
def coords(dm, p, coordinates):
coordinatesV = coordinates.function_space()
data = coordinates.dat.data_ro_with_halos
coordinatesDM = coordinatesV.dm
coordinatesSection = coordinatesDM.getDefaultSection()
closure_of_p = [x for x in dm.getTransitiveClosure(p, useCone=True)[0] if coordinatesSection.getDof(x) > 0]
gdim = data.shape[1]
bary = numpy.zeros(gdim)
ndof = 0
for p_ in closure_of_p:
(dof, offset) = (coordinatesSection.getDof(p_), coordinatesSection.getOffset(p_))
bary += data[offset:offset + dof].reshape(dof, gdim).sum(axis=0)
ndof += dof
bary /= ndof
return bary
[docs]
def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None):
# compute
# [(pStart, (x, y, z)), (pEnd, (x, y, z))]
from firedrake.assemble import assemble
if ndiv is None and divisions is None:
raise RuntimeError("Must either set ndiv or divisions for PlaneSmoother!")
mesh = dm.getAttr("__firedrake_mesh__")
if len(set(mesh)) == 1:
mesh_unique = mesh.unique()
else:
raise NotImplementedError("Not implemented for general mixed meshes")
coordinates = mesh_unique.coordinates
V = coordinates.function_space()
if V.finat_element.is_dg():
# We're using DG or DQ for our coordinates, so we got
# a periodic mesh. We need to interpolate to CGk
# with access descriptor MAX to define a consistent opinion
# about where the vertices are.
CGk = V.reconstruct(family="Lagrange")
coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX))
select = partial(select_entity, dm=dm, exclude="pyop2_ghost")
entities = [(p, self.coords(dm, p, coordinates)) for p in
filter(select, range(*dm.getChart()))]
if isinstance(axis, int):
minx = min(entities, key=lambda z: z[1][axis])[1][axis]
maxx = max(entities, key=lambda z: z[1][axis])[1][axis]
def keyfunc(z):
coords = tuple(z[1])
return (coords[axis], ) + tuple(coords[:axis] + coords[axis+1:])
else:
minx = axis(min(entities, key=lambda z: axis(z[1]))[1])
maxx = axis(max(entities, key=lambda z: axis(z[1]))[1])
def keyfunc(z):
coords = tuple(z[1])
return (axis(coords), ) + coords
s = sorted(entities, key=keyfunc, reverse=(dir == -1))
(entities, coords) = zip(*s)
if isinstance(axis, int):
coords = [c[axis] for c in coords]
else:
coords = [axis(c) for c in coords]
if divisions is None:
divisions = numpy.linspace(minx, maxx, ndiv+1)
if ndiv is None:
ndiv = numpy.size(divisions)-1
indices = numpy.searchsorted(coords[::dir], divisions)
out = []
for k in range(ndiv):
out.append(entities[indices[k]:indices[k+1]])
out.append(entities[indices[-1]:])
return out
[docs]
def __call__(self, pc):
if complex_mode:
raise NotImplementedError("Sorry, plane smoothers not yet implemented in complex mode")
dm = pc.getDM()
context = dm.getAttr("__firedrake_ctx__")
prefix = pc.getOptionsPrefix() or ""
sentinel = object()
sweeps = PETSc.Options(prefix).getString("pc_patch_construct_ps_sweeps", default=sentinel)
if sweeps == sentinel:
raise ValueError("Must set %spc_patch_construct_ps_sweeps" % prefix)
patches = []
import re
for sweep in sweeps.split(':'):
sweep_split = re.split(r'([+-])', sweep)
try:
axis = int(sweep_split[0])
except ValueError:
try:
axis = context.appctx[sweep_split[0]]
except KeyError:
raise KeyError("PlaneSmoother axis key %s not provided" % sweep_split[0])
dir = {'+': +1, '-': -1}[sweep_split[1]]
# Either use equispaced bins for relaxation or get from appctx
try:
ndiv = int(sweep_split[2])
entities = self.sort_entities(dm, axis, dir, ndiv=ndiv)
except ValueError:
try:
divisions = context.appctx[sweep_split[2]]
entities = self.sort_entities(dm, axis, dir, divisions=divisions)
except KeyError:
raise KeyError("PlaneSmoother division key %s not provided" % sweep_split[2:])
for patch in entities:
if not patch:
continue
else:
iset = PETSc.IS().createGeneral(patch, comm=COMM_SELF)
patches.append(iset)
iterationSet = PETSc.IS().createStride(size=len(patches), first=0, step=1, comm=COMM_SELF)
return (patches, iterationSet)
class PatchBase(PCSNESBase):
def initialize(self, obj):
ctx = get_appctx(obj.getDM())
if ctx is None:
raise ValueError("No context found on form")
if not isinstance(ctx, _SNESContext):
raise ValueError("Don't know how to get form from %r" % ctx)
J, bcs = self.form(obj)
V = J.arguments()[0].function_space()
mesh = V.mesh()
if len(set(mesh)) == 1:
mesh_unique = mesh.unique()
else:
raise NotImplementedError("Not implemented for general mixed meshes")
self.plex = mesh_unique.topology_dm
# We need to attach the mesh and appctx to the plex, so that
# PlaneSmoothers (and any other user-customised patch
# constructors) can use firedrake's opinion of what
# the coordinates are, rather than plex's.
self.plex.setAttr("__firedrake_mesh__", weakref.proxy(mesh))
self.ctx = ctx
self.plex.setAttr("__firedrake_ctx__", weakref.proxy(ctx))
if mesh_unique.cell_set._extruded:
raise NotImplementedError("Not implemented on extruded meshes")
# Validate the mesh overlap
prefix = (obj.getOptionsPrefix() or "") + "patch_"
opts = PETSc.Options(prefix)
petsc_prefix = self._petsc_prefix
patch_type = opts.getString(f"{petsc_prefix}construct_type")
patch_dim = opts.getInt(f"{petsc_prefix}construct_dim", -1)
patch_codim = opts.getInt(f"{petsc_prefix}construct_codim", -1)
if patch_dim != -1:
assert patch_codim == -1, "Cannot set both dim and codim"
elif patch_codim != -1:
assert patch_dim == -1, "Cannot set both dim and codim"
patch_dim = self.plex.getDimension() - patch_codim
else:
patch_dim = 0
validate_overlap(mesh_unique, patch_dim, patch_type)
patch = obj.__class__().create(comm=mesh.comm)
patch.setOptionsPrefix(prefix)
self.configure_patch(patch, obj)
patch.setType("patch")
if isinstance(obj, PETSc.SNES):
Jstate = ctx._problem.u
is_snes = True
else:
Jstate = None
is_snes = False
if len(bcs) > 0:
ghost_bc_nodes = numpy.unique(
numpy.concatenate([bcdofs(bc, ghost=True) for bc in bcs],
dtype=PETSc.IntType)
)
global_bc_nodes = numpy.unique(
numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs],
dtype=PETSc.IntType))
else:
ghost_bc_nodes = numpy.empty(0, dtype=PETSc.IntType)
global_bc_nodes = numpy.empty(0, dtype=PETSc.IntType)
Jcell_kernels, Jint_facet_kernels, Jext_facet_kernels = matrix_funptr(J, Jstate)
Jhas_cell_kernel = len(Jcell_kernels) > 0
if Jhas_cell_kernel:
Jcell_kernel, = Jcell_kernels
Jcell_flops = Jcell_kernel.kinfo.kernel.num_flops
Jop_data_args, Jop_map_args = make_c_arguments(J, Jcell_kernel, Jstate, "cell")
code, Struct = make_jacobian_wrapper(Jop_data_args, Jop_map_args, Jcell_flops)
Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm)
Jop_struct = make_c_struct(Jop_data_args, Jop_map_args, Jcell_kernel.funptr, Struct)
Jhas_int_facet_kernel = False
if len(Jint_facet_kernels) > 0:
Jint_facet_kernel, = Jint_facet_kernels
Jhas_int_facet_kernel = True
Jint_facet_flops = Jint_facet_kernel.kinfo.kernel.num_flops
facet_Jop_data_args, facet_Jop_map_args = make_c_arguments(J, Jint_facet_kernel, Jstate,
"interior_facet",
require_facet_number=True)
code, Struct = make_jacobian_wrapper(facet_Jop_data_args, facet_Jop_map_args, Jint_facet_flops)
facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm)
point2facet = mesh_unique.interior_facets.point2facetnumber.ctypes.data
facet_Jop_struct = make_c_struct(facet_Jop_data_args, facet_Jop_map_args,
Jint_facet_kernel.funptr, Struct,
point2facet=point2facet)
Jhas_ext_facet_kernel = False
if len(Jext_facet_kernels) > 0:
Jext_facet_kernel, = Jext_facet_kernels
Jhas_ext_facet_kernel = True
Jext_facet_flops = Jext_facet_kernel.kinfo.kernel.num_flops
ext_facet_Jop_data_args, ext_facet_Jop_map_args = make_c_arguments(J, Jext_facet_kernel, Jstate,
"exterior_facet",
require_facet_number=True)
code, Struct = make_jacobian_wrapper(ext_facet_Jop_data_args, ext_facet_Jop_map_args, Jext_facet_flops)
ext_facet_Jop_function = load_c_function(code, "ComputeJacobian", mesh.comm)
ext_point2facet = mesh_unique.exterior_facets.point2facetnumber.ctypes.data
ext_facet_Jop_struct = make_c_struct(ext_facet_Jop_data_args, ext_facet_Jop_map_args,
Jext_facet_kernel.funptr, Struct,
point2facet=ext_point2facet)
set_residual = hasattr(ctx, "F") and isinstance(obj, PETSc.SNES)
if set_residual:
F = ctx.F
Fstate = ctx._problem.u
Fcell_kernels, Fint_facet_kernels, Fext_facet_kernels = residual_funptr(F, Fstate)
Fhas_cell_kernel = len(Fcell_kernels) > 0
if Fhas_cell_kernel:
Fcell_kernel, = Fcell_kernels
Fcell_flops = Fcell_kernel.kinfo.kernel.num_flops
Fop_data_args, Fop_map_args = make_c_arguments(F, Fcell_kernel, Fstate,
"cell",
require_state=True)
code, Struct = make_residual_wrapper(Fop_data_args, Fop_map_args, Fcell_flops)
Fop_function = load_c_function(code, "ComputeResidual", mesh.comm)
Fop_struct = make_c_struct(Fop_data_args, Fop_map_args, Fcell_kernel.funptr, Struct)
Fhas_int_facet_kernel = False
if len(Fint_facet_kernels) > 0:
Fint_facet_kernel, = Fint_facet_kernels
Fhas_int_facet_kernel = True
Fint_facet_flops = Fint_facet_kernel.kinfo.kernel.num_flops
facet_Fop_data_args, facet_Fop_map_args = make_c_arguments(F, Fint_facet_kernel, Fstate,
"interior_facet",
require_state=True,
require_facet_number=True)
code, Struct = make_jacobian_wrapper(facet_Fop_data_args, facet_Fop_map_args, Fint_facet_flops)
facet_Fop_function = load_c_function(code, "ComputeResidual", mesh.comm)
point2facet = extract_unique_domain(F).interior_facets.point2facetnumber.ctypes.data
facet_Fop_struct = make_c_struct(facet_Fop_data_args, facet_Fop_map_args,
Fint_facet_kernel.funptr, Struct,
point2facet=point2facet)
Fhas_ext_facet_kernel = False
if len(Fext_facet_kernels) > 0:
Fext_facet_kernel, = Fext_facet_kernels
Fhas_ext_facet_kernel = True
Fext_facet_flops = Fext_facet_kernel.kinfo.kernel.num_flops
ext_facet_Fop_data_args, ext_facet_Fop_map_args = make_c_arguments(F, Fext_facet_kernel, Fstate,
"exterior_facet",
require_state=True,
require_facet_number=True)
code, Struct = make_residual_wrapper(ext_facet_Fop_data_args, ext_facet_Fop_map_args, Fext_facet_flops)
ext_facet_Fop_function = load_c_function(code, "ComputeResidual", mesh.comm)
ext_point2facet = extract_unique_domain(F).exterior_facets.point2facetnumber.ctypes.data
ext_facet_Fop_struct = make_c_struct(ext_facet_Fop_data_args, ext_facet_Fop_map_args,
Fext_facet_kernel.funptr, Struct,
point2facet=ext_point2facet)
patch.setDM(self.plex)
patch.setPatchCellNumbering(mesh_unique._cell_numbering)
offsets = numpy.append([0], numpy.cumsum([W.dof_count
for W in V])).astype(PETSc.IntType)
patch.setPatchDiscretisationInfo([W.dm for W in V],
numpy.array([W.block_size for
W in V], dtype=PETSc.IntType),
[W.cell_node_list for W in V],
offsets,
ghost_bc_nodes,
global_bc_nodes)
if Jhas_cell_kernel:
self.Jop_struct = Jop_struct
set_patch_jacobian(patch, ctypes.cast(Jop_function, ctypes.c_voidp).value,
ctypes.addressof(Jop_struct), is_snes=is_snes)
if Jhas_int_facet_kernel:
self.facet_Jop_struct = facet_Jop_struct
set_patch_jacobian(patch, ctypes.cast(facet_Jop_function, ctypes.c_voidp).value,
ctypes.addressof(facet_Jop_struct), is_snes=is_snes,
interior_facets=True)
if Jhas_ext_facet_kernel:
self.ext_facet_Jop_struct = ext_facet_Jop_struct
set_patch_jacobian(patch, ctypes.cast(ext_facet_Jop_function, ctypes.c_voidp).value,
ctypes.addressof(ext_facet_Jop_struct), is_snes=is_snes,
exterior_facets=True)
if set_residual:
if Fhas_cell_kernel:
self.Fop_struct = Fop_struct
set_patch_residual(patch, ctypes.cast(Fop_function, ctypes.c_voidp).value,
ctypes.addressof(Fop_struct), is_snes=is_snes)
if Fhas_int_facet_kernel:
set_patch_residual(patch, ctypes.cast(facet_Fop_function, ctypes.c_voidp).value,
ctypes.addressof(facet_Fop_struct), is_snes=is_snes,
interior_facets=True)
if Fhas_ext_facet_kernel:
self.ext_facet_Fop_struct = ext_facet_Fop_struct
set_patch_residual(patch, ctypes.cast(ext_facet_Fop_function, ctypes.c_voidp).value,
ctypes.addressof(ext_facet_Fop_struct), is_snes=is_snes,
exterior_facets=True)
patch.setPatchConstructType(PETSc.PC.PatchConstructType.PYTHON, operator=self.user_construction_op)
patch.setAttr("ctx", ctx)
patch.incrementTabLevel(1, parent=obj)
patch.setFromOptions()
patch.setUp()
self.patch = patch
def destroy(self, obj):
# In this destructor we clean up the __firedrake_mesh__ we set
# on the plex and the context we set on the patch object.
# We have to check if these attributes are available because
# the destroy function will be called by petsc4py when
# PCPythonSetContext is called (which occurs before
# initialize).
if hasattr(self, "plex"):
d = self.plex.getDict()
try:
del d["__firedrake_mesh__"]
except KeyError:
pass
if hasattr(self, "patch"):
try:
del self.patch.getDict()["ctx"]
except KeyError:
pass
self.patch.destroy()
def user_construction_op(self, obj, *args, **kwargs):
prefix = obj.getOptionsPrefix() or ""
sentinel = object()
usercode = PETSc.Options(prefix).getString("%s_patch_construct_python_type" % self._objectname, default=sentinel)
if usercode == sentinel:
raise ValueError("Must set %s%s_patch_construct_python_type" % (prefix, self._objectname))
(modname, funname) = usercode.rsplit('.', 1)
mod = __import__(modname)
fun = getattr(mod, funname)
if isinstance(fun, type):
fun = fun()
return fun(obj, *args, **kwargs)
def update(self, pc):
self.patch.setUp()
def view(self, pc, viewer=None):
self.patch.view(viewer=viewer)
[docs]
class PatchPC(PCBase, PatchBase):
_petsc_prefix = "pc_patch_"
[docs]
def apply(self, pc, x, y):
self.patch.apply(x, y)
[docs]
def applyTranspose(self, pc, x, y):
self.patch.applyTranspose(x, y)
[docs]
class PatchSNES(SNESBase, PatchBase):
_petsc_prefix = "snes_patch_"
[docs]
def step(self, snes, x, f, y):
push_appctx(self.plex, self.ctx)
x.copy(y)
self.patch.solve(snes.vec_rhs or self.dummy, y)
y.axpy(-1, x)
y.scale(-1)
snes.setConvergedReason(self.patch.getConvergedReason())
pop_appctx(self.plex)