Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ml-dsa/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ hybrid-array = { version = "0.4", features = ["extra-sizes"] }
module-lattice = "0.1"
sha3 = { version = "0.11.0-rc.8", default-features = false }
signature = { version = "3.0.0-rc.10", default-features = false, features = ["digest"] }
ctutils = { version = "0.4", default-features = false }

# optional dependencies
const-oid = { version = "0.10", features = ["db"], optional = true }
Expand Down
42 changes: 19 additions & 23 deletions ml-dsa/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use ctutils::{CtEq, CtGt, CtLt};
use hybrid_array::{
ArraySize,
typenum::{Shleft, U1, U13, Unsigned},
};
use module_lattice::{Field, Truncate};

use crate::ct::ct_select;

module_lattice::define_field!(BaseField, u32, u64, u128, 8_380_417);

pub(crate) type Int = <BaseField as Field>::Int;
Expand All @@ -28,11 +31,9 @@ pub(crate) trait BarrettReduce: Unsigned {
let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
let remainder = x - quotient * m;

if remainder < m {
Truncate::truncate(remainder)
} else {
Truncate::truncate(remainder - m)
}
let r_small: u32 = Truncate::truncate(remainder);
let r_large: u32 = Truncate::truncate(remainder.wrapping_sub(m));
ct_select!(remainder.ct_lt(&m), r_small, r_large)
}
}

Expand Down Expand Up @@ -103,14 +104,15 @@ impl Decompose for Elem {
let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<TwoGamma2>();

if r_plus - r0 == Elem::new(BaseField::Q - 1) {
(Elem::new(0), r0 - Elem::new(1))
} else {
let diff = r_plus - r0;
// Use constant-time division instead of hardware division
let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
(r1, r0)
}
let diff = r_plus - r0;
let is_edge = diff.0.ct_eq(&(BaseField::Q - 1));

// Compute both branches unconditionally
let edge = (Elem::new(0), r0 - Elem::new(1));
let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
let normal = (r1, r0);

ct_select!(is_edge, edge, normal)
}
}

Expand All @@ -126,11 +128,8 @@ pub(crate) trait AlgebraExt: Sized {
impl AlgebraExt for Elem {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
let raw_mod = Elem::new(M::reduce(self.0));
if raw_mod.0 <= M::U32 >> 1 {
raw_mod
} else {
raw_mod - Elem::new(M::U32)
}
let in_lower_half = !raw_mod.0.ct_gt(&(M::U32 >> 1));
ct_select!(in_lower_half, raw_mod, raw_mod - Elem::new(M::U32))
}

// FIPS 204 defines the infinity norm differently for signed vs. unsigned integers:
Expand All @@ -142,11 +141,8 @@ impl AlgebraExt for Elem {
// the signed integers used in this crate, so we can safely use the unsigned version. However,
// since mod_plus_minus is also unsigned, we need to unwrap the "negative" values.
fn infinity_norm(&self) -> u32 {
if self.0 <= BaseField::Q >> 1 {
self.0
} else {
BaseField::Q - self.0
}
let in_lower_half = !self.0.ct_gt(&(BaseField::Q >> 1));
ct_select!(in_lower_half, self.0, BaseField::Q - self.0)
}

// Algorithm 35 Power2Round
Expand Down
89 changes: 89 additions & 0 deletions ml-dsa/src/ct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//! Constant-time selection utilities.
//!
//! Provides a [`ct_select!`] macro and supporting [`CtSelectExt`] trait for
//! branchless conditional selection, preventing timing side-channels from
//! branch prediction on secret-dependent values.
//!
//! Built on the [`ctutils`] crate, which uses the [`cmov`] crate for
//! architecture-specific predication intrinsics (`cmov` on x86-64,
//! `CSEL` on `AArch64`).
//!
//! See: <https://blog.trailofbits.com/2025/12/02/introducing-constant-time-support-for-llvm-to-protect-cryptographic-code/>

use ctutils::Choice;

use crate::algebra::Elem;

/// Constant-time conditional selection for types not covered by
/// [`ctutils::CtSelect`] (which cannot be impl'd for foreign types
/// due to the orphan rule).
///
/// Selects between two values based on a [`Choice`] without branching.
/// When `choice` is `TRUE`, returns `b`; when `FALSE`, returns `a`.
pub(crate) trait CtSelectExt: Sized {
fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self;
}

impl CtSelectExt for u32 {
#[allow(clippy::inline_always)]
#[inline(always)]
fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self {
ctutils::CtSelect::ct_select(a, b, choice)
}
}

impl CtSelectExt for u64 {
#[allow(clippy::inline_always)]
#[inline(always)]
fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self {
ctutils::CtSelect::ct_select(a, b, choice)
}
}

impl CtSelectExt for Elem {
#[allow(clippy::inline_always)]
#[inline(always)]
fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self {
Elem::new(ctutils::CtSelect::ct_select(&a.0, &b.0, choice))
}
}

impl CtSelectExt for (Elem, Elem) {
#[allow(clippy::inline_always)]
#[inline(always)]
fn ct_select(a: &Self, b: &Self, choice: Choice) -> Self {
(
CtSelectExt::ct_select(&a.0, &b.0, choice),
CtSelectExt::ct_select(&a.1, &b.1, choice),
)
}
}

/// Branchless conditional select — the Rust equivalent of LLVM's
/// `__builtin_ct_select`.
///
/// Evaluates both `$if_true` and `$if_false` unconditionally, then
/// selects the result based on `$choice` using predication instructions
/// (no CPU branch). Both expressions must have the same type, which
/// must implement [`CtSelectExt`].
///
/// `$choice` must be a [`ctutils::Choice`]. Use `ct_eq`, `ct_gt`,
/// `ct_lt` from the `ctutils` crate to produce [`Choice`] values from
/// constant-time comparisons.
///
/// # Examples
///
/// ```ignore
/// use ctutils::CtLt;
/// let result: u32 = ct_select!(a.ct_lt(&b), val_if_true, val_if_false);
/// ```
macro_rules! ct_select {
($choice:expr, $if_true:expr, $if_false:expr) => {{
let if_true = $if_true;
let if_false = $if_false;
let choice: ctutils::Choice = $choice;
$crate::ct::CtSelectExt::ct_select(&if_false, &if_true, choice)
}};
}

pub(crate) use ct_select;
28 changes: 15 additions & 13 deletions ml-dsa/src/hint.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::{
algebra::{AlgebraExt, BaseField, Decompose, Elem, Polynomial, Vector},
ct::ct_select,
param::{EncodedHint, SignatureParams},
};
use ctutils::{Choice, CtEq, CtGt};
use hybrid_array::{
Array,
typenum::{U256, Unsigned},
Expand All @@ -17,25 +19,25 @@ fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
}

/// Algorithm 40 `UseHint`: returns the high bits of `r` adjusted according to hint `h`.
///
/// All branches are replaced with constant-time selection to avoid
/// leaking information about `r0` through branch timing.
#[allow(clippy::integer_division_remainder_used, reason = "params are public")]
fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
let (r1, r0) = r.decompose::<TwoGamma2>();
let gamma2 = TwoGamma2::U32 / 2;

if h {
if r0.0 > 0 && r0.0 <= gamma2 {
Elem::new((r1.0 + 1) % m)
} else if (r0.0 == 0) || (r0.0 >= BaseField::Q - gamma2) {
Elem::new((r1.0 + m - 1) % m)
} else {
// We use the Elem encoding even for signed integers. Since r0 is computed
// mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2].
unreachable!();
}
} else {
r1
}
// Compute both possible hint-adjusted results unconditionally
let r1_inc = Elem::new((r1.0 + 1) % m);
let r1_dec = Elem::new((r1.0 + m - 1) % m);

// r0 is "positive" when r0 > 0 AND r0 <= gamma2
let r0_positive = !r0.0.ct_eq(&0) & !r0.0.ct_gt(&gamma2);
let hinted = ct_select!(r0_positive, r1_inc, r1_dec);

// Apply hint only when h is set
ct_select!(Choice::from_u8_lsb(u8::from(h)), hinted, r1)
}

#[derive(Clone, PartialEq, Debug)]
Expand Down
1 change: 1 addition & 0 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

mod algebra;
mod crypto;
mod ct;
mod encode;
mod hint;
mod ntt;
Expand Down