-
chp06 ipynb notebook extract: def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
model.eval()
# Prepare inputs to the model
input_ids = tokenizer.encode(text)
supported_context_length = model.pos_emb.weight.shape[0]
# Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake
# It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)
# Truncate sequences if they too long
input_ids = input_ids[:min(max_length, supported_context_length)]
# Pad sequences to the longest sequence
input_ids += [pad_token_id] * (max_length - len(input_ids))
input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension
# Model inference
with torch.no_grad():
logits = model(input_tensor)[:, -1, :] # Logits of the last output token
predicted_label = torch.argmax(logits, dim=-1).item()
# Return the classified result
return "spam" if predicted_label == 1 else "not spam Hi, if the Thank you. |
Beta Was this translation helpful? Give feedback.
Answered by
rasbt
Jan 22, 2025
Replies: 1 comment
-
That's a good point, and yes, you are right. I added a safeguard like that to the notebook in cell 15 (i.e., the CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
f"Dataset length {train_dataset.max_length} exceeds model's context "
f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
f"`max_length={BASE_CONFIG['context_length']}`"
) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
rasbt
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That's a good point, and yes, you are right. I added a safeguard like that to the notebook in cell 15 (i.e., the
assert
part):