From 0e5263c2a498a8a81d0d9f2a1f4d2538cf8406c8 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Tue, 7 Apr 2026 10:27:14 -0700 Subject: [PATCH] enable flash attention option in cli. enable flash attention for qwen3 and qwen2 model script PiperOrigin-RevId: 895970430 --- examples/rl/grpo/gsm8k/run_qwen3.sh | 2 + .../rl/grpo/gsm8k/run_qwen3_simplereward.sh | 2 + examples/sft/mtnt/run_qwen2.5_0.5b.sh | 1 + tests/models/automodel_test.py | 44 +++++++++++++++++++ tunix/cli/base_config.yaml | 4 ++ tunix/cli/utils/model.py | 4 ++ tunix/models/automodel.py | 10 +++++ 7 files changed, 67 insertions(+) diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index 7b4daa849..7f0afafa5 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -33,6 +33,8 @@ python3 -m tunix.cli.grpo_main \ model_config.model_name=${model_name} \ model_config.model_id=Qwen/${model_name} \ model_config.model_source=huggingface \ + model_config.use_flash_attention=true \ + model_config.flash_attention_block_size=256 \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.mesh.shape="(2,4)" \ model_config.mesh.axis_names="('fsdp','tp')" \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index 1a84f156a..4ddca5486 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -42,6 +42,8 @@ python3 -m tunix.cli.grpo_main \ model_config.model_name=${model_name} \ model_config.model_id=Qwen/${model_name} \ model_config.model_source=huggingface \ + model_config.use_flash_attention=true \ + model_config.flash_attention_block_size=256 \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.mesh.shape="(2,4)" \ model_config.mesh.axis_names="('fsdp','tp')" \ diff --git a/examples/sft/mtnt/run_qwen2.5_0.5b.sh b/examples/sft/mtnt/run_qwen2.5_0.5b.sh index 850b4ac19..de652607a 100755 --- a/examples/sft/mtnt/run_qwen2.5_0.5b.sh +++ b/examples/sft/mtnt/run_qwen2.5_0.5b.sh @@ -24,6 +24,7 @@ python3 -m tunix.cli.peft_main \ model_config.mesh.shape="(2,2)" \ model_config.mesh.axis_names="('fsdp','tp')" \ model_config.rng_seed=0 \ + model_config.use_flash_attn=true \ model_config.model_download_path="/tmp/models" \ tokenizer_config.tokenizer_path="Qwen/Qwen2.5-0.5B"\ tokenizer_config.tokenizer_type="huggingface" \ diff --git a/tests/models/automodel_test.py b/tests/models/automodel_test.py index 7b5e3db29..f23606f31 100644 --- a/tests/models/automodel_test.py +++ b/tests/models/automodel_test.py @@ -1,3 +1,4 @@ +import dataclasses from unittest import mock from absl.testing import absltest @@ -244,5 +245,48 @@ def test_from_pretrained_missing_model_path(self, model_source): ) + @mock.patch.object(naming, "ModelNaming", autospec=True) + @mock.patch.object(automodel, "call_model_config", autospec=True) + @mock.patch.object(automodel, "download_model", autospec=True) + @mock.patch.object(automodel, "create_model_from_safe_tensors", autospec=True) + def test_from_pretrained_with_config_overrides( + self, + mock_create_model, + mock_download_model, + mock_call_model_config, + mock_model_naming, + ): + @dataclasses.dataclass + class FakeConfig: + use_flash_attention: bool = False + flash_attention_block_size: int = 1024 + + mock_naming_info = mock.Mock() + mock_naming_info.model_family = "qwen2" + mock_naming_info.model_name = "qwen2.5-0.5b" + mock_model_naming.return_value = mock_naming_info + + mock_call_model_config.return_value = FakeConfig() + mock_download_model.return_value = "fake_path" + mesh = jax.sharding.Mesh(jax.devices(), ("devices",)) + + # Execution + automodel.AutoModel.from_pretrained( + model_id="qwen/Qwen2.5-0.5B", + mesh=mesh, + use_flash_attention=True, + flash_attention_block_size=512, + invalid_param="ignored", + ) + + # Verification + # check that create_model_from_safe_tensors was called with the overrides + self.assertTrue(mock_create_model.called) + called_config = mock_create_model.call_args[0][2] + self.assertTrue(called_config.use_flash_attention) + self.assertEqual(called_config.flash_attention_block_size, 512) + self.assertFalse(hasattr(called_config, "invalid_param")) + + if __name__ == "__main__": absltest.main() diff --git a/tunix/cli/base_config.yaml b/tunix/cli/base_config.yaml index 199faff04..299f144af 100644 --- a/tunix/cli/base_config.yaml +++ b/tunix/cli/base_config.yaml @@ -47,6 +47,10 @@ model_config: &base_model_config # Directory used for NNX conversion if downloaded Gemma/Gemma2 from Kaggle source. intermediate_ckpt_dir: "/tmp/intermediate_ckpt/" + + # Flash Attention configuration + use_flash_attention: false + flash_attention_block_size: 1024 ################################# LoRa ################################# # If you want to use LoRa, specify the module_path, rank, and alpha. You can also optionally specify the weight_qtype and tile_size, e.g.: #lora_config: diff --git a/tunix/cli/utils/model.py b/tunix/cli/utils/model.py index 2b4b72a00..98d8b66a5 100644 --- a/tunix/cli/utils/model.py +++ b/tunix/cli/utils/model.py @@ -126,6 +126,10 @@ def create_model( intermediate_ckpt_dir=model_config.get('intermediate_ckpt_dir'), rng_seed=model_config.get('rng_seed', 0), model_path=model_config.get('model_path'), + use_flash_attention=model_config.get('use_flash_attention', False), + flash_attention_block_size=model_config.get( + 'flash_attention_block_size', 1024 + ), ) # Handle Tokenizer Path overrides diff --git a/tunix/models/automodel.py b/tunix/models/automodel.py index 289418035..e8cf233c2 100644 --- a/tunix/models/automodel.py +++ b/tunix/models/automodel.py @@ -13,6 +13,7 @@ # limitations under the License. """AutoModel class for Tunix.""" +import dataclasses import enum import gc import importlib @@ -487,6 +488,15 @@ def from_pretrained( # pick corresponding config based on model version model_params = call_model_config(naming_info.model_name) + # Apply any model config field overrides passed via kwargs (e.g. + # use_flash_attention, flash_attention_block_size). + if dataclasses.is_dataclass(model_params): + valid_fields = {f.name for f in dataclasses.fields(model_params)} + overrides = {k: v for k, v in kwargs.items() if k in valid_fields} + if overrides: + logging.info('Applying model config overrides: %s', overrides) + model_params = dataclasses.replace(model_params, **overrides) + with mesh: model = create_model_from_safe_tensors( naming_info.model_name,