Skip to content

Add ZAYA1-base contrib model (MoE with CCA attention)#127

Open
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/zaya1-base
Open

Add ZAYA1-base contrib model (MoE with CCA attention)#127
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/zaya1-base

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds Zyphra's ZAYA1-base (8.84B total / 800M active, MoE with CCA attention) to contrib
  • Full NxDI implementation with custom CCA attention, non-linear MLP router, and MoD (Mixture of Depths)
  • Includes vLLM-neuron serving support with near-linear throughput scaling

Model Details

  • Source: Zyphra/ZAYA1-base (Apache 2.0)
  • Architecture: 80 layers alternating attention (CCA) and MoE, 16 real + 1 skip expert, partial RoPE
  • Parameters: 8.84B total / 800M active per token
  • Instance: trn2.3xlarge (TP=2)
  • Prerequisite: Zyphra's custom transformers fork (pip install transformers @ git+https://github.com/Zyphra/transformers.git@zaya)

Performance

Metric Neuron trn2 (TP=2) GPU A10G Advantage
TKG (batch=1) 22.3 tok/s 6.2 tok/s 3.6x
TKG (batch=4) 86.3 tok/s 23.9 tok/s 3.6x
TTFT 75.6 ms 184.7 ms 2.4x

vLLM serving: 72.3 tok/s at concurrency=4 with near-linear scaling.

Key Technical Contributions

  • ManualConv1d: Replaces nn.Conv1d to avoid NCC_ITEN404 crash after all-gather
  • SPMDRank: Per-rank tensor extraction for SPMD tracing compatibility
  • CCA state caching: Conv states and previous hidden states persisted via input_output_aliases
  • Static expert dispatch: Mask-based dispatch replacing unsupported torch.bincount
  • SDK 2.29 validated: Tested PASS with 21.1 tok/s

Testing

  • 7 files: README.md, src/{init,modeling_zaya}.py, test/{init,integration/{init,test_model},unit/init}.py
  • Integration tests: smoke test, prefill accuracy (" Paris"), batch independence (4 prompts), token matching
  • Validated on trn2.3xlarge with SDK 2.27, 2.28, and 2.29

jimburtoft and others added 2 commits March 11, 2026 19:59
…han GPU)

NxDI contrib for Zyphra/ZAYA1-base -- a novel 800M active / 8.84B total
parameter MoE with Compressed Convolutional Attention (CCA), non-linear
MLP router with Mixture of Depths, and partial RoPE.

Key results on trn2.3xlarge (TP=2, BF16):
- batch=1: 22.3 tok/s (44.8 ms/token, TTFT 75.6ms)
- batch=4: 86.3 tok/s (11.59 ms/token)
- vLLM: 72.3 tok/s at concurrency=4
- 3.6x faster than NVIDIA A10G GPU at all batch sizes

Novel NxDI patterns:
- ManualConv1d replacing nn.Conv1d (avoids NCC_ITEN404 compiler crash)
- CCA conv_state + prev_hs persistence via input_output_aliases
- CTE/TKG branching for extended state management
- Non-linear 3-layer MLP router with EDA and MoD skip expert
- SPMDRank for per-rank extraction during SPMD tracing
- Static XLA-compatible expert dispatch (mask-based, no bincount)

Tests: 12 integration tests (smoke, prefill, generation, performance)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant