Add MiniMax-M2 and MiMo-V2-Flash model support#119
Open
whn09 wants to merge 11 commits intoaws-neuron:mainfrom
Open
Add MiniMax-M2 and MiMo-V2-Flash model support#119whn09 wants to merge 11 commits intoaws-neuron:mainfrom
whn09 wants to merge 11 commits intoaws-neuron:mainfrom
Conversation
- MiniMax M2: Custom MoE (62 layers, 256 experts, top-8, sigmoid router, QK norm, partial RoPE, fused_qkv). TP=64 on trn2.48xlarge. - MiMo-V2-Flash: Custom MoE (48 layers, 256 experts, top-8, hybrid attention with full + sliding window, asymmetric head dims Q/K=192 V=128, attention sink bias). TP=64, EP=64 on trn2.48xlarge. Both models include: - Model implementations in src/neuronx_distributed_inference/models/ - Contrib wrappers following standard NxDI pattern - Integration tests - READMEs with architecture details and usage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MiMoV2InferenceConfig requires 27 attributes at init time. Test now checks get_required_attributes() without instantiating config. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Author
End-to-End Test Results on trn2.48xlargeBoth models have been compiled and tested on a trn2.48xlarge instance (32 NeuronCores, logical_nc_config=2 → 64 logical cores) using NxDI 2.22+ (PyTorch 2.9). MiMo-V2-Flash (XiaomiMiMo/MiMo-V2-Flash)Configuration:
Result: ✅ Passed — generates coherent text output MiniMax-M2 (MiniMax/MiniMax-M2)Configuration:
Important: Must NOT enable Result: ✅ Passed — generates correct, coherent output with chat template Sample outputs: Test Environment
|
Author
Benchmark Results: MiMo-V2-Flash & MiniMax-M2 on trn2.48xlargeEnvironment
Configuration
Results
Notes
|
- Add vllm-neuron patch for MiMo/MiniMax architecture support - Add benchmark scripts for MiMo-V2-Flash and MiniMax-M2 (multiple BS/EP configs) - Add setup script for vllm-neuron installation and model weight download - Update READMEs with vLLM serving instructions and patch documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The upstream merge added required attribute validation in InferenceConfig.__init__. The test now provides a proper HF config via load_pretrained_config(hf_config=...). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add vLLM serving performance tables from trn2.48xlarge benchmarks: - MiMo-V2-Flash: BS=32/EP=64 with c=1/16/32 (up to 302.61 tok/s) - MiniMax-M2: Config 1 (BS=1/EP=1) and Config 2 (BS=256/EP=64) with c=1/16/32/128/256 - Add note about VLLM_ENGINE_READY_TIMEOUT_S=3600 for large MoE models Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…and router bias Key improvements over previous v3 implementation: - Add nki-library attention_block_tkg kernel override with partial RoPE (rotary_dim=64) - Fix QK normalization: use Neuron-native RmsNorm per-rank instead of hand-rolled all-reduce (which compiled differently in CE vs TG NEFFs) - Preserve e_score_correction_bias as nn.Parameter in RouterTopKWithBias (dropping the bias causes ~75% wrong expert selection since bias values ~8-9 dominate sigmoid scores 0-1) - Fix fused QKV state dict ordering: convert_state_dict_to_fused_qkv must run BEFORE the qkv_proj key rename, not after - Add KV cache rank-3 to rank-4 reshape in NKI kernel override - Fix NKI grid syntax for SDK 2.29 (plain int instead of nc() tuple) - Drop V3 suffix from class names (MiniMaxM2InferenceConfig, NeuronMiniMaxM2ForCausalLM) - Simplify contrib wrapper to direct re-export
Pass per-rank QK norm weights to the nki-library kernel's new rmsnorm_QK_flat_* parameters, which normalize Q and K across all heads concatenated (before head split) rather than per-head. The per-rank weight slice is extracted via torch.index_select on rank_util.rank for SPMD-compatible tracing, matching the approach used in the non-NKI code path. This should fix the NKI kernel output quality degradation -- QK norm was previously skipped entirely in the NKI path because the kernel only supported per-head norm.
- Fix _helper_concat_and_delete_qkv to use 'self_attn.qkv_proj.Wqkv' key path instead of 'self_attn.Wqkv' to match NxDI model parameter hierarchy - Add --enable-nki-attention flag to inference script for toggling NKI kernel - Set top_k=1 in GenerationConfig to match OnDeviceSamplingConfig global_topk=1 - Add attn_block_tkg_nki_kernel_cache_update flag for in-kernel KV cache update
MiniMax-M2: NKI attention kernel, correct QK norm, and router bias
Add imports and MODEL_TYPES entries for MiMo-V2 and MiniMax-M2 to enable vllm-neuron model discovery. Fixes ImportError after PR #7 renamed NeuronMiniMaxM2ForCausalLMV3 to NeuronMiniMaxM2ForCausalLM. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add NxDI contrib model adapters for MiniMax-M2 and MiMo-V2-Flash, two large-scale Mixture-of-Experts (MoE) models. Both require trn2.48xlarge with 64 logical cores (LNC=2) and feature custom MoE routing, partial RoPE, and architecture-specific attention mechanisms. Includes vLLM-neuron integration with benchmark scripts and performance results.
Model Information
Model Name: MiniMax-M2, MiMo-V2-Flash
Model Architecture: Decoder-only MoE transformer
Purpose: Text generation
Checklist
Required Components
Accuracy Test (
test/integration/test_model.py)README.md with the following sections:
Source Code (
src/)models/minimax_m2/— MiniMax-M2: NKI attention kernel, Neuron-native QK norm, sigmoid router with e_score_correction_bias, fused_qkv, partial RoPEmodels/mimo_v2/— MiMo-V2-Flash: hybrid attention (full + sliding window), asymmetric head dims, attention sink biasmodels/mimo_v2/conversion_script/— FP8 to BF16 weight converterutils/constants.py— Register both models in MODEL_TYPES for vllm-neuron discoveryOptional Components
test/integration/Folder Structure
Testing
How did you test this change?
Tested on trn2.48xlarge with TP=64, LNC=2, Neuron SDK 2.22+.
MiniMax-M2:
MiMo-V2-Flash:
Standalone NxDI (BF16, TP=64, EP=64):
vLLM Serving (BS=32, TP=64/EP=64, CB, 900/90 tokens):
Compatibility
Tested with:
Additional Information
MiniMax-M2 key features:
attention_block_tkgwith partial RoPE support (rotary_dim=64, head_dim=128) and flat QK RMSNorm fused into the kernelRmsNorm.apply(AwsNeuronRmsNorm custom call) instead of hand-rolled PyTorch ops, which compiled into different HLO in CE vs TG NEFFse_score_correction_biaspreserved asnn.Parameter(notregister_buffer) with non-uniform init to prevent XLA optimization from eliminating the add operationconvert_state_dict_to_fused_qkvruns before key rename to match expected state dict key pathsMiMo-V2-Flash key features:
Related Issues
N/A
vLLM Integration
perf_test/directory)perf_test/vllm-neuron-mimo-minimax.patch)By submitting this PR, I confirm that: