Estimators
The package provides several classes of neural estimator, organised within a type hierarchy rooted at the abstract supertype NeuralEstimator.
NeuralEstimators.NeuralEstimator Type
NeuralEstimatorAn abstract supertype for all neural estimators.
sourcePosterior estimators
PosteriorEstimator approximates the posterior distribution, using a flexible, parametric family of distributions (see Approximate distributions).
NeuralEstimators.PosteriorEstimator Type
PosteriorEstimator <: NeuralEstimator
PosteriorEstimator(summary_network, q::ApproximateDistribution)
PosteriorEstimator(summary_network, num_parameters::Integer; num_summaries::Integer, q = nothing, kwargs...)A neural estimator that approximates the posterior distribution summary_network and an approximate distribution q (see the available in-built Approximate distributions).
The summary_network maps data q. The precise way in which the summary statistics condition q depends on the choice of approximate distribution: for example, GaussianMixture uses an MLP to map NormalisingFlow uses
The convenience constructor builds q internally given num_parameters and num_summaries, with any additional keyword arguments passed to the constructor of q.
Keyword arguments
num_summaries::Integer: the number of summary statistics output bysummary_network. Must match the output dimension ofsummary_network.q::Type{<:ApproximateDistribution}: the type of approximate distribution to use. Defaults toNormalisingFlowwhen usingFlux, andGaussianMixturewhen usingLux.kwargs...: additional keyword arguments passed to the constructor ofq.
Examples
using NeuralEstimators, Flux
# Data Z|μ,σ ~ N(μ, σ²) with priors μ ~ N(0, 1) and σ ~ U(0, 1)
d, m = 2, 100 # dimension of θ and number of replicates
sampler(K) = NamedMatrix(μ = randn(K), σ = rand(K))
simulator(θ::AbstractVector) = θ["μ"] .+ θ["σ"] .* sort(randn(m))
simulator(θ::AbstractMatrix) = reduce(hcat, map(simulator, eachcol(θ)))
# Neural network
num_summaries = 3d
summary_network = Chain(Dense(m, 64, gelu), Dense(64, 64, gelu), Dense(64, num_summaries))
# Initialise the estimator, with q built internally
estimator = PosteriorEstimator(summary_network, d; num_summaries = num_summaries)
# Or, build q explicitly
q = NormalisingFlow(d; num_summaries = num_summaries)
estimator = PosteriorEstimator(summary_network, q)
# Training
estimator = train(estimator, sampler, simulator, K = 3000)
# Assess the estimator
θ_test = sampler(250)
Z_test = simulator(θ_test);
assessment = assess(estimator, θ_test, Z_test)
# Inference with observed data
θ = sampler(1)
Z = simulator(θ)
sampleposterior(estimator, Z) # posterior draws
posteriormean(estimator, Z) # point estimateRatio estimators
RatioEstimator approximates the likelihood-to-evidence ratio, enabling both frequentist and Bayesian inference through various downstream algorithms.
NeuralEstimators.RatioEstimator Type
RatioEstimator <: NeuralEstimator
RatioEstimator(summary_network, num_parameters; num_summaries, kwargs...)A neural estimator that estimates the likelihood-to-evidence ratio,
where
The estimator jointly summarises the data 2 * num_parameters summaries by default.
For numerical stability, training is done on the log-scale using the relation
Given data Z and parameters θ, the estimated ratio can be obtained using logratio and can be used in various Bayesian (e.g., Hermans et al., 2020) or frequentist (e.g., Walchessen et al., 2024) inferential algorithms. For Bayesian inference, posterior samples can be obtained via simple grid-based sampling using sampleposterior.
Keyword arguments
num_summaries::Integer: the number of summaries output bysummary_network. Must match the output dimension ofsummary_network.num_summaries_θ::Integer = 2 * num_parameters: the number of summaries output by the parameter summary network.summary_network_θ_kwargs::NamedTuple = (;): keyword arguments passed to the MLP constructor for the parameter summary network.kwargs...: additional keyword arguments passed to the MLP constructor for the inference network.
Examples
using NeuralEstimators, Flux, CairoMakie
# Data Z|μ,σ ~ N(μ, σ²) with priors μ ~ U(0, 1) and σ ~ U(0, 1)
d, m = 2, 100 # dimension of θ and number of replicates
sampler(K) = NamedMatrix(μ = rand(K), σ = rand(K))
simulator(θ::AbstractVector) = θ["μ"] .+ θ["σ"] .* sort(randn(m))
simulator(θ::AbstractMatrix) = reduce(hcat, map(simulator, eachcol(θ)))
# Neural network
num_summaries = 3d
summary_network = Chain(Dense(m, 64, gelu), Dense(64, 64, gelu), Dense(64, num_summaries))
# Initialise the estimator
estimator = RatioEstimator(summary_network, d; num_summaries = num_summaries)
# Train the estimator
estimator = train(estimator, sampler, simulator, K = 1000)
# Plot the risk history
plotrisk()
# Assess the estimator
θ_test = sampler(250)
Z_test = simulator(θ_test);
grid = expandgrid(0:0.01:1, 0:0.01:1)' # fine gridding of the parameter space
assessment = assess(estimator, θ_test, Z_test; grid = grid)
plot(assessment)
# Generate "observed" data
θ = sampler(1)
z = simulator(θ)
# Grid-based evaluation and sampling
logratio(estimator, z; grid = grid) # log of likelihood-to-evidence ratios
sampleposterior(estimator, z; grid = grid) # posterior sampleBayes estimators
Neural Bayes estimators are implemented as subtypes of BayesEstimator. The general-purpose PointEstimator supports user-defined loss functions (see Loss functions). The types IntervalEstimator and its generalisation QuantileEstimator are designed for posterior quantile estimation based on user-specified probability levels, automatically configuring the quantile loss and enforcing non-crossing constraints.
NeuralEstimators.BayesEstimator Type
BayesEstimator <: NeuralEstimatorAn abstract supertype for neural Bayes estimators.
sourceNeuralEstimators.PointEstimator Type
PointEstimator <: BayesEstimator
PointEstimator(network)
PointEstimator(summary_network, inference_network)
PointEstimator(summary_network, num_parameters; num_summaries, kwargs...)A neural point estimator mapping data to a point summary of the posterior distribution.
The neural network can be provided in two ways:
As a single
networkthat maps data directly to the parameter space.As a
summary_networkthat maps data to a vector of summary statistics, with theinference_networkconstructed internally based onnum_parametersandnum_summaries.
Examples
using NeuralEstimators, Flux, CairoMakie
# Data Z|μ,σ ~ N(μ, σ²) with priors μ ~ N(0, 1) and σ ~ U(0, 1)
d, m = 2, 100 # dimension of θ and number of replicates
sampler(K) = NamedMatrix(μ = randn(K), σ = rand(K))
simulator(θ::AbstractVector) = θ["μ"] .+ θ["σ"] .* sort(randn(m))
simulator(θ::AbstractMatrix) = reduce(hcat, map(simulator, eachcol(θ)))
# Neural network, an MLP mapping m inputs into d outputs
network = Chain(Dense(m, 64, gelu), Dense(64, 64, gelu), Dense(64, d))
# Initialise a neural point estimator
estimator = PointEstimator(network)
# Train the estimator
estimator = train(estimator, sampler, simulator)
# Plot the risk history
plotrisk()
# Assess the estimator
θ_test = sampler(1000)
Z_test = simulator(θ_test)
assessment = assess(estimator, θ_test, Z_test)
bias(assessment)
rmse(assessment)
plot(assessment)
# Apply to observed data (here, simulated as a stand-in)
θ = sampler(1) # ground truth (not known in practice)
Z = simulator(θ) # stand-in for real observations
estimate(estimator, Z) # point estimateNeuralEstimators.IntervalEstimator Type
IntervalEstimator <: BayesEstimator
IntervalEstimator(summary_network, num_parameters; num_summaries, kwargs...)
IntervalEstimator(summary_network, num_parameters, num_summaries; kwargs...)A neural estimator that jointly estimates marginal posterior credible intervals based on the probability levels probs (by default, 95% central credible intervals).
The estimator summarises the data summary_network whose output is passed to two MLP inference networks,
The estimator employs a representation that prevents quantile crossing. Specifically, given data
where
The functions Compress object, which can constrain the interval estimator's output to the prior support. If these functions are unspecified, they will be set to the identity function so that the range of the intervals will be unrestricted.
The return value when applied to data using estimate() is a matrix with interval() can be used to format this output in a readable
See also QuantileEstimator.
Keyword arguments
num_summaries::Integer: the number of summaries output bysummary_network. Must match the output dimension ofsummary_network.c::Union{Function, Compress} = identity: monotonically increasing function(s) mapping to the prior support of each parameter.probs = [0.025, 0.975]: probability levels for the lower and upper bounds.g = softplus: monotonically increasing function used to ensure a positive interval width.kwargs...: additional keyword arguments passed to the MLP constructors for the inference networksuandv.
Examples
using NeuralEstimators, Flux
# Data Z|μ,σ ~ N(μ, σ²) with priors μ ~ U(0, 1) and σ ~ U(0, 1)
d, m = 2, 100 # dimension of θ and number of replicates
sampler(K) = NamedMatrix(μ = rand(K), σ = rand(K))
simulator(θ::AbstractVector) = θ["μ"] .+ θ["σ"] .* sort(randn(m))
simulator(θ::AbstractMatrix) = reduce(hcat, simulator.(eachcol(θ)))
# Neural network
num_summaries = 3d
summary_network = Chain(Dense(m, 64, relu), Dense(64, 64, relu), Dense(64, num_summaries))
# Initialise and train the estimator
estimator = IntervalEstimator(summary_network, d; num_summaries = num_summaries)
estimator = train(estimator, sampler, simulator, K = 3000)
# Assessment
θ_test = sampler(1000)
Z_test = simulator(θ_test)
assessment = assess(estimator, θ_test, Z_test)
coverage(assessment)
# Inference
θ = sampler(1)
Z = simulator(θ);
estimate(estimator, Z)
interval(estimator, Z)NeuralEstimators.QuantileEstimator Type
QuantileEstimator <: BayesEstimator
QuantileEstimator(summary_network, num_parameters; num_summaries, kwargs...)A neural estimator that jointly estimates a fixed set of marginal posterior quantiles, with probability levels probs. This generalises IntervalEstimator to support an arbitrary number of probability levels.
Given data
for parameters i set to a positive integer, the estimator approximates quantiles of the full conditional distribution of
where
The estimator employs a representation that prevents quantile crossing, namely,
where g = nothing, the quantiles are estimated independently through the representation
The return value is a matrix with i is specified, quantiles can be used to format this output in a readable
Keyword arguments
num_summaries::Integer: the number of summaries output bysummary_network. Must match the output dimension ofsummary_network.probs = [0.025, 0.5, 0.975]: probability levels for the quantiles.g = softplus: monotonically increasing function applied to enforce non-crossing quantiles.i::Union{Integer, Nothing} = nothing: if set to a positive integer, the estimator targets the full conditional distribution of .num_summaries_θ::Integer = 2 * (num_parameters - 1): number of summaries for the parameter summary network (only used wheniis specified).summary_network_θ_kwargs::NamedTuple = (;): keyword arguments for the parameter summary network MLP (only used wheniis specified).kwargs...: additional keyword arguments passed to the MLP constructors for the inference networks.
Examples
using NeuralEstimators, Flux
# Data Z|μ,σ ~ N(μ, σ²) with priors μ ~ N(0, 1) and σ ~ U(0, 1)
d, m = 2, 100 # dimension of θ and number of replicates
sampler(K) = NamedMatrix(μ = rand(K), σ = rand(K))
simulator(θ::AbstractVector) = θ["μ"] .+ θ["σ"] .* sort(randn(m))
simulator(θ::AbstractMatrix) = reduce(hcat, simulator.(eachcol(θ)))
# Neural network
num_summaries = 3d
summary_network = Chain(Dense(m, 64, gelu), Dense(64, 64, gelu), Dense(64, num_summaries))
# ---- Quantiles of θᵢ ∣ 𝐙, i = 1, …, d ----
# Initialise the estimator
estimator = QuantileEstimator(summary_network, d; num_summaries = num_summaries)
# Training
estimator = train(estimator, sampler, simulator)
# Assessment
θ_test = sampler(1000)
Z_test = simulator(θ_test)
assessment = assess(estimator, θ_test, Z_test)
# Inference
θ = sampler(1)
Z = simulator(θ);
estimate(estimator, Z)
quantiles(estimator, Z)
# ---- Quantiles of θᵢ ∣ 𝐙, θ₋ᵢ ----
# Initialise estimators respectively targeting quantiles of μ∣Z,σ and σ∣Z,μ
q₁ = QuantileEstimator(summary_network, d; num_summaries = num_summaries, i = 1)
q₂ = QuantileEstimator(summary_network, d; num_summaries = num_summaries, i = 2)
# Training
q₁ = train(q₁, sampler, simulator)
q₂ = train(q₂, sampler, simulator)
# Inference: Estimate quantiles of μ∣Z,σ with known σ
σ₀ = θ["σ", 1]
θ₋ᵢ = [σ₀;]
estimate(q₁, (Z, θ₋ᵢ))
quantiles(q₁, (Z, θ₋ᵢ))Ensembles
Ensemble combines multiple estimators, aggregating their individual estimates to improve accuracy.
NeuralEstimators.Ensemble Type
Ensemble <: NeuralEstimator
Ensemble(estimators)
Ensemble(architecture::Function, J::Integer)
(ensemble::Ensemble)(Z; aggr = mean)Defines an ensemble of estimators which, when applied to data Z, returns the mean (or another summary defined by aggr) of the individual estimates (see, e.g., Sainsbury-Dale et al., 2025, Sec. S5).
The ensemble can be initialised with a collection of trained estimators and then applied immediately to observed data. Alternatively, the ensemble can be initialised with a collection of untrained estimators (or a function defining the architecture of each estimator, and the number of estimators in the ensemble), trained with train(), and then applied to observed data. In the latter case, where the ensemble is trained directly, if savepath is specified both the ensemble and component estimators will be saved.
Note that train() currently acts sequentially on the component estimators, using the Adam optimiser.
The ensemble components can be accessed by indexing the ensemble; the number of component estimators can be obtained using length().
See also Parallel (Flux/Lux)), which can be used to mimic ensemble methods with an appropriately chosen connection.
Note
Ensemble is currently only implemented for the Flux backend.
Examples
using NeuralEstimators, Flux
# Data Z|θ ~ N(θ, 1) with θ ~ N(0, 1)
d = 1 # dimension of the parameter vector θ
n = 1 # dimension of each replicate of Z
m = 30 # number of replicates in each data set
sampler(K) = randn32(d, K)
simulator(θ, m) = [μ .+ randn32(n, m) for μ ∈ eachcol(θ)]
# Neural-network architecture of each ensemble component
function architecture()
ψ = Chain(Dense(n, 64, relu), Dense(64, 64, relu))
ϕ = Chain(Dense(64, 64, relu), Dense(64, d))
network = DeepSet(ψ, ϕ)
PointEstimator(network)
end
# Initialise ensemble with three component estimators
ensemble = Ensemble(architecture, 3)
ensemble[1] # access component estimators by indexing
ensemble[1:2] # indexing with an iterable collection returns the corresponding ensemble
length(ensemble) # number of component estimators
# Training
ensemble = train(ensemble, sampler, simulator, m = m, epochs = 5)
# Assessment
θ = sampler(1000)
Z = simulator(θ, m)
assessment = assess(ensemble, θ, Z)
rmse(assessment)
# Apply to data
ensemble(Z)Helper functions
The following helper functions operate on an estimator to inspect its components or apply parts of it to data. For the main inference functions used post-training, see Inference with observed data.
NeuralEstimators.summarynetwork Function
summarynetwork(estimator::NeuralEstimator)Returns the summary network of estimator.
See also summarystatistics.
NeuralEstimators.setsummarynetwork Function
setsummarynetwork(estimator::NeuralEstimator, network)Returns a new estimator identical to estimator but with the summary network replaced by network. Useful for transfer learning.
Note that RatioEstimator has a second summary network for the parameters, accessible via estimator.summary_network_θ, which is not affected by this function.
See also summarynetwork.
NeuralEstimators.summarystatistics Function
summarystatistics(estimator::NeuralEstimator, Z; batchsize = 32, device = nothing, use_gpu = true)Computes learned summary statistics by applying the summary network of estimator to data Z.
If Z is a DataSet object, the learned summary statistics are concatenated with the precomputed expert summary statistics stored in Z.S.
The device used for computation can be specified via device (e.g., cpu_device(), gpu_device(), or reactant_device(), the latter requiring Lux.jl) or inferred automatically by setting use_gpu = true (default) to use a GPU if one is available. The device argument takes priority over use_gpu if both are provided.
See also summarynetwork.
Lux.jl convenience wrapper
Both Flux.jl and Lux.jl are supported. These frameworks differ in a key way: Flux networks store their trainable parameters and states inside the network object, while Lux networks store them externally as explicit, separate objects.
For convenience, LuxEstimator bundles a Lux-based estimator together with its parameters and states for a unified, backend-agnostic API.
NeuralEstimators.LuxEstimator Type
LuxEstimator(estimator::NeuralEstimator, ps, st)
LuxEstimator(estimator::NeuralEstimator; rng::AbstractRNG = Random.default_rng())Wraps a NeuralEstimator containing Lux.jl networks together with their parameters ps and states st.
The convenience constructor automatically calls Lux.setup(rng, estimator) to initialise ps and st.
Examples
using NeuralEstimators, Lux
network = Lux.Chain(Lux.Dense(10, 64, gelu), Lux.Dense(64, 2))
estimator = LuxEstimator(PointEstimator(network))
# Training, assessment, and inference proceed identically to the Flux API:
estimator = train(estimator, ...)
estimate(estimator, ...)
assess(estimator, ...)
# Access parameters and states directly if needed
estimator.ps
estimator.st