Skip to content

Commit

Permalink
fixed yapf fails
Browse files Browse the repository at this point in the history
  • Loading branch information
bmullick-amd committed Jan 30, 2025
1 parent 783092b commit 736f0e6
Showing 1 changed file with 56 additions and 53 deletions.
109 changes: 56 additions & 53 deletions vllm/model_executor/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,12 @@ def __init__(
bias=False,
quant_config=quant_config,
)
self.attn = Attention(
self.n_heads,
self.inner_dim // self.n_heads,
scale=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix
)
self.attn = Attention(self.n_heads,
self.inner_dim // self.n_heads,
scale=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)

def forward(
self,
Expand Down Expand Up @@ -396,14 +394,12 @@ def forward(

class T5LayerSelfAttention(nn.Module):

def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
has_relative_attention_bias=False,
prefix: str = ""
):
def __init__(self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
has_relative_attention_bias=False,
prefix: str = ""):
super().__init__()
self.SelfAttention = T5Attention(
config,
Expand Down Expand Up @@ -439,13 +435,12 @@ def __init__(
prefix: str = "",
):
super().__init__()
self.EncDecAttention = T5Attention(
config,
cache_config,
quant_config,
has_relative_attention_bias=False,
prefix=f"{prefix}.attn")

self.EncDecAttention = T5Attention(config,
cache_config,
quant_config,
has_relative_attention_bias=False,
prefix=f"{prefix}.attn")

self.layer_norm = T5LayerNorm(config.d_model,
eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
Expand Down Expand Up @@ -489,9 +484,11 @@ def __init__(
prefix=f"{prefix}.self_attn",
)
if self.is_decoder:
self.cross_attn = T5LayerCrossAttention(config, cache_config,
quant_config,
prefix=f"{prefix}.encoder_attn")
self.cross_attn = T5LayerCrossAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder_attn")
self.fc = T5LayerFF(config, quant_config)

def forward(
Expand Down Expand Up @@ -584,13 +581,12 @@ def __init__(
self.is_decoder = config.is_decoder

self.block = nn.ModuleList([
T5Block(
config,
cache_config,
quant_config,
has_relative_attention_bias=bool(i == 0),
prefix=f"{prefix}.block.{i}"
) for i in range(config.num_layers)
T5Block(config,
cache_config,
quant_config,
has_relative_attention_bias=bool(i == 0),
prefix=f"{prefix}.block.{i}")
for i in range(config.num_layers)
])
self.final_layer_norm = T5LayerNorm(config.d_model,
eps=config.layer_norm_epsilon)
Expand Down Expand Up @@ -629,14 +625,15 @@ class T5Model(nn.Module):
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# def __init__(
# self,
# config: T5Config,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# lora_config: Optional[LoRAConfig] = None,
# ):
# def __init__(
# self,
# config: T5Config,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# lora_config: Optional[LoRAConfig] = None,
# ):
super().__init__()
# self.shared = nn.Embedding(config.vocab_size, config.d_model)
config = vllm_config.model_config.hf_config
Expand All @@ -658,15 +655,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, cache_config, quant_config,
self.shared, prefix=f"{prefix}.encoder")
self.encoder = T5Stack(encoder_config,
cache_config,
quant_config,
self.shared,
prefix=f"{prefix}.encoder")

decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, cache_config, quant_config,
self.shared, prefix=f"{prefix}.decoder")
self.decoder = T5Stack(decoder_config,
cache_config,
quant_config,
self.shared,
prefix=f"{prefix}.decoder")

def forward(
self,
Expand Down Expand Up @@ -703,20 +706,20 @@ class T5ForConditionalGeneration(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

# def __init__(
# self,
# config: T5Config,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# lora_config: Optional[LoRAConfig] = None,
# ):
# def __init__(
# self,
# config: T5Config,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# lora_config: Optional[LoRAConfig] = None,
# ):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.model_dim = config.d_model
self.model = T5Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"))
# self.model = T5Model(config,
# cache_config,
# quant_config,
Expand Down Expand Up @@ -924,4 +927,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if shard_id:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
weight_loader(param, loaded_weight)

0 comments on commit 736f0e6

Please sign in to comment.