diff --git a/Cargo.lock b/Cargo.lock index 7fbccf54..363f2cb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -870,6 +870,7 @@ version = "0.1.0-rc.7" dependencies = [ "const-oid 0.10.2", "criterion", + "ctutils", "getrandom 0.4.1", "hex", "hex-literal", diff --git a/ml-dsa/Cargo.toml b/ml-dsa/Cargo.toml index aefad86d..7f804b1f 100644 --- a/ml-dsa/Cargo.toml +++ b/ml-dsa/Cargo.toml @@ -36,6 +36,7 @@ hybrid-array = { version = "0.4", features = ["extra-sizes"] } module-lattice = "0.1" sha3 = { version = "0.11.0-rc.8", default-features = false } signature = { version = "3.0.0-rc.10", default-features = false, features = ["digest"] } +ctutils = { version = "0.4", default-features = false } # optional dependencies const-oid = { version = "0.10", features = ["db"], optional = true } diff --git a/ml-dsa/src/algebra.rs b/ml-dsa/src/algebra.rs index d3a9b02a..fccf62ba 100644 --- a/ml-dsa/src/algebra.rs +++ b/ml-dsa/src/algebra.rs @@ -1,9 +1,12 @@ +use ctutils::{CtEq, CtGt, CtLt}; use hybrid_array::{ ArraySize, typenum::{Shleft, U1, U13, Unsigned}, }; use module_lattice::{Field, Truncate}; +use crate::ct::ct_select; + module_lattice::define_field!(BaseField, u32, u64, u128, 8_380_417); pub(crate) type Int = ::Int; @@ -28,11 +31,9 @@ pub(crate) trait BarrettReduce: Unsigned { let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT; let remainder = x - quotient * m; - if remainder < m { - Truncate::truncate(remainder) - } else { - Truncate::truncate(remainder - m) - } + let r_small: u32 = Truncate::truncate(remainder); + let r_large: u32 = Truncate::truncate(remainder.wrapping_sub(m)); + ct_select!(remainder.ct_lt(&m), r_small, r_large) } } @@ -103,14 +104,15 @@ impl Decompose for Elem { let r_plus = self.clone(); let r0 = r_plus.mod_plus_minus::(); - if r_plus - r0 == Elem::new(BaseField::Q - 1) { - (Elem::new(0), r0 - Elem::new(1)) - } else { - let diff = r_plus - r0; - // Use constant-time division instead of hardware division - let r1 = Elem::new(TwoGamma2::ct_div(diff.0)); - (r1, r0) - } + let diff = r_plus - r0; + let is_edge = diff.0.ct_eq(&(BaseField::Q - 1)); + + // Compute both branches unconditionally + let edge = (Elem::new(0), r0 - Elem::new(1)); + let r1 = Elem::new(TwoGamma2::ct_div(diff.0)); + let normal = (r1, r0); + + ct_select!(is_edge, edge, normal) } } @@ -126,11 +128,8 @@ pub(crate) trait AlgebraExt: Sized { impl AlgebraExt for Elem { fn mod_plus_minus(&self) -> Self { let raw_mod = Elem::new(M::reduce(self.0)); - if raw_mod.0 <= M::U32 >> 1 { - raw_mod - } else { - raw_mod - Elem::new(M::U32) - } + let in_lower_half = !raw_mod.0.ct_gt(&(M::U32 >> 1)); + ct_select!(in_lower_half, raw_mod, raw_mod - Elem::new(M::U32)) } // FIPS 204 defines the infinity norm differently for signed vs. unsigned integers: @@ -142,11 +141,8 @@ impl AlgebraExt for Elem { // the signed integers used in this crate, so we can safely use the unsigned version. However, // since mod_plus_minus is also unsigned, we need to unwrap the "negative" values. fn infinity_norm(&self) -> u32 { - if self.0 <= BaseField::Q >> 1 { - self.0 - } else { - BaseField::Q - self.0 - } + let in_lower_half = !self.0.ct_gt(&(BaseField::Q >> 1)); + ct_select!(in_lower_half, self.0, BaseField::Q - self.0) } // Algorithm 35 Power2Round diff --git a/ml-dsa/src/ct.rs b/ml-dsa/src/ct.rs new file mode 100644 index 00000000..552b2a5d --- /dev/null +++ b/ml-dsa/src/ct.rs @@ -0,0 +1,89 @@ +//! Constant-time selection utilities. +//! +//! Provides a [`ct_select!`] macro and supporting [`CtSelectExt`] trait for +//! branchless conditional selection, preventing timing side-channels from +//! branch prediction on secret-dependent values. +//! +//! Built on the [`ctutils`] crate, which uses the [`cmov`] crate for +//! architecture-specific predication intrinsics (`cmov` on x86-64, +//! `CSEL` on `AArch64`). +//! +//! See: + +use ctutils::Choice; + +use crate::algebra::Elem; + +/// Constant-time conditional selection for types not covered by +/// [`ctutils::CtSelect`] (which cannot be impl'd for foreign types +/// due to the orphan rule). +/// +/// Selects between two values based on a [`Choice`] without branching. +/// When `choice` is `TRUE`, returns `b`; when `FALSE`, returns `a`. +pub(crate) trait CtSelectExt: Sized { + fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self; +} + +impl CtSelectExt for u32 { + #[allow(clippy::inline_always)] + #[inline(always)] + fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self { + ctutils::CtSelect::ct_select(a, b, choice) + } +} + +impl CtSelectExt for u64 { + #[allow(clippy::inline_always)] + #[inline(always)] + fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self { + ctutils::CtSelect::ct_select(a, b, choice) + } +} + +impl CtSelectExt for Elem { + #[allow(clippy::inline_always)] + #[inline(always)] + fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self { + Elem::new(ctutils::CtSelect::ct_select(&a.0, &b.0, choice)) + } +} + +impl CtSelectExt for (Elem, Elem) { + #[allow(clippy::inline_always)] + #[inline(always)] + fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self { + ( + CtSelectExt::ct_select(&a.0, &b.0, choice), + CtSelectExt::ct_select(&a.1, &b.1, choice), + ) + } +} + +/// Branchless conditional select — the Rust equivalent of LLVM's +/// `__builtin_ct_select`. +/// +/// Evaluates both `$if_true` and `$if_false` unconditionally, then +/// selects the result based on `$choice` using predication instructions +/// (no CPU branch). Both expressions must have the same type, which +/// must implement [`CtSelectExt`]. +/// +/// `$choice` must be a [`ctutils::Choice`]. Use `ct_eq`, `ct_gt`, +/// `ct_lt` from the `ctutils` crate to produce [`Choice`] values from +/// constant-time comparisons. +/// +/// # Examples +/// +/// ```ignore +/// use ctutils::CtLt; +/// let result: u32 = ct_select!(a.ct_lt(&b), val_if_true, val_if_false); +/// ``` +macro_rules! ct_select { + ($choice:expr, $if_true:expr, $if_false:expr) => {{ + let if_true = $if_true; + let if_false = $if_false; + let choice: ctutils::Choice = $choice; + $crate::ct::CtSelectExt::ct_select(&if_false, &if_true, choice) + }}; +} + +pub(crate) use ct_select; diff --git a/ml-dsa/src/hint.rs b/ml-dsa/src/hint.rs index e1d497f6..c247b00b 100644 --- a/ml-dsa/src/hint.rs +++ b/ml-dsa/src/hint.rs @@ -1,7 +1,9 @@ use crate::{ algebra::{AlgebraExt, BaseField, Decompose, Elem, Polynomial, Vector}, + ct::ct_select, param::{EncodedHint, SignatureParams}, }; +use ctutils::{Choice, CtEq, CtGt}; use hybrid_array::{ Array, typenum::{U256, Unsigned}, @@ -17,25 +19,25 @@ fn make_hint(z: Elem, r: Elem) -> bool { } /// Algorithm 40 `UseHint`: returns the high bits of `r` adjusted according to hint `h`. +/// +/// All branches are replaced with constant-time selection to avoid +/// leaking information about `r0` through branch timing. #[allow(clippy::integer_division_remainder_used, reason = "params are public")] fn use_hint(h: bool, r: Elem) -> Elem { let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32; let (r1, r0) = r.decompose::(); let gamma2 = TwoGamma2::U32 / 2; - if h { - if r0.0 > 0 && r0.0 <= gamma2 { - Elem::new((r1.0 + 1) % m) - } else if (r0.0 == 0) || (r0.0 >= BaseField::Q - gamma2) { - Elem::new((r1.0 + m - 1) % m) - } else { - // We use the Elem encoding even for signed integers. Since r0 is computed - // mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2]. - unreachable!(); - } - } else { - r1 - } + // Compute both possible hint-adjusted results unconditionally + let r1_inc = Elem::new((r1.0 + 1) % m); + let r1_dec = Elem::new((r1.0 + m - 1) % m); + + // r0 is "positive" when r0 > 0 AND r0 <= gamma2 + let r0_positive = !r0.0.ct_eq(&0) & !r0.0.ct_gt(&gamma2); + let hinted = ct_select!(r0_positive, r1_inc, r1_dec); + + // Apply hint only when h is set + ct_select!(Choice::from_u8_lsb(u8::from(h)), hinted, r1) } #[derive(Clone, PartialEq, Debug)] diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 71c5adc8..4d9d6c52 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -38,6 +38,7 @@ mod algebra; mod crypto; +mod ct; mod encode; mod hint; mod ntt;