firedrake.ml.jax package¶
Submodules¶
firedrake.ml.jax.fem_operator module¶
- class firedrake.ml.jax.fem_operator.FiredrakeJaxOperator(F: ReducedFunctional)[source]¶
Bases:
objectJAX custom operator representing a set of Firedrake operations expressed as a reduced functional F.
FiredrakeJaxOperator executes forward and backward passes by directly calling the reduced functional F.
- Parameters:
F – The reduced functional to wrap.
- forward = None¶
- firedrake.ml.jax.fem_operator.fem_operator(F: ReducedFunctional) FiredrakeJaxOperator[source]¶
Cast a Firedrake reduced functional to a JAX operator.
The resulting
FiredrakeJaxOperatorwill take JAX tensors as inputs and return JAX tensors as outputs.- Parameters:
F – The reduced functional to wrap.
- Returns:
A JAX custom operator that wraps the reduced functional F.
- Return type:
- firedrake.ml.jax.fem_operator.from_jax(x: jax.Array, V: WithGeometry | None = None) Function | Constant[source]¶
Convert a JAX tensor x into a Firedrake object.
- Parameters:
- Returns:
Firedrake object representing the JAX tensor x.
- Return type:
- firedrake.ml.jax.fem_operator.to_jax(x: Function | Constant, gather: bool | None = False, batched: bool | None = False, **kwargs) jax.Array[source]¶
Convert a Firedrake object x into a JAX tensor.
- Parameters:
x – Firedrake object to convert.
gather – If True, gather data from all processes
batched – If True, add a batch dimension to the tensor
kwargs –
- Additional arguments to be passed to the
jax.Arrayconstructor such as: device: device on which the tensor is allocated
dtype: the desired data type of returned tensor (default: type of x.dat.data)
- Additional arguments to be passed to the
- Returns:
JAX tensor representing the Firedrake object x.
- Return type:
firedrake.ml.jax.ml_operator module¶
- class firedrake.ml.jax.ml_operator.JaxOperator(*operands: Expr | BaseForm, function_space: WithGeometryBase, derivatives: tuple | None = None, argument_slots: tuple[BaseCoefficient | BaseArgument] = (), operator_data: dict | None = {})[source]¶
Bases:
MLOperatorExternal operator class representing machine learning models implemented in JAX.
The
JaxOperatorallows users to embed machine learning models implemented in JAX into PDE systems implemented in Firedrake. The actual evaluation of theJaxOperatoris delegated to the specified JAX model. Similarly, differentiation through theJaxOperatoris achieved using JAX differentiation on the associated JAX model.- Parameters:
*operands – Operands of the
JaxOperator.function_space – The function space the ML operator is mapping to.
derivatives – Tuple specifying the derivative multi-index.
argument_slots – Tuple containing the arguments of the linear form associated with the ML operator, i.e., the arguments with respect to which the ML operator is linear. These arguments can be
ufl.argument.BaseArgumentobjects, as a result of differentiation, or bothufl.coefficient.BaseCoefficientandufl.argument.BaseArgumentobjects, as a result of taking the action on a given function.operator_data – Dictionary to stash external data specific to the ML operator. This dictionary must contain the following: (i)
'model': The machine learning model implemented in JaX. (ii)'model': The format of the inputs to the ML model:0for models acting globally on the inputs.1for models acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing theJaxOperatorclass.
- firedrake.ml.jax.ml_operator.ml_operator(model: Callable, function_space: WithGeometryBase, inputs_format: int | None = 0) Callable[source]¶
Helper function for instantiating the
JaxOperatorclass.This function facilitates having a two-stage instantiation which dissociates between class arguments that are fixed, such as the function space or the ML model, and the operands of the operator, which may change, e.g. when the operator is used in a time-loop.
Example
# Stage 1: Partially initialise the operator. N = ml_operator(model, function_space=V) # Stage 2: Define the operands and use the operator in a UFL expression. F = (inner(grad(u), grad(v)) + inner(N(u), v) - inner(f, v)) * dx
- Parameters:
model – The JAX model to embed in Firedrake.
function_space – The function space into which the machine learning model is mapping.
inputs_format – The format of the input data of the ML model: 0 for models acting globally on the inputs, 1 when acting locally/pointwise on the inputs. Other strategies can also be considered by subclassing the
JaxOperatorclass.
- Returns:
The partially initialised
JaxOperatorclass.- Return type:
