Skip to content

Commit

Permalink
let github find all the splits of all LoadHF
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Feb 11, 2025
1 parent a21a9fb commit 794f039
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion src/unitxt/test_utils/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os.path
import tempfile

from datasets import get_dataset_split_names, load_dataset, load_dataset_builder

from .. import add_to_catalog, register_local_catalog
from ..artifact import fetch_artifact
from ..collections import Collection
Expand Down Expand Up @@ -232,7 +234,7 @@ def test_wrong_predictions(
logger.info("*" * 10)
logger.info(warning_message)


# flake8: noqa: C901
def test_card(
card,
debug=False,
Expand All @@ -244,6 +246,42 @@ def test_card(
full_mismatch_prediction_values=None,
**kwargs,
):
if isinstance(card.loader, LoadHF):
path = card.loader.path
name = card.loader.name
splits = None
logger.critical(f"Starting the search for splits for LoadHF of path {path} and name {name}")
try:
ds_builder = load_dataset_builder(
path, name, trust_remote_code=True
)
dataset_info = ds_builder.info
if dataset_info.splits is not None:
splits = dataset_info.splits
# split names are known before the split themselves are pulled from HF,
# and we can postpone that pulling of the splits until actually demanded
logger.critical(f"for path {path} and name {name}, splits found by ds_builder_info: {splits}")
except Exception as e:
logger.critical(f"Exception {e} thrown by load_dataset_builder for path {path} and name {name}")
splits = None
if splits is None:
try:
splits = get_dataset_split_names(path=path, config_name=name, trust_remote_code=True)
logger.critical(f"for path {path} and name {name}, splits found by get_dataset_split_names: {splits}")
except Exception as e:
logger.critical(f"Exception {e} thrown by get_dataset_split_names for path {path} and name {name}")
splits = None
if splits is None:
try:
dataset = load_dataset(path=path, name=name, trust_remote_code=True)
splits = sorted(dataset.keys())
logger.critical(f"for path {path} and name {name}, splits found by load_dataset: {splits}")
except Exception as e:
logger.critical(f"Exception {e} thrown by load_dataset for path {path} and name {name}")
splits = None
if splits is None:
logger.critical(f"No splits found for path {path} and name {name}")

"""Tests a given card.
By default, the test goes over all templates defined in the card,
Expand Down

0 comments on commit 794f039

Please sign in to comment.