Support Float32 and type-stable warmup/NUTS#199
Support Float32 and type-stable warmup/NUTS#199SamuelBrand1 wants to merge 1 commit intotpapp:masterfrom
Conversation
|
Was this PR written by a LLM? (not a problem per se, but useful info when I review). |
Hi @tpapp . Sorry about delay replying! I was travelling and this fell off my TODO chart!
Which shows that maybe |
|
Thanks for taking the time to make this PR. As I said, I am very open to extending the package in this direction, but I want to think about some things first. It would be great if you could help me figure out what the goal is. I can think about the following:
|
Hi @tpapp ! The goal is option 1; that is preserving the user's array type through Re 2. This could be a precursor step to doing the whole thing on GPU which in But to reiterate its option 1 I'm thinking of. |
|
@SamuelBrand1: you can do (1) without making any modifications to DynamicHMC.jl at all, just have the method for I am not sure about the feasibility of (2) though. I am not against it, before making large changes to this package, I would prefer to see a proof of concept demo, so that we know where we are going. |
Unless I'm missing something obvious (which is very possible!), then DynamicHMC's internals promote back to Float64 during warmup — e.g. the So even if This is a hypothetical and based on my best understanding here, so maybe I could go and do a minimal example of this. EDIT: thinking abit more you don't avoid the scored through latency this way. |
|
To make things concrete, I am thinking of a wrapper like struct LDTypeWrapper{T,L}
ℓ::T
end
function LogDensityProblems.logdensity_and_gradient(ℓ::LDTypeWrapper{T}, x) where T
logdensity_and_gradient(ℓ.ℓ, T.(x))
endwhere the internal
Yes, but can you avoid that for (1) in any case? My understanding is that if DynamicHMC.jl itself is not running on the GPU, you have to transfer to the GPU for every evaluation. To avoid that, you need (2). Examples welcome, I really want to help with this and I am open to extensions, just want to understand the context. |
|
That wrapper is a nice workaround! The way I was thinking about this is that at DynamicHMC.jl/src/hamiltonian.jl Lines 273 to 282 in 021ffac then lines like This PR doesn't avoid GPU transfer at each compute but it seems (IMO) a cleaner approach than having a wrapper to avoid the under the hood float promotion of having f64 literals in the codebase. Thats what this PR is aiming at. This PR does leave open the option of doing (2) as well, but YMMV on the benefit of that. So, I guess it basically comes down to if going in to remove the f64 bits is worth it or not compared to the wrapper approach. I also thought that this approach avoided device transfer latency, but now I don't think so. EDIT: Demo is inbound after I finish work! |
…-export `DynamicPPL.set_logprob_type!` (#2794) Closes #2739. As a nice by-product of using `rand(ldf)` rather than `vi[:]`, we also avoid accidentally promoting Float32 to Float64. This means that (together with TuringLang/DynamicPPL.jl#1328 and tpapp/DynamicHMC.jl#199) one can do ```julia julia> using DynamicPPL; DynamicPPL.set_logprob_type!(Float32) ┌ Info: DynamicPPL's log probability type has been set to Float32. └ Please note you will need to restart your Julia session for this change to take effect. ``` and then after restarting ```julia julia> using Turing, FlexiChains, DynamicHMC julia> @model function f() x ~ Normal(0.0f0, 1.0f0) end f (generic function with 2 methods) julia> chn = sample(f(), externalsampler(DynamicHMC.NUTS()), 100; chain_type=VNChain) Sampling 100%|████████████████████████████████████████████| Time: 0:00:02 FlexiChain (100 iterations, 1 chain) ↓ iter=1:100 | → chain=1:1 Parameter type VarName Parameters x Extra keys :logprior, :loglikelihood, :logjoint julia> eltype(chn[@varname(x)]) Float32 julia> eltype(chn[:logjoint]) Float32 ``` (Previously, the values of `x` would be Float32, but logjoint would be Float64. And if you used MCMCChains, everything would be Float64.)
This PR adds support for non-Float64 floating point types (primarily Float32) throughout the sampling pipeline. When a user provides a Float32 log density problem and Float32 initial position, the entire warmup and inference should run in Float32 without any API changes. The default Float64 approach should just "happen" without change. The idea is that Float32 models (e.g. GPU-based or memory-constrained applications) might want a sampler that preserves their numeric type. Previously, Float64 literals and default parameters in the warmup pipeline caused silent type promotion, making all arithmetic run in Float64 regardless of input type.
This PR closes #198
The key design decision:
default_warmup_stages()continues to create Float64-typed stages (no API change). At the two warmup entry points —warmup(_, ::InitialStepsizeSearch, _)andwarmup(_, ::TuningNUTS, _)— the position's element type is detected viatypeof(one(eltype(Q.q))), and the adaptation parameters are converted to match using_oftype. This means a user simply needs to provide a Float32 initial position and a Float32 log density:Changes
No API changes
All public functions (
mcmc_with_warmup,default_warmup_stages,fixed_stepsize_warmup_stages, etc.) retain their existing signatures and Float64 defaults. Float32 conversion happens automatically inside the warmup pipeline when it encounters a non-Float64 position.src/hamiltonian.jlGaussianKineticEnergy(N, m⁻¹): Wrappedm⁻¹withfloat()so integer inputs work correctly.rand_p: Usesrandn(rng, eltype(κ.W), ...)instead of untypedrandnto match the kinetic energy's element type.src/NUTS.jlNUTSstruct: Parameterized asNUTS{T<:Real, S}somin_Δcarries its type.TreeStatisticsNUTSstruct: Parameterized asTreeStatisticsNUTS{T<:Real}soπandacceptance_ratecarry their type.sample_tree: Usesoftype(π₀, ϵ)andoftype(π₀, min_Δ)to convert the stepsize and divergence threshold to match the log density type.rand_bool_logprob: Usestypeof(logprob)inrandexpto match the log-probability type.src/stepsize.jlInitialStepsizeSearchstruct: Parameterized asInitialStepsizeSearch{T<:Real}withpromote_typein constructor.initial_adaptation_state: Useslog(T(10))instead oflog(10)whereTcomes fromDualAveraging{T}, preventing Float64 promotion ofμ.adapt_stepsize: Usesoftype(μ, m)beforesqrtand exponentiation to preventInt → Float64promotion via√mandm^(-κ)._oftype): Small functions that convertDualAveraging,InitialStepsizeSearch, andFixedStepsizeparameters to a target float type. No-ops when types already match.src/mcmc.jlinitialize_warmup_state: Usesone(eltype(q))for the default kinetic energy scale, soGaussianKineticEnergymatches the position type.warmupforInitialStepsizeSearch: Converts stepsize search parameters to match the position's element type via_oftypebefore callingfind_initial_stepsize.warmupforTuningNUTS: Convertsstepsize_adaptation(e.g.DualAveraging) to match the position's element type via_oftypebefore callinginitial_adaptation_state.TuningNUTSstruct: Parameterized asTuningNUTS{M, D, T<:Real}.ϵsvector: UsesVector{typeof(float(ϵ))}to match the stepsize type.test/test_mcmc.jlAdded a
"Float32 support"test section with:posterior_matrix,logdensities, tree statistics, and finalϵare all Float32.eltype(q) === Float32on every evaluation — catches any accidental promotion in leapfrog, stepsize adaptation, or metric updates during the full warmup + inference pipeline.FixedStepsizewarmup path.mcmc_keep_warmup/mcmc_steps/mcmc_next_step.