Skip to content

Commit

Permalink
Fix metrics formatting and style (#1591)
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel authored Feb 9, 2025
1 parent d9bf74c commit 98243fb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 202 deletions.
8 changes: 0 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ repos:
# Run the linter on all files except the specific one
- id: ruff
args: [--fix]
exclude: src/unitxt/metrics.py|examples/evaluate_existing_dataset_no_install.py
# Run the linter on the specific file with the ignore flag
- id: ruff
name: ruff (examples/evaluate_existing_dataset_no_install.py)
files: examples/evaluate_existing_dataset_no_install.py
args: [--fix, --ignore, T201]
# Run the formatter
- id: ruff-format

- repo: https://github.com/Yelp/detect-secrets
rev: v1.5.0
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ keep-runtime-typing = true
"tests/*" = ["TID251"]
"utils/*" = ["TID251"]
"src/unitxt/api.py" = ["B904"]
"src/unitxt/metrics.py" = ["C901"]
"src/unitxt/__init__.py" = ["F811", "F401"]
"src/unitxt/metric.py" = ["F811", "F401"]
"src/unitxt/dataset.py" = ["F811", "F401"]
Expand Down
212 changes: 20 additions & 192 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
FINQA_HASH = "42430b8613082bb4b85d49210284135d"
import ast
import json
import math
Expand All @@ -11,7 +10,18 @@
from collections import Counter, defaultdict
from dataclasses import field
from functools import lru_cache
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Generic,
List,
Literal,
Optional,
Tuple,
TypeVar,
Union,
)

import evaluate
import numpy
Expand Down Expand Up @@ -55,6 +65,8 @@
from .type_utils import Type, isoftype, parse_type_string, to_type_string
from .utils import deep_copy, recursive_copy

FINQA_HASH = "42430b8613082bb4b85d49210284135d"

logger = get_logger()
settings = get_settings()

Expand Down Expand Up @@ -378,7 +390,6 @@ def bootstrap(self, data: List[Any], score_names: List[str]):
return result


from typing import Generic, TypeVar

IntermediateType = TypeVar("IntermediateType")
PredictionType = TypeVar("PredictionType")
Expand Down Expand Up @@ -1779,7 +1790,7 @@ class ExactMatchMM(InstanceMetric):
@staticmethod
@lru_cache(maxsize=10000)
def exact_match(pred, gt):
"""Brought from MMStar"""
"""Brought from MMStar."""
answer = gt.lower().strip().replace("\n", " ")
predict = pred.lower().strip().replace("\n", " ")
try:
Expand Down Expand Up @@ -1873,183 +1884,6 @@ def levenshtein_distance(s1, s2):
return distances[-1]


class RelaxedCorrectness(GlobalMetric):
main_score = "relaxed_overall"
prediction_type = str # string representation is compared

def compute(
self, references: List[List[str]], predictions: List[str], task_data: List[Dict]
) -> dict:
return_dict = {
self.main_score: [],
"relaxed_human_split": [],
"relaxed_augmented_split": [],
}
for pred, ref, task_data_i in zip(predictions, references, task_data):
print(task_data_i)
type = task_data_i["type"]
score = self.relaxed_correctness(pred, ref[0])
score = 1.0 if score else 0.0
return_dict["relaxed_overall"].append(score)
if type == "human_test":
return_dict["relaxed_human_split"].append(score)
else:
return_dict["relaxed_augmented_split"].append(score)
return_dict = {
key: sum(value) / len(value)
for key, value in return_dict.items()
if len(value) > 0
}
return return_dict

@staticmethod
def _to_float(text: str):
try:
if text.endswith("%"):
# Convert percentages to floats.
return float(text.rstrip("%")) / 100.0
else:
return float(text)
except ValueError:
return None

def relaxed_correctness(
self, prediction, target, max_relative_change: float = 0.05
) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
This function is taken from https://github.com/QwenLM/Qwen-VL/blob/34b4c0ee7b07726371b960911f249fe61b362ca3/eval_mm/evaluate_vqa.py#L113
Args:
target: List of target string.
prediction: List of predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
prediction_float = self._to_float(prediction)
target_float = self._to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float - target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()


class WebsrcSquadF1(GlobalMetric):
main_score = "websrc_squad_f1"
prediction_type = Any # string representation is compared
DOMAINS = [
"auto",
"book",
"camera",
"game",
"jobs",
"movie",
"phone",
"restaurant",
"sports",
"university",
"hotel",
]

def compute(
self,
references: List[List[str]],
predictions: List[str],
task_data: List[Dict],
) -> dict:
"""ANLS image-text accuracy metric."""
evaluation_result = {}
# Group results by domain
subset_to_eval_samples = defaultdict(list)
for pred, ref, task_data_i in zip(predictions, references, task_data):
subset_to_eval_samples[task_data_i["domain"]].append([pred, ref[0]])
# Evaluate each domain
for subset, sub_eval_samples in subset_to_eval_samples.items():
judge_dict, metric_dict = self.evaluate_websrc(sub_eval_samples)
metric_dict.update({"num_example": len(sub_eval_samples)})
evaluation_result[subset] = metric_dict

# Aggregate results for all domains
printable_results = {}
for domain in self.DOMAINS:
if domain not in evaluation_result:
continue
printable_results[domain] = {
"num": int(evaluation_result[domain]["num_example"]),
"f1": round(evaluation_result[domain]["f1"], 3),
}
all_ins_f1 = np.sum(
[
cat_results["f1"] * cat_results["num_example"]
for cat_results in evaluation_result.values()
]
) / sum(
[cat_results["num_example"] for cat_results in evaluation_result.values()]
)
printable_results["Overall"] = {
"num": sum(
[
cat_results["num_example"]
for cat_results in evaluation_result.values()
]
),
"f1": round(all_ins_f1, 3),
}
return {self.main_score: printable_results["Overall"]["f1"]}

def evaluate_websrc(self, samples):
def _normalize_str(string):
# lower it
string = string.lower()

# strip leading and trailing whitespaces
string = string.strip()

return string

def _tokenize(text):
# Regex pattern to match words and isolate punctuation
pattern = r"\w+|[^\w\s]"
tokens = re.findall(pattern, text)
return tokens

def _compute_f1(sa, sb):
sa = _normalize_str(sa)
sb = _normalize_str(sb)

sa = _tokenize(sa)
sb = _tokenize(sb)

sa = set(sa)
sb = set(sb)

if len(sa) == 0 or len(sb) == 0:
return 0.0

comm = sa.intersection(sb)
prec = len(comm) / len(sb)
rec = len(comm) / len(sa)
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
return f1

judge_list = []
for sample in samples:
judge_list.append(_compute_f1(sample[1], sample[0]))

f1 = np.mean(judge_list)
return judge_list, {"f1": f1}


class RelaxedCorrectness(GlobalMetric):
main_score = "relaxed_overall"
prediction_type = str # string representation is compared
Expand All @@ -2071,12 +1905,11 @@ def compute(
return_dict["relaxed_human_split"].append(score)
else:
return_dict["relaxed_augmented_split"].append(score)
return_dict = {
return {
key: sum(value) / len(value)
for key, value in return_dict.items()
if len(value) > 0
}
return return_dict

@staticmethod
def _to_float(text: str):
Expand Down Expand Up @@ -2187,15 +2020,12 @@ def _normalize_str(string):
string = string.lower()

# strip leading and trailing whitespaces
string = string.strip()

return string
return string.strip()

def _tokenize(text):
# Regex pattern to match words and isolate punctuation
pattern = r"\w+|[^\w\s]"
tokens = re.findall(pattern, text)
return tokens
return re.findall(pattern, text)

def _compute_f1(sa, sb):
sa = _normalize_str(sa)
Expand All @@ -2213,8 +2043,7 @@ def _compute_f1(sa, sb):
comm = sa.intersection(sb)
prec = len(comm) / len(sb)
rec = len(comm) / len(sa)
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
return f1
return 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0

judge_list = []
for sample in samples:
Expand Down Expand Up @@ -3682,8 +3511,7 @@ def prepare(self):
def map_stream(
self, evaluation_inputs_stream: Generator[EvaluationInput, None, None]
):
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from sentence_transformers import SentenceTransformer, util

if self.model is None:
self.model = SentenceTransformer(
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 1,
"line_number": 68,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-09T12:55:56Z"
"generated_at": "2025-02-09T13:52:43Z"
}

0 comments on commit 98243fb

Please sign in to comment.