Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass@k #519

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Faithfulness,
LoglikelihoodAcc,
MajAtK,
PassAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand Down Expand Up @@ -364,6 +365,30 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
pass_at_1 = SampleLevelMetric(
metric_name="pass@1:32_samples",
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_10 = SampleLevelMetric(
metric_name="pass@10:32_samples",
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_100 = SampleLevelMetric(
metric_name="pass@100:32_samples",
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
perfect_exact_match = SampleLevelMetric(
metric_name="perfect_em",
sample_level_fn=ExactMatches().compute,
Expand Down
117 changes: 116 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import logging
import os
from typing import Callable, Literal
from typing import Callable, Literal, Union

import nltk
import numpy as np
Expand Down Expand Up @@ -1043,3 +1043,118 @@ def compute_score(self, pred: str, gold: str) -> int:
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0


class PassAtK:
def __init__(
self,
k: int,
n: int = None,
normalize_gold: Callable = None,
normalize_pred: Callable = None,
strip_strings: bool = False,
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
):
"""Computing pass at k
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it as exhaustive/customizable as the other metrics (full exact match for the individual predictions by default, options to normalize strings in case you use it for math evals for ex) but I can remove some options if you feel that's too much complexity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having higghly paramitrized metrics is good imo, it does not add that much compleixity here


Args:
k (int): Threshold for the number of successful attempts.
n (int): Number of samples to generate
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
sample_scoring_function (callable or str, optional): Function to use to score each sample.
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
`prefix` checks if the prediction starts with the gold,
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
self.k = k
self.n = n
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings

# Managed the logic of the per prediction of sample scoring
if callable(sample_scoring_function):
self.score_sample = sample_scoring_function
self.type_exact_match = None
else:
if isinstance(sample_scoring_function, str):
if sample_scoring_function not in ["prefix", "suffix", "full"]:
raise ValueError(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
)
self.type_exact_match = sample_scoring_function
else:
self.type_exact_match = "full"
self.score_sample = self.default_sample_scoring

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core logic here

It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
then aggregates the scores over the samples using a pass@k.

Args:
golds (list[str]): Reference targets
predictions (list[str]): k predicted strings

Returns:
float: Aggregated score over the current sample's items.
"""
if len(golds) > 1:
raise Exception("Cannot compute pass@k with several golds")

if self.n is None:
self.n = len(predictions)
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
elif len(predictions) < self.n:
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")

gold = self.get_processed_gold(golds[0])

all_scores = []
for pred in predictions[: self.n]:
cur_pred = self.get_processed_pred(pred=pred)
all_scores.append(self.score_sample(cur_pred, gold))

return self.pass_at_k(all_scores)

def get_processed_gold(self, gold: str) -> float:
if self.strip_strings:
gold = gold.strip()

if self.normalize_gold:
gold = self.normalize_gold(gold)

return gold

def get_processed_pred(self, pred: str) -> float:
if not pred:
return ""

if self.strip_strings:
pred = pred.strip()

if self.normalize_pred:
pred = self.normalize_pred(pred)

return pred

def default_sample_scoring(self, pred: str, gold: str) -> int:
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0

def pass_at_k(self, all_scores: list[int]) -> float:
"""Algo from https://arxiv.org/pdf/2107.03374"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass at K here, literally the one from codex

c: int = all_scores.count(1)
if self.n - c < self.k:
return 1.0

return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))
Loading