Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Seed Dataset to improve compatibility and simplify usage #1734

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 167 additions & 16 deletions camel/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
import os
import random
from typing import (
Expand Down Expand Up @@ -322,46 +323,196 @@ def to_pytorch_dataset(
return dataset


class SeedDataset(BaseDataset):
class SeedDataset(Dataset):
r"""A dataset containing validated seed examples for data generation.
Ensures that all items adhere to the DataPoint schema.

This class is used to initialize a dataset from a list of dictionary items,
validating each against the DataPoint schema.
This class can initialize from Hugging Face Datasets,
PyTorch Datasets, JSON file paths, or lists of dictionaries,
converting them into a consistent internal format.
"""

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This __init__ is getting really messy. It would be better to create specific methods to handle the conversion.

self,
data: List[Dict[str, str]],
data: Union[HFDataset, Dataset, str, List[Dict[str, Any]]],
cache_dir: Optional[str] = None,
min_samples: int = 1,
**kwargs,
):
r"""Initialize the seed dataset.
r"""Initialize the seed dataset and validate integrity.

Args:
data (List[Dict[str, str]]): List of dictionary items to create the
dataset from.
data (Union[HFDataset, Dataset, str, List[Dict[str, Any]]]):
Input data, which can be:
- A Hugging Face Dataset (HFDataset)
- A PyTorch Dataset (torch.utils.data.Dataset)
- A string path to a JSON file
- A list of dictionaries with DataPoint-compatible fields
cache_dir (Optional[str]): Directory to cache dataset files.
(default: :obj:`None`)
min_samples (int): Minimum number of samples required.
(default: :obj:`1`)
**kwargs: Additional dataset parameters.

Raises:
ValueError: If dataset size is less than min_samples or if sample
validation fails.
TypeError: If the data type is not supported.
ValueError: If dataset size is less than min_samples or
if sample validation fails.
FileNotFoundError: If the JSON file path doesn't exist.
json.JSONDecodeError: If the JSON file is invalid.
"""
if len(data) < min_samples:

# Store all parameters in metadata dict for compatibility
self._cache_dir = str(cache_dir) if cache_dir is not None else None
self._metadata = {
'cache_dir': self._cache_dir,
**kwargs,
}

# Type checking and conversion into list of dicts

if isinstance(data, HFDataset):
self._raw_data = [dict(item) for item in data]
elif isinstance(data, Dataset):
try:
self._raw_data = [dict(data[i]) for i in range(len(data))]
except (TypeError, KeyError, AttributeError) as e:
raise TypeError(f"Unsupported PyTorch Dataset: {e}")
elif isinstance(data, str):
if not os.path.exists(data):
raise FileNotFoundError(f"JSON file not found: {data}")
with open(data, 'r') as f:
self._raw_data = json.load(f)
if not isinstance(self._raw_data, list):
raise ValueError(
"JSON file must contain a list of dictionaries"
)
elif isinstance(data, list):
self._raw_data = data if data is not None else []
else:
raise TypeError("Unsupported data type")

self.data: List[DataPoint] = []
self._setup(min_samples)

def sample(self) -> DataPoint:
r"""Sample a random datapoint from the dataset.

Returns:
DataPoint: A randomly sampled DataPoint.

Raises:
RuntimeError: If the dataset is empty.
"""
if not self.data:
raise RuntimeError("Dataset is empty, cannot sample.")
idx = random.randint(0, len(self) - 1)
return self[idx]

def _setup(self, min_samples: int) -> None:
r"""Set up the dataset by validating and processing raw data.

This method:
1. Checks if the dataset meets the minimum sample requirement.
2. Creates the cache directory if specified.
3. Processes raw data into DataPoint objects
for validation and consistency.

Args:
min_samples (int): Minimum number of samples required.

Raises:
ValueError: If the dataset size is less than
min_samples or if validation fails.
OSError: If cache directory creation fails.
"""
if len(self._raw_data) < min_samples:
raise ValueError(
f"Seed dataset must contain at least {min_samples} samples."
f"Dataset must have at least {min_samples} samples, "
f"got {len(self._raw_data)}"
)

super().__init__(
data=data,
cache_dir=cache_dir,
**kwargs,
)
if self._cache_dir:
try:
os.makedirs(self._cache_dir, exist_ok=True)
logger.debug(f"Created cache directory: {self._cache_dir}")
except OSError as e:
logger.error(
f"Failed to create cache directory {self._cache_dir}: {e}"
)
raise

# Process raw data into DataPoint objects for validation purposes
if not self._raw_data:
if min_samples > 0:
raise ValueError("No data provided, but min_samples > 0")
logger.debug("No raw data to process")
return

if self._cache_dir:
try:
os.makedirs(self._cache_dir, exist_ok=True)
logger.debug(f"Created cache directory: {self._cache_dir}")
except OSError as e:
logger.error(
f"Failed to create cache directory {self._cache_dir}: {e}"
)
raise

if not self._raw_data:
if min_samples > 0:
raise ValueError("No data provided, but min_samples > 0")
logger.debug("No raw data to process")
return

def create_datapoint(item: Dict[str, Any], idx: int) -> DataPoint:
try:
return DataPoint(
question=item.get('question', ''),
rationale=item.get('rationale', ''),
final_answer=item.get('final_answer', ''),
metadata=item.get('metadata', {})
if isinstance(item.get('metadata'), dict)
else {},
difficulty=item.get('difficulty', ''), # Match BaseDataset
# raw_markdown='' if DataPoint supports it
)
except ValidationError as e:
raise ValueError(
f"Sample at index {idx} validation error: {e}"
)

self.data = [
create_datapoint(item, i) for i, item in enumerate(self._raw_data)
]
logger.debug(f"Processed {len(self.data)} data points")

def __len__(self) -> int:
r"""Return the size of the dataset."""
return len(self.data)

def __getitem__(self, idx: int) -> DataPoint:
r"""Get an item from the dataset.

Args:
idx (int): Index of the item to get.

Returns:
DataPoint: DataPoint from the dataset with the given index.

Raises:
IndexError: If idx is out of bounds.
"""
if idx < 0 or idx >= len(self):
raise IndexError(
f"Index {idx} out of bounds for dataset of size {len(self)}"
)
return self.data[idx]

@property
def metadata(self) -> Dict[str, Any]:
r"""Get dataset metadata."""
return self._metadata.copy()


class SyntheticDataset(BaseDataset):
Expand Down
Loading
Loading