diff --git a/deepspeed/runtime/sequence_parallel/ulysses_sp.py b/deepspeed/runtime/sequence_parallel/ulysses_sp.py index d59edfa9b6bf..413921c2090c 100644 --- a/deepspeed/runtime/sequence_parallel/ulysses_sp.py +++ b/deepspeed/runtime/sequence_parallel/ulysses_sp.py @@ -491,14 +491,19 @@ def register_with_transformers( local_seq_length = seq_length // mpu.get_sequence_parallel_world_size() global_seq_length = seq_length + arch_cfg = hf_model_config.get_text_config() + uattn = UlyssesSPAttentionHF( attn=core_attn_function, batch_size=micro_batch_size, - attn_head_count=hf_model_config.num_attention_heads, - attn_head_size=getattr(hf_model_config, "head_dim", - hf_model_config.hidden_size // hf_model_config.num_attention_heads), - kv_head_count=hf_model_config.num_key_value_heads, - num_hidden_layers=hf_model_config.num_hidden_layers, + attn_head_count=arch_cfg.num_attention_heads, + attn_head_size=getattr( + arch_cfg, + "head_dim", + arch_cfg.hidden_size // arch_cfg.num_attention_heads, + ), + kv_head_count=arch_cfg.num_key_value_heads, + num_hidden_layers=arch_cfg.num_hidden_layers, process_group=mpu.get_sequence_parallel_group(), seq_length_is_variable=seq_length_is_variable, local_seq_length=local_seq_length, diff --git a/tests/unit/v1/multimodal/__init__.py b/tests/unit/v1/multimodal/__init__.py new file mode 100644 index 000000000000..c8d652d4dc49 --- /dev/null +++ b/tests/unit/v1/multimodal/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team diff --git a/tests/unit/v1/multimodal/test_gemma4_config.py b/tests/unit/v1/multimodal/test_gemma4_config.py new file mode 100644 index 000000000000..5d1a3ca1ad12 --- /dev/null +++ b/tests/unit/v1/multimodal/test_gemma4_config.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +import pytest + +transformers = pytest.importorskip("transformers") +Gemma4Config = getattr(transformers, "Gemma4Config", None) +pytestmark = pytest.mark.skipif(Gemma4Config is None, reason="Gemma4Config not available in this transformers version") + + +def test_gemma4_text_config_fallback(): + config = Gemma4Config() + assert not hasattr(config, 'num_attention_heads'), \ + "Gemma4Config top-level should not have num_attention_heads" + arch_cfg = config.get_text_config() + assert hasattr(arch_cfg, 'num_attention_heads') + assert arch_cfg.num_attention_heads > 0 + assert hasattr(arch_cfg, 'num_key_value_heads') + assert arch_cfg.num_key_value_heads > 0 + assert hasattr(arch_cfg, 'num_hidden_layers') + assert arch_cfg.num_hidden_layers > 0 + assert hasattr(arch_cfg, 'hidden_size') + assert arch_cfg.hidden_size > 0 + + +def test_gemma4_text_config_matches_text_config(): + config = Gemma4Config() + arch_cfg = config.get_text_config() + assert arch_cfg is config.text_config + assert arch_cfg.num_attention_heads == config.text_config.num_attention_heads + assert arch_cfg.num_key_value_heads == config.text_config.num_key_value_heads