Source code for firedrake.embedding

# -*- coding: utf-8 -*-
"""Module for utility functions for scalable HDF5 I/O."""
import finat.ufl
import ufl

[docs] def get_embedding_dg_element(element): cell = element.cell degree = family = lambda c: "DG" if c.is_simplex() else "DQ" if isinstance(cell, ufl.TensorProductCell): if type(degree) is int: scalar_element = finat.ufl.FiniteElement("DQ", cell=cell, degree=degree) else: scalar_element = finat.ufl.TensorProductElement(*(finat.ufl.FiniteElement(family(c), cell=c, degree=d) for (c, d) in zip(cell.sub_cells(), degree))) else: scalar_element = finat.ufl.FiniteElement(family(cell), cell=cell, degree=degree) shape = element.value_shape if len(shape) == 0: DG = scalar_element elif len(shape) == 1: shape, = shape DG = finat.ufl.VectorElement(scalar_element, dim=shape) else: if isinstance(element, finat.ufl.TensorElement): symmetry = element.symmetry() else: symmetry = None DG = finat.ufl.TensorElement(scalar_element, shape=shape, symmetry=symmetry) return DG
native_elements_for_checkpointing = {"Lagrange", "Discontinuous Lagrange", "Q", "DQ", "Real"}
[docs] def get_embedding_element_for_checkpointing(element): """Convert the given UFL element to an element that :class:`~.CheckpointFile` can handle.""" if in native_elements_for_checkpointing: return element else: return get_embedding_dg_element(element)
[docs] def get_embedding_method_for_checkpointing(element): """Return the method used to embed element in dg space.""" if isinstance(element, (finat.ufl.HDivElement, finat.ufl.HCurlElement, finat.ufl.WithMapping)): return "project" elif isinstance(element, (finat.ufl.VectorElement, finat.ufl.TensorElement)): elem, = set(element.sub_elements) return get_embedding_method_for_checkpointing(elem) elif in ['Lagrange', 'Discontinuous Lagrange', 'Nedelec 1st kind H(curl)', 'Raviart-Thomas', 'Nedelec 2nd kind H(curl)', 'Brezzi-Douglas-Marini', 'Q', 'DQ', 'S', 'DPC', 'Real']: return "interpolate" elif isinstance(element, finat.ufl.TensorProductElement): methods = [get_embedding_method_for_checkpointing(elem) for elem in element.sub_elements] if any(method == "project" for method in methods): return "project" else: return "interpolate" else: return "project"