Automatic Differentiation for GMRF Hyperparameters
Introduction
GaussianMarkovRandomFields.jl provides comprehensive automatic differentiation (AD) support for gradient-based inference and optimization. This tutorial demonstrates how to compute gradients through GMRF operations to optimize hyperparameters like precision parameters, mean field values, and other model parameters.
We currently support AD via Zygote, Enzyme, and ForwardDiff.
AD may break
Our current AD rules cover the most common workflows with GMRFs. Less common operations may or may not work. If one of these backends breaks for your use case, please open an issue.
Basic Setup
We'll start by loading the required packages:
using GaussianMarkovRandomFields
using DifferentiationInterface
using Zygote, Enzyme
using LinearAlgebra
using LinearSolve
using Distributions
using Random
Random.seed!(123)
Random.TaskLocalRNG()
Example Problem: Hyperparameter Optimization
Consider a problem where we have Poisson count observations and want to infer the hyperparameters of a simple IID (independent and identically distributed) prior.
The model has two hyperparameters:
The mean parameter μ
The precision parameter τ
Start by generating some synthetic data:
n = 50 # Number of observations
τ_true = 4.0
μ_true = 5.0
5.0
Next, let's define the prior:
function build_prior(log_τ, log_μ, n)
τ = exp(log_τ)
μ = exp(log_μ)
model = IIDModel(n)
Q = precision_matrix(model; τ = τ)
return GMRF(μ * ones(n), Q, LinearSolve.DiagonalFactorization())
end
build_prior (generic function with 1 method)
Finally, sample a ground-truth latent field and generate observations
true_gmrf = build_prior(log(τ_true), log(μ_true), n)
x_latent = rand(true_gmrf)
obs_model = ExponentialFamily(Poisson)
y_obs = rand(conditional_distribution(obs_model, x_latent))
println("Generated $n observations with τ = $τ_true, μ = $μ_true")
Generated 50 observations with τ = 4.0, μ = 5.0
Computing Gradients with DifferentiationInterface
Now we'll define an objective function that maps hyperparameters to a scalar loss, and compute its gradient using both Zygote and Enzyme.
The objective function takes hyperparameters [log_τ, log_μ], builds a GMRF prior, computes a Gaussian approximation to the posterior, and returns the negative log marginal likelihood.
function objective(θ::Vector{Float64}, y::Vector{Int}, n::Int)
log_τ, log_μ = θ
prior = build_prior(log_τ, log_μ, n)
obs_model = ExponentialFamily(Poisson)
likelihood = obs_model(y)
posterior = gaussian_approximation(prior, likelihood)
x_map = mean(posterior)
return -logpdf(prior, x_map) - loglik(x_map, likelihood) + logpdf(posterior, x_map)
end
objective (generic function with 1 method)
Initialize hyperparameters (perturbed from truth)
θ_init = [log(τ_true) + 0.2, log(μ_true) - 0.3]
2-element Vector{Float64}:
1.5862943611198905
1.3094379124341002
Compute gradient with Zygote
backend_zygote = AutoZygote()
grad_zygote = DifferentiationInterface.gradient(
θ -> objective(θ, y_obs, n),
backend_zygote,
θ_init
)
println("Zygote gradient computed, norm: $(norm(grad_zygote))")
Zygote gradient computed, norm: 1106.5782216602752
Compute gradient with Enzyme
backend_enzyme = AutoEnzyme(; function_annotation = Enzyme.Const)
grad_enzyme = DifferentiationInterface.gradient(
θ -> objective(θ, y_obs, n),
backend_enzyme,
θ_init
)
println("Enzyme gradient computed, norm: $(norm(grad_enzyme))")
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/LJjsP/src/compiler.jl:4608
Enzyme gradient computed, norm: 1106.5782216602754
Verify gradients match
max_diff = maximum(abs.(grad_zygote - grad_enzyme))
println("Maximum difference between backends: $(max_diff)")
Maximum difference between backends: 2.2737367544323206e-13
Optimization with Optim.jl
We can use these gradients with optimization libraries like Optim.jl to find maximum a posteriori (MAP) estimates. Optim.jl allows you to choose the backend used for automatic differentation through its autodiff
parameter.
using Optim
Optimize using L-BFGS with Zygote-based autodiff
result = optimize(
θ -> objective(θ, y_obs, n),
θ_init,
LBFGS(; alphaguess = Optim.LineSearches.InitialStatic(; alpha = 0.001)),
autodiff = AutoZygote()
)
* Status: success
* Candidate solution
Final objective value: 2.850519e+02
* Found with
Algorithm: L-BFGS
* Convergence measures
|x - x'| = 1.93e-10 ≰ 0.0e+00
|x - x'|/|x'| = 1.21e-10 ≰ 0.0e+00
|f(x) - f(x')| = 1.14e-13 ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = 3.99e-16 ≰ 0.0e+00
|g(x)| = 1.14e-12 ≤ 1.0e-08
* Work counters
Seconds run: 0 (vs limit Inf)
Iterations: 8
f(x) calls: 54
∇f(x) calls: 54
Extract optimal parameters
θ_opt = Optim.minimizer(result)
τ_opt = exp(θ_opt[1])
μ_opt = exp(θ_opt[2])
println("\nOptimization results:")
println(" Iterations: $(result.iterations)")
println(" Estimated τ: $(round(τ_opt, digits = 2)) (true: $τ_true)")
println(" Estimated μ: $(round(μ_opt, digits = 2)) (true: $μ_true)")
println(" Converged: $(Optim.converged(result))")
Optimization results:
Iterations: 8
Estimated τ: 3.98 (true: 4.0)
Estimated μ: 4.96 (true: 5.0)
Converged: true
Choosing a Backend
First, you need to choose between forward- and reverse-mode AD. Generally, the recommendation for AD through a function with n inputs and m outputs is: If n is sufficiently small or n << m, use forward-mode. Else, use reverse-mode.
This same advice applies here, with the added caveat that ForwardDiff currently does not support Gaussian approximations. If you need to autodiff through Gaussian approximations, use Zygote or Enzyme.
Both Zygote and Enzyme produce identical gradients, so the choice between them comes down to performance and ease of use.
Zygote has low pre-compilation times and works in most cases. By contrast, Enzyme incurs large pre-compilation overheads and may not work in some situations. The upside is that once pre-compilation is complete, Enzyme is generally much faster than Zygote.
In practice, our recommendation is: Start with Zygote for prototyping. For large-scale problems, switch to Enzyme.
Solver Considerations
Enzyme is particularly finicky when it comes to type stability. This causes issues when it comes to a GMRF's linear solver.
Using the two-argument GMRF constructor gives you the default linear solver:
using SparseArrays
Q_sparse = sprand(10, 10, 0.3)
Q_sparse = Q_sparse + Q_sparse' + 10I # Make symmetric positive definite
x_default_solver = GMRF(zeros(10), Q_sparse)
x_default_solver.linsolve_cache.alg
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization, true)
Unfortunately, the default linear solver is not type-stable. It checks the type of the precision matrix at runtime and only then decides on an algorithm. In our experience, Enzyme very much does not like this behaviour.
To avoid these issues, always pass a specialized linear solver to your GMRF, e.g. CHOLMOD for general sparse matrices:
x_cholmod = GMRF(zeros(10), Q_sparse, LinearSolve.CHOLMODFactorization())
GMRF{Float64} with 10 variables
Algorithm: LinearSolve.CHOLMODFactorization{Nothing}
Mean: [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]
Q_sqrt: not available
Conclusion
GaussianMarkovRandomFields.jl provides custom chain rules for common GMRF workflows. As a user, you should not have to worry about the details of this. AD should "just work". If it doesn't, please open an issue on GitHub. For Enzyme, as mentioned above, pay extra attention to type stability.
For more details on AD implementation and advanced usage, see the Automatic Differentiation Reference.
This page was generated using Literate.jl.