Skip to content

Add Gemma 4 31B IT contrib model#106

Open
jimburtoft wants to merge 6 commits intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-4-31b-it
Open

Add Gemma 4 31B IT contrib model#106
jimburtoft wants to merge 6 commits intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-4-31b-it

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds NxDI contrib implementation of google/gemma-4-31b-it text decoder (31B parameters)
  • Validated on trn2.3xlarge (TP=4, bfloat16): 32.1 tok/s, 65ms TTFT, Pearson r=0.980 vs HF CPU reference
  • Vision encoder port is planned as a follow-up PR

Architecture Highlights

Gemma 4 31B has several features not present in the existing Gemma 3 contrib:

Feature Details
Heterogeneous layers SWA layers (head_dim=256, 16 KV heads) and Global layers (head_dim=512, 4 KV heads)
attention_k_eq_v Global layers share K/V projections — V = K before normalization
QK + V normalization RMSNorm on Q/K after projection; separate RMSNorm (no learnable scale) on V
layer_scalar Per-layer learned multiplicative factor applied at end of decoder forward
final_logit_softcapping 30 * tanh(logits / 30) after lm_head
Partial RoPE Global layers rotate only 25% of dims (128 of 512)
Key implementation detail: NxDI applies 1/sqrt(head_dim) scaling in attention, but Gemma 4 uses scaling=1.0. Naively scaling Q weights doesn't work because q_norm (RMSNorm) is scale-invariant. The fix is to scale the q_layernorm.weight parameters by sqrt(head_dim) instead, which survives normalization.

Test Results

All 7 integration tests pass on trn2.3xlarge (TP=4, batch_size=1, bfloat16):

Test Result
Smoke (model loads) PASS
Text generation PASS — 10 tokens generated
Token matching PASS — "France" in output for capital-of-France prompt
Chat generation PASS — "2 + 2 = 4" with chat template
Coherence PASS — coherent haiku
TTFT PASS — 65.1ms (threshold: 200ms)
Throughput PASS — 32.1 tok/s, TPOT=31.1ms (threshold: 10 tok/s)

Files

contrib/models/gemma-4-31b-it/
├── README.md                              # Usage, compatibility matrix, test results
├── src/
│   ├── __init__.py                        # Exports NeuronGemma4ForCausalLM, Gemma4InferenceConfig
│   └── modeling_gemma4.py                 # 1070 lines — full text decoder implementation
└── test/
    ├── __init__.py
    ├── integration/
    │   ├── __init__.py
    │   └── test_model.py                  # 7 tests (pytest + standalone runner)
    └── unit/
        └── __init__.py

Prerequisites

  • Neuron SDK 2.28+ (trn2.3xlarge with LNC=2)
  • transformers >= 5.5.0 (required for Gemma4ForConditionalGeneration)
  • transformers/utils/fx.py shim needed (NxD imports HFTracer which was removed in transformers 5.x — see README for details)
  • attn_kernel_enabled=False required (head_dim > 128 exceeds NKI flash attention limit)

Known Limitations

  • Text-only: This PR covers the text decoder only. The Gemma 4 vision encoder (gemma4_vision) is not yet ported. A follow-up PR will add VLM (vision + language) support.
  • No flash attention: Both SWA (head_dim=256) and global (head_dim=512) layers exceed the NKI kernel's 128 limit, requiring decomposed attention.
  • CONVERT_TO_MHA warnings: Benign warnings from NxDI about TP=4 vs 16 KV heads; does not affect correctness.

Checklist

  • At least one accuracy test (token matching + logit correlation)
  • README with usage example
  • README with compatibility matrix
  • README with checkpoint links
  • README with test instructions

NeuronX Distributed Inference implementation of google/gemma-4-31b-it
text decoder with heterogeneous SWA/global attention layers, QK/V norms,
attention_k_eq_v, layer_scalar, and final_logit_softcapping.

Validated on trn2.3xlarge (TP=4, bf16): 32.1 tok/s, TTFT 65ms,
Pearson r=0.980 vs HF CPU reference.
New files:
- modeling_gemma4_vision.py: Vision encoder (27-layer, 2D RoPE, spatial pooler,
  QK/V norms, multi-modal projector). Verified: cosine 0.9995 vs HF reference
  on CPU, cosine 0.9995 vs CPU on Neuron (trn2, 18.5ms/image at 384x384).
- modeling_gemma4_vlm.py: VLM orchestrator (NeuronGemma4ForConditionalGeneration),
  config (Gemma4VLMInferenceConfig), vision wrapper (Gemma4VisionModelWrapper),
  state dict conversion for combined text+vision weights, load_pretrained_config
  helper that replaces broken hf_adapter import.
- ndxi_patch.py: NxDI monkey-patches for SWA gather .long() fix,
  tensor_capture_hook, and hf_adapter SampleDecoderOnlyOutput rename.

Modified:
- modeling_gemma4.py: Added encode_vision_to_input and scatter_by_index_put
  methods to NeuronGemma4TextModel for vision token merging.
- __init__.py: Export VLM classes.
- sampling module import: try new path first, fall back to old
- load_pretrained_config: JSON fallback when AutoConfig fails on
  unrecognized gemma4 model_type, set _name_or_path
- Set missing config attrs (output_attentions, pad_token_id, etc.)
- Remove int64->int32 conversion in vision wrapper (NEFF traced with int64)
- Fix vision weight key prefixes (no vision_encoder. prefix)
- Add 2D->3D reshape for vision encoder output
- Bypass _get_model_outputs entirely in ndxi_patch forward: call
  context_encoding_model/token_generation_model directly with 24 args
  to avoid deepstack_vision_embeds mismatch
- Create properly-shaped zero vision tensors for text-only prefill

All three E2E tests pass: text prefill, token generation, vision prefill.
@jimburtoft jimburtoft changed the title Add Gemma 4 31B IT contrib model (text-only decoder) Add Gemma 4 31B IT contrib model Apr 5, 2026
Custom NKI kernel tiles the QK contraction dimension in 128-element chunks,
enabling flash attention for Gemma4's d=256 (SWA) and d=512 (global) layers.
Validated end-to-end with cosine > 0.999 standalone and correct greedy output.
When max_length exceeds sliding_window (e.g., 2048 > 1024), the uniform
KV cache size is max_length but the base class SWA TKG mask was only
sliding_window-sized, causing a RuntimeError in compute_for_token_gen's
torch.where(). Override _create_windowed_attn_mask_tkg to pad the SWA
mask to match the uniform cache size, with extra slots masked out.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant