Objectives

  • Understand when to use Approximate Bayesian Computation (ABC) and when not.
  • Understand the fundamental difference between ABC and likelihood-based inference.
  • Understand why and how to use summary statistics in ABC.

In this exercise, we run the Simulated Annealing ABC (SABC), for a Bayesian inference of the parameter of the model “survival”, with a flat prior. Note, in practice, we would not want to use an ABC algorithm for this problem. Why?

Define prior and a function to sample from the model

For ABC, we need to define the prior density and a function that generates random outputs from our model \(p(y|\theta)\). That means we are sampling from the probabilistic model. Note however, that we do not need a function which evaluates the associated density, i.e., the likelihood.

Python

We use the Python package SimulatedAnnealingABC, which implements the SABC algorithm.

from pathlib import Path
import os
import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.special import gammaln

from simulated_annealing_abc import (
    DifferentialEvolution,
    SABCConfig,
    make_f_dist,
    sabc,
)

As a next step, let’s read the data that we observed in the form of death counts per observation interval, and which we try to explain by the model survival. The experiment was started with N=100 individuals.

obs = pd.read_csv( "../../data/model_survival.csv", sep=r"\s+")
times = obs["time"].to_numpy(dtype=float)
deaths_obs = obs["deaths"].to_numpy(dtype=int)

For ABC, we always need a function that evaluates the prior density, a function that samples from the prior, and a function that generates random model outputs (our model). The functions that we need are defined in the following:


SEED = 20260525
N_INDIVIDUALS = 100

class Prior:
    def __init__(self, lower: float = 0.01, upper: float = 0.6):
        self.lower = lower
        self.upper = upper

    def rvs(self, rng: np.random.Generator, size: int = 1) -> np.ndarray:
        return rng.uniform(self.lower, self.upper, size=(size, 1))

    def logpdf(self, theta: np.ndarray) -> np.ndarray:
        theta = np.atleast_2d(theta)
        lam = theta[:, 0]
        in_bounds = (self.lower <= lam) & (lam <= self.upper)
        lp = np.full(theta.shape[0], -np.inf)
        lp[in_bounds] = -np.log(self.upper - self.lower)
        return lp


prior = Prior()


def death_probabilities(lam: np.ndarray) -> np.ndarray:
    lam = np.asarray(lam, dtype=float).reshape(-1, 1)
    survival = np.exp(-lam * times.reshape(1, -1))
    survival = np.column_stack([np.ones(lam.shape[0]), survival, np.zeros(lam.shape[0])])
    return -np.diff(survival, axis=1)


def simulator(theta: np.ndarray, y: np.ndarray, rng: np.random.Generator) -> None:
    probs = death_probabilities(theta[:, 0])
    for i, p in enumerate(probs):
        y[i, :] = rng.multinomial(N_INDIVIDUALS, p)[:-1]

Julia

We use the Julia package SimulatedAnnealingABC.jl, which implements the SABC algorithm. We also need packages for distributions and data handling.

using Random
using SimulatedAnnealingABC
using Distributions
using CSV
using DataFrames
using AdaptiveMCMC
using MCMCChains

ENV["GKSwstype"] = "100"
## "100"
using Plots
using StatsPlots

Next, we load the observed death counts. The experiment started with N=100 individuals.

obs = CSV.read("../../data/model_survival.csv", DataFrame)
## 30×2 DataFrame
##  Row │ time   deaths
##      │ Int64  Int64
## ─────┼───────────────
##    1 │     1      14
##    2 │     2      10
##    3 │     3      11
##    4 │     4      12
##    5 │     5       3
##    6 │     6       6
##    7 │     7       4
##    8 │     8       7
##   ⋮  │   ⋮      ⋮
##   24 │    24       0
##   25 │    25       0
##   26 │    26       0
##   27 │    27       0
##   28 │    28       1
##   29 │    29       0
##   30 │    30       0
##       15 rows omitted

For ABC, we need a prior distribution, a function to sample from it, and a model function that generates random outputs.

prior = Uniform(0.01, 0.6)
## Uniform{Float64}(a=0.01, b=0.6)
d_prior(x) = pdf(prior, x)
## d_prior (generic function with 1 method)
r_prior() = rand(prior)
## r_prior (generic function with 1 method)

function survival_sampler(λ; N=100, t=obs.time)
    S = exp.(-λ .* t)
    S = [1; S; 0]
    p = -diff(S)
    rand(Multinomial(N, p))[1:(end-1)]
end
## survival_sampler (generic function with 1 method)

1. ABC without summary statistics

For ABC we need a distance measure to compare the model output to the observations, in order to judge the plausibility of the parameters that generated this model output. Here we do not apply summary statistics, we compare the model output and the observations directly.

Python

Use the sum of the absolute distances as distance function.

def raw_stats(y: np.ndarray, ss_out: np.ndarray) -> None:
    ss_out[:, :] = y

ss_obs_raw = deaths_obs.astype(float)

f_dist_raw = make_f_dist(
    n_samples=len(deaths_obs),
    ss_obs=ss_obs_raw,
    simulator=simulator,
    stats_fn=raw_stats,
    seed=123,
    distance="abs",
)

Now we can run SABC using f_dist_raw as distance function:

raw_config = SABCConfig(
    f_dist=f_dist_raw,
    prior=prior,
    n_particles = 2000,
    v=1.2,
    algorithm="single_eps",
    proposal=DifferentialEvolution(n_para=1, rng=np.random.default_rng(22)),
    rng=np.random.default_rng(18),
    show_checkpoint=100,
)
raw_result = sabc(raw_config, n_simulation=1_000_000)
## Initialization done, starting population updates
## Update 100/499  avg_u=0.1278  eps=[0.0508]  ETA (DD:HH:MM)=00:00:00
## Update 200/499  avg_u=0.1191  eps=[0.0465]  ETA (DD:HH:MM)=00:00:00
## Update 300/499  avg_u=0.1159  eps=[0.0449]  ETA (DD:HH:MM)=00:00:00
## Update 400/499  avg_u=0.1139  eps=[0.0439]  ETA (DD:HH:MM)=00:00:00
## Update 499/499  avg_u=0.1123  eps=[0.0431]  ETA (DD:HH:MM)=00:00:00

Julia

We define a distance function that calculates the sum of absolute differences between the model output and the observations.

f_dist(par, obs) = Float64.(abs.(survival_sampler(par) .- obs.deaths))
## f_dist (generic function with 1 method)

Now, we can run the SABC algorithm using this distance function.

sabc_res = sabc(f_dist, prior, obs;
                n_simulation = 1_000_000,
                n_particles = 2000)
## Approximate posterior sample with 2000 particles:
##   - algorithm: :single_eps
##   - simulations used: 1000000
##   - number of population updates: 499
##   - average transformed distance: 0.06317
##   - ϵ: [0.0229]
##   - number of population resamplings: 3
##   - acceptance rate: 0.01244
## The sample can be accessed with the field `population`.
## The history of ϵ can be accessed with the field `state.ϵ_history`.
## The history of ρ can be accessed with the field `state.ρ_history`.
## The history of u can be accessed with the field `state.u_history`.

2. Use sufficient summary statistics

This simple model has a set of two sufficient statistics (even though there is just one parameter!). You can try to derive them!

Python

Definition of the summery statistics and the distance function:

def sufficient_stats(y: np.ndarray, ss_out: np.ndarray) -> None:
    weights = np.arange(1, y.shape[1] + 1)
    ss_out[:, 0] = np.sum(y, axis=1)
    ss_out[:, 1] = y @ weights

ss_obs_suff = np.array([np.sum(deaths_obs), deaths_obs @ np.arange(1, len(deaths_obs) + 1)])

f_dist_suff = make_f_dist(
    n_samples=len(deaths_obs),
    ss_obs=ss_obs_suff,
    simulator=simulator,
    stats_fn=sufficient_stats,
    seed=456,
    distance="abs",
)
suff_config = SABCConfig(
    f_dist=f_dist_suff,
    prior=prior,
    n_particles = 2000,
    v=1.2,
    algorithm="single_eps",
    proposal=DifferentialEvolution(n_para=1, rng=np.random.default_rng(23)),
    rng=np.random.default_rng(19),
    show_checkpoint=100,
)

suff_result = sabc(suff_config, n_simulation=1_000_000)
## Initialization done, starting population updates
## Update 100/499  avg_u=0.01652  eps=[0.0036]  ETA (DD:HH:MM)=00:00:00
## Update 200/499  avg_u=0.01131  eps=[0.0022]  ETA (DD:HH:MM)=00:00:00
## Update 300/499  avg_u=0.008787  eps=[0.0016]  ETA (DD:HH:MM)=00:00:00
## Update 400/499  avg_u=0.006785  eps=[0.0011]  ETA (DD:HH:MM)=00:00:00
## Update 499/499  avg_u=0.005877  eps=[0.0009]  ETA (DD:HH:MM)=00:00:00

Julia

Definition of the summery statistics and the distance function:

ss(y) = (sum(y), sum(y .* (1:length(y))))
## ss (generic function with 1 method)
f_dist_suff(par, obs) = Float64.(abs.(ss(survival_sampler(par)) .- ss(obs.deaths)))
## f_dist_suff (generic function with 1 method)

sabc_suff_res = sabc(f_dist_suff, prior, obs;
                     n_simulation = 1_000_000,
                     n_particles = 2000)
## Approximate posterior sample with 2000 particles:
##   - algorithm: :single_eps
##   - simulations used: 1000000
##   - number of population updates: 499
##   - average transformed distance: 0.001233
##   - ϵ: [0.0001312]
##   - number of population resamplings: 4
##   - acceptance rate: 0.01949
## The sample can be accessed with the field `population`.
## The history of ϵ can be accessed with the field `state.ϵ_history`.
## The history of ρ can be accessed with the field `state.ρ_history`.
## The history of u can be accessed with the field `state.u_history`.

4. Compare the previously generated posterior samples to a sample from the true posterior

For this model we do not need to use ABC (and should not!) because we can evaluate the likelihood function, hence we can sample from the posterior with standard MCMC methods. This allows us to compare the two samples generated in the previous tasks to a sample of the “true posterior”.

Let us now compare the posterior samples generated with ABC with and without summary statistics, and the sample of the true posterior. In the first figure below, you can see that not using summary statistics can lead to slow convergence: many proposals will be rejected since in high dimensions it is very unlikely to match the data with sufficient accuracy. However, the sample will eventually converge to the true distribution. The figure also shows that using two sufficient statistics leads to a fast convergence to an unbiased result.

Python

To generate a sample of the true posterior, we first define the posterior function (see Monday excerises):

def multinomial_logpmf(y: np.ndarray, p: np.ndarray) -> float:
    return float(
        gammaln(N_INDIVIDUALS + 1)
        - np.sum(gammaln(y + 1))
        + np.sum(y * np.log(p))
    )


def logposterior(lam: float) -> float:
    if not (prior.lower <= lam <= prior.upper):
        return -np.inf
    probs = death_probabilities(np.array([lam]))[0]
    y = np.append(deaths_obs, N_INDIVIDUALS - np.sum(deaths_obs))
    return multinomial_logpmf(y, probs) - np.log(prior.upper - prior.lower)

In the following, we use an adaptive Metropolis algorithm to generate a sample from the true posterior:

def run_metropolis(start: float = 0.3,
                   proposal_sd: float = 0.01,
                   n_iter: int = 20_000,
                   burn_in: int = 2_000,
                   seed: int = SEED):
    rng = np.random.default_rng(seed)
    chain = np.empty(n_iter)
    current = start
    current_lp = logposterior(current)
    for i in range(n_iter):
        proposed = current + rng.normal(0.0, proposal_sd)
        proposed_lp = logposterior(proposed)
        if np.log(rng.uniform()) < proposed_lp - current_lp:
            current = proposed
            current_lp = proposed_lp
        chain[i] = current
    return chain[burn_in:]

Plotting the samples:

reference = run_metropolis()
raw_pop = raw_result.population[:, 0]
suff_pop = suff_result.population[:, 0]

plt.figure(figsize=(9, 6))
for sample, label, color, linewidth in [
    (reference, "Reference MCMC", "black", 2.5),
    (raw_pop, "ABC raw data", "#1f77b4", 1.8),
    (suff_pop, "ABC sufficient summary", "#2ca02c", 1.8),
]:
    density = np.histogram(sample, bins=15, density=True)
    centers = 0.5 * (density[1][1:] + density[1][:-1])
    plt.plot(centers, density[0], label=label, color=color, linewidth=linewidth)

plt.xlabel("lambda")
plt.ylabel("Density")
plt.title("Python: ABC posterior comparison")
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

Julia

To get a sample from the “true” posterior, we first define the log-posterior function.

using AdaptiveMCMC
using MCMCChains
using Plots
using StatsPlots

function loglikelihood_survival(N::Int, t, y, λ)
    S = exp.(-λ .* t)
    S = [1; S; 0]
    p = -diff(S)
    y = [y; N - sum(y)]
    logpdf(Multinomial(N, p), y)
end
## loglikelihood_survival (generic function with 1 method)

function logposterior(par)
    ll = logpdf(prior, par[1])
    if isfinite(ll)
        ll += loglikelihood_survival(100, obs.time, obs.deaths, par[1])
    end
    ll
end
## logposterior (generic function with 1 method)

Then we use the AdaptiveMCMC package to run a MCMC sampler to get a reference solution.

mcmc_res = adaptive_rwm([0.3], logposterior,
                        10_000;
                        algorithm = :ram,
                        b = 1000)
## (X = [0.12157444173442723 0.12157444173442723 … 0.125474109831732 0.125474109831732], allX = [[0.12157444173442723 0.12157444173442723 … 0.125474109831732 0.125474109831732]], D = [[-46.940779504988555, -46.940779504988555, -46.940779504988555, -46.940779504988555, -46.940779504988555, -46.940779504988555, -46.940779504988555, -46.940779504988555, -48.009409481551074, -48.009409481551074  …  -46.69838069715209, -46.69838069715209, -46.69838069715209, -46.75947351348771, -46.75947351348771, -46.75947351348771, -46.75947351348771, -46.76440659865653, -46.76440659865653, -46.76440659865653]], R = RWMState{1, Float64, Vector{Float64}, typeof(randn!), TaskLocalRNG}[RWMState{1, Float64, Vector{Float64}, typeof(randn!), TaskLocalRNG}(TaskLocalRNG(), Random.randn!, [0.125474109831732], [-1.6159582330204738], [0.011211088991019233])], S = RobustAdaptiveMetropolis{1, Float64, Vector{Float64}, RAMStepSize{Float64}, Cholesky{Float64, Matrix{Float64}}}[RobustAdaptiveMetropolis{1, Float64, Vector{Float64}, RAMStepSize{Float64}, Cholesky{Float64, Matrix{Float64}}}(Cholesky{Float64, Matrix{Float64}}([0.07069018885023082;;], 'L', 0), 0.234, [-0.0016370772290752679], [-1.0], RAMStepSize{Float64}(PolynomialStepSize{Float64}(0.66, 1.0)))], Rhos = AdaptiveScalingMetropolis{PolynomialStepSize{Float64}, Float64}[], accRWM = [0.2044], accSW = Float64[], args = ([0.3], Main.logposterior, 10000), params = (algorithm = :ram, thin = 1, b = 1000, fulladapt = true, q = Random.randn!, L = 1, log_pr = AdaptiveMCMC.var"#adaptive_rwm##2#adaptive_rwm##3"{Float64}(), all_levels = false, acc_sw = 0.234, swaps = :single, rng = Xoshiro(0x5b4565c1ddf9e13a, 0xd7fb6000eb0f8e61, 0x8f376e0a2e5cdf3b, 0x610c1549269569e9, 0xdd10fe0aa640345f)))
reference = vec(mcmc_res.X')
## 9001-element reshape(adjoint(::Matrix{Float64}), 9001) with eltype Float64:
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.12157444173442723
##  0.15430554921558173
##  0.15430554921558173
##  ⋮
##  0.13427208940541574
##  0.13427208940541574
##  0.1256252376091502
##  0.1256252376091502
##  0.1256252376091502
##  0.1256252376091502
##  0.125474109831732
##  0.125474109831732
##  0.125474109831732
raw_pop = vec(sabc_res.population)
## 2000-element Vector{Float64}:
##  0.14042618367789836
##  0.2271746740067415
##  0.1322568986302252
##  0.18573279664038145
##  0.1878663257239192
##  0.18844049438411856
##  0.17196054651154805
##  0.15082342710207602
##  0.1657933990596474
##  0.2332937449118348
##  ⋮
##  0.19884919334394638
##  0.1998581634822923
##  0.10912834589361768
##  0.19836966401030495
##  0.19381472601823094
##  0.1721014396998603
##  0.18002252484362
##  0.14788139078637494
##  0.15262648983670468
suff_pop = vec(sabc_suff_res.population)
## 2000-element Vector{Float64}:
##  0.11682960647944805
##  0.1258278789796156
##  0.14211173862804896
##  0.14803454274100988
##  0.1453790943859507
##  0.12309278116974144
##  0.1503572313957907
##  0.140238357589841
##  0.14341873907566904
##  0.13244368383052393
##  ⋮
##  0.1366930804221639
##  0.15595126369327575
##  0.13560711732500946
##  0.13521477474162574
##  0.1533545045321874
##  0.12652152673440475
##  0.14609769373656142
##  0.13655408388399337
##  0.12747510920964972
plt = StatsPlots.density(reference, label = "Reference MCMC", xlab = "lambda",
                         title = "Julia: ABC posterior comparison",
                         linewidth = 3, color = :black);
StatsPlots.density!(plt, raw_pop, label = "ABC raw data", linewidth = 2);
StatsPlots.density!(plt, suff_pop, label = "ABC sufficient summary", linewidth = 2);
plt