Automatic Differentiation and MCMC

Introduction

Gaussian Markov Random Fields (GMRFs) are powerful tools for Bayesian inference. When the parameters of a GMRF model are unknown, we often want to infer them from observed data. Traditional optimization methods can be limiting, but modern MCMC methods like NUTS (No-U-Turn Sampler) provide a robust approach for full Bayesian inference.

Automatic differentiation (AD) is crucial for efficient MCMC sampling, as it enables gradient-based samplers to explore complex posterior geometries effectively. This tutorial demonstrates how to leverage AD with GMRFs for Bayesian parameter inference using NUTS in Turing.jl.

The key to AD support in GaussianMarkovRandomFields.jl is using the :autodiffable solver backend, which uses LDLFactorizations.jl instead of the default sparse Cholesky factorization. This enables ForwardDiff.jl to track derivatives through the linear algebra operations.

using LDLFactorizations, Distributions
using GaussianMarkovRandomFields
using Random, LinearAlgebra, SparseArrays
using Plots

MCMC Parameter Inference for CAR Models

We'll demonstrate Bayesian inference for the parameters of a conditional autoregressive (CAR) model. Given observations from a CAR process, we'll infer both the spatial correlation parameter (ρ) and variance parameter (σ) using NUTS sampling.

Problem setup: CAR model parameter inference

We'll tackle a simple 1D time series problem: given observations sampled from a conditional autoregressive (CAR) process, infer the CAR parameter (ρ).

using Turing, SparseArrays

Set up a 1D grid (time points)

xs = 0:0.1:2  # 21 time points
N = length(xs)
21

Create adjacency matrix for second-order CAR (neighbors and next-neighbors)

W = spzeros(N, N)
for i = 1:N
    for k in [-2, -1, 1, 2]
        j = i + k
        if 1 <= j <= N
            W[i, j] = 1.0 / abs(k)
        end
    end
end

Generate synthetic observations from true CAR process

Random.seed!(123)
true_ρ = 0.85  # True CAR parameter
true_σ = 0.01  # True field variance
μ = zeros(N)   # Zero mean
21-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

Generate true CAR model and sample from it

true_car = generate_car_model(W, true_ρ; μ = μ, σ = true_σ,
                             solver_blueprint=CholeskySolverBlueprint{:autodiffable}())
GMRF{Float64, LinearMaps.WrappedMap{Float64, SparseArrays.SparseMatrixCSC{Float64, Int64}}, CholeskySolver{:autodiffable, TakahashiStrategy, Float64, LinearMaps.WrappedMap{Float64, SparseArrays.SparseMatrixCSC{Float64, Int64}}, LDLFactorizations.LDLFactorization{Float64, Int64, Int64, Int64}}}(
mean: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
precision: 21×21 LinearMaps.WrappedMap{Float64} of
  21×21 SparseArrays.SparseMatrixCSC{Float64, Int64} with 99 stored entries
solver: CholeskySolver{:autodiffable, TakahashiStrategy, Float64, LinearMaps.WrappedMap{Float64, SparseArrays.SparseMatrixCSC{Float64, Int64}}, LDLFactorizations.LDLFactorization{Float64, Int64, Int64, Int64}}([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 21×21 LinearMaps.WrappedMap{Float64} of
  21×21 SparseArrays.SparseMatrixCSC{Float64, Int64} with 99 stored entries, LDLFactorizations.LDLFactorization{Float64, Int64, Int64, Int64}(true, true, false, 21, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11  …  13, 14, 15, 16, 17, 18, 20, 20, 21, -1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2  …  2, 2, 2, 2, 2, 2, 2, 2, 1, 0], [3, 4, 5, 6, 7, 8, 9, 10, 11, 12  …  14, 15, 16, 17, 18, 20, 21, 21, 21, 21], [21, 20, 19, 18, 17, 16, 15, 14, 13, 12  …  10, 9, 8, 7, 6, 5, 4, 1, 3, 2], [19, 21, 20, 18, 17, 16, 15, 14, 13, 12  …  10, 9, 8, 7, 6, 5, 4, 3, 2, 1], [1, 3, 5, 7, 9, 11, 13, 15, 17, 19  …  25, 27, 29, 31, 33, 35, 37, 39, 40, 40], Int64[], Int64[], [2, 3, 3, 4, 4, 5, 5, 6, 6, 7  …  17, 17, 18, 18, 20, 20, 21, 20, 21, 21], [-0.5666666666666667, -0.2833333333333333, -0.5404624277456647, -0.21056977704376548, -0.47147725121474093, -0.18558717738373193, -0.4373921712696389, -0.17697600536466962, -0.4208036487942443, -0.1726450317492599  …  -0.16924310308829454, -0.40744205480342677, -0.16924258203470016, -0.40744115840424483, -0.1692423531404636, -0.40744076462519874, -0.1692422525895785, -0.2833333333333333, -0.5666666666666667, -0.5286965448138948], [150.0, 201.83333333333334, 229.0028901734104, 240.14554918010617, 246.1698409122172, 248.887820205458, 250.13278555011902, 250.6838947502651, 250.927847503607, 251.0351996989583  …  251.10317683356055, 251.11229565170635, 251.11630175652869, 251.1180616785763, 251.11883480533368, 251.1191744345867, 251.1193236305166, 150.0, 239.07772250412165, 127.81352387802943], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [18, 20, 3, 4, 5, 6, 7, 8, 9, 10  …  12, -1, -1, -1, -1, -1, -1, 18, 19, 20], 0.0, 0.0, 0.0, 21), TakahashiStrategy(), nothing, nothing, nothing)
)

Our "observations" are just a sample from the true CAR

observations = rand(true_car)

println("Generated observations from CAR process with $(N) time points")
println("True CAR parameter ρ: $(true_ρ)")
println("True variance parameter σ: $(true_σ)")
Generated observations from CAR process with 21 time points
True CAR parameter ρ: 0.85
True variance parameter σ: 0.01

Bayesian model in Turing

We'll use a simple model:

  • x ~ CAR(W, ρ, σ) where ρ and σ are parameters to infer
  • ρ ~ Uniform(0.5, 0.99)
  • σ ~ Uniform(0.001, 0.1)
  • y = x (direct observation of the CAR process)

Again, the crucial bit here is using the :autodiffable solver backend.

@model function car_model(y, W, μ)
    # Prior on CAR parameter
    ρ ~ Uniform(0.5, 0.99)

    # Prior on variance parameter
    σ ~ Uniform(0.001, 0.1)

    # CAR process
    car_dist = generate_car_model(W, ρ; μ = μ, σ = σ,
                                 solver_blueprint=CholeskySolverBlueprint{:autodiffable}())

    # Direct observation
    y ~ car_dist
end
car_model (generic function with 2 methods)

Create the model

model = car_model(observations, W, μ)
DynamicPPL.Model{typeof(Main.car_model), (:y, :W, :μ), (), (), Tuple{Vector{Float64}, SparseArrays.SparseMatrixCSC{Float64, Int64}, Vector{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(Main.car_model, (y = [0.07840387299483453, 0.08930313551294565, 0.0568103036193451, 0.08267097408818341, -0.025373176002426064, -0.0525144955980359, -0.16214702025581984, -0.05502378606163386, -0.051828462039471106, -0.011947395381926016  …  0.016379935558217797, 0.021789192443508, -0.07393817365539637, -0.05328451077296106, -0.020084724732943468, 0.000678582049044622, -0.03016630854446354, -0.08709273978627284, -0.13240380240329042, -0.03370866463496189], W = sparse([2, 3, 1, 3, 4, 1, 2, 4, 5, 2  …  20, 17, 18, 20, 21, 18, 19, 21, 19, 20], [1, 1, 2, 2, 2, 3, 3, 3, 3, 4  …  18, 19, 19, 19, 19, 20, 20, 20, 21, 21], [1.0, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5  …  0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 1.0], 21, 21), μ = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), NamedTuple(), DynamicPPL.DefaultContext())

MCMC Sampling with NUTS

NUTS requires gradients, which are automatically computed via ForwardDiff thanks to our autodifferentiable GMRF implementation.

println("Starting MCMC sampling...")
Random.seed!(456)

sampler = NUTS()
chain = sample(model, sampler, 1000, progress=false)

println("MCMC sampling completed!")
Starting MCMC sampling...
┌ Info: Found initial step size
└   ϵ = 1.5812500000000003
MCMC sampling completed!

Analyze results

Extract CAR parameter samples

ρ_samples = chain[:ρ].data[:, 1]
ρ_mean = mean(ρ_samples)
ρ_std = std(ρ_samples)
0.12628870284765983

Extract variance parameter samples

σ_samples = chain[:σ].data[:, 1]
σ_mean = mean(σ_samples)
σ_std = std(σ_samples)

println("Posterior summary for CAR parameter ρ:")
println("True value: $(true_ρ)")
println("Posterior mean: $(round(ρ_mean, digits=4)) ± $(round(ρ_std, digits=4))")
println("95% credible interval: $(round(quantile(ρ_samples, 0.025), digits=4)) - $(round(quantile(ρ_samples, 0.975), digits=4))")

ρ_in_ci = quantile(ρ_samples, 0.025) <= true_ρ <= quantile(ρ_samples, 0.975)
println("True ρ value in 95% CI: $(ρ_in_ci)")

println("\nPosterior summary for variance parameter σ:")
println("True value: $(true_σ)")
println("Posterior mean: $(round(σ_mean, digits=4)) ± $(round(σ_std, digits=4))")
println("95% credible interval: $(round(quantile(σ_samples, 0.025), digits=4)) - $(round(quantile(σ_samples, 0.975), digits=4))")

σ_in_ci = quantile(σ_samples, 0.025) <= true_σ <= quantile(σ_samples, 0.975)
println("True σ value in 95% CI: $(σ_in_ci)")
Posterior summary for CAR parameter ρ:
True value: 0.85
Posterior mean: 0.7714 ± 0.1263
95% credible interval: 0.515 - 0.9665
True ρ value in 95% CI: true

Posterior summary for variance parameter σ:
True value: 0.01
Posterior mean: 0.0097 ± 0.0039
95% credible interval: 0.0044 - 0.0197
True σ value in 95% CI: true

Plot posterior for CAR parameter

p1 = histogram(ρ_samples, bins=20, alpha=0.7, label="Posterior samples",
               xlabel="CAR Parameter ρ", ylabel="Density",
               title="Posterior Distribution of ρ")
vline!([true_ρ], label="True Value", color=:red, linewidth=2)
vline!([ρ_mean], label="Posterior Mean", color=:blue, linewidth=2, linestyle=:dash)
Example block output

Plot posterior for variance parameter

p2 = histogram(σ_samples, bins=20, alpha=0.7, label="Posterior samples",
               xlabel="Variance Parameter σ", ylabel="Density",
               title="Posterior Distribution of σ")
vline!([true_σ], label="True Value", color=:red, linewidth=2)
vline!([σ_mean], label="Posterior Mean", color=:blue, linewidth=2, linestyle=:dash)
Example block output

Plot trace of CAR parameter

p3 = plot(ρ_samples, label="ρ samples", xlabel="Iteration", ylabel="CAR Parameter ρ",
          title="MCMC Trace for ρ")
hline!([true_ρ], label="True Value", color=:red, linewidth=2)
Example block output

Plot trace of variance parameter

p4 = plot(σ_samples, label="σ samples", xlabel="Iteration", ylabel="Variance Parameter σ",
          title="MCMC Trace for σ")
hline!([true_σ], label="True Value", color=:red, linewidth=2)
Example block output

Combine plots

combined_plot = plot(p1, p2, p3, p4, layout=(2,2), size=(1000, 800))
Example block output

Conclusion

This tutorial demonstrated how AD enables advanced sampling methods like NUTS for full Bayesian inference with GMRFs.

In practice, MCMC still quickly becomes prohibitively expensive, so people instead use Integrated Nested Laplace Approximations (INLA). More on this soon.


This page was generated using Literate.jl.