Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warning bug in Qwen2DecoderLayer in transformers ==4.49 #36361

Open
2 of 4 tasks
Kyrie666 opened this issue Feb 24, 2025 · 1 comment · May be fixed by #36377
Open
2 of 4 tasks

warning bug in Qwen2DecoderLayer in transformers ==4.49 #36361

Kyrie666 opened this issue Feb 24, 2025 · 1 comment · May be fixed by #36377
Labels

Comments

@Kyrie666
Copy link

Kyrie666 commented Feb 24, 2025

System Info

transformers ==4.49

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for {config._attn_implementation}; " "unexpected results may be encountered." )
config.sliding_window is a number , the warning active 100%
the code should be config.use_sliding_window ?

Expected behavior

class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for {config._attn_implementation}; " "unexpected results may be encountered." )
config.sliding_window is a number , the warning active 100%
the code should be config.use_sliding_window ?

@KarthikaRajagopal44
Copy link

KarthikaRajagopal44 commented Feb 24, 2025

Hey @Kyrie666 I have submitted a PR to fix this issue: #36377

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants