Skip to content

Add Gemma-4 support#109

Open
cszhz wants to merge 6 commits intoaws-neuron:mainfrom
cszhz:main
Open

Add Gemma-4 support#109
cszhz wants to merge 6 commits intoaws-neuron:mainfrom
cszhz:main

Conversation

@cszhz
Copy link
Copy Markdown

@cszhz cszhz commented Apr 7, 2026

Description

Add NeuronX Distributed Inference implementation of the Gemma-4 model family. Supports all 4 variants (E2B, E4B, 31B, 26B-A4B) in a single unified modeling_gemma4.py,
handling variable head_dim (256/512), Per-Layer Embeddings (PLE), hybrid sliding/full attention, Q-K-V normalization, attention_k_eq_v, MoE routing (128 experts, top-8),
and logit softcapping. Includes audio (Conformer) and vision (ViT) encoder implementations for multimodal inference.

Model Information

Model Name: Gemma-4 (E2B / E4B / 31B / 26B-A4B)

Model Architecture: Decoder-only transformer (Dense + MoE variants)

Purpose: Text generation (instruction-tuned), multimodal audio/vision understanding

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)

    • Integration tests validating generation correctness, output coherence, and performance
    • Supports variant selection via GEMMA4_VARIANT env var (e2b, e4b, 31b, 26b)
    • Test can compile and run the model on Neuron
  • README.md with the following sections:

    • Usage Example: Compile and inference code examples
    • Compatibility Matrix: Trn2/Trn1/Inf2 table with NxDI version info
    • Example Checkpoints: All 4 HuggingFace checkpoint IDs
    • Testing Instructions: Commands for pytest and direct execution
  • Source Code (src/)

    • modeling_gemma4.py — Main model implementation following NxDI patterns
    • gemma4_audio_encoder.py — Conformer-based audio encoder
    • gemma4_vision_encoder.py — ViT-based vision encoder
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/gemma-4/
README.md
/src
init.py
modeling_gemma4.py
gemma4_audio_encoder.py
gemma4_vision_encoder.py
/test
/unit
/integration
test_model.py
verify_audio.py
verify_image.py
verify_video.py

Testing

How did you test this change?

All 4 variants compiled and tested on trn2.3xlarge (4 Neuron cores, LNC=2) with bfloat16 precision. Each variant validated with text generation prompts for output
coherence and correctness. Multimodal (audio/vision) paths verified separately against HuggingFace CPU (float32) reference outputs.

Test Results:

Variant TP Batch Seq Len Throughput Status
E2B 2 1 512 ~80-120 tok/s ✅ Validated
E4B 2 1 512 ~70-80 tok/s ✅ Validated
31B 4 1 2048 ~13-23 tok/s ✅ Validated
26B-A4B 4 1 2048 ~43-60 tok/s ✅ Validated

Compatibility

Tested with:

Item Version
Neuron SDK NxDI 0.8.x (neuronx-cc 2.23.x)
Instance Type Trn2 (trn2.3xlarge)
PyTorch 2.9.0
Python 3.12.3

Additional Information

  • attn_kernel_enabled=False is required — NKI attention kernel does not support head_dim > 128 (Gemma-4 uses 256/512).
  • Sliding window attention is disabled at the Neuron level; all layers use full context length.
  • KV sharing (E2B/E4B) is disabled in v1; all layers compute their own KV cache.
  • Weights loaded directly from safetensors (HF AutoModel does not support Gemma-4 in transformers < 5.5).
  • Audio encoder compiles to fixed mel length (2048 frames, ~30s max). Vision encoder runs on CPU.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant