Skip to content

Fix Qwen2 KV-cache dtype mismatch in cached decoding#1322

Open
skwh54 wants to merge 2 commits intogoogle:mainfrom
skwh54:fix/qwen2-kv-cache-dtype-cast
Open

Fix Qwen2 KV-cache dtype mismatch in cached decoding#1322
skwh54 wants to merge 2 commits intogoogle:mainfrom
skwh54:fix/qwen2-kv-cache-dtype-cast

Conversation

@skwh54
Copy link
Copy Markdown

@skwh54 skwh54 commented Mar 27, 2026

Resolves #1321

This PR fixes a dtype mismatch in the Qwen2 cached decoding path.

In tunix/models/qwen2/model.py, value_proj and key_proj are written into
cache['v'] / cache['k'] via jax.lax.dynamic_update_slice(...) without
first matching the cache dtype:

value_proj = jax.lax.dynamic_update_slice(
    cache['v'],
    value_proj,
    slice_indices,
)
key_proj = jax.lax.dynamic_update_slice(
    cache['k'], key_proj, slice_indices
)

When the projection dtype differs from the cache dtype,
jax.lax.dynamic_update_slice(...) raises a dtype mismatch error.
This PR adds explicit casts to cache['v'].dtype / cache['k'].dtype
immediately before the cache update.

The change is intentionally minimal and only affects the Qwen2 cache update
path in tunix/models/qwen2/model.py.

Reference

Colab Notebook

  • N/A; this PR adds a regression unit test instead.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@google-cla
Copy link
Copy Markdown

google-cla bot commented Mar 27, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen2 KV-cache updates do not cast projections to cache dtype before dynamic_update_slice

2 participants