-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass@k #519
Open
clefourrier
wants to merge
4
commits into
main
Choose a base branch
from
clem_pass_at_k
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+141
−1
Open
Pass@k #519
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
import logging | ||
import os | ||
from typing import Callable, Literal | ||
from typing import Callable, Literal, Union | ||
|
||
import nltk | ||
import numpy as np | ||
|
@@ -1043,3 +1043,118 @@ def compute_score(self, pred: str, gold: str) -> int: | |
if self.type_exact_match == "suffix": | ||
return 1 if pred.endswith(gold) else 0 | ||
return 1 if gold == pred else 0 | ||
|
||
|
||
class PassAtK: | ||
def __init__( | ||
self, | ||
k: int, | ||
n: int = None, | ||
normalize_gold: Callable = None, | ||
normalize_pred: Callable = None, | ||
strip_strings: bool = False, | ||
sample_scoring_function: Union[Callable[[str, str], float], str] = None, | ||
): | ||
"""Computing pass at k | ||
|
||
Args: | ||
k (int): Threshold for the number of successful attempts. | ||
n (int): Number of samples to generate | ||
normalize_gold (callable, optional): Function to use to normalize the reference strings. | ||
Defaults to None if no normalization is applied. | ||
normalize_pred (callable, optional): Function to use to normalize the predicted strings. | ||
Defaults to None if no normalization is applied. | ||
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. | ||
sample_scoring_function (callable or str, optional): Function to use to score each sample. | ||
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1) | ||
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full". | ||
`prefix` checks if the prediction starts with the gold, | ||
`suffix` if the prediction ends with the gold, | ||
`full` if the prediction and gold are equal | ||
""" | ||
self.k = k | ||
self.n = n | ||
self.normalize_gold = normalize_gold | ||
self.normalize_pred = normalize_pred | ||
self.strip_strings = strip_strings | ||
|
||
# Managed the logic of the per prediction of sample scoring | ||
if callable(sample_scoring_function): | ||
self.score_sample = sample_scoring_function | ||
self.type_exact_match = None | ||
else: | ||
if isinstance(sample_scoring_function, str): | ||
if sample_scoring_function not in ["prefix", "suffix", "full"]: | ||
raise ValueError( | ||
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." | ||
) | ||
self.type_exact_match = sample_scoring_function | ||
else: | ||
self.type_exact_match = "full" | ||
self.score_sample = self.default_sample_scoring | ||
|
||
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]: | ||
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Core logic here |
||
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score, | ||
then aggregates the scores over the samples using a pass@k. | ||
|
||
Args: | ||
golds (list[str]): Reference targets | ||
predictions (list[str]): k predicted strings | ||
|
||
Returns: | ||
float: Aggregated score over the current sample's items. | ||
""" | ||
if len(golds) > 1: | ||
raise Exception("Cannot compute pass@k with several golds") | ||
|
||
if self.n is None: | ||
self.n = len(predictions) | ||
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.") | ||
elif len(predictions) < self.n: | ||
logger.warning(f"Number of predictions is less than {self.n} for pass@k.") | ||
|
||
gold = self.get_processed_gold(golds[0]) | ||
|
||
all_scores = [] | ||
for pred in predictions[: self.n]: | ||
cur_pred = self.get_processed_pred(pred=pred) | ||
all_scores.append(self.score_sample(cur_pred, gold)) | ||
|
||
return self.pass_at_k(all_scores) | ||
|
||
def get_processed_gold(self, gold: str) -> float: | ||
if self.strip_strings: | ||
gold = gold.strip() | ||
|
||
if self.normalize_gold: | ||
gold = self.normalize_gold(gold) | ||
|
||
return gold | ||
|
||
def get_processed_pred(self, pred: str) -> float: | ||
if not pred: | ||
return "" | ||
|
||
if self.strip_strings: | ||
pred = pred.strip() | ||
|
||
if self.normalize_pred: | ||
pred = self.normalize_pred(pred) | ||
|
||
return pred | ||
|
||
def default_sample_scoring(self, pred: str, gold: str) -> int: | ||
if self.type_exact_match == "prefix": | ||
return 1 if pred.startswith(gold) else 0 | ||
if self.type_exact_match == "suffix": | ||
return 1 if pred.endswith(gold) else 0 | ||
return 1 if gold == pred else 0 | ||
|
||
def pass_at_k(self, all_scores: list[int]) -> float: | ||
"""Algo from https://arxiv.org/pdf/2107.03374""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass at K here, literally the one from codex |
||
c: int = all_scores.count(1) | ||
if self.n - c < self.k: | ||
return 1.0 | ||
|
||
return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it as exhaustive/customizable as the other metrics (full exact match for the individual predictions by default, options to normalize strings in case you use it for math evals for ex) but I can remove some options if you feel that's too much complexity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having higghly paramitrized metrics is good imo, it does not add that much compleixity here