Advanced usage

Saving and loading neural estimators

In regards to saving and loading, neural estimators behave in the same manner as regular Flux models. Therefore, the examples and recommendations outlined in the Flux documentation also apply directly to neural estimators. For example, to save the model state of the neural estimator θ̂:

using Flux
using BSON: @save, @load
model_state = Flux.state(θ̂)
@save "estimator.bson" model_state

Then, to load it in a new session, one may initialise a neural estimator with the same architecture used previously, and load the saved model state:

@load "estimator.bson" model_state
Flux.loadmodel!(θ̂, model_state)

It is also straightforward to save the entire neural estimator, including its architecture (see here). However, the first approach outlined above is recommended for long-term storage.

For convenience, the function train() allows for the automatic saving of the model state during the training stage, via the argument savepath.

Storing expensive intermediate objects for data simulation

Parameters sampled from the prior distribution may be stored in two ways. Most simply, they can be stored as a $p \times K$ matrix, where $p$ is the number of parameters in the model and $K$ is the number of parameter vectors sampled from the prior distribution. Alternatively, they can be stored in a user-defined struct subtyping ParameterConfigurations, whose only requirement is a field θ that stores the $p \times K$ matrix of parameters. With this approach, one may store computationally expensive intermediate objects, such as Cholesky factors, for later use when conducting "on-the-fly" simulation, which is discussed below.

On-the-fly and just-in-time simulation

When data simulation is (relatively) computationally inexpensive, the training data set, $\mathcal{Z}_{\text{train}}$, can be simulated continuously during training, a technique coined "simulation-on-the-fly". Regularly refreshing $\mathcal{Z}_{\text{train}}$ leads to lower out-of-sample error and to a reduction in overfitting. This strategy therefore facilitates the use of larger, more representationally-powerful networks that are prone to overfitting when $\mathcal{Z}_{\text{train}}$ is fixed. Further, this technique allows for data be simulated "just-in-time", in the sense that they can be simulated in small batches, used to train the neural estimator, and then removed from memory. This can substantially reduce pressure on memory resources, particularly when working with large data sets.

One may also regularly refresh the set $\vartheta_{\text{train}}$ of parameter vectors used during training, and doing so leads to similar benefits. However, fixing $\vartheta_{\text{train}}$ allows computationally expensive terms, such as Cholesky factors when working with Gaussian process models, to be reused throughout training, which can substantially reduce the training time for some models. Hybrid approaches are also possible, whereby the parameters (and possibly the data) are held fixed for several epochs (i.e., several passes through the training set when performing stochastic gradient descent) before being refreshed.

The above strategies are facilitated with various methods of train().

Regularisation

The term regularisation refers to a variety of techniques aimed to reduce overfitting when training a neural network, primarily by discouraging complex models.

One common regularisation technique is known as dropout (Srivastava et al., 2014), implemented in Flux's Dropout layer. Dropout involves temporarily dropping ("turning off") a randomly selected set of neurons (along with their connections) at each iteration of the training stage, and this results in a computationally-efficient form of model (neural-network) averaging.

Another class of regularisation techniques involve modifying the loss function. For instance, L₁ regularisation (sometimes called lasso regression) adds to the loss a penalty based on the absolute value of the neural-network parameters. Similarly, L₂ regularisation (sometimes called ridge regression) adds to the loss a penalty based on the square of the neural-network parameters. Note that these penalty terms are not functions of the data or of the statistical-model parameters that we are trying to infer, and therefore do not modify the Bayes risk or the associated Bayes estimator. These regularisation techniques can be implemented straightforwardly by providing a custom optimiser to train that includes a SignDecay object for L₁ regularisation, or a WeightDecay object for L₂ regularisation. See the Flux documentation for further details.

For example, the following code constructs a neural Bayes estimator using dropout and L₁ regularisation with penalty coefficient $\lambda = 10^{-4}$:

using NeuralEstimators
using Flux

# Generate data from the model Z ~ N(θ, 1) and θ ~ N(0, 1)
p = 1       # number of unknown parameters in the statistical model
m = 5       # number of independent replicates
d = 1       # dimension of each independent replicate
K = 3000    # number of training samples
θ_train = randn(1, K)
θ_val   = randn(1, K)
Z_train = [μ .+ randn(1, m) for μ ∈ eachcol(θ_train)]
Z_val   = [μ .+ randn(1, m) for μ ∈ eachcol(θ_val)]

# Architecture with dropout layers
ψ = Chain(
	Dense(1, 32, relu),
	Dropout(0.1),
	Dense(32, 32, relu),
	Dropout(0.5)
	)     
ϕ = Chain(
	Dense(32, 32, relu),
	Dropout(0.5),
	Dense(32, 1)
	)           
θ̂ = DeepSet(ψ, ϕ)

# Optimiser with L₂ regularisation
optimiser = Flux.setup(OptimiserChain(SignDecay(1e-4), Adam()), θ̂)

# Train the estimator
train(θ̂, θ_train, θ_val, Z_train, Z_val; optimiser = optimiser)

Note that when the training data and/or parameters are held fixed during training, L₂ regularisation with penalty coefficient $\lambda = 10^{-4}$ is applied by default.

Expert summary statistics

Implicitly, neural estimators involve the learning of summary statistics. However, some summary statistics are available in closed form, simple to compute, and highly informative (e.g., sample quantiles, the empirical variogram, etc.). Often, explicitly incorporating these expert summary statistics in a neural estimator can simplify the optimisation problem, and lead to a better estimator.

The fusion of learned and expert summary statistics is facilitated by our implementation of the DeepSet framework. Note that this implementation also allows the user to construct a neural estimator using only expert summary statistics, following, for example, Gerber and Nychka (2021) and Rai et al. (2024). Note also that the user may specify arbitrary expert summary statistics, however, for convenience several standard User-defined summary statistics are provided with the package, including a fast approximate version of the empirical variogram.

Variable sample sizes

A neural estimator in the Deep Set representation can be applied to data sets of arbitrary size. However, even when the neural Bayes estimator approximates the true Bayes estimator arbitrarily well, it is conditional on the number of replicates, $m$, and is not necessarily a Bayes estimator for $m^* \ne m$. Denote a data set comprising $m$ replicates as $\boldsymbol{Z}^{(m)} \equiv (\boldsymbol{Z}_1', \dots, \boldsymbol{Z}_m')'$. There are at least two (non-mutually exclusive) approaches one could adopt if data sets with varying $m$ are envisaged, which we describe below.

Piecewise estimators

If data sets with varying $m$ are envisaged, one could train $l$ neural Bayes estimators for different sample sizes, or groups thereof (e.g., a small-sample estimator and a large-sample estimator). Specifically, for sample-size changepoints $m_1$, $m_2$, $\dots$, $m_{l-1}$, one could construct a piecewise neural Bayes estimator,

\[\hat{\boldsymbol{\theta}}(\boldsymbol{Z}^{(m)}; \boldsymbol{\gamma}^*) = \begin{cases} \hat{\boldsymbol{\theta}}(\boldsymbol{Z}^{(m)}; \boldsymbol{\gamma}^*_{\tilde{m}_1}) & m \leq m_1,\\ \hat{\boldsymbol{\theta}}(\boldsymbol{Z}^{(m)}; \boldsymbol{\gamma}^*_{\tilde{m}_2}) & m_1 < m \leq m_2,\\ \quad \vdots \\ \hat{\boldsymbol{\theta}}(\boldsymbol{Z}^{(m)}; \boldsymbol{\gamma}^*_{\tilde{m}_l}) & m > m_{l-1}, \end{cases}\]

where, here, $\boldsymbol{\gamma}^* \equiv (\boldsymbol{\gamma}^*_{\tilde{m}_1}, \dots, \boldsymbol{\gamma}^*_{\tilde{m}_{l-1}})$, and where $\boldsymbol{\gamma}^*_{\tilde{m}}$ are the neural-network parameters optimised for sample size $\tilde{m}$ chosen so that $\hat{\boldsymbol{\theta}}(\cdot; \boldsymbol{\gamma}^*_{\tilde{m}})$ is near-optimal over the range of sample sizes in which it is applied. This approach works well in practice, and it is less computationally burdensome than it first appears when used in conjunction with pre-training.

Piecewise neural estimators are implemented with the struct, PiecewiseEstimator, and their construction is facilitated with trainx().

Training with variable sample sizes

Alternatively, one could treat the sample size as a random variable, $M$, with support over a set of positive integers, $\mathcal{M}$, in which case, for the neural Bayes estimator, the risk function becomes

\[\sum_{m \in \mathcal{M}} P(M=m)\left( \int_\Theta \int_{\mathcal{Z}^m} L(\boldsymbol{\theta}, \hat{\boldsymbol{\theta}}(\boldsymbol{z}^{(m)}))f(\boldsymbol{z}^{(m)} \mid \boldsymbol{\theta}) \rm{d} \boldsymbol{z}^{(m)} \rm{d} \Pi(\boldsymbol{\theta}) \right).\]

This approach does not materially alter the workflow, except that one must also sample the number of replicates before simulating the data during the training phase.

The following pseudocode illustrates how one may modify a general data simulator to train under a range of sample sizes, with the distribution of $M$ defined by passing any object that can be sampled using rand(m, K) (e.g., an integer range like 1:30, an integer-valued distribution from Distributions.jl, etc.):

function simulate(parameters, m)

	## Number of parameter vectors stored in parameters
	K = size(parameters, 2)

	## Generate K sample sizes from the prior distribution for M
	m̃ = rand(m, K)

	## Pseudocode for data simulation
	Z = [<simulate m̃[k] realisations from the model> for k ∈ 1:K]

	return Z
end

## Method that allows an integer to be passed for m
simulate(parameters, m::Integer) = simulate(parameters, range(m, m))

Missing data

Neural networks do not naturally handle missing data, and this property can preclude their use in a broad range of applications. Here, we describe two techniques that alleviate this challenge in the context of parameter point estimation: The masking approach and The neural EM algorithm.

As a running example, we consider a Gaussian process model where the data are collected over a regular grid, but where some elements of the grid are unobserved. This situation often arises in, for example, remote-sensing applications, where the presence of cloud cover prevents measurement in some places. Below, we load the packages needed in this example, and define some aspects of the model that will remain constant throughout (e.g., the prior, the spatial domain, etc.). We also define structs and functions for sampling from the prior distribution and for simulating marginally from the data model.

using Distances
using Distributions
using Flux
using LinearAlgebra
using NeuralEstimators
using Statistics: mean

# Set the prior and define the number of parameters in the statistical model
Π = (
	τ = Uniform(0, 1.0),
	ρ = Uniform(0, 0.4)
)
p = length(Π)

# Define the (gridded) spatial domain and compute the distance matrix
points = range(0, 1, 16)
S = expandgrid(points, points)
D = pairwise(Euclidean(), S, dims = 1)

# Store model information for later use
ξ = (
	Π = Π,
	S = S,
	D = D
)

# Struct for storing parameters+Cholesky factors
struct Parameters <: ParameterConfigurations
	θ
	L
end

# Constructor for above struct
function Parameters(K::Integer, ξ)

	# Sample parameters from the prior
	Π = ξ.Π
	τ = rand(Π.τ, K)
	ρ = rand(Π.ρ, K)
	ν = 1 # fixed smoothness

	# Compute Cholesky factors  
	L = maternchols(ξ.D, ρ, ν)

	# Concatenate into matrix
	θ = permutedims(hcat(τ, ρ))

	Parameters(θ, L)
end

# Marginal simulation from the data model
function simulate(parameters::Parameters, m::Integer)

	K = size(parameters, 2)
	τ = parameters.θ[1, :]
	L = parameters.L
	n = isqrt(size(L, 1))

	Z = map(1:K) do k
		z = simulategaussian(L[:, :, k], m)
		z = z + τ[k] * randn(size(z)...)
		z = Float32.(z)
		z = reshape(z, n, n, 1, :)
		z
	end

	return Z
end

The masking approach

The first missing-data technique that we consider is the so-called masking approach of Wang et al. (2024). The strategy involves completing the data by replacing missing values with zeros, and using auxiliary variables to encode the missingness pattern, which are also passed into the network.

Let $\boldsymbol{Z}$ denote the complete-data vector. Then, the masking approach considers inference based on $\boldsymbol{W}$, a vector of indicator variables that encode the missingness pattern (with elements equal to one or zero if the corresponding element of $\boldsymbol{Z}$ is observed or missing, respectively), and

\[\boldsymbol{U} \equiv \boldsymbol{Z} \odot \boldsymbol{W},\]

where $\odot$ denotes elementwise multiplication and the product of a missing element and zero is defined to be zero. Irrespective of the missingness pattern, $\boldsymbol{U}$ and $\boldsymbol{W}$ have the same fixed dimensions and hence may be processed easily using a single neural network. A neural point estimator is then trained on realisations of $\{\boldsymbol{U}, \boldsymbol{W}\}$ which, by construction, do not contain any missing elements.

Since the missingness pattern $\boldsymbol{W}$ is now an input to the neural network, it must be incorporated during the training phase. When interest lies only in making inference from a single already-observed data set, $\boldsymbol{W}$ is fixed and known, and the Bayes risk remains unchanged. However, amortised inference, whereby one trains a single neural network that will be used to make inference with many data sets, requires a joint model for the data $\boldsymbol{Z}$ and the missingness pattern $\boldsymbol{W}$:

# Marginal simulation from the data model and a MCAR missingness model
function simulatemissing(parameters::Parameters, m::Integer)

	Z = simulate(parameters, m)   # simulate completely-observed data

	UW = map(Z) do z
		prop = rand()             # sample a missingness proportion
		z = removedata(z, prop)   # randomly remove a proportion of the data
		uw = encodedata(z)        # replace missing entries with zero and encode missingness pattern
		uw
	end

	return UW
end

Note that the helper functions removedata() and encodedata() facilitate the construction of augmented data sets $\{\boldsymbol{U}, \boldsymbol{W}\}$.

Next, we construct and train a masked neural Bayes estimator. Here, the first convolutional layer takes two input channels, since we store the augmented data $\boldsymbol{U}$ in the first channel and the missingness pattern $\boldsymbol{W}$ in the second. We construct a point estimator, but the masking approach is applicable with any other kind of estimator (see Estimators):

# Construct DeepSet object
ψ = Chain(
	Conv((10, 10), 2 => 16,  relu),
	Conv((5, 5),  16 => 32,  relu),
	Conv((3, 3),  32 => 64, relu),
	Flux.flatten
	)
ϕ = Chain(Dense(64, 256, relu), Dense(256, p, exp))
deepset = DeepSet(ψ, ϕ)

# Initialise point estimator
θ̂ = PointEstimator(deepset)

# Train the masked neural Bayes estimator
θ̂ = train(θ̂, Parameters, simulatemissing, m = 1, ξ = ξ, K = 1000, epochs = 10)

Once trained, we can apply our masked neural Bayes estimator to (incomplete) observed data. The data must be encoded in the same manner that was done during training. Below, we use simulated data as a surrogate for real data, with a missingness proportion of 0.25:

θ = Parameters(1, ξ)
Z = simulate(θ, 1)[1]
Z = removedata(Z, 0.25)
UW = encodedata(Z)
θ̂(UW)

The neural EM algorithm

Let $\boldsymbol{Z}_1$ and $\boldsymbol{Z}_2$ denote the observed and unobserved (i.e., missing) data, respectively, and let $\boldsymbol{Z} \equiv (\boldsymbol{Z}_1', \boldsymbol{Z}_2')'$ denote the complete data. A classical approach to facilitating inference when data are missing is the expectation-maximisation (EM) algorithm. The neural EM algorithm is an approximate version of the conventional (Bayesian) Monte Carlo EM algorithm which, at the $l$th iteration, updates the parameter vector through

\[\boldsymbol{\theta}^{(l)} = \argmax_{\boldsymbol{\theta}} \sum_{h = 1}^H \ell(\boldsymbol{\theta}; \boldsymbol{Z}_1, \boldsymbol{Z}_2^{(lh)}) + \log \pi_H(\boldsymbol{\theta}),\]

where realisations of the missing-data component, $\{\boldsymbol{Z}_2^{(lh)} : h = 1, \dots, H\}$, are sampled from the probability distribution of $\boldsymbol{Z}_2$ given $\boldsymbol{Z}_1$ and $\boldsymbol{\theta}^{(l-1)}$, and where $\pi_H(\boldsymbol{\theta}) \propto \{\pi(\boldsymbol{\theta})\}^H$ is a concentrated version of the original prior density. Given the conditionally simulated data, the neural EM algorithm performs the above EM update using a neural network that returns the MAP estimate (i.e., the posterior mode) conditionally simulated data. Such a neural network can be obtained by training a neural Bayes estimator under a continuous relaxation of the 0–1 loss function, such as

First, we construct a neural approximation of the MAP estimator. In this example, we will take $H=50$. When $H$ is taken to be reasonably large, one may lean on the Bernstein-von Mises theorem to train the neural Bayes estimator under linear or quadratic loss; otherwise, one should train the estimator under a continuous relaxation of the 0–1 loss (e.g., the tanhloss or kpowerloss in the limit $\kappa \to 0$):

# Construct DeepSet object
ψ = Chain(
	Conv((10, 10), 1 => 16,  relu),
	Conv((5, 5),  16 => 32,  relu),
	Conv((3, 3),  32 => 64, relu),
	Flux.flatten
	)
ϕ = Chain(
	Dense(64, 256, relu),
	Dense(256, p, exp)
	)
deepset = DeepSet(ψ, ϕ)

# Initialise point estimator
θ̂ = PointEstimator(deepset)

# Train neural Bayes estimator
H = 50
θ̂ = train(θ̂, Parameters, simulate, m = H, ξ = ξ, K = 1000, epochs = 10)

Next, we define a function for conditional simulation (see EM for details on the required format of this function):

function simulateconditional(Z::M, θ, ξ; nsims::Integer = 1) where {M <: AbstractMatrix{Union{Missing, T}}} where T

	# Save the original dimensions
	dims = size(Z)

	# Convert to vector
	Z = vec(Z)

	# Compute the indices of the observed and missing data
	I₁ = findall(z -> !ismissing(z), Z) # indices of observed data
	I₂ = findall(z -> ismissing(z), Z)  # indices of missing data
	n₁ = length(I₁)
	n₂ = length(I₂)

	# Extract the observed data and drop Missing from the eltype of the container
	Z₁ = Z[I₁]
	Z₁ = [Z₁...]

	# Distance matrices needed for covariance matrices
	D   = ξ.D # distance matrix for all locations in the grid
	D₂₂ = D[I₂, I₂]
	D₁₁ = D[I₁, I₁]
	D₁₂ = D[I₁, I₂]

	# Extract the parameters from θ
	τ = θ[1]
	ρ = θ[2]

	# Compute covariance matrices
	ν = 1 # fixed smoothness
	Σ₂₂ = matern.(UpperTriangular(D₂₂), ρ, ν); Σ₂₂[diagind(Σ₂₂)] .+= τ^2
	Σ₁₁ = matern.(UpperTriangular(D₁₁), ρ, ν); Σ₁₁[diagind(Σ₁₁)] .+= τ^2
	Σ₁₂ = matern.(D₁₂, ρ, ν)

	# Compute the Cholesky factor of Σ₁₁ and solve the lower triangular system
	L₁₁ = cholesky(Symmetric(Σ₁₁)).L
	x = L₁₁ \ Σ₁₂

	# Conditional covariance matrix, cov(Z₂ ∣ Z₁, θ),  and its Cholesky factor
	Σ = Σ₂₂ - x'x
	L = cholesky(Symmetric(Σ)).L

	# Conditonal mean, E(Z₂ ∣ Z₁, θ)
	y = L₁₁ \ Z₁
	μ = x'y

	# Simulate from the distribution Z₂ ∣ Z₁, θ ∼ N(μ, Σ)
	z = randn(n₂, nsims)
	Z₂ = μ .+ L * z

	# Combine the observed and missing data to form the complete data
	Z = map(1:nsims) do l
		z = Vector{T}(undef, n₁ + n₂)
		z[I₁] = Z₁
		z[I₂] = Z₂[:, l]
		z
	end
	Z = stackarrays(Z, merge = false)

	# Convert Z to an array with appropriate dimensions
	Z = reshape(Z, dims..., 1, nsims)

	return Z
end

Now we can use the neural EM algorithm to get parameter point estimates from data containing missing values. The algorithm is implemented with the struct EM. Again, here we use simulated data as a surrogate for real data:

θ = Parameters(1, ξ)
Z = simulate(θ, 1)[1][:, :]     # simulate a single gridded field
Z = removedata(Z, 0.25)         # remove 25% of the data
θ₀ = mean.([Π...])              # initial estimate, the prior mean

neuralem = EM(simulateconditional, θ̂)
neuralem(Z, θ₀, ξ = ξ, nsims = H, use_ξ_in_simulateconditional = true)

Censored data

Coming soon, based on the methodology presented in Richards et al. (2023+).