We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
`import numpy as np from rouge import Rouge
rouge = Rouge()
def test_loop(dataloader, model): preds, labels = [], [] model.eval() for batch_data in tqdm(dataloader): batch_data = batch_data.to(device) with torch.no_grad(): generated_tokens = model.generate( batch_data["input_ids"], attention_mask=batch_data["attention_mask"], max_length=max_target_length, num_beams=4, no_repeat_ngram_size=2, ).cpu().numpy() if isinstance(generated_tokens, tuple): generated_tokens = generated_tokens[0] label_tokens = batch_data["labels"].cpu().numpy()
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True) preds += [' '.join(pred.strip()) for pred in decoded_preds] labels += [' '.join(label.strip()) for label in decoded_labels] scores = rouge.get_scores(hyps=preds, refs=labels)[0] result = {key: value['f'] * 100 for key, value in scores.items()} result['avg'] = np.mean(list(result.values())) print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n") return result`
The text was updated successfully, but these errors were encountered:
scores = rouge.get_scores(hyps=preds, refs=labels)[0] 就是这行代码
Sorry, something went wrong.
非常感谢!是的,我搞错了,Rouge 的计算我直接参考了别人的代码,没有自己细看。 这样写在评估多个句子的情况下就只取了第一个句子的值。
如果评估多个句子,需要添加 avg=True 参数来取平均值:
avg=True
scores = rouge.get_scores(hyps, refs, avg=True)
我已经在代码中修正了,再次感谢!
No branches or pull requests
`import numpy as np
from rouge import Rouge
rouge = Rouge()
def test_loop(dataloader, model):
preds, labels = [], []
model.eval()
for batch_data in tqdm(dataloader):
batch_data = batch_data.to(device)
with torch.no_grad():
generated_tokens = model.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
num_beams=4,
no_repeat_ngram_size=2,
).cpu().numpy()
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
label_tokens = batch_data["labels"].cpu().numpy()
The text was updated successfully, but these errors were encountered: