Skip to content

Gridded data: non-stationary spatial fields

Here, we develop a neural Bayes estimator for a non-stationary spatial Gaussian process.

We assume that the data   are observed at locations on a regular spatial grid. The data are modelled as correlated mean-zero Gaussian random variables with the Paciorek–Schervish covariance function (Paciorek and Schervish, 2006),

where   ,   is a spatially varying kernel matrix, and

is the squared Mahalanobis distance. The parameter of interest is the spatially-varying range field  , modelled as a smooth spatial process. Because both the input (data field) and output (range field) are on the same grid (i.e., each pixel has its own range parameter), we use an image-to-image (I2I) network following the approach of Sikorski et al. (2025).

Package dependencies

julia
using NeuralEstimators
using Flux
using Distances
using Distributions: Uniform
using Folds
using GaussianRandomFields
using LinearAlgebra
using MLUtils: unsqueeze
using Plots
using UnicodePlots
unicodeplots() # plotting directly in terminal

To improve computational efficiency, various GPU backends are supported. Once the relevant package is loaded and a compatible GPU is available, it will be used automatically:

julia
using CUDA, cuDNN
julia
using AMDGPU
julia
using Metal
julia
using oneAPI

Sampling parameters

The range field   is generated by squaring a smooth GP realization. We store both the range fields and the Cholesky factors of the corresponding Paciorek–Schervish covariance matrices (Paciorek and Schervish, 2006), which are needed to simulate data:

julia
struct Parameters <: AbstractParameterSet
    θ   # parameter array: grid_dim × grid_dim × d × K
    L   # K-vector of Cholesky factors, one per parameter field
end

function sampler(K::Integer; grid_dim = 32)

    # Spatial grid
    pts = range(0, 1, length = grid_dim)
    S   = expandgrid(pts, pts)
    n   = size(S, 1)

    # Spatially varying range fields drawn from a GP hyperprior
    θ = Folds.map(1:K) do k
        λ   = rand(Uniform(0.03, 0.75))
        cov = CovarianceFunction(2, Exponential(λ, σ = 0.25))
        grf = GaussianRandomField(cov, CirculantEmbedding(), pts, pts, minpadding = 100)
        sample(grf)
    end
    θ = stack(θ)
    θ = unsqueeze(θ, dims = 3)
    θ = θ.^2 .+ 0.03

    # Pairwise squared distances
    H = pairwise(Euclidean(), S, dims = 1).^2

    # Cholesky factors of the covariance matrix
    L = Folds.map(1:K) do k
        ρ = vec(θ[:, :, 1, k])
        Σ = Matrix{Float64}(undef, n, n)
        for i in 1:n 
            aₛ = ρ[i]^2
            for j in i:n
                aᵣ  = ρ[j]^2
                Q   = 2H[i, j] / (aₛ + aᵣ)
                logC = 0.5log(aₛ) + 0.5log(aᵣ) - log((aₛ + aᵣ) / 2) - sqrt(Q)^1.5
                Σ[i, j] = Σ[j, i] = exp(logC)
            end
        end
        cholesky(Symmetric(Σ)).L
    end

    Parameters(θ, L)
end

Simulating data

Data simulation is trivial once we have the Cholesky factors computed above. We store the simulated data set as a four-dimensional array, where the third dimension is the number of channels (singleton for a univariate process) and the fourth stores the independent data sets:

julia
function simulator(parameters::Parameters)
	Z = Folds.map(parameters.L) do L
		n = size(L, 1)
		z = L * randn(n, 1)
		grid_dim = isqrt(n)
		reshape(z, grid_dim, grid_dim, 1)
	end
	stack(Z)
end

Constructing the neural network

For image-to-image tasks, a U-Net (Ronneberger et al., 2015) is a natural choice an among the best performing architectures in the LatticeVision paper of Sikorski et al. (2025).

julia
using ConcreteStructs: @concrete

"""
    Unet(channels::Int = 1, output_channels::Int = channels)

Initializes a [UNet](https://arxiv.org/pdf/1505.04597.pdf) instance with the given number of `channels`, typically equal to the number of channels in the input images.

The input grid must have a dimension that is a power of two (e.g., 32x32, 64x64).

Code adapted from https://github.com/DhairyaLGandhi/UNet.jl

# Examples
u = Unet(1)
dev = cpu
u = u |> dev
rand32(256, 256, 1, 1) |> dev |> u
rand32(128, 128, 1, 1) |> dev |> u
rand32(64, 64, 1, 1) |> dev |> u
rand32(32, 32, 1, 1) |> dev |> u
"""
@concrete struct Unet
  conv_down_blocks
  conv_blocks
  up_blocks
end

function UNetConvBlock(in_chs, out_chs, kernel = (3, 3))
    Chain(
        Conv(kernel, in_chs=>out_chs, pad=(1,1)),
        BatchNorm(out_chs),
        x -> leakyrelu.(x, 0.2f0)
    )
end

function ConvDown(in_chs, out_chs, kernel = (4,4))
    Chain(
        Conv(kernel, in_chs=>out_chs, pad=(1,1), stride=(2,2)),
        BatchNorm(out_chs),
        x -> leakyrelu.(x, 0.2f0)
    )
end

@concrete struct UNetUpBlock
  upsample
end

function UNetUpBlock(in_chs::Int, out_chs::Int; kernel = (2, 2), p = 0.5f0)
    UNetUpBlock(
        Chain(
            x -> leakyrelu.(x, 0.2f0),
            ConvTranspose(kernel, in_chs=>out_chs, stride=kernel),
            BatchNorm(out_chs),
            Dropout(p)
        )
    )
end

function (u::UNetUpBlock)(x, bridge)
  	x = u.upsample(x)
  	cat(x, bridge, dims = 3)
end

function Unet(channels::Int = 1, output_channels::Int = channels)
  conv_down_blocks = Chain(ConvDown(64,64),
		      ConvDown(128,128),
		      ConvDown(256,256),
		      ConvDown(512,512))

  conv_blocks = Chain(UNetConvBlock(channels, 3),
		 UNetConvBlock(3, 64),
		 UNetConvBlock(64, 128),
		 UNetConvBlock(128, 256),
		 UNetConvBlock(256, 512),
		 UNetConvBlock(512, 1024),
		 UNetConvBlock(1024, 1024))

  up_blocks = Chain(UNetUpBlock(1024, 512),
		UNetUpBlock(1024, 256),
		UNetUpBlock(512, 128),
		UNetUpBlock(256, 64,p = 0.0f0),
		Chain(x->leakyrelu.(x,0.2f0),
		Conv((1, 1), 128 => output_channels)))

  Unet(conv_down_blocks, conv_blocks, up_blocks)
end

function (u::Unet)(x::AbstractArray)
  op = u.conv_blocks[1:2](x)

  x1 = u.conv_blocks[3](u.conv_down_blocks[1](op))
  x2 = u.conv_blocks[4](u.conv_down_blocks[2](x1))
  x3 = u.conv_blocks[5](u.conv_down_blocks[3](x2))
  x4 = u.conv_blocks[6](u.conv_down_blocks[4](x3))

  up_x4 = u.conv_blocks[7](x4)

  up_x1 = u.up_blocks[1](up_x4, x3)
  up_x2 = u.up_blocks[2](up_x1, x2)
  up_x3 = u.up_blocks[3](up_x2, x1)
  up_x5 = u.up_blocks[4](up_x3, op)
  tanh.(u.up_blocks[end](up_x5))
end

function Base.show(io::IO, u::Unet)
  println(io, "UNet:")

  for l in u.conv_down_blocks
    println(io, "  ConvDown($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))")
  end

  print(io, "\n")
  for l in u.conv_blocks
    println(io, "  UNetConvBlock($(size(l[1].weight)[end-1]), $(size(l[1].weight)[end]))")
  end

  print(io, "\n")
  for l in u.up_blocks
    l isa UNetUpBlock || continue
    println(io, "  UNetUpBlock($(size(l.upsample[2].weight)[end]), $(size(l.upsample[2].weight)[end-1]))")
  end
end

Constructing the neural estimator

julia
network = Unet(1)
estimator = PointEstimator(network)

Training the estimator

julia
K = 2500
θ_train = sampler(K)
θ_val   = sampler(K)
estimator = train(estimator, θ_train, θ_val, simulator)

Assessing the estimator

NB: when calling estimate with this architecture, batchsize must evenly divide the number of data sets (i.e., no partial batches).

julia
K_test  = 100
θ_test  = sampler(K_test)
Z_test  = simulator(θ_test)
estimates = estimate(estimator, Z_test; batchsize = K_test)

It is informative to visualise the estimated and true range surfaces side-by-side for a few test instances:

julia
# Choose a test instance
k = 1
θ_true = θ_test.θ[:, :, 1, k]
θ_hat  = estimates[:, :, 1, k]

# Plotting
grid_dim = size(Z_test, 1)
x = y = range(0, 1, length = grid_dim)
p1 = contour(x, y, θ_true);
p2 = contour(x, y, θ_hat);
plot(p1, p2)

Applying the estimator to observed data

julia
Z_obs = simulator(sampler(1))           # stand-in for observed data
θ_hat = estimate(estimator, Z_obs)      # range field estimate
θ_hat = dropdims(θ_hat, dims = (3, 4))  # convert to matrix
contour(x, y, θ_hat)