Skip to content

Support Float32 and type-stable warmup/NUTS#199

Open
SamuelBrand1 wants to merge 1 commit intotpapp:masterfrom
SamuelBrand1:master
Open

Support Float32 and type-stable warmup/NUTS#199
SamuelBrand1 wants to merge 1 commit intotpapp:masterfrom
SamuelBrand1:master

Conversation

@SamuelBrand1
Copy link
Copy Markdown

This PR adds support for non-Float64 floating point types (primarily Float32) throughout the sampling pipeline. When a user provides a Float32 log density problem and Float32 initial position, the entire warmup and inference should run in Float32 without any API changes. The default Float64 approach should just "happen" without change. The idea is that Float32 models (e.g. GPU-based or memory-constrained applications) might want a sampler that preserves their numeric type. Previously, Float64 literals and default parameters in the warmup pipeline caused silent type promotion, making all arithmetic run in Float64 regardless of input type.

This PR closes #198

The key design decision: default_warmup_stages() continues to create Float64-typed stages (no API change). At the two warmup entry points — warmup(_, ::InitialStepsizeSearch, _) and warmup(_, ::TuningNUTS, _) — the position's element type is detected via typeof(one(eltype(Q.q))), and the adaptation parameters are converted to match using _oftype. This means a user simply needs to provide a Float32 initial position and a Float32 log density:

results = mcmc_with_warmup(rng, ℓ_float32, N; initialization = (q = randn(Float32, d),))
# → results.posterior_matrix is Float32, results.ϵ is Float32

Changes

No API changes

All public functions (mcmc_with_warmup, default_warmup_stages, fixed_stepsize_warmup_stages, etc.) retain their existing signatures and Float64 defaults. Float32 conversion happens automatically inside the warmup pipeline when it encounters a non-Float64 position.

src/hamiltonian.jl

  • GaussianKineticEnergy(N, m⁻¹): Wrapped m⁻¹ with float() so integer inputs work correctly.
  • rand_p: Uses randn(rng, eltype(κ.W), ...) instead of untyped randn to match the kinetic energy's element type.

src/NUTS.jl

  • NUTS struct: Parameterized as NUTS{T<:Real, S} so min_Δ carries its type.
  • TreeStatisticsNUTS struct: Parameterized as TreeStatisticsNUTS{T<:Real} so π and acceptance_rate carry their type.
  • sample_tree: Uses oftype(π₀, ϵ) and oftype(π₀, min_Δ) to convert the stepsize and divergence threshold to match the log density type.
  • rand_bool_logprob: Uses typeof(logprob) in randexp to match the log-probability type.

src/stepsize.jl

  • InitialStepsizeSearch struct: Parameterized as InitialStepsizeSearch{T<:Real} with promote_type in constructor.
  • initial_adaptation_state: Uses log(T(10)) instead of log(10) where T comes from DualAveraging{T}, preventing Float64 promotion of μ.
  • adapt_stepsize: Uses oftype(μ, m) before sqrt and exponentiation to prevent Int → Float64 promotion via √m and m^(-κ).
  • Type conversion helpers (_oftype): Small functions that convert DualAveraging, InitialStepsizeSearch, and FixedStepsize parameters to a target float type. No-ops when types already match.

src/mcmc.jl

  • initialize_warmup_state: Uses one(eltype(q)) for the default kinetic energy scale, so GaussianKineticEnergy matches the position type.
  • warmup for InitialStepsizeSearch: Converts stepsize search parameters to match the position's element type via _oftype before calling find_initial_stepsize.
  • warmup for TuningNUTS: Converts stepsize_adaptation (e.g. DualAveraging) to match the position's element type via _oftype before calling initial_adaptation_state.
  • TuningNUTS struct: Parameterized as TuningNUTS{M, D, T<:Real}.
  • ϵs vector: Uses Vector{typeof(float(ϵ))} to match the stepsize type.

test/test_mcmc.jl

Added a "Float32 support" test section with:

  • Type propagation: Verifies posterior_matrix, logdensities, tree statistics, and final ϵ are all Float32.
  • No type promotion in compute: Uses a strict log density that asserts eltype(q) === Float32 on every evaluation — catches any accidental promotion in leapfrog, stepsize adaptation, or metric updates during the full warmup + inference pipeline.
  • Sample correctness: 5D Float32 normal, 10000 samples, checks mean and std convergence.
  • Fixed stepsize: Float32 with FixedStepsize warmup path.
  • Stepwise API: Float32 through mcmc_keep_warmup / mcmc_steps / mcmc_next_step.

@tpapp
Copy link
Copy Markdown
Owner

tpapp commented Mar 6, 2026

Was this PR written by a LLM? (not a problem per se, but useful info when I review).

@SamuelBrand1
Copy link
Copy Markdown
Author

Was this PR written by a LLM? (not a problem per se, but useful info when I review).

Hi @tpapp . Sorry about delay replying! I was travelling and this fell off my TODO chart!

  1. The top para was written by me and the "Changes" section was generated by a LLM on the diff in code base.
  2. I note that the CI fail for julia 1 was

INFO while testing: mixture of two normals, dimension 3
✘ R̂ = 1.02 ≰ 1.02
NUTS tests with mixtures: Test Failed at /home/runner/work/DynamicHMC.jl/DynamicHMC.jl/test/sample-correctness_utilities.jl:113
Expression: all(maximum(R̂) ≤ R̂_fail)

Which shows that maybe Float32 loss of precision has an affect here. I think that drawing a few more samples would cause the test to pass here, but I'm not keen to alter your testing framework unilaterally (as opposed to applying your test framework to Float32 based computations). If I have done that anywhere its accidental.

@tpapp
Copy link
Copy Markdown
Owner

tpapp commented Mar 23, 2026

Thanks for taking the time to make this PR. As I said, I am very open to extending the package in this direction, but I want to think about some things first.

It would be great if you could help me figure out what the goal is. I can think about the following:

  1. Making sure that logdensity calculations run using a given type, eg Float32. That is, whenever logdensity_and_gradient is called, it is called with the appropriate array type, and this array type is preserved across calculations. This should be relatively simple to achieve and requires significantly fewer changes.

  2. Above, plus make all calculations (within this package) run using a given type. Does this make sense on a GPU though? My understanding is that branching etc that happen in these calculations would not be ideal for a GPU.

  3. Above, plus make all reporting (tree statistics etc) use this type. It is unclear to me why this would be necessary, please explain.

@SamuelBrand1
Copy link
Copy Markdown
Author

SamuelBrand1 commented Mar 25, 2026

Thanks for taking the time to make this PR. As I said, I am very open to extending the package in this direction, but I want to think about some things first.

It would be great if you could help me figure out what the goal is. I can think about the following:

  1. Making sure that logdensity calculations run using a given type, eg Float32. That is, whenever logdensity_and_gradient is called, it is called with the appropriate array type, and this array type is preserved across calculations. This should be relatively simple to achieve and requires significantly fewer changes.
  2. Above, plus make all calculations (within this package) run using a given type. Does this make sense on a GPU though? My understanding is that branching etc that happen in these calculations would not be ideal for a GPU.
  3. Above, plus make all reporting (tree statistics etc) use this type. It is unclear to me why this would be necessary, please explain.

Hi @tpapp ! The goal is option 1; that is preserving the user's array type through logdensity_and_gradient calls so that either GPU-backed or memory-constrained models work without silent promotion. The additional type changes in the tree logic and statistics are there for float type stability, not because the tree logic itself needs to run on GPU... open to not bothering there.

Re 2. This could be a precursor step to doing the whole thing on GPU which in JAX/numpyro seems to sometimes work better when lifting the whole inference to GPU and sometimes be slower depending on context. I'm not sure how much of that discussion is relevant to julia because the issues there are different (e.g. desiring to avoid python interpreter overhead as much as possible) but I wouldn't mind giving that a try.

But to reiterate its option 1 I'm thinking of.

@tpapp
Copy link
Copy Markdown
Owner

tpapp commented Mar 25, 2026

@SamuelBrand1: you can do (1) without making any modifications to DynamicHMC.jl at all, just have the method for logdensity_and_gradient convert to Float32 or whatever you prefer. There may be an issue of promotion within DynamicHMC.jl if it gets back Float32 values, if that is a problem we can deal with it.

I am not sure about the feasibility of (2) though. I am not against it, before making large changes to this package, I would prefer to see a proof of concept demo, so that we know where we are going.

@SamuelBrand1
Copy link
Copy Markdown
Author

SamuelBrand1 commented Mar 25, 2026

@SamuelBrand1: you can do (1) without making any modifications to DynamicHMC.jl at all, just have the method for logdensity_and_gradient convert to Float32 or whatever you prefer. There may be an issue of promotion within DynamicHMC.jl if it gets back Float32 values, if that is a problem we can deal with it.

I am not sure about the feasibility of (2) though. I am not against it, before making large changes to this package, I would prefer to see a proof of concept demo, so that we know where we are going.

Unless I'm missing something obvious (which is very possible!), then DynamicHMC's internals promote back to Float64 during warmup — e.g. the log(10) literal in initial_adaptation_state, DualAveraging defaults, the randn call in rand_p that don't carry the element type etc and after leapfrog steps (I think epsilon always becomes f64?).

So even if logdensity_and_gradient is a fully parametric function in float type, the leapfrog step ends up in f64 by the next iteration. I think a workaround to that problem would have to involve something like relifting the primal + grad calculation to GPU device at every branch point in the NUTS call which is a lot of latency.

This is a hypothetical and based on my best understanding here, so maybe I could go and do a minimal example of this.

EDIT: thinking abit more you don't avoid the scored through latency this way.

@tpapp
Copy link
Copy Markdown
Owner

tpapp commented Mar 25, 2026

To make things concrete, I am thinking of a wrapper like

struct LDTypeWrapper{T,L}
    ℓ::T
end

function LogDensityProblems.logdensity_and_gradient(ℓ::LDTypeWrapper{T}, x) where T
    logdensity_and_gradient(ℓ.ℓ, T.(x))
end

where the internal is your GPU-implemented log density.

I think a workaround to that problem would have to involve something like relifting the primal + grad calculation to GPU device at every branch point in the NUTS call which is a lot of latency.

Yes, but can you avoid that for (1) in any case? My understanding is that if DynamicHMC.jl itself is not running on the GPU, you have to transfer to the GPU for every evaluation. To avoid that, you need (2).

Examples welcome, I really want to help with this and I am open to extensions, just want to understand the context.

@SamuelBrand1
Copy link
Copy Markdown
Author

SamuelBrand1 commented Mar 25, 2026

That wrapper is a nice workaround!

The way I was thinking about this is that at

function leapfrog(H::Hamiltonian{<: EuclideanKineticEnergy}, z::PhasePoint, ϵ)
(; ℓ, κ) = H
(; p, Q) = z
@argcheck isfinite(Q.ℓq) "Internal error: leapfrog called from non-finite log density"
pₘ = p + ϵ/2 * Q.∇ℓq
q′ = Q.q + ϵ * ∇kinetic_energy(κ, pₘ)
Q′ = evaluate_ℓ(H.ℓ, q′)
p′ = pₘ + ϵ/2 * Q′.∇ℓq
PhasePoint(Q′, p′)
end

then lines like pₘ = p + ϵ/2 * Q.∇ℓq trigger promotion to f64 eltype. The wrapper workaround then catches those at the barrier to logdensity_and_gradient which would work ok, but what I had in my head was that so long as operations were eltype stable they would dispatch to array operation ala https://cuda.juliagpu.org/stable/usage/array/ without triggering promotion e.g. to CuArray{Float64} or error.

This PR doesn't avoid GPU transfer at each compute but it seems (IMO) a cleaner approach than having a wrapper to avoid the under the hood float promotion of having f64 literals in the codebase. Thats what this PR is aiming at. This PR does leave open the option of doing (2) as well, but YMMV on the benefit of that.

So, I guess it basically comes down to if going in to remove the f64 bits is worth it or not compared to the wrapper approach. I also thought that this approach avoided device transfer latency, but now I don't think so.

EDIT: Demo is inbound after I finish work!

penelopeysm added a commit to TuringLang/Turing.jl that referenced this pull request Mar 26, 2026
…-export `DynamicPPL.set_logprob_type!` (#2794)

Closes #2739.

As a nice by-product of using `rand(ldf)` rather than `vi[:]`, we also
avoid accidentally promoting Float32 to Float64. This means that
(together with TuringLang/DynamicPPL.jl#1328 and
tpapp/DynamicHMC.jl#199) one can do

```julia
julia> using DynamicPPL; DynamicPPL.set_logprob_type!(Float32)
┌ Info: DynamicPPL's log probability type has been set to Float32.
└ Please note you will need to restart your Julia session for this change to take effect.
```

and then after restarting

```julia
julia> using Turing, FlexiChains, DynamicHMC

julia> @model function f()
           x ~ Normal(0.0f0, 1.0f0)
       end
f (generic function with 2 methods)

julia> chn = sample(f(), externalsampler(DynamicHMC.NUTS()), 100; chain_type=VNChain)
Sampling 100%|████████████████████████████████████████████| Time: 0:00:02
FlexiChain (100 iterations, 1 chain)
↓ iter=1:100 | → chain=1:1

Parameter type   VarName
Parameters       x
Extra keys       :logprior, :loglikelihood, :logjoint


julia> eltype(chn[@varname(x)])
Float32

julia> eltype(chn[:logjoint])
Float32
```

(Previously, the values of `x` would be Float32, but logjoint would be
Float64. And if you used MCMCChains, everything would be Float64.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Avoiding automatic type promotion from F32 to F64 in NUTS

2 participants