EKI and MF-EKI Demo#

This demo showcases single-fidelity EKI and multi-fidelity EKI on a convection-diffusion-reaction (CDR) model with two inferred parameters. It mirrors the UQ-style workflow used elsewhere in the demos and keeps the physics lightweight enough to run quickly.

The forward model solves a steady 2D CDR equation on a structured grid, and the QoI is the right-boundary flux functional used in the existing UQ examples. EKI estimates the diffusion coefficient nu and reaction rate sigma from a synthetic observation. MF‑EKI augments each iteration with ROM evaluations built from FOM snapshots, illustrating how multi-fidelity updates reduce error at lower cost.

Run the demo#

python docs/source/demos/notebooks/eki_mf_eki_demo.py

Results#

EKI and MF-EKI error convergence

Mean observation error across iterations for EKI and MF-EKI.#

Implementation#

  1import os
  2import shutil
  3import sys
  4import numpy as np
  5import matplotlib.pyplot as plt
  6
  7from romtools.workflows.parameters import UniformParameter
  8from romtools.workflows.parameter_spaces import HeterogeneousParameterSpace
  9from romtools.workflows.models import QoiModel
 10from romtools.workflows.model_builders import QoiModelBuilder
 11from romtools.workflows.inverse.run_eki import run_eki
 12from romtools.workflows.inverse.run_mf_eki import run_mf_eki
 13
 14CDR_PATH = os.path.abspath(
 15    os.path.join(
 16        os.path.dirname(__file__),
 17        "convection_diffusion_reaction_system_code",
 18    )
 19)
 20if CDR_PATH not in sys.path:
 21    sys.path.insert(0, CDR_PATH)
 22
 23import cdr  # noqa: E402
 24import cdr_rom  # noqa: E402
 25
 26
 27class CdrFomQoiModel:
 28    def __init__(self, system: cdr.AdvectionDiffusionSystem, b_vec: np.ndarray):
 29        self._system = system
 30        self._b_vec = b_vec
 31
 32    def populate_run_directory(self, run_directory: str, parameter_sample: dict) -> None:
 33        with open(os.path.join(run_directory, "params.txt"), "w", encoding="utf-8") as handle:
 34            for key, value in parameter_sample.items():
 35                handle.write(f"{key}: {value}\n")
 36
 37    def run_model(self, run_directory: str, parameter_sample: dict) -> int:
 38        nu = float(parameter_sample["nu"])
 39        sigma = float(parameter_sample["sigma"])
 40        u = cdr.solveFom(self._system, self._b_vec, nu, sigma)
 41        np.savez(os.path.join(run_directory, "solution.npz"), u=u)
 42        return 0
 43
 44    def compute_qoi(self, run_directory: str, parameter_sample: dict) -> np.ndarray:
 45        data = np.load(os.path.join(run_directory, "solution.npz"))
 46        u = data["u"]
 47        qoi = np.dot(self._system.C, u)
 48        return np.array([qoi])
 49
 50
 51class CdrRomQoiModel:
 52    def __init__(self, system: cdr.AdvectionDiffusionSystem, basis: np.ndarray, b_vec: np.ndarray):
 53        self._system = system
 54        self._basis = basis
 55        self._b_vec = b_vec
 56        self._rom = cdr_rom.primalGalerkinROM(system, basis)
 57
 58    def populate_run_directory(self, run_directory: str, parameter_sample: dict) -> None:
 59        with open(os.path.join(run_directory, "params.txt"), "w", encoding="utf-8") as handle:
 60            for key, value in parameter_sample.items():
 61                handle.write(f"{key}: {value}\n")
 62
 63    def run_model(self, run_directory: str, parameter_sample: dict) -> int:
 64        nu = float(parameter_sample["nu"])
 65        sigma = float(parameter_sample["sigma"])
 66        u_hat = cdr_rom.solveRom(self._rom, self._b_vec, nu, sigma)
 67        u = self._basis @ u_hat
 68        np.savez(os.path.join(run_directory, "solution.npz"), u=u, u_hat=u_hat)
 69        return 0
 70
 71    def compute_qoi(self, run_directory: str, parameter_sample: dict) -> np.ndarray:
 72        data = np.load(os.path.join(run_directory, "solution.npz"))
 73        u = data["u"]
 74        qoi = np.dot(self._system.C, u)
 75        return np.array([qoi])
 76
 77
 78class CdrRomBuilder:
 79    def __init__(self, system: cdr.AdvectionDiffusionSystem, b_vec: np.ndarray, rom_dim: int = 12):
 80        self._system = system
 81        self._b_vec = b_vec
 82        self._rom_dim = rom_dim
 83
 84    def build_from_training_dirs(self, offline_data_dir: str, training_data_dirs):
 85        snapshots = []
 86        for run_dir in training_data_dirs:
 87            solution_path = os.path.join(run_dir, "solution.npz")
 88            if os.path.exists(solution_path):
 89                data = np.load(solution_path)
 90                snapshots.append(data["u"])
 91        if not snapshots:
 92            raise RuntimeError("No training snapshots found for ROM construction.")
 93        snapshot_matrix = np.column_stack(snapshots)
 94        u, _, _ = np.linalg.svd(snapshot_matrix, full_matrices=False)
 95        basis = u[:, : min(self._rom_dim, u.shape[1])]
 96        return CdrRomQoiModel(self._system, basis, self._b_vec)
 97
 98
 99def _collect_error_history(work_dir: str, mf: bool) -> list:
100    history = []
101    iteration = 0
102    while True:
103        restart_path = os.path.join(work_dir, f"iteration_{iteration}", "restart.npz")
104        if not os.path.exists(restart_path):
105            break
106        data = np.load(restart_path, allow_pickle=True)
107        if mf:
108            sample_one_fom_results = data["sample_one_fom_results"].item()
109            errors = sample_one_fom_results["errors"]
110        else:
111            errors = data["errors"]
112        history.append(float(np.mean(np.linalg.norm(errors, axis=0))))
113        iteration += 1
114    return history
115
116
117def main():
118    np.random.seed(1)
119
120    system = cdr.AdvectionDiffusionSystem(Nx=25, Ny=25)
121    b_vec = np.array([1.0, 1.0])
122
123    nu_true = 0.04
124    sigma_true = 0.3
125    u_true = cdr.solveFom(system, b_vec, nu_true, sigma_true)
126    observations = np.array([np.dot(system.C, u_true)])
127    observations_covariance = np.eye(1) * 1e-5
128
129    parameter_space = HeterogeneousParameterSpace(
130        [
131            UniformParameter("nu", 0.01, 0.08),
132            UniformParameter("sigma", 0.1, 0.6),
133        ]
134    )
135
136    base_dir = os.path.abspath("docs/source/demos/notebooks/eki_mf_eki_work")
137    eki_dir = os.path.join(base_dir, "eki")
138    mf_dir = os.path.join(base_dir, "mf_eki")
139    shutil.rmtree(base_dir, ignore_errors=True)
140    os.makedirs(base_dir, exist_ok=True)
141
142    fom_model = CdrFomQoiModel(system, b_vec)
143    rom_builder = CdrRomBuilder(system, b_vec, rom_dim=12)
144
145    run_eki(
146        model=fom_model,
147        parameter_space=parameter_space,
148        observations=observations,
149        observations_covariance=observations_covariance,
150        absolute_eki_directory=eki_dir,
151        ensemble_size=18,
152        max_iterations=8,
153        evaluation_concurrency=1,
154    )
155
156    run_mf_eki(
157        model=fom_model,
158        rom_model_builder=rom_builder,
159        parameter_space=parameter_space,
160        observations=observations,
161        observations_covariance=observations_covariance,
162        absolute_eki_directory=mf_dir,
163        fom_ensemble_size=8,
164        rom_extra_ensemble_size=12,
165        rom_tolerance=0.1,
166        max_iterations=8,
167        fom_evaluation_concurrency=1,
168        rom_evaluation_concurrency=1,
169    )
170
171    eki_history = _collect_error_history(eki_dir, mf=False)
172    mf_history = _collect_error_history(mf_dir, mf=True)
173
174    plt.figure(figsize=(6.5, 4.0))
175    plt.plot(eki_history, marker="o", label="EKI (FOM)")
176    plt.plot(mf_history, marker="s", label="MF-EKI (FOM+ROM)")
177    plt.xlabel("Iteration")
178    plt.ylabel("Mean observation error")
179    plt.title("EKI vs MF-EKI on a convection-diffusion-reaction model")
180    plt.grid(True, alpha=0.3)
181    plt.legend()
182
183    output_path = os.path.abspath("docs/source/demos/notebooks/eki_mf_eki_demo.png")
184    plt.tight_layout()
185    plt.savefig(output_path, dpi=180)
186    print(f"Wrote {output_path}")
187
188
189if __name__ == "__main__":
190    main()