Skip to content

Add OpenFold3 contrib model (biomolecular structure prediction)#124

Open
jimburtoft wants to merge 4 commits intoaws-neuron:mainfrom
jimburtoft:contrib/openfold3
Open

Add OpenFold3 contrib model (biomolecular structure prediction)#124
jimburtoft wants to merge 4 commits intoaws-neuron:mainfrom
jimburtoft:contrib/openfold3

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds OpenFold3 (AlphaFold3 reproduction, ~330M params) biomolecular structure prediction to contrib
  • Vanilla torch_neuronx.trace() with weight replacement for 48-layer PairFormer stack
  • Two compilation strategies: monolithic (N<=256) and decomposed (N>256, up to N=2048)
  • Enables protein structure prediction at scales where GPU (A100-40GB/80GB, H100) runs out of memory

Model Details

  • Source: aqlaboratory/openfold-3 (v0.4.0, Apache 2.0)
  • Architecture: 48-layer PairFormer + 4-block MSA + 2-block template embedder + 24-block diffusion transformer
  • Parameters: ~330M (FP32 -- auto-cast NOT used, triangular ops need full precision)
  • Instance: trn2.3xlarge (LNC=2)

Performance

N Neuron (48 layers) CPU Speedup
128 0.3s 2.9s 9.7x
256 3.2s (mono) 38.4s 12.1x
512 47.9s (decomp) 228s 4.8x
1024 200s (decomp) 1346s 6.7x
2048 920s (decomp) OOM N/A

At N=2048, the TriMul intermediate requires ~128 GB -- exceeding all single-GPU HBM sizes (A100-80GB, H100-80GB). Only Neuron decomposition or B200 (192GB) can handle this scale.

Key Technical Contributions

  • N-range-aware decomposition: Auto-selects optimal sub-op segmentation strategy per sequence length
  • Fused TriMul: Both TriMulOut+TriMulIn in single trace at small N (30% faster)
  • Chunked MHA: Row-chunked triangular attention for N>1024 (fits in 24GB HBM/core)
  • Weight replacement: One NEFF compiled per component, 47 weight swaps for 48 PairFormer layers

Testing

  • 7 files: README.md, src/init.py, src/modeling_openfold3.py, test/{init,integration/{init,test_model},unit/init}.py
  • 11 integration tests: 6 monolithic block tests (N=128) + 5 decomposed tests (N=384)
  • All tests use neuron_allclose with cosine similarity validation
  • Validated on trn2.3xlarge with SDK 2.28

…mized chunk size

- Add 4 merged wrapper classes (TriMulFullWrapper, TriMulBmmOutputWrapper,
  TriAttnFullWrapper, AttnPairBiasFullWrapper) for fewer trace calls
- Rewrite DecomposedPairFormerCompiler with N-range-aware strategy:
  N<=384: full TriMul + merged attention (7 calls/layer)
  N<=512: merged BMM+Output + merged attention (9 calls/layer)
  N<=1024: 3-segment TriMul + 2-segment attention (14 calls/layer)
  N>1024: 3-segment + chunked attention (14+2*ceil(N/128) calls/layer)
- Change CHUNKED_ATTN_CHUNK_SIZE from 64 to 128 (8% faster at N=2048)
- Add threshold constants: MERGED_TRIMUL_MAX_N, MERGED_TRIMUL_BMM_OUTPUT_MAX_N,
  MERGED_ATTN_MAX_N for configurable strategy boundaries
- Update README with N-range strategy table and merged segment speedups
- Update __init__.py exports with new wrappers and constants
Fuse both triangle multiplicative updates (outgoing + incoming) into a
single traced model, reducing calls per layer from 7 to 5 at N<=384.
Validated on trn2.3xlarge with SDK 2.28: 11/11 tests pass (503s).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant