Skip to content

Add Gemma4-E2B contrib model (text decoder + VLM)#115

Draft
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-4-E2B
Draft

Add Gemma4-E2B contrib model (text decoder + VLM)#115
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-4-E2B

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • NxDI implementation of google/gemma-4-E2B (2.3B effective params)
  • Text decoder is production-ready with validated accuracy and performance
  • Vision encoder and VLM wrapper included but VLM compilation is blocked by an internal compiler issue (NCC_ITEN404)

Model Architecture

Gemma4-E2B is a decoder-only transformer with several novel features for NxDI:

  • Per-Layer Embeddings (PLE): Shared embedding table [262144, 8960] with per-layer gated projections
  • KV Cache Sharing: 20 of 35 layers reuse K/V from donor layers
  • Heterogeneous Layers: SWA (head_dim=256) and full attention (head_dim=512) with different MLP widths
  • Vision Encoder: 16-layer ViT producing 280 soft tokens per image

Text Decoder Results (TP=1, batch=1, trn2.3xlarge)

Metric Value
BF16 + KV Sharing Cosine 0.999999
TTFT (bucket=128) 27.3 ms
TPOT 10.4 ms
Throughput 96 tok/s

Known Limitation

VLM (text + vision) compilation fails with NCC_ITEN404 in neuronx-cc 2.23. The error occurs in the TensorInitialization tensorizer pass when compiling the context encoding NEFF with vision inputs. Text-only inference is unaffected. The VLM code is included and architecturally complete -- ready to enable once the compiler issue is resolved.

Files

File Description Lines
modeling_gemma4_e2b.py Text decoder (PLE, KV sharing, heterogeneous attention) ~1750
modeling_gemma4_e2b_vlm.py VLM wrapper (vision + text pipeline) ~857
modeling_gemma4_vision.py Vision encoder (config-driven, shared with 31B) ~770
ndxi_patch.py NxDI compatibility patches ~376
test_model.py Integration tests (text-only, 6 tests) ~340

Testing

# On a trn2.3xlarge with model weights at /mnt/models/gemma-4-E2B/
pytest contrib/models/gemma-4-E2B/test/integration/test_model.py --capture=tee-sys

NxDI implementation of google/gemma-4-E2B (2.3B effective params) with:
- Text decoder: PLE, KV cache sharing, heterogeneous SWA/global attention
- Vision encoder and VLM wrapper (compilation blocked by NCC_ITEN404)
- NxDI 0.7/0.8 compatibility helpers
- Integration tests (text-only)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant