Open
Conversation
NeuronX Distributed Inference implementation of google/gemma-4-31b-it text decoder with heterogeneous SWA/global attention layers, QK/V norms, attention_k_eq_v, layer_scalar, and final_logit_softcapping. Validated on trn2.3xlarge (TP=4, bf16): 32.1 tok/s, TTFT 65ms, Pearson r=0.980 vs HF CPU reference.
New files: - modeling_gemma4_vision.py: Vision encoder (27-layer, 2D RoPE, spatial pooler, QK/V norms, multi-modal projector). Verified: cosine 0.9995 vs HF reference on CPU, cosine 0.9995 vs CPU on Neuron (trn2, 18.5ms/image at 384x384). - modeling_gemma4_vlm.py: VLM orchestrator (NeuronGemma4ForConditionalGeneration), config (Gemma4VLMInferenceConfig), vision wrapper (Gemma4VisionModelWrapper), state dict conversion for combined text+vision weights, load_pretrained_config helper that replaces broken hf_adapter import. - ndxi_patch.py: NxDI monkey-patches for SWA gather .long() fix, tensor_capture_hook, and hf_adapter SampleDecoderOnlyOutput rename. Modified: - modeling_gemma4.py: Added encode_vision_to_input and scatter_by_index_put methods to NeuronGemma4TextModel for vision token merging. - __init__.py: Export VLM classes.
- sampling module import: try new path first, fall back to old - load_pretrained_config: JSON fallback when AutoConfig fails on unrecognized gemma4 model_type, set _name_or_path - Set missing config attrs (output_attentions, pad_token_id, etc.) - Remove int64->int32 conversion in vision wrapper (NEFF traced with int64) - Fix vision weight key prefixes (no vision_encoder. prefix) - Add 2D->3D reshape for vision encoder output - Bypass _get_model_outputs entirely in ndxi_patch forward: call context_encoding_model/token_generation_model directly with 24 args to avoid deepstack_vision_embeds mismatch - Create properly-shaped zero vision tensors for text-only prefill All three E2E tests pass: text prefill, token generation, vision prefill.
Custom NKI kernel tiles the QK contraction dimension in 128-element chunks, enabling flash attention for Gemma4's d=256 (SWA) and d=512 (global) layers. Validated end-to-end with cosine > 0.999 standalone and correct greedy output.
When max_length exceeds sliding_window (e.g., 2048 > 1024), the uniform KV cache size is max_length but the base class SWA TKG mask was only sliding_window-sized, causing a RuntimeError in compute_for_token_gen's torch.where(). Override _create_windowed_attn_mask_tkg to pad the SWA mask to match the uniform cache size, with extra slots masked out.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Architecture Highlights
Gemma 4 31B has several features not present in the existing Gemma 3 contrib:
30 * tanh(logits / 30)after lm_head1/sqrt(head_dim)scaling in attention, but Gemma 4 usesscaling=1.0. Naively scaling Q weights doesn't work becauseq_norm(RMSNorm) is scale-invariant. The fix is to scale theq_layernorm.weightparameters bysqrt(head_dim)instead, which survives normalization.Test Results
All 7 integration tests pass on trn2.3xlarge (TP=4, batch_size=1, bfloat16):
Files
Prerequisites
transformers >= 5.5.0(required forGemma4ForConditionalGeneration)transformers/utils/fx.pyshim needed (NxD importsHFTracerwhich was removed in transformers 5.x — see README for details)attn_kernel_enabled=Falserequired (head_dim > 128 exceeds NKI flash attention limit)Known Limitations
gemma4_vision) is not yet ported. A follow-up PR will add VLM (vision + language) support.Checklist