Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Seed Dataset to improve compatibility and simplify usage #1734

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 190 additions & 15 deletions camel/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
import os
import random
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sized,
TypeVar,
Union,
)
Expand Down Expand Up @@ -322,47 +325,219 @@ def to_pytorch_dataset(
return dataset


class SeedDataset(BaseDataset):
class SeedDataset(Dataset):
r"""A dataset containing validated seed examples for data generation.
Ensures that all items adhere to the DataPoint schema.

This class is used to initialize a dataset from a list of dictionary items,
validating each against the DataPoint schema.
This class can initialize from Hugging Face Datasets,
PyTorch Datasets, JSON file paths, or lists of dictionaries,
converting them into a consistent internal format.
"""

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This __init__ is getting really messy. It would be better to create specific methods to handle the conversion.

self,
data: List[Dict[str, str]],
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
cache_dir: Optional[str] = None,
seed: Optional[int] = None,
min_samples: int = 1,
strict: bool = False,
**kwargs,
):
r"""Initialize the seed dataset.
r"""Initialize the seed dataset and validate integrity.

Args:
data (List[Dict[str, str]]): List of dictionary items to create the
dataset from.
data (Union[HFDataset, Dataset, str, List[Dict[str, Any]]]):
Input data, which can be:
- A Hugging Face Dataset (HFDataset)
- A PyTorch Dataset (torch.utils.data.Dataset)
- A Path object representing the path to a JSON file
- A list of dictionaries with DataPoint-compatible fields
cache_dir (Optional[str]): Directory to cache dataset files.
(default: :obj:`None`)
seed (Optional[int]): Seed for reproducibility.
(default: :obj:`1`)
min_samples (int): Minimum number of samples required.
(default: :obj:`1`)
strict (bool): Whether to raise an error on invalid datapoints
(True) or skip/filter them (False). (default: False)
**kwargs: Additional dataset parameters.

Raises:
ValueError: If dataset size is less than min_samples or if sample
validation fails.
TypeError: If the data type is not supported.
ValueError: If dataset size is less than min_samples or
if sample validation fails.
FileNotFoundError: If the JSON file path doesn't exist.
json.JSONDecodeError: If the JSON file is invalid.
"""
if len(data) < min_samples:

# Store all parameters in metadata dict for compatibility
self._cache_dir = str(cache_dir) if cache_dir is not None else None
self._metadata = {
'cache_dir': self._cache_dir,
**kwargs,
}
self._rng = random.Random(seed)
self._strict = strict

# Type checking and conversion into list of dicts to have a
# consistent internal format. Since Seed Dataset should be
# small, we can load it entirely into memmory

if isinstance(data, HFDataset):
self._raw_data = [dict(item) for item in data]
elif isinstance(data, Dataset):
if not isinstance(data, Sized):
raise TypeError(
f"{type(data).__name__} does not implement `__len__()`."
)

# Make MyPy happy by ensuring indexability
assert callable(
getattr(data, "__getitem__", None)
), "Dataset does not support indexing."

self._raw_data = [dict(data[i]) for i in range(len(data))]
elif isinstance(data, Path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some more safety features here. The current configuration does not:

  • check if the list contains only dictionaries
  • check if the file is valid JSON before attempting to load
  • handle potential encoding issues

if not data.exists():
raise FileNotFoundError(f"JSON file not found: {data}")
with data.open('r') as f:
self._raw_data = json.load(f)
if not isinstance(self._raw_data, list):
raise ValueError(
"JSON file must contain a list of dictionaries"
)
elif isinstance(data, list):
self._raw_data = data if data is not None else []
else:
raise TypeError("Unsupported data type")

self.data: List[DataPoint] = []
self._setup(min_samples)
self._length = len(self.data)

def __len__(self) -> int:
r"""Return the size of the dataset."""
return self._length

def __getitem__(self, idx: int) -> DataPoint:
r"""Get an item from the dataset.

Args:
idx (int): Index of the item to get.

Returns:
DataPoint: DataPoint from the dataset with the given index.

Raises:
IndexError: If idx is out of bounds.
"""
if idx < 0 or idx >= self._length:
raise IndexError(
f"Index {idx} out of bounds for dataset of size {self._length}"
)
return self.data[idx]

def sample(self) -> DataPoint:
r"""Sample a random datapoint from the dataset.

Returns:
DataPoint: A randomly sampled DataPoint.

Raises:
RuntimeError: If the dataset is empty.
"""
if self._length == 0:
raise RuntimeError("Dataset is empty, cannot sample.")
idx = self._rng.randint(0, self._length - 1)
return self[idx]

def _setup(self, min_samples: int) -> None:
r"""Set up the dataset by validating and processing raw data.

This method:
1. Checks if the dataset meets the minimum sample requirement.
2. Creates the cache directory if specified.
3. Processes raw data into DataPoint objects
for validation and consistency.

In non-strict mode, invalid datapoints are filtered out
rather than raising an error.

Args:
min_samples (int): Minimum number of samples required.

Raises:
ValueError: If the dataset size is less than min_samples or
if sample validation fails (in strict mode),
or if the dataset size is smaller than
min_samples after filtering invalid datapoints
(in non-strict mode).
OSError: If cache directory creation fails.
"""
if len(self._raw_data) < min_samples:
raise ValueError(
f"Seed dataset must contain at least {min_samples} samples."
f"Dataset must have at least {min_samples} samples, "
f"got {len(self._raw_data)}"
)

super().__init__(
data=data,
cache_dir=cache_dir,
**kwargs,
if self._cache_dir:
try:
os.makedirs(self._cache_dir, exist_ok=True)
logger.debug(f"Created cache directory: {self._cache_dir}")
except OSError as e:
logger.error(
f"Failed to create cache directory {self._cache_dir}: {e}"
)
raise

# Process raw data into DataPoint objects for validation purposes
if not self._raw_data:
if min_samples > 0:
raise ValueError("No data provided, but min_samples > 0")
logger.debug("No raw data to process")
return

def create_datapoint(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is supposed to validate whether a dict is a valid DataPoint. Providing default values for non-existant Keys defeats the whole purpose.

Note that it should not fail if there are extra fields.

item: Dict[str, Any], idx: int
) -> Optional[DataPoint]:
try:
return DataPoint(
question=item.get('question', ''),
rationale=item.get('rationale', ''),
final_answer=item.get('final_answer', ''),
metadata=item.get('metadata', {})
if isinstance(item.get('metadata'), dict)
else {},
difficulty=item.get('difficulty', ''), # Match BaseDataset
# raw_markdown='' if DataPoint supports it
)

except ValidationError as e:
if self._strict:
raise ValueError(
f"Sample at index {idx} validation error: {e}"
)
else:
logger.warning(
f"Skipping invalid sample at index {idx} "
f"due to validation error: {e}"
)
return None

raw_data = [
create_datapoint(item, i) for i, item in enumerate(self._raw_data)
]
self.data = [dp for dp in raw_data if dp is not None]
logger.debug(
f"Processed {len(raw_data)} data points, of which "
f"{len(self.data)} were valid."
)

@property
def metadata(self) -> Dict[str, Any]:
r"""Get dataset metadata."""
return self._metadata.copy()


class SyntheticDataset(BaseDataset):
r"""A dataset for storing synthetically generated data points.
Expand Down
Loading
Loading