Add OpenFold3 contrib model (biomolecular structure prediction)#124
Open
jimburtoft wants to merge 4 commits intoaws-neuron:mainfrom
Open
Add OpenFold3 contrib model (biomolecular structure prediction)#124jimburtoft wants to merge 4 commits intoaws-neuron:mainfrom
jimburtoft wants to merge 4 commits intoaws-neuron:mainfrom
Conversation
…mized chunk size - Add 4 merged wrapper classes (TriMulFullWrapper, TriMulBmmOutputWrapper, TriAttnFullWrapper, AttnPairBiasFullWrapper) for fewer trace calls - Rewrite DecomposedPairFormerCompiler with N-range-aware strategy: N<=384: full TriMul + merged attention (7 calls/layer) N<=512: merged BMM+Output + merged attention (9 calls/layer) N<=1024: 3-segment TriMul + 2-segment attention (14 calls/layer) N>1024: 3-segment + chunked attention (14+2*ceil(N/128) calls/layer) - Change CHUNKED_ATTN_CHUNK_SIZE from 64 to 128 (8% faster at N=2048) - Add threshold constants: MERGED_TRIMUL_MAX_N, MERGED_TRIMUL_BMM_OUTPUT_MAX_N, MERGED_ATTN_MAX_N for configurable strategy boundaries - Update README with N-range strategy table and merged segment speedups - Update __init__.py exports with new wrappers and constants
Fuse both triangle multiplicative updates (outgoing + incoming) into a single traced model, reducing calls per layer from 7 to 5 at N<=384. Validated on trn2.3xlarge with SDK 2.28: 11/11 tests pass (503s).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch_neuronx.trace()with weight replacement for 48-layer PairFormer stackModel Details
Performance
At N=2048, the TriMul intermediate requires ~128 GB -- exceeding all single-GPU HBM sizes (A100-80GB, H100-80GB). Only Neuron decomposition or B200 (192GB) can handle this scale.
Key Technical Contributions
Testing