Weak constraint 4DVar data assimilation

This tutorial was contributed by Josh Hope-Collins

Weak constraint 4DVar

Data assimilation is the process of using real world observations to improve the accuracy of a simulation, and is commonly used in weather and climate modelling. A particular variant called “weak constraint” 4DVar (WC4DVar) allows the use of parallel-in-time solvers.

In data assimilation problems we want to find an approximation \(x_{j}\) of the true values \(x^{t}_{j}\) of a timeseries, \(0<j<N\). and we have the following available to us:

  1. Observation operators \(\mathcal{H}_{j}\), and incomplete and imperfect (noisy) observations of \(x^{t}_{j}\) at each time point \(y_j=\mathcal{H}(x^{t}_{j}) + r_{j}\), where the noise is \(r_{j}\sim\mathcal{N}(0,R_{j})\) with correlation matrix \(R_{j}\).

  2. An imperfect PDE model \(\mathcal{M}_{j}\) that propagates from one value to the next \(x^{t}_{j}=\mathcal{M}_{j}(x^{t}_{j-1})+q_{j}\), where the noise is \(q_{j}\sim\mathcal{N}(0,Q_{j})\) with correlation matrix \(Q_{j}\).

  3. A prior estimate of the initial condition, called the “background”, \(x_{b}=x^{t}_{0}+b\), where the noise is \(b\sim\mathcal{N}(0,B)\) with correlation matrix \(B\).

We want to find a timeseries that minimises the misfits with the background, the observations, and the propagator, which we formulate as finding the minimiser \(\mathbf{x}=(x_{0}, x_{1}, \dots, x_{N})\) of the following objective functional:

\[\min_{\mathbf{x}} \mathcal{J}(\mathbf{x}) = \|x_{0} - x_{b}\|_{B^{-1}}^{2} + \sum^{N_{w}}_{j=0}\|\mathcal{H}_{j}(x_{j}) - y_{j}\|_{R_{j}^{-1}}^{2} + \sum^{N_{w}}_{j=1}\|x_{j} - \mathcal{M}_{j}(x_{j-1})\|_{Q_{j}^{-1}}^{2}\]

The “weak constraint” is that we have allowed our PDE model \(\mathcal{M}\) to be imperfect, rather than requiring the entire timeseries to be a perfect trajectory of \(\mathcal{M}\) as is the standard in “strong constraint” 4DVar. This accounts for numerical and modelling errors in the PDE model, and also enables time-parallelism. Each model misfit term \(x_{j}-\mathcal{M}_{j}(x_{j-1})\) only requires the two neighbouring values of \(x_{j}\), so can be evaluated independently of model misfit terms earlier or later in the timeseries.

The minimisation is solved using a Gauss-Newton method, where at each iteration \(k\) the increment \(\delta x^{k} = x^{k+1} - x^{k}\) is calculated by minimising the linearised objective functional \(J(\mathbf{\delta x})\) (often called the “incremental formulation” in 4DVar literature).

\[\min_{\mathbf{\delta x}} J(\mathbf{\delta x}) = \|\delta x_{0} - b_{0}\|_{B^{-1}}^{2} + \sum^{N_{w}}_{j=0}\|d_{i} - H_{i}\delta x_{i}\|_{R_{j}^{-1}}^{2} + \sum^{N_{w}}_{j=1}\|\delta x_{j} - M_{j}\delta x_{j-1} - c_{i}\|_{Q_{j}^{-1}}^{2}\]

where \(H\) and \(M\) are linearisations of \(\mathcal{H}\) and \(\mathcal{M}\) respectively, and the “misfits” are defined as

\[b_{0} = x_{b} - x^{k}_{0}, \quad d_{i} = y_{i} - \mathcal{H}_{i}(x^{k}_{i}), \quad c_{i} = \mathcal{M}_{i}(x^{k}_{i-1}) - x^{k}_{i}.\]

This is a linear least squares problem which can be written in terms of the Hessian matrix \(\mathbf{S}\) of \(J\).

\[\mathbf{S}\delta\mathbf{x} = (\mathbf{L}^{T}\mathbf{D}^{-1}\mathbf{L} + \mathbf{H}^{T}\mathbf{R}^{-1}\mathbf{H})\mathbf{\delta x} = \mathbf{L}^{T}\mathbf{R}^{-1}\mathbf{b} + \mathbf{H}^{T}\mathbf{R}^{-1}\mathbf{d}\]

with \(\mathbf{b}=(b_{0}, c_{1}, c_{2}, \dots, c_{N})^{T}\) and \(\mathbf{d}=(d_{0}, d_{1}, d_{2}, \dots, d_{N})^{T}\). The matrices in the Hessian are constructed from the linearised propagator and observation operators

\[\begin{split}\mathbf{L} = \begin{pmatrix} I & & & & \\ -M_{1} & I & & & \\ & -M_{2} & I & & \\ & & \ddots & \ddots & \\ & & & -M_{N} & I \\ \end{pmatrix}, \quad \mathbf{H} = \begin{pmatrix} H_{0} & & & & \\ & H_{1} & & & \\ & & H_{2} & & \\ & & & \ddots & \\ & & & & H_{N} \\ \end{pmatrix},\end{split}\]

and the covariance operators for the background, model, and observation errors

\[\begin{split}\mathbf{D} = \begin{pmatrix} B & & & & \\ & Q_{1} & & & \\ & & Q_{2} & & \\ & & & \ddots & \\ & & & & Q_{N} \\ \end{pmatrix}, \quad \mathbf{R} = \begin{pmatrix} R_{0} & & & & \\ & R_{1} & & & \\ & & R_{2} & & \\ & & & \ddots & \\ & & & & R_{N} \\ \end{pmatrix}.\end{split}\]

The observation operator matrix \(\textbf{H}\) and the correlation operator matrices \(\textbf{D}\) and \(\textbf{R}\) are all block diagonal, with one block per timestep, so clearly both the action and inverse of each can be applied parallel-in-time. The model integration matrix \(\textbf{L}\) is block lower bidiagonal so its action can be applied parallel-in-time, with a communication overhead for time-halos that is independent of the number of timesteps.

Assuming that \(\mathbf{S}\) is dominated by the first term, which describes the model integration, we can construct a preconditioner \(\mathbf{\tilde{S}}\) using an approximate model integration operator \(\mathbf{\tilde{L}}\approx\mathbf{L}\)

\[\mathbf{\tilde{S}} = \mathbf{\tilde{L}}^{T}\mathbf{\tilde{D}}^{-1}\mathbf{\tilde{L}} \approx \mathbf{S}\]

Unfortunately, using the exact integration operator \(\mathbf{\tilde{L}}=\mathbf{L}\) is impractical because the inverse is a block dense lower triangular matrix:

\[\begin{split}\textbf{L}^{-1} = \begin{pmatrix} I & & & & \\ -M_{1,1} & I & & & \\ -M_{1,2} & -M_{2,2} & I & & \\ \dots & \dots & \dots & \dots & \\ -M_{1,N} & -M_{2,N} & \dots & -M_{N-1,N} & I \\ \end{pmatrix}\end{split}\]

where \(M_{i,j}=\prod^{j}_{k=i}M_{k}\) is the integration from timestep \(i-1\) to timestep \(j\) (so \(M_{j,j}=M_{j}\)). Clearly, the inverse \(\mathbf{L}^{-1}\) cannot be applied parallel-in-time because each timestep depends on all previous steps, which motivates the need for cheaper approximations \(\mathbf{\tilde{L}}\). For example, we could build \(\mathbf{\tilde{L}}\) using an approximation \(\tilde{M}\approx M\). A very simple approximation sometimes used in WC4DVar is \(\tilde{M}=I\), in which case \(\mathbf{\tilde{L}}^{-1}\) must still be applied sequentially, but can be done so very cheaply because \(M_{i,j}=I\;\forall i,j\).

Finding \(\mathbf{\delta x}\) by solving the Hessian \(\mathbf{S}\) is referred to as the “primal” formulation. An alternative is the “saddle point” formulation, which finds \(\mathbf{\delta x}\) at each Gauss-Newton iteration by solving the saddle point matrix \(\mathbf{A}\) for the KKT conditions:

\[\begin{split}\mathbf{A}\delta\mathbf{w} = \begin{pmatrix} \mathbf{D} & \mathbf{0} & \mathbf{L} \\ \mathbf{0} & \mathbf{R} & \mathbf{H} \\ \mathbf{L}^{T} & \mathbf{H}^{T} & \mathbf{0} \\ \end{pmatrix} \begin{pmatrix} \delta\mathbf{\eta} \\ \delta\mathbf{\lambda} \\ \delta\mathbf{x} \\ \end{pmatrix} = \begin{pmatrix} \mathbf{b} \\ \mathbf{d} \\ \mathbf{0} \\ \end{pmatrix}\end{split}\]

where \(\delta\mathbf{\eta}\) and \(\delta\mathbf{\lambda}\) are Lagrange multipliers for the model and observation misfits respectively. The action of \(\mathbf{A}\) is also parallel-in-time.

There is a wide range of research on preconditioning saddle point systems, often based on partial and/or approximate Schur LDU factorisations. Notice that the primal Hessian \(\mathbf{S}\) is the Schur complement of the saddle point matrix \(\mathbf{A}\) after eliminating \(\delta\mathbf{\eta}\) and \(\delta\mathbf{\lambda}\). Using the same approximation \(\mathbf{\tilde{S}}\) as above we can construct several block preconditioners:

\[\begin{split}\mathbf{P}_{D} = \begin{pmatrix} \mathbf{D} & \mathbf{0} & \mathbf{0} \\ \mathbf{0} & \mathbf{R} & \mathbf{0} \\ \mathbf{0} & \mathbf{0} & \mathbf{\tilde{S}} \\ \end{pmatrix}, \quad \mathbf{P}_{U} = \begin{pmatrix} \mathbf{D} & \mathbf{0} & \mathbf{L} \\ \mathbf{0} & \mathbf{R} & \mathbf{H} \\ \mathbf{0} & \mathbf{0} & \mathbf{\tilde{S}} \\ \end{pmatrix}, \quad \mathbf{P}_{L} = \begin{pmatrix} \mathbf{D} & \mathbf{0} & \mathbf{0} \\ \mathbf{0} & \mathbf{R} & \mathbf{0} \\ \mathbf{L}^{T} & \mathbf{H}^{T} & \mathbf{\tilde{S}} \\ \end{pmatrix}\end{split}\]

In this demo we will solve the WC4DVar system for the advection-diffusion equation using the saddle point formulation preconditioned with \(\mathbf{P_{U}}\). We will go through how to set up and solve the WC4DVar system in Firedrake, with the following steps:

  1. Define the finite element model for the advection-diffusion equation \(\mathcal{M}\).

  2. Define the observation operator \(\mathcal{H}\).

  3. Define the error covariance operators \(B\), \(R\), and \(Q\).

  4. Generate synthetic “ground-truth” observation data for \(y_{j}\).

  5. Create a ReducedFunctional for \(\mathcal{J}\).

  6. Specify a solver configuration and calculate an optimised \(\mathbf{x}\).

Constructing the 4DVar system

First we import Firedrake, including everything from the adjoint module. As we will be generating some random noise, we set the random number generator seed to a fixed value. We will integrate forward in time with an implicit Runge-Kutta method using the excellent Irksome library (to install Irksome either install fdvar with the “demos” optional dependency using pip install fdvar[demos], or see the installation instructions on the Irksome website)

import os
import numpy as np
from firedrake import *
from firedrake.adjoint import *
from irksome import Dt, TimeStepper, GaussLegendre

np.random.seed(13)

We use the advection-diffusion equation in one spatial dimension \(z\), with a spatially varying advection velocity \(c(z)\), a time-dependent forcing term \(g(t)\), and periodic boundary conditions.

\[ \begin{align}\begin{aligned}\partial_{t}u + \vec{c}(z)\cdot\nabla u + \nu\nabla^{2}u = g(t) &\\t \in [0, T], \quad z \in \Omega = [0, 1) &\\u(0, t) = u(1, t) &\\c(z) = 1 + \overline{c}\cos(2\pi z)\end{aligned}\end{align} \]

The reference state \(\hat{u}\) that we will use to generate the “ground-truth” trajectory \(x^{t}\) is just a simple sinusoid.

\[\hat{u} = \overline{u}\sin(2\pi z)\]

For the time integration we use the implicit midpoint rule with the semi-discrete weak form:

\[\int_{\Omega}\left(\partial_{t}u\right)v\mathrm{d}x + \int_{\Omega}\left(\vec{c}\cdot\nabla u_{h} \right)v\mathrm{d}x + \int_{\Omega}\nu\nabla u_{h}\cdot\nabla v\mathrm{d}x - \int_{\Omega}gv\mathrm{d}x = 0, \quad \forall v \in V\]

where \(V\) is the function space for the solution.

First we create the mesh and function spaces. To enable time-parallelism we use Firedrake’s Ensemble, which splits COMM_WORLD into several ensemble members, with spatial parallelism within each ensemble member and time-parallelism between members. Here we specify just one MPI rank per ensemble member, and the number of ensemble members automatically adjusts to use all available ranks. The communicator ensemble.comm is used for the spatial parallelism, so is the one we use to construct the mesh. We create the CG1 function space for \(V\), and the space of real numbers to hold the time \(t\).

ensemble = Ensemble(COMM_WORLD, 1)
ensemble_rank = ensemble.ensemble_rank
ensemble_size = ensemble.ensemble_size

mesh = PeriodicUnitIntervalMesh(100, comm=ensemble.comm)

V = FunctionSpace(mesh, "CG", 1)
Vr = FunctionSpace(mesh, "R", 0)

The control \(\mathbf{x}\) is a timeseries distributed in time over the Ensemble, with each timestep \(x_{j}\) being a Firedrake Function. For this we use an EnsembleFunctionSpace which represents a mixed function space with each component living on a particular ensemble member. To initialise the EnsembleFunctionSpace we just need the Ensemble and a list of FunctionSpace for the local components. We split the number of observation stages N equally across the ensemble members, and include an extra component on the first member for the initial condition \(x_{0}\). The observations are taken at intervals of \(T_{\textrm{stage}}=n_{t}\Delta t\), where \(n_{t}\) is the number of timesteps between each observation.

N = 8
Tstage = 1e-1
nt = 3

if os.getenv("FIREDRAKE_CI") == "1":
    N = 4
    Tstage = 2e-2
    nt = 2

nlocal_stages = N//ensemble_size
nlocal_spaces = nlocal_stages + int(ensemble_rank == 0)

W = EnsembleFunctionSpace([V for _ in range(nlocal_spaces)], ensemble)

Defining the propagator

We construct the propagator \(\mathcal{M}\) for the advection-diffusion scheme, using Irksome to provide the time integrator. The forcing term \(g(t)\) is rather involved, but just ensures that there is some non-trivial variation in the solution and prevents it decaying to zero over long time periods due to the diffusion. We use Irksome’s Dt symbol to signify the time derivative term, and the one stage GaussLegendre method which is equivalent to the implicit midpoint rule.

dt = Function(Vr).assign(Tstage/nt)

t = Function(Vr).zero()
z, = SpatialCoordinate(mesh)

cbar = Constant(0.2)
c = Function(V).project(1 + cbar*cos(2*pi*z))

reynolds = 100
nu = Constant(1/reynolds)

u = Function(V)
v = TestFunction(V)

ubar = Constant(0.3)
reference_ic = Function(V).project(ubar*sin(2*pi*z))

g = (
    ubar*cos(2*pi*z)*(
        - sin(2*pi*(z + 0.1*sin(2*pi*t)))
        + ubar*cos(2*pi*t + 1)*sin(2*pi*(3*z - 2*t))
    )
)

F = (
    inner(Dt(u), v)*dx
    + inner(c, u.dx(0))*v*dx
    + inner(nu*grad(u), grad(v))*dx
    - inner(g, v)*dx(degree=4)
)

solver_parameters = {
    "snes_type": "ksponly",
    "ksp_type": "preonly",
    "pc_type": "lu",
    "ksp_reuse_preconditioner": None,
}

tableau = GaussLegendre(1)

stepper = TimeStepper(
    F, tableau, t, dt, u,
    solver_parameters=solver_parameters,
    options_prefix="irk")

For convenience we make a Python function for the propagator \(\mathcal{M}(x)\).

def M(x):
    stepper.u0.assign(x)
    for _ in range(nt):
        stepper.stages.zero()
        stepper.advance()
        t.assign(t + dt)
    return stepper.u0.copy(deepcopy=True)

Defining the observation operator

Our observations will be point evaluations at a set of random locations in the domain, which are defined using a VertexOnlyMesh. The observation operator \(\mathcal{H}\) is then simply interpolating onto this mesh.

stations = np.random.random_sample((20, 1))
vom = VertexOnlyMesh(mesh, stations)
U = FunctionSpace(vom, "DG", 0)

def H(x):
    return assemble(interpolate(x, U))

Defining the error covariance operators

We need to do three things with correlation operators: apply the action \(B\), the inverse \(B^{-1}\), and and generate physically relevant noise. If \(w\sim\mathcal{N}(0,I)\) is a vector of white noise then \(B^{1/2}w=v\sim\mathcal{N}(0,B)\), i.e. \(B^{1/2}\) transforms uncorrelated noise to correlated noise with covariance \(B\).

Firedrake provides an implementation of diffusion-based autoregressive covariance operators with the AutoregressiveCovariance class. The action \(Bx=y\) of an m-th order autoregressive covariance operator is equivalent to \(m\) Backward Euler steps of a diffusion equation with initial condition \(x\), where the diffusion coefficient depends on the correlation lengthscale. This makes this type of covariance operator well suited to finite element models. If \(m\) is even then an efficient square root \(B^{1/2}\) can be calculated by taking just \(m/2\) Backward Euler steps.

We create the background and model error covariance operators with specified lengthscales \(L\) and standard deviations \(\sigma\). The variance of the model error is made proportional to the length of the observation stage \(T_{\textrm{stage}}\).

sigma_b = sqrt(1e-2)
B = AutoregressiveCovariance(V, L=0.2, sigma=sigma_b, m=2, seed=2)

sigma_q = sqrt(1e-3*Tstage)
Q = AutoregressiveCovariance(V, L=0.05, sigma=sigma_q, m=2, seed=17)

The observations are treated as uncorrelated, i.e. a diagonal covariance operator, which is created by setting \(m=0\).

sigma_r = sqrt(1e-3)
R = AutoregressiveCovariance(U, L=0, sigma=sigma_r, m=0, seed=18)

Firedrake provides an abstract base class CovarianceOperatorBase for implementing new covariance operators.

Generating observational data

We can use a known reference initial condition \(\hat{x}\) to generate synthetic “ground-truth” observations \(y_{i}\). We do this by adding noise consistently with the original definition of the problem, i.e. we add noise \(b_{j}\sim\mathcal{N}(0,B)\) at the initial condition, then at each observation time we add noise \(q_{j}\sim\mathcal{N}(0,Q)\) to the solution and add noise \(r_{j}\sim\mathcal{N}(0,R)\) to the observations. This process is detailed below:

  1. \(x_{b} \leftarrow \hat{x} + b_{b}\)

  2. \(x^{t}_{0} \leftarrow \hat{x} + b_{0}\)

  3. \(y_{0} \leftarrow \mathcal{H}(x^{t}_{0}) + r_{0}\)

  4. for \(j=1\) to \(j=N\) do

    1. \(x^{t}_{j} \leftarrow \mathcal{M}(x^{t}_{j-1}) + q_{j}\)

    2. \(y_{j} \leftarrow \mathcal{H}(x^{t}_{j}) + r_{j}\)

  5. end for

See that we generate both the background \(x_{b}\) and the “truth” initial condition \(x^{t}_{0}\) by perturbing \(\hat{x}\), which means that both states will contain noise (rather than one or the other being completely deterministic).

The code below uses this process to generate synthetic observation data. Because our timeseries is distributed over the Ensemble, each observation \(y_{j}\) needs to live on the right ensemble member. To do this we use the Ensemble.sequential context manager, which runs the code within the context on each ensemble member in turn. Any kwarg passed to Ensemble.sequential is made available in the ctx object, and is sent forward to the next ensemble member once the local code block is complete. After running the local part of the timeseries on each ensemble member, this allows us to pass forward the state xt and the time t to the next member.

xb = Function(V).assign(reference_ic + B.sample())
xt = Function(V).assign(reference_ic + B.sample())

# send ground-truth initial condition to all ranks.
truth_ic = ensemble.bcast(xt, root=0).copy(deepcopy=True)

if ensemble_rank == 0:
    y = [Function(U).assign(H(xt) + R.sample())]
else:
    y = []

t.zero()
with ensemble.sequential(state=xt, t=t) as ctx:
    t.assign(ctx.t)
    xt.assign(ctx.state)

    for _ in range(nlocal_stages):
        xt.assign(M(xt) + Q.sample())
        y.append(Function(U).assign(H(xt) + R.sample()))

    ctx.state.assign(xt)

# send ground-truth end condition to all ranks.
truth_end = ensemble.bcast(xt.copy(deepcopy=True), root=ensemble_size-1)

Now that we have the “ground-truth” observations, we can create a function to generate callbacks for the error vs the observation at each timestep i.

def observation_error(i):
    return lambda x: Function(U).assign(H(x) - y[i])

Building the ReducedFunctional

Now we have all the pieces ready to start assembling the 4DVar system. continue_annotation tells Pyadjoint to start recording any code that is executed from now on. The WC4DVarReducedFunctional class will manage recording, constructing, and solving the 4DVar system. To initialise it, it needs an EnsembleFunction as a Control, and the components to evaluate the functional at the initial condition, i.e. the background state and covariance for \(\|x_{0}-x_{b}\|_{B^{-1}}^{2}\), and the observation error and covariance for \(\|\mathcal{H}_{0}(x_{0})-y_{0}\|_{R_{0}^{-1}}^{2}\).

from fdvar import WC4DVarReducedFunctional

continue_annotation()

control = EnsembleFunction(W)

Jhat = WC4DVarReducedFunctional(
    Control(control),
    background=xb,
    background_covariance=B,
    observation_covariance=R,
    observation_error=observation_error(0),
    gauss_newton=True)

All Firedrake operations are “taped” by pyadjoint, so all we need to do to initialise the stages is to run \(\mathcal{M}\) and \(\mathcal{H}\) within the recording_stages context manager below. For each stage, we integrate forward from stage.control (i.e. \(x_{j-1}\)), and then set the observation by providing the state (i.e. \(x_{j}=\mathcal{M}_{j}(x_{j-1})\)) error operator, and the covariances.

t.zero()
with Jhat.recording_stages(t=t) as stages:
    for stage, ctx in stages:
        t.assign(ctx.t)
        xn1 = M(stage.control)

        obs_error = observation_error(stage.observation_index)

        stage.set_observation(
            state=xn1,
            observation_error=obs_error,
            observation_covariance=R,
            forward_model_covariance=Q)

pause_annotation()

To ensure that the initial guess for \(x_{j}\) is a continuous trajectory over the entire Ensemble, the recording_stages context manager wraps ensemble.sequential. The control for the first stage is set to \(x_{b}\), and the control for subsequent stages is set to the value of the state passed to set_observation by the previous stage.

Jhat now has a record of all operations in the model, and can use this to a) re-evaluate \(\hat{J}(x)\) with different control values, b) calculate the derivative with respect to the controls, and c) apply the action of the Hessian.

We save a copy of the initial control to compare the optimised state to.

prior = control.copy()

Solving the 4DVar system

TAO is PETSc’s optimisation library and provides a range of optimisation methods. Pyadjoint provides a TAOSolver which creates all the necessary callbacks for TAO from a ReducedFunctional.

Configuring the WC4DVar solver

Just like the timestepper, the TAO solver is configured using a set of options strings. We will configure the solver to use the saddle point formulation \(\mathbf{A}\) preconditioned by the upper triangular Schur factorisation \(\mathbf{P}_{U}\) with the approximate Schur complement \(\mathbf{\tilde{S}}=\mathbf{\tilde{L}}^{T}\mathbf{D}^{-1}\mathbf{\tilde{L}}\) where \(\mathbf{\tilde{L}}\) is constructed using \(\tilde{M}=I\).

To make this a bit simpler, we will define a couple of options sets for components of the full solver. The covariance_parameters below can be used to solve the matrices \(\mathbf{D}\) and \(\mathbf{R}\) in \(\mathbf{P}_{U}\), and \(\mathbf{D}^{-1}\) in \(\mathbf{\tilde{S}}^{-1}\). These matrices are block diagonal with one block per component of an EnsembleFunctionSpace, so we can use the EnsembleBJacobiPC. Just like PETSc’s PCBJacobi this creates a sub KSP for each block (i.e. for each covariance operator \(B\), \(Q\), or \(R\)). On each block we use the CovariancePC which will automatically apply the inverse or action depending on if it acts on e.g. \(\mathbf{D}\) or \(\mathbf{D}^{-1}\).

covariance_parameters = {
    'pc_type': 'python',
    'pc_python_type': 'firedrake.EnsembleBJacobiPC',
    'sub_pc_type': 'python',
    'sub_pc_python_type': 'firedrake.CovariancePC',
}

The schur_parameters specify the approximate Schur complement \(\mathbf{\tilde{S}}\), which is implemented with the WC4DVarSchurPC. This preconditioner requires options to solve \(\mathbf{D}^{-1}\), given in wcschur_d, and to solve \(\mathbf{\tilde{L}}\), given in wcschur_l. For \(\mathbf{D}^{-1}\) we can use the covariance_parameters. For \(\mathbf{\tilde{L}}\) we use the AllAtOnceRFGaussSeidelPC, which uses forward substitution so solve \(\mathbf{\tilde{L}}\). This preconditioner has one option, pc_aaogs_type, which can be a) 'model' i.e. \(\tilde{M}=M\) and \(\mathbf{\tilde{L}}=\mathbf{L}\) or b) 'identity' i.e. \(\tilde{M}=I\).

schur_parameters = {
    'ksp_type': 'preonly',
    'pc_type': 'python',
    'pc_python_type': 'fdvar.WC4DVarSchurPC',
    'wcschur_l': {
        'pc_type': 'python',
        'pc_python_type': 'fdvar.AllAtOnceRFGaussSeidelPC',
        'pc_aaogs_type': 'identity',
    },
    'wcschur_d': covariance_parameters,
}

Now we set up the full solver options in the 'tao_parameters' below.

  • At the top level of the dictionary, 'tao_gttol': 1e-2 sets the convergence tolerance for the reduction drop in the gradient norm. Next, we specify a Newton method using 'tao_type': 'nls', which needs options for the linear solver in the 'tao_nls' dictionary.

  • At each Newton iteration, we use 'ksp_type': 'preonly' to replace the linear solve with the WC4DVarSaddlePC preconditioner, which solves the saddle point system \(\mathbf{A}\) and returns the \(\mathbf{\delta x}\) part of the solution.

  • To use a Schur complement factorisation so we have to tell PETSc’s PCFieldsplit how to reinterpret the \(3\times3\) matrix \(\mathbf{A}\) as a \(2\times2\) matrix \(\mathbf{\hat{A}}\). This is done using the 'pc_fieldsplit_{0,1}_fields' options, where the Schur complement will be formed on the '1_fields' after eliminating the '0_fields'.

\[\begin{split}\mathbf{\hat{A}} = \begin{pmatrix} \mathbf{\hat{A}}_{00} & \mathbf{\hat{A}}_{01} \\ \mathbf{\hat{A}}_{10} & \mathbf{0} \\ \end{pmatrix}, \quad \mathbf{\hat{A}}_{00} = \begin{pmatrix} \mathbf{D} & \mathbf{0} \\ \mathbf{0} & \mathbf{R} \\ \end{pmatrix}, \quad \mathbf{\hat{A}}_{01} = \begin{pmatrix} \mathbf{L} \\ \mathbf{H} \\ \end{pmatrix}, \quad \mathbf{\hat{A}}_{10} = \begin{pmatrix} \mathbf{L}^{T} & \mathbf{H}^{T} \\ \end{pmatrix}.\end{split}\]
  • The upper triangular \(\mathbf{P}_{U}\) preconditioner is specified using the 'pc_fieldsplit_schur_fact_type' option.

\[\begin{split}\mathbf{P}_{U} = \begin{pmatrix} \mathbf{\hat{A}}_{00} & \mathbf{\hat{A}}_{01} \\ \mathbf{0} & \mathbf{\tilde{S}} \\ \end{pmatrix}\end{split}\]
  • The solver options for \(\mathbf{\hat{A}}_{00}\) are in the 'fieldsplit_0' dictionary. This block-diagonal matrix is solved by splitting it apart using an 'additive' fieldsplit, then solving \(\mathbf{D}\) and \(\mathbf{R}\) separately using the 'covariance_parameters'.

  • The Schur complement is specified in the 'fieldsplit_1' dictionary using the 'schur_parameters' above.

tao_parameters = {
    'tao_monitor': None,  #  .  .  .  .  # Print out diagnostics.
    'tao_converged_reason': None,
    'tao_gttol': 5e-2,  # .  .  .  .  .  # Gradient reduction.
    'tao_max_it': 30,
    'tao_type': 'nls',  # .  .  .  .  .  # Newton iterations
    'tao_ls_type': 'unit',  #.  .  .  .  # without linesearch.
    'tao_nls': {
        'ksp_type': 'preonly',  #  .  .  # Replace the hessian solve with the PC.
        'ksp_monitor_short': None,
        'pc_type': 'python',
        'pc_python_type': 'fdvar.WC4DVarSaddlePC',
        'wcsaddle': {
            'ksp_monitor_short': None,
            'ksp_converged_rate': None,  #  .  .  .  # Print contraction rate.
            'ksp_converged_maxits': None,
            'ksp_rtol': 1e-2,
            'ksp_min_it': 10,
            'ksp_max_it': 100,
            'ksp_gmres_restart': 100,
            'ksp_type': 'gmres',
            'pc_type': 'fieldsplit',
            'pc_fieldsplit_type': 'schur',  #  .  .  .  .  # Use a schur (LDU) factorisation,
            'pc_fieldsplit_0_fields': '0,1',  #.  .  .  .  # eliminating the first two fields,
            'pc_fieldsplit_1_fields': '2',  #  .  .  .  .  # forming the schur complement on the third,
            'pc_fieldsplit_schur_fact_type': 'upper',  #.  # and using just the DU part of the LDU.
            'fieldsplit_0': {
                'ksp_type': 'preonly',
                'pc_type': 'fieldsplit',  # .  .  .  .  .  # Solve the covariance
                'pc_fieldsplit_type': 'additive',  # .  .  # matrices separately.
                'fieldsplit_ksp_type': 'preonly',
                'fieldsplit': covariance_parameters,
            },
            'fieldsplit_1': schur_parameters,
        },
    }
}

if os.getenv("FIREDRAKE_CI") == "1":
    tao_parameters["tao_gttol"] = 0.9
    tao_parameters["tao_nls"]["wcsaddle"]["ksp_rtol"] = 1e-1

Computing the optimised state

Now we have a reduced functional and a set of TAO parameters we can solve the optimisation problem using Pyadjoint’s TAOSolver.

from pyadjoint import TAOSolver

tao = TAOSolver(MinimizationProblem(Jhat),
                parameters=tao_parameters,
                options_prefix="")
xopts = tao.solve()

Lastly, we compare the error between the optimised solution and ground truth data with the error between the initial guess and the ground truth data, at both the initial and final times.

prior_ic = ensemble.bcast(prior.subfunctions[0], root=0)
xopts_ic = ensemble.bcast(xopts.subfunctions[0], root=0)

prior_end = ensemble.bcast(prior.subfunctions[-1], root=ensemble_size-1)
xopts_end = ensemble.bcast(xopts.subfunctions[-1], root=ensemble_size-1)

PETSc.Sys.Print()

PETSc.Sys.Print("Errors at initial timestep:")
prior_error = errornorm(truth_ic, prior_ic)/norm(truth_ic)
xopts_error = errornorm(truth_ic, xopts_ic)/norm(truth_ic)
PETSc.Sys.Print(f"{prior_error = :.3e}")
PETSc.Sys.Print(f"{xopts_error = :.3e}")
PETSc.Sys.Print(f"Error reduction factor = {xopts_error/prior_error:.3e}")
PETSc.Sys.Print()

PETSc.Sys.Print("Errors at final timestep:")
prior_error = errornorm(truth_end, prior_end)/norm(truth_end)
xopts_error = errornorm(truth_end, xopts_end)/norm(truth_end)
PETSc.Sys.Print(f"{prior_error = :.3e}")
PETSc.Sys.Print(f"{xopts_error = :.3e}")
PETSc.Sys.Print(f"Error reduction factor = {xopts_error/prior_error:.3e}")
PETSc.Sys.Print()

Representative output of these print statements is shown below (exact values may be change slightly due to RNG). At the initial and final conditions the optimised solution matches the ground truth around 13 times and 20 times more accurately than the prior solution respectively.

Errors at initial timestep:
prior_error = 6.723e-01
xopts_error = 4.925e-02
Error reduction factor = 7.326e-02

Errors at final timestep:
prior_error = 8.843e-01
xopts_error = 4.333e-02
Error reduction factor = 4.900e-02

Running the demo yourself.

A runnable python script of this demo can be found here.

The python script can be run in parallel as long as the number of observations stages \(N\) is divisible by the number of MPI ranks nranks:

mpiexec -n <nranks> python wc4dvar_advection.py

Additional options can be passed to the TAOSolver from the command line, for example changing the tolerance of the saddle point solve:

python wc4dvar_advection.py -tao_nls_wcsaddle_ksp_rtol 1e-3