Skip to content

Add Voxtral Mini 3B contrib model (audio-language)#125

Open
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/voxtral-mini-3B
Open

Add Voxtral Mini 3B contrib model (audio-language)#125
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/voxtral-mini-3B

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds Mistral AI's Voxtral Mini 3B audio-language model to contrib
  • Decomposed pipeline: Whisper-like encoder via torch_neuronx.trace(), Llama LLM via NxDI ImageToTextModelWrapper
  • Supports text-only generation, audio transcription, and audio understanding

Model Details

  • Source: mistralai/Voxtral-Mini-3B-2507 (Apache 2.0)
  • Architecture: 637M audio encoder (32L) + 25M projector + 3.3B Llama LLM (30L)
  • Precision: BF16
  • Instance: trn2.3xlarge (TP=1) and inf2.xlarge (TP=1)

Performance

Instance Throughput TTFT
trn2.3xlarge 58.5 tok/s 418ms
inf2.xlarge 28.4 tok/s (text), 27.4 tok/s (audio) --

Key Technical Contributions

  • Reuses NxDI Llama: VoxtralTextModel extends NeuronLlamaModel -- no custom attention/MLP
  • Scatter-based audio injection: Uses NxDI's scatter_by_index_put for audio token embedding
  • PixtralInferenceConfig reuse: Voxtral's encoder+projector+LLM maps to Pixtral's vision+LLM config pattern
  • Hardware-aware compiler args: Auto-detects trn2 for --lnc=2 flag

Testing

  • 9 files: README.md, benchmark_encoder.py, src/{init,modeling_voxtral,utils/init}.py, test/{init,integration/{init,test_model},unit/init}.py
  • Integration tests: model compilation, text generation (correctness + determinism), audio transcription
  • Validated on trn2.3xlarge and inf2.xlarge with SDK 2.28

Mistral AI's Voxtral Mini 3B audio-language model on Neuron
(Trainium2/Inferentia2) using a decomposed pipeline:
- Audio encoder: torch_neuronx.trace() with inline_weights_to_neff=False
- Projector: CPU (25M params)
- LLM backbone: NxDI ImageToTextModelWrapper (Llama 3.3B, TP=1)

Supports text-only generation, audio transcription, and audio
understanding. Validated at 58.5 tok/s on trn2.3xlarge and
28.4 tok/s on inf2.xlarge.
- Add -O1, tensorizer options, and --lnc=2 to audio encoder trace
  (matched to decoder optimization level)
- Fix neuron-ls trn2 detection: use plain 'neuron-ls' instead of
  '--json-output' which does not contain instance type string
- Add benchmark_encoder.py for component-level latency measurement

Benchmark results (SDK 2.28, trn2.3xlarge): encoder 224ms with both
minimal and optimized flags -- SDK 2.28 already fully optimizes the
encoder trace. Projector trace to Neuron saves 2ms (3ms to 1ms).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant