Gridded data
Here, we develop a neural estimator for a spatial Gaussian process model with exponential covariance function and unknown range parameter
Package dependencies
using NeuralEstimators
using Flux
using CairoMakie
using Distances
using Folds # parallel simulation (start Julia with --threads=auto)
using LinearAlgebra # Cholesky factorisationTo improve computational efficiency, various GPU backends are supported. Once the relevant package is loaded and a compatible GPU is available, it will be used automatically:
using CUDAusing AMDGPUusing Metalusing oneAPISampling parameters
Simulation from Gaussian processes requires computing the Cholesky factor of a covariance matrix, which is expensive but reusable across repeated simulations from the same parameters. We therefore define a custom type Parameters subtyping AbstractParameterSet to store both the parameters and their corresponding Cholesky factors:
struct Parameters <: AbstractParameterSet
θ
L
endWe define two constructors: one that accepts an integer and samples from the prior (used during training), and one that accepts a parameter matrix directly (useful for parametric bootstrap at inference time):
function sampler(K::Integer)
θ = 0.5 * rand(K) # K samples from π(θ) = Unif(0, 0.5)
Parameters(NamedMatrix(θ = θ)) # Wrap as a named matrix and pass to matrix constructor
end
function Parameters(θ::AbstractMatrix)
# Spatial locations: 16×16 grid over the unit square
pts = range(0, 1, length = 16)
S = expandgrid(pts, pts)
# Pairwise distances, covariance matrices, and Cholesky factors
D = pairwise(Euclidean(), S, dims = 1)
K = size(θ, 2)
L = Folds.map(1:K) do k
Σ = exp.(-D ./ θ[k])
cholesky(Symmetric(Σ)).L
end
Parameters(θ, L)
endSimulating data
We store each simulated data set as a four-dimensional array of dimension
function simulator(parameters::Parameters, m = 1)
Folds.map(parameters.L) do L
n = size(L, 1)
z = L * randn(n, m)
reshape(z, 16, 16, 1, m)
end
endConstructing the neural network
For data collected over a regular grid, the inner network is typically a convolutional neural network (CNN; see, e.g., Dumoulin and Visin, 2016). Note that deeper architectures employing residual connections (see ResidualBlock) often lead to improved performance, and certain pooling layers (e.g., GlobalMeanPool) allow the network to accommodate grids of varying dimension; for further discussion, see Sainsbury-Dale et al. (2025, Sec. S3, S4).
d = 1 # dimension of the parameter vector θ
num_summaries = 3d # number of summary statistics for θ
# Inner network (CNN)
ψ = Chain(
Conv((3, 3), 1 => 32, relu), # 3×3 filter, 1 → 32 channels
MaxPool((2, 2)), # 2×2 max pooling
Conv((3, 3), 32 => 64, relu), # 3×3 filter, 32 → 64 channels
MaxPool((2, 2)), # 2×2 max pooling
Flux.flatten # flatten for fully connected layers
)
# Outer network
ϕ = Chain(Dense(256, 64, relu), Dense(64, num_summaries))
# DeepSet object
network = DeepSet(ψ, ϕ)Constructing the neural estimator
We now construct a NeuralEstimator by wrapping the neural network in the subtype corresponding to the intended inferential method:
estimator = PointEstimator(network, d; num_summaries = num_summaries)estimator = PosteriorEstimator(network, d; num_summaries = num_summaries)estimator = RatioEstimator(network, d; num_summaries = num_summaries)Training the estimator
We train the estimators using fixed parameter instances to avoid repeated Cholesky factorisations (see Storing expensive intermediate objects for data simulation and On-the-fly and just-in-time simulation for further discussion):
K = 5000
θ_train = sampler(K)
θ_val = sampler(K)
estimator = train(estimator, θ_train, θ_val, simulator)The empirical risk (average loss) over the training and validation sets can be plotted using plotrisk.
One may wish to save a trained estimator and load it in a later session: see Saving and loading neural estimators for details on how this can be done.
Assessing the estimator
The function assess can be used to assess the trained estimator:
θ_test = sampler(1000) # test parameters
Z_test = simulator(θ_test) # test data
assessment = assess(estimator, θ_test, Z_test)The resulting Assessment object contains ground-truth parameters, estimates, and other quantities that can be used to compute quantitative and qualitative diagnostics:
bias(assessment) # 0.005
rmse(assessment) # 0.032
plot(assessment)
Applying the estimator to observed data
Once an estimator is deemed to be well calibrated, it may be applied to observed data (below, we use simulated data as a stand-in for observed data):
θ = Parameters(Matrix([0.1]')) # ground truth (not known in practice)
Z = simulator(θ) # stand-in for real dataestimate(estimator, Z) # point estimatesampleposterior(estimator, Z) # posterior samplesampleposterior(estimator, Z) # posterior sampleNote that missing data (e.g., due to cloud cover) can be accommodated using the missing-data methods implemented in the package.