Gridded data: non-stationary spatial fields
Here, we develop a neural Bayes estimator for a non-stationary spatial Gaussian process.
We assume that the data
where
is the squared Mahalanobis distance. The parameter of interest is the spatially-varying range field
Package dependencies
using NeuralEstimators
using Flux
using Distances
using Distributions: Uniform
using Folds
using GaussianRandomFields
using LinearAlgebra
using MLUtils: unsqueeze
using Plots
using UnicodePlots
unicodeplots() # plotting directly in terminalTo improve computational efficiency, various GPU backends are supported. Once the relevant package is loaded and a compatible GPU is available, it will be used automatically:
using CUDA, cuDNNusing AMDGPUusing Metalusing oneAPISampling parameters
The range field
struct Parameters <: AbstractParameterSet
θ # parameter array: grid_dim × grid_dim × d × K
L # K-vector of Cholesky factors, one per parameter field
end
function sampler(K::Integer; grid_dim = 32)
# Spatial grid
pts = range(0, 1, length = grid_dim)
S = expandgrid(pts, pts)
n = size(S, 1)
# Spatially varying range fields drawn from a GP hyperprior
θ = Folds.map(1:K) do k
λ = rand(Uniform(0.03, 0.75))
cov = CovarianceFunction(2, Exponential(λ, σ = 0.25))
grf = GaussianRandomField(cov, CirculantEmbedding(), pts, pts, minpadding = 100)
sample(grf)
end
θ = stack(θ)
θ = unsqueeze(θ, dims = 3)
θ = θ.^2 .+ 0.03
# Pairwise squared distances
H = pairwise(Euclidean(), S, dims = 1).^2
# Cholesky factors of the covariance matrix
L = Folds.map(1:K) do k
ρ = vec(θ[:, :, 1, k])
Σ = Matrix{Float64}(undef, n, n)
for i in 1:n
aₛ = ρ[i]^2
for j in i:n
aᵣ = ρ[j]^2
Q = 2H[i, j] / (aₛ + aᵣ)
logC = 0.5log(aₛ) + 0.5log(aᵣ) - log((aₛ + aᵣ) / 2) - sqrt(Q)^1.5
Σ[i, j] = Σ[j, i] = exp(logC)
end
end
cholesky(Symmetric(Σ)).L
end
Parameters(θ, L)
endSimulating data
Data simulation is trivial once we have the Cholesky factors computed above. We store the simulated data set as a four-dimensional array, where the third dimension is the number of channels (singleton for a univariate process) and the fourth stores the independent data sets:
function simulator(parameters::Parameters)
Z = Folds.map(parameters.L) do L
n = size(L, 1)
z = L * randn(n, 1)
grid_dim = isqrt(n)
reshape(z, grid_dim, grid_dim, 1)
end
stack(Z)
endConstructing the neural network
For image-to-image tasks, a U-Net (Ronneberger et al., 2015) is a natural choice an among the best performing architectures in the LatticeVision paper of Sikorski et al. (2025).
using ConcreteStructs: @concrete
"""
Unet(channels::Int = 1, output_channels::Int = channels)
Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images.
The input grid must have a dimension that is a power of two (e.g., 32x32, 64x64).
Code adapted from https://github.com/DhairyaLGandhi/UNet.jl
# Examples
u = Unet(1)
dev = cpu
u = u |> dev
rand32(256, 256, 1, 1) |> dev |> u
rand32(128, 128, 1, 1) |> dev |> u
rand32(64, 64, 1, 1) |> dev |> u
rand32(32, 32, 1, 1) |> dev |> u
"""
@concrete struct Unet
conv_down_blocks
conv_blocks
up_blocks
end
function UNetConvBlock(in_chs, out_chs, kernel = (3, 3))
Chain(
Conv(kernel, in_chs=>out_chs, pad=(1,1)),
BatchNorm(out_chs),
x -> leakyrelu.(x, 0.2f0)
)
end
function ConvDown(in_chs, out_chs, kernel = (4,4))
Chain(
Conv(kernel, in_chs=>out_chs, pad=(1,1), stride=(2,2)),
BatchNorm(out_chs),
x -> leakyrelu.(x, 0.2f0)
)
end
@concrete struct UNetUpBlock
upsample
end
function UNetUpBlock(in_chs::Int, out_chs::Int; kernel = (2, 2), p = 0.5f0)
UNetUpBlock(
Chain(
x -> leakyrelu.(x, 0.2f0),
ConvTranspose(kernel, in_chs=>out_chs, stride=kernel),
BatchNorm(out_chs),
Dropout(p)
)
)
end
function (u::UNetUpBlock)(x, bridge)
x = u.upsample(x)
cat(x, bridge, dims = 3)
end
function Unet(channels::Int = 1, output_channels::Int = channels)
conv_down_blocks = Chain(ConvDown(64,64),
ConvDown(128,128),
ConvDown(256,256),
ConvDown(512,512))
conv_blocks = Chain(UNetConvBlock(channels, 3),
UNetConvBlock(3, 64),
UNetConvBlock(64, 128),
UNetConvBlock(128, 256),
UNetConvBlock(256, 512),
UNetConvBlock(512, 1024),
UNetConvBlock(1024, 1024))
up_blocks = Chain(UNetUpBlock(1024, 512),
UNetUpBlock(1024, 256),
UNetUpBlock(512, 128),
UNetUpBlock(256, 64,p = 0.0f0),
Chain(x->leakyrelu.(x,0.2f0),
Conv((1, 1), 128 => output_channels)))
Unet(conv_down_blocks, conv_blocks, up_blocks)
end
function (u::Unet)(x::AbstractArray)
op = u.conv_blocks[1:2](x)
x1 = u.conv_blocks[3](u.conv_down_blocks[1](op))
x2 = u.conv_blocks[4](u.conv_down_blocks[2](x1))
x3 = u.conv_blocks[5](u.conv_down_blocks[3](x2))
x4 = u.conv_blocks[6](u.conv_down_blocks[4](x3))
up_x4 = u.conv_blocks[7](x4)
up_x1 = u.up_blocks[1](up_x4, x3)
up_x2 = u.up_blocks[2](up_x1, x2)
up_x3 = u.up_blocks[3](up_x2, x1)
up_x5 = u.up_blocks[4](up_x3, op)
tanh.(u.up_blocks[end](up_x5))
end
function Base.show(io::IO, u::Unet)
println(io, "UNet:")
for l in u.conv_down_blocks
println(io, " ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))")
end
print(io, "\n")
for l in u.conv_blocks
println(io, " UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))")
end
print(io, "\n")
for l in u.up_blocks
l isa UNetUpBlock || continue
println(io, " UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))")
end
endConstructing the neural estimator
network = Unet(1)
estimator = PointEstimator(network)Training the estimator
K = 2500
θ_train = sampler(K)
θ_val = sampler(K)
estimator = train(estimator, θ_train, θ_val, simulator)Assessing the estimator
NB: when calling estimate with this architecture, batchsize must evenly divide the number of data sets (i.e., no partial batches).
K_test = 100
θ_test = sampler(K_test)
Z_test = simulator(θ_test)
estimates = estimate(estimator, Z_test; batchsize = K_test)It is informative to visualise the estimated and true range surfaces side-by-side for a few test instances:
# Choose a test instance
k = 1
θ_true = θ_test.θ[:, :, 1, k]
θ_hat = estimates[:, :, 1, k]
# Plotting
grid_dim = size(Z_test, 1)
x = y = range(0, 1, length = grid_dim)
p1 = contour(x, y, θ_true);
p2 = contour(x, y, θ_hat);
plot(p1, p2)Applying the estimator to observed data
Z_obs = simulator(sampler(1)) # stand-in for observed data
θ_hat = estimate(estimator, Z_obs) # range field estimate
θ_hat = dropdims(θ_hat, dims = (3, 4)) # convert to matrix
contour(x, y, θ_hat)