Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/rl/grpo/gsm8k/run_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ python3 -m tunix.cli.grpo_main \
model_config.model_name=${model_name} \
model_config.model_id=Qwen/${model_name} \
model_config.model_source=huggingface \
model_config.use_flash_attention=true \
model_config.flash_attention_block_size=256 \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.mesh.shape="(2,4)" \
model_config.mesh.axis_names="('fsdp','tp')" \
Expand Down
2 changes: 2 additions & 0 deletions examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ python3 -m tunix.cli.grpo_main \
model_config.model_name=${model_name} \
model_config.model_id=Qwen/${model_name} \
model_config.model_source=huggingface \
model_config.use_flash_attention=true \
model_config.flash_attention_block_size=256 \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.mesh.shape="(2,4)" \
model_config.mesh.axis_names="('fsdp','tp')" \
Expand Down
1 change: 1 addition & 0 deletions examples/sft/mtnt/run_qwen2.5_0.5b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python3 -m tunix.cli.peft_main \
model_config.mesh.shape="(2,2)" \
model_config.mesh.axis_names="('fsdp','tp')" \
model_config.rng_seed=0 \
model_config.use_flash_attn=true \
model_config.model_download_path="/tmp/models" \
tokenizer_config.tokenizer_path="Qwen/Qwen2.5-0.5B"\
tokenizer_config.tokenizer_type="huggingface" \
Expand Down
44 changes: 44 additions & 0 deletions tests/models/automodel_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -244,5 +245,48 @@ def test_from_pretrained_missing_model_path(self, model_source):
)


@mock.patch.object(naming, "ModelNaming", autospec=True)
@mock.patch.object(automodel, "call_model_config", autospec=True)
@mock.patch.object(automodel, "download_model", autospec=True)
@mock.patch.object(automodel, "create_model_from_safe_tensors", autospec=True)
def test_from_pretrained_with_config_overrides(
self,
mock_create_model,
mock_download_model,
mock_call_model_config,
mock_model_naming,
):
@dataclasses.dataclass
class FakeConfig:
use_flash_attention: bool = False
flash_attention_block_size: int = 1024

mock_naming_info = mock.Mock()
mock_naming_info.model_family = "qwen2"
mock_naming_info.model_name = "qwen2.5-0.5b"
mock_model_naming.return_value = mock_naming_info

mock_call_model_config.return_value = FakeConfig()
mock_download_model.return_value = "fake_path"
mesh = jax.sharding.Mesh(jax.devices(), ("devices",))

# Execution
automodel.AutoModel.from_pretrained(
model_id="qwen/Qwen2.5-0.5B",
mesh=mesh,
use_flash_attention=True,
flash_attention_block_size=512,
invalid_param="ignored",
)

# Verification
# check that create_model_from_safe_tensors was called with the overrides
self.assertTrue(mock_create_model.called)
called_config = mock_create_model.call_args[0][2]
self.assertTrue(called_config.use_flash_attention)
self.assertEqual(called_config.flash_attention_block_size, 512)
self.assertFalse(hasattr(called_config, "invalid_param"))


if __name__ == "__main__":
absltest.main()
4 changes: 4 additions & 0 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ model_config: &base_model_config

# Directory used for NNX conversion if downloaded Gemma/Gemma2 from Kaggle source.
intermediate_ckpt_dir: "/tmp/intermediate_ckpt/"

# Flash Attention configuration
use_flash_attention: false
flash_attention_block_size: 1024
################################# LoRa #################################
# If you want to use LoRa, specify the module_path, rank, and alpha. You can also optionally specify the weight_qtype and tile_size, e.g.:
#lora_config:
Expand Down
4 changes: 4 additions & 0 deletions tunix/cli/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def create_model(
intermediate_ckpt_dir=model_config.get('intermediate_ckpt_dir'),
rng_seed=model_config.get('rng_seed', 0),
model_path=model_config.get('model_path'),
use_flash_attention=model_config.get('use_flash_attention', False),
flash_attention_block_size=model_config.get(
'flash_attention_block_size', 1024
),
)

# Handle Tokenizer Path overrides
Expand Down
10 changes: 10 additions & 0 deletions tunix/models/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""AutoModel class for Tunix."""

import dataclasses
import enum
import gc
import importlib
Expand Down Expand Up @@ -487,6 +488,15 @@ def from_pretrained(
# pick corresponding config based on model version
model_params = call_model_config(naming_info.model_name)

# Apply any model config field overrides passed via kwargs (e.g.
# use_flash_attention, flash_attention_block_size).
if dataclasses.is_dataclass(model_params):
valid_fields = {f.name for f in dataclasses.fields(model_params)}
overrides = {k: v for k, v in kwargs.items() if k in valid_fields}
if overrides:
logging.info('Applying model config overrides: %s', overrides)
model_params = dataclasses.replace(model_params, **overrides)

with mesh:
model = create_model_from_safe_tensors(
naming_info.model_name,
Expand Down
Loading