diff --git a/camel/datasets/base.py b/camel/datasets/base.py index e085eeb462..30093cc2dc 100644 --- a/camel/datasets/base.py +++ b/camel/datasets/base.py @@ -12,14 +12,17 @@ # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json import os import random +from pathlib import Path from typing import ( Any, Callable, Dict, List, Optional, + Sized, TypeVar, Union, ) @@ -322,47 +325,219 @@ def to_pytorch_dataset( return dataset -class SeedDataset(BaseDataset): +class SeedDataset(Dataset): r"""A dataset containing validated seed examples for data generation. Ensures that all items adhere to the DataPoint schema. - This class is used to initialize a dataset from a list of dictionary items, - validating each against the DataPoint schema. + This class can initialize from Hugging Face Datasets, + PyTorch Datasets, JSON file paths, or lists of dictionaries, + converting them into a consistent internal format. """ def __init__( self, - data: List[Dict[str, str]], + data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]], cache_dir: Optional[str] = None, + seed: Optional[int] = None, min_samples: int = 1, + strict: bool = False, **kwargs, ): - r"""Initialize the seed dataset. + r"""Initialize the seed dataset and validate integrity. Args: - data (List[Dict[str, str]]): List of dictionary items to create the - dataset from. + data (Union[HFDataset, Dataset, str, List[Dict[str, Any]]]): + Input data, which can be: + - A Hugging Face Dataset (HFDataset) + - A PyTorch Dataset (torch.utils.data.Dataset) + - A Path object representing the path to a JSON file + - A list of dictionaries with DataPoint-compatible fields cache_dir (Optional[str]): Directory to cache dataset files. (default: :obj:`None`) + seed (Optional[int]): Seed for reproducibility. + (default: :obj:`1`) min_samples (int): Minimum number of samples required. (default: :obj:`1`) + strict (bool): Whether to raise an error on invalid datapoints + (True) or skip/filter them (False). (default: False) **kwargs: Additional dataset parameters. Raises: - ValueError: If dataset size is less than min_samples or if sample - validation fails. + TypeError: If the data type is not supported. + ValueError: If dataset size is less than min_samples or + if sample validation fails. + FileNotFoundError: If the JSON file path doesn't exist. + json.JSONDecodeError: If the JSON file is invalid. """ - if len(data) < min_samples: + + # Store all parameters in metadata dict for compatibility + self._cache_dir = str(cache_dir) if cache_dir is not None else None + self._metadata = { + 'cache_dir': self._cache_dir, + **kwargs, + } + self._rng = random.Random(seed) + self._strict = strict + + # Type checking and conversion into list of dicts to have a + # consistent internal format. Since Seed Dataset should be + # small, we can load it entirely into memmory + + if isinstance(data, HFDataset): + self._raw_data = [dict(item) for item in data] + elif isinstance(data, Dataset): + if not isinstance(data, Sized): + raise TypeError( + f"{type(data).__name__} does not implement `__len__()`." + ) + + # Make MyPy happy by ensuring indexability + assert callable( + getattr(data, "__getitem__", None) + ), "Dataset does not support indexing." + + self._raw_data = [dict(data[i]) for i in range(len(data))] + elif isinstance(data, Path): + if not data.exists(): + raise FileNotFoundError(f"JSON file not found: {data}") + with data.open('r') as f: + self._raw_data = json.load(f) + if not isinstance(self._raw_data, list): + raise ValueError( + "JSON file must contain a list of dictionaries" + ) + elif isinstance(data, list): + self._raw_data = data if data is not None else [] + else: + raise TypeError("Unsupported data type") + + self.data: List[DataPoint] = [] + self._setup(min_samples) + self._length = len(self.data) + + def __len__(self) -> int: + r"""Return the size of the dataset.""" + return self._length + + def __getitem__(self, idx: int) -> DataPoint: + r"""Get an item from the dataset. + + Args: + idx (int): Index of the item to get. + + Returns: + DataPoint: DataPoint from the dataset with the given index. + + Raises: + IndexError: If idx is out of bounds. + """ + if idx < 0 or idx >= self._length: + raise IndexError( + f"Index {idx} out of bounds for dataset of size {self._length}" + ) + return self.data[idx] + + def sample(self) -> DataPoint: + r"""Sample a random datapoint from the dataset. + + Returns: + DataPoint: A randomly sampled DataPoint. + + Raises: + RuntimeError: If the dataset is empty. + """ + if self._length == 0: + raise RuntimeError("Dataset is empty, cannot sample.") + idx = self._rng.randint(0, self._length - 1) + return self[idx] + + def _setup(self, min_samples: int) -> None: + r"""Set up the dataset by validating and processing raw data. + + This method: + 1. Checks if the dataset meets the minimum sample requirement. + 2. Creates the cache directory if specified. + 3. Processes raw data into DataPoint objects + for validation and consistency. + + In non-strict mode, invalid datapoints are filtered out + rather than raising an error. + + Args: + min_samples (int): Minimum number of samples required. + + Raises: + ValueError: If the dataset size is less than min_samples or + if sample validation fails (in strict mode), + or if the dataset size is smaller than + min_samples after filtering invalid datapoints + (in non-strict mode). + OSError: If cache directory creation fails. + """ + if len(self._raw_data) < min_samples: raise ValueError( - f"Seed dataset must contain at least {min_samples} samples." + f"Dataset must have at least {min_samples} samples, " + f"got {len(self._raw_data)}" ) - super().__init__( - data=data, - cache_dir=cache_dir, - **kwargs, + if self._cache_dir: + try: + os.makedirs(self._cache_dir, exist_ok=True) + logger.debug(f"Created cache directory: {self._cache_dir}") + except OSError as e: + logger.error( + f"Failed to create cache directory {self._cache_dir}: {e}" + ) + raise + + # Process raw data into DataPoint objects for validation purposes + if not self._raw_data: + if min_samples > 0: + raise ValueError("No data provided, but min_samples > 0") + logger.debug("No raw data to process") + return + + def create_datapoint( + item: Dict[str, Any], idx: int + ) -> Optional[DataPoint]: + try: + return DataPoint( + question=item.get('question', ''), + rationale=item.get('rationale', ''), + final_answer=item.get('final_answer', ''), + metadata=item.get('metadata', {}) + if isinstance(item.get('metadata'), dict) + else {}, + difficulty=item.get('difficulty', ''), # Match BaseDataset + # raw_markdown='' if DataPoint supports it + ) + + except ValidationError as e: + if self._strict: + raise ValueError( + f"Sample at index {idx} validation error: {e}" + ) + else: + logger.warning( + f"Skipping invalid sample at index {idx} " + f"due to validation error: {e}" + ) + return None + + raw_data = [ + create_datapoint(item, i) for i, item in enumerate(self._raw_data) + ] + self.data = [dp for dp in raw_data if dp is not None] + logger.debug( + f"Processed {len(raw_data)} data points, of which " + f"{len(self.data)} were valid." ) + @property + def metadata(self) -> Dict[str, Any]: + r"""Get dataset metadata.""" + return self._metadata.copy() + class SyntheticDataset(BaseDataset): r"""A dataset for storing synthetically generated data points. diff --git a/test/datasets/test_base_dataset.py b/test/datasets/test_base_dataset.py index 83db9ad60a..0b38635d84 100644 --- a/test/datasets/test_base_dataset.py +++ b/test/datasets/test_base_dataset.py @@ -12,6 +12,7 @@ # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json import tempfile from unittest.mock import AsyncMock, MagicMock @@ -19,6 +20,7 @@ import torch from datasets import Dataset as HFDataset from pydantic import ValidationError +from torch.utils.data import Dataset from camel.datasets.base import ( BaseDataset, @@ -192,12 +194,465 @@ def test_base_dataset_metadata(): def test_seed_dataset_init(sample_data): - r"""Test SeedDataset initialization.""" + r"""Test SeedDataset initialization with valid input data.""" dataset = SeedDataset(data=sample_data, min_samples=1) - assert dataset._raw_data == sample_data - - with pytest.raises(ValueError): + assert dataset._raw_data == sample_data, "Raw data should match input list" + assert len(dataset.data) == 2, "Processed data should have 2 items" + assert isinstance( + dataset.data[0], DataPoint + ), "Items should be DataPoint instances" + assert ( + dataset.data[0].question == 'What is 2+2?' + ), "DataPoint content should match input" + with pytest.raises(ValueError) as exc_info: SeedDataset(data=sample_data, min_samples=3) + assert "must have at least 3 samples" in str( + exc_info.value + ), "Should raise ValueError for insufficient samples" + + # Test with an empty dataset when min_samples is 0 + dataset_empty = SeedDataset(data=[], min_samples=0) + assert len(dataset_empty.data) == 0, "Empty dataset should have no items" + + +def test_seed_dataset_strict_mode(): + r"""Test SeedDataset in strict mode where + invalid datapoints raise errors.""" + invalid_data = [ + { + "question": "Incomplete sample", + "rationale": "Some reasoning", + } # Missing 'final_answer' + ] + with pytest.raises(ValueError) as exc_info: + # strict=True should raise an error on the first invalid datapoint + SeedDataset(data=invalid_data, min_samples=1, strict=True) + assert "validation error" in str( + exc_info.value + ), "Strict mode should raise ValueError for invalid datapoint" + + +def test_seed_dataset_non_strict_mode(): + r"""Test SeedDataset in non-strict mode where + invalid datapoints are skipped.""" + + invalid_data = [ + {"question": "Incomplete sample", "rationale": "Some reasoning"} + ] + # strict=False should filter out invalid samples + dataset = SeedDataset(data=invalid_data, min_samples=0, strict=False) + # Expect that the invalid sample is skipped, so + # dataset.data should be empty + assert ( + len(dataset.data) == 0 + ), "Non-strict mode should filter out invalid samples" + + +def test_seed_dataset_init_hf_dataset(): + r"""Test SeedDataset initialization with a mock + IMDB-style Hugging Face Dataset.""" + # Mock IMDB-style data + mock_imdb_data = [ + { + "text": "This movie was absolutely fantastic, " + "a real joy to watch!", + "label": 1, + "rationale": "The reviewer uses positive adjectives like " + "'fantastic' and 'joy'.", + }, + { + "text": "Terrible acting and a boring plot ruined this film.", + "label": 0, + "rationale": "Negative terms like 'terrible' and 'boring' " + "suggest dissatisfaction.", + }, + { + "text": "An incredible cast made this a thrilling experience.", + "label": 1, + "rationale": "Words like 'incredible' and 'thrilling' reflect " + "a positive reaction.", + }, + ] + + hf_dataset = HFDataset.from_list(mock_imdb_data) + + mapped_dataset = hf_dataset.map( + lambda example: { + "question": "What is the sentiment of this review? " + f"{example['text'][:30]}...", + "rationale": example["rationale"], + "final_answer": "positive" + if example["label"] == 1 + else "negative", + } + ) + + # Valid data + dataset = SeedDataset(data=mapped_dataset, min_samples=1, strict=True) + assert len(dataset.data) == 3, "There should be 3 valid data points." + assert isinstance( + dataset.data[0], DataPoint + ), "Items should be DataPoint instances." + assert ( + dataset.data[0].question == mapped_dataset[0]["question"] + ), "Question should match input." + assert ( + dataset.data[0].rationale == mapped_dataset[0]["rationale"] + ), "Rationale should match input." + assert ( + dataset.data[0].final_answer == mapped_dataset[0]["final_answer"] + ), "Final answer should match input." + + # Invalid data + invalid_data_missing = [ + { + "question": "What is the sentiment of this review? " + "Missing rationale...", + "final_answer": "positive", + # Missing "rationale" + } + ] + hf_invalid_missing = HFDataset.from_list(invalid_data_missing) + with pytest.raises(ValueError, match="Sample at index 0 validation error"): + SeedDataset(data=hf_invalid_missing, min_samples=1, strict=True) + + empty_data = [] + hf_empty = HFDataset.from_list(empty_data) + with pytest.raises( + ValueError, match="Dataset must have at least 1 samples, got 0" + ): + SeedDataset(data=hf_empty, min_samples=1, strict=True) + + dataset_empty = SeedDataset(data=hf_empty, min_samples=0, strict=True) + assert ( + len(dataset_empty.data) == 0 + ), "Empty dataset should have no valid items." + + non_dict_data = [ + "Not a dictionary", + { + "question": "Valid question", + "rationale": "Valid rationale", + "final_answer": "positive", + }, + ] + with pytest.raises(TypeError, match="Unsupported data type"): + SeedDataset(data=non_dict_data, min_samples=1, strict=True) + + data_with_optional = [ + { + "question": "What is the sentiment of this review? " + "This movie was awesome!...", + "rationale": "Positive sentiment detected.", + "final_answer": "positive", + "difficulty": "medium", + "metadata": {"source": "imdb"}, + } + ] + hf_optional = HFDataset.from_list(data_with_optional) + dataset_optional = SeedDataset( + data=hf_optional, min_samples=1, strict=True + ) + assert ( + dataset_optional.data[0].difficulty == "medium" + ), "Difficulty field should be 'medium'." + assert dataset_optional.data[0].metadata == { + "source": "imdb" + }, "Metadata should match input." + + +def test_seed_dataset_init_pytorch_dataset(): + r"""Test SeedDataset initialization with a + mock IMDB-style PyTorch Dataset.""" + + # Define a reusable PyTorch Dataset class + class MockIMDBDataset(Dataset): + def __init__(self, data_list): + self.data = data_list + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + valid_data = [ + { + "text": "This movie was absolutely fantastic, " + "a real joy to watch!", + "label": 1, + "rationale": "The reviewer uses positive adjectives like " + "'fantastic' and 'joy'.", + }, + { + "text": "Terrible acting and a boring plot ruined this film.", + "label": 0, + "rationale": "Negative terms like 'terrible' and 'boring' " + "suggest dissatisfaction.", + }, + { + "text": "An incredible cast made this a thrilling experience.", + "label": 1, + "rationale": "Words like 'incredible' and 'thrilling' " + "reflect a positive reaction.", + }, + ] + + mapped_data = [ + { + "question": "What is the sentiment of this review? " + f"{item['text'][:30]}...", + "rationale": item["rationale"], + "final_answer": "positive" if item["label"] == 1 else "negative", + } + for item in valid_data + ] + + pytorch_dataset = MockIMDBDataset(mapped_data) + dataset = SeedDataset(data=pytorch_dataset, min_samples=1) + assert len(dataset.data) == 3 + assert isinstance(dataset.data[0], DataPoint) + assert dataset.data[0].question == mapped_data[0]["question"] + assert dataset.data[0].rationale == mapped_data[0]["rationale"] + assert dataset.data[0].final_answer == mapped_data[0]["final_answer"] + + invalid_data_missing = [ + { + "question": "What is the sentiment of this review? " + "Missing rationale...", + "final_answer": "positive", + # Missing "rationale" + } + ] + pytorch_invalid_missing = MockIMDBDataset(invalid_data_missing) + with pytest.raises(ValueError, match="Sample at index 0 validation error"): + SeedDataset(data=pytorch_invalid_missing, min_samples=1) + + empty_data = [] + pytorch_empty = MockIMDBDataset(empty_data) + with pytest.raises( + ValueError, match="Dataset must have at least 1 samples, got 0" + ): + SeedDataset(data=pytorch_empty, min_samples=1) + + dataset_empty = SeedDataset(data=pytorch_empty, min_samples=0) + assert len(dataset_empty.data) == 0 + + non_dict_data = [ + "Not a dictionary", + { + "question": "Valid question", + "rationale": "Valid rationale", + "final_answer": "positive", + }, + ] + pytorch_non_dict = MockIMDBDataset(non_dict_data) + with pytest.raises(TypeError, match="Unsupported data type"): + SeedDataset(data=pytorch_non_dict, min_samples=1) + + data_with_optional = [ + { + "question": "What is the sentiment of this review? " + "This movie was awesome!...", + "rationale": "Positive sentiment detected.", + "final_answer": "positive", + "difficulty": "medium", + "metadata": {"source": "imdb"}, + } + ] + pytorch_optional = MockIMDBDataset(data_with_optional) + dataset_optional = SeedDataset(data=pytorch_optional, min_samples=1) + assert dataset_optional.data[0].difficulty == "medium" + assert dataset_optional.data[0].metadata == {"source": "imdb"} + + +def test_seed_dataset_init_list_extended(sample_data): + r"""Test SeedDataset initialization with a list of dictionaries.""" + + data_with_optional = [ + *sample_data, + { + "question": "What is 5-3?", + "rationale": "Subtraction", + "final_answer": "2", + "difficulty": "easy", # Optional field + "metadata": {"topic": "math"}, # Optional field + }, + ] + dataset = SeedDataset(data=data_with_optional, min_samples=1) + assert len(dataset.data) == 3, "Dataset should contain 3 items" + assert ( + dataset.data[2].difficulty == "easy" + ), "Optional difficulty field should be preserved" + assert dataset.data[2].metadata == { + "topic": "math" + }, "Optional metadata field should be preserved" + assert ( + dataset.data[0].question == sample_data[0]["question"] + ), "First item question should match" + assert ( + dataset.data[1].final_answer == sample_data[1]["final_answer"] + ), "Second item final_answer should match" + + invalid_data_missing = [ + {"question": "What is 2+2?", "rationale": "Addition"} + ] + with pytest.raises(ValueError, match="Sample at index 0 validation error"): + SeedDataset(data=invalid_data_missing, min_samples=1) + + invalid_data_type = [ + { + "question": "What is 3+3?", + "rationale": "Addition", + "final_answer": 6, + } + ] + with pytest.raises(ValueError, match="Sample at index 0 validation error"): + SeedDataset(data=invalid_data_type, min_samples=1) + + empty_data = [] + with pytest.raises( + ValueError, match="Dataset must have at least 1 samples, got 0" + ): + SeedDataset(data=empty_data, min_samples=1) + + dataset_empty = SeedDataset(data=empty_data, min_samples=0) + assert ( + len(dataset_empty.data) == 0 + ), "Empty dataset with min_samples=0 should have no items" + + non_dict_data = [ + "Not a dictionary", + { + "question": "What is 4+4?", + "rationale": "Addition", + "final_answer": "8", + }, + ] + with pytest.raises(TypeError, match="Unsupported data type"): + SeedDataset(data=non_dict_data, min_samples=1) + + mixed_data = [ + { + "question": "What is 1+1?", + "rationale": "Addition", + "final_answer": "2", + }, + {"question": "What is 2+2?"}, + ] + with pytest.raises(ValueError, match="Sample at index 1 validation error"): + SeedDataset(data=mixed_data, min_samples=1) + + +def test_seed_dataset_init_json_file(): + r"""Test SeedDataset initialization with a JSON file path.""" + + sample_data = [ + { + "question": "What is 2+2?", + "rationale": "Addition", + "final_answer": "4", + }, + { + "question": "What is 3×3?", + "rationale": "Multiplication", + "final_answer": "9", + }, + ] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump(sample_data, temp_file) + temp_file.flush() + dataset = SeedDataset(data=temp_file.name, min_samples=1) + assert len(dataset.data) == 2, "Should have 2 items from the JSON file" + assert isinstance( + dataset.data[0], DataPoint + ), "Items should be DataPoint instances" + assert ( + dataset.data[0].question == "What is 2+2?" + ), "Question should match the JSON data" + assert ( + dataset.data[1].final_answer == "9" + ), "Final answer should match the JSON data" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as invalid_file: + invalid_file.write("Invalid JSON") + invalid_file.flush() + with pytest.raises(json.JSONDecodeError): + SeedDataset(data=invalid_file.name, min_samples=1) + + invalid_data_missing = [ + { + "question": "What is 2+2?", + "rationale": "Addition", # Missing "final_answer" + } + ] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump(invalid_data_missing, temp_file) + temp_file.flush() + with pytest.raises( + ValueError, match="Sample at index 0 validation error" + ): + SeedDataset(data=temp_file.name, min_samples=1) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump([], temp_file) + temp_file.flush() + with pytest.raises( + ValueError, match="Dataset must have at least 1 samples, got 0" + ): + SeedDataset(data=temp_file.name, min_samples=1) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump([], temp_file) + temp_file.flush() + dataset_empty = SeedDataset(data=temp_file.name, min_samples=0) + assert ( + len(dataset_empty.data) == 0 + ), "Empty dataset with min_samples=0 should have no items" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump({"not": "a list"}, temp_file) + temp_file.flush() + with pytest.raises( + ValueError, match="JSON file must contain a list of dictionaries" + ): + SeedDataset(data=temp_file.name, min_samples=1) + + data_with_optional = [ + { + "question": "What is 5-3?", + "rationale": "Subtraction", + "final_answer": "2", + "difficulty": "easy", + "metadata": {"topic": "math"}, + } + ] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump(data_with_optional, temp_file) + temp_file.flush() + dataset_optional = SeedDataset(data=temp_file.name, min_samples=1) + assert ( + dataset_optional.data[0].difficulty == "easy" + ), "Optional difficulty field should be preserved" + assert dataset_optional.data[0].metadata == { + "topic": "math" + }, "Optional metadata field should be preserved" + + data_with_extra = [ + { + "question": "What is 4+4?", + "rationale": "Addition", + "final_answer": "8", + "extra_field": "should be ignored", + } + ] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as temp_file: + json.dump(data_with_extra, temp_file) + temp_file.flush() + dataset_extra = SeedDataset(data=temp_file.name, min_samples=1) + assert ( + "extra_field" not in dataset_extra.data[0].__dict__ + ), "Extra fields should be ignored" def test_synthetic_dataset_init():