-
Notifications
You must be signed in to change notification settings - Fork 873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Seed Dataset to improve compatibility and simplify usage #1734
base: master
Are you sure you want to change the base?
Changes from all commits
82eb593
52f012a
93a114c
7b08d62
3f1861b
ad1949b
50d0795
79d1429
9bd18d4
447bf07
521f511
4162a9a
094762e
5d222c9
a7f2f02
3f6aec0
f7e78a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,17 @@ | |
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
|
||
import json | ||
import os | ||
import random | ||
from pathlib import Path | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
List, | ||
Optional, | ||
Sized, | ||
TypeVar, | ||
Union, | ||
) | ||
|
@@ -322,47 +325,219 @@ def to_pytorch_dataset( | |
return dataset | ||
|
||
|
||
class SeedDataset(BaseDataset): | ||
class SeedDataset(Dataset): | ||
r"""A dataset containing validated seed examples for data generation. | ||
Ensures that all items adhere to the DataPoint schema. | ||
|
||
This class is used to initialize a dataset from a list of dictionary items, | ||
validating each against the DataPoint schema. | ||
This class can initialize from Hugging Face Datasets, | ||
PyTorch Datasets, JSON file paths, or lists of dictionaries, | ||
converting them into a consistent internal format. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data: List[Dict[str, str]], | ||
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]], | ||
cache_dir: Optional[str] = None, | ||
seed: Optional[int] = None, | ||
min_samples: int = 1, | ||
strict: bool = False, | ||
**kwargs, | ||
): | ||
r"""Initialize the seed dataset. | ||
r"""Initialize the seed dataset and validate integrity. | ||
|
||
Args: | ||
data (List[Dict[str, str]]): List of dictionary items to create the | ||
dataset from. | ||
data (Union[HFDataset, Dataset, str, List[Dict[str, Any]]]): | ||
Input data, which can be: | ||
- A Hugging Face Dataset (HFDataset) | ||
- A PyTorch Dataset (torch.utils.data.Dataset) | ||
- A Path object representing the path to a JSON file | ||
- A list of dictionaries with DataPoint-compatible fields | ||
cache_dir (Optional[str]): Directory to cache dataset files. | ||
(default: :obj:`None`) | ||
seed (Optional[int]): Seed for reproducibility. | ||
(default: :obj:`1`) | ||
min_samples (int): Minimum number of samples required. | ||
(default: :obj:`1`) | ||
strict (bool): Whether to raise an error on invalid datapoints | ||
(True) or skip/filter them (False). (default: False) | ||
**kwargs: Additional dataset parameters. | ||
|
||
Raises: | ||
ValueError: If dataset size is less than min_samples or if sample | ||
validation fails. | ||
TypeError: If the data type is not supported. | ||
ValueError: If dataset size is less than min_samples or | ||
if sample validation fails. | ||
FileNotFoundError: If the JSON file path doesn't exist. | ||
json.JSONDecodeError: If the JSON file is invalid. | ||
""" | ||
if len(data) < min_samples: | ||
|
||
# Store all parameters in metadata dict for compatibility | ||
self._cache_dir = str(cache_dir) if cache_dir is not None else None | ||
self._metadata = { | ||
'cache_dir': self._cache_dir, | ||
**kwargs, | ||
} | ||
self._rng = random.Random(seed) | ||
self._strict = strict | ||
|
||
# Type checking and conversion into list of dicts to have a | ||
# consistent internal format. Since Seed Dataset should be | ||
# small, we can load it entirely into memmory | ||
|
||
if isinstance(data, HFDataset): | ||
self._raw_data = [dict(item) for item in data] | ||
elif isinstance(data, Dataset): | ||
if not isinstance(data, Sized): | ||
raise TypeError( | ||
f"{type(data).__name__} does not implement `__len__()`." | ||
) | ||
|
||
# Make MyPy happy by ensuring indexability | ||
assert callable( | ||
getattr(data, "__getitem__", None) | ||
), "Dataset does not support indexing." | ||
|
||
self._raw_data = [dict(data[i]) for i in range(len(data))] | ||
elif isinstance(data, Path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add some more safety features here. The current configuration does not:
|
||
if not data.exists(): | ||
raise FileNotFoundError(f"JSON file not found: {data}") | ||
with data.open('r') as f: | ||
self._raw_data = json.load(f) | ||
if not isinstance(self._raw_data, list): | ||
raise ValueError( | ||
"JSON file must contain a list of dictionaries" | ||
) | ||
elif isinstance(data, list): | ||
self._raw_data = data if data is not None else [] | ||
else: | ||
raise TypeError("Unsupported data type") | ||
|
||
self.data: List[DataPoint] = [] | ||
self._setup(min_samples) | ||
self._length = len(self.data) | ||
|
||
def __len__(self) -> int: | ||
r"""Return the size of the dataset.""" | ||
return self._length | ||
|
||
def __getitem__(self, idx: int) -> DataPoint: | ||
r"""Get an item from the dataset. | ||
|
||
Args: | ||
idx (int): Index of the item to get. | ||
|
||
Returns: | ||
DataPoint: DataPoint from the dataset with the given index. | ||
|
||
Raises: | ||
IndexError: If idx is out of bounds. | ||
""" | ||
if idx < 0 or idx >= self._length: | ||
raise IndexError( | ||
f"Index {idx} out of bounds for dataset of size {self._length}" | ||
) | ||
return self.data[idx] | ||
|
||
def sample(self) -> DataPoint: | ||
apokryphosx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r"""Sample a random datapoint from the dataset. | ||
|
||
Returns: | ||
DataPoint: A randomly sampled DataPoint. | ||
|
||
Raises: | ||
RuntimeError: If the dataset is empty. | ||
""" | ||
if self._length == 0: | ||
raise RuntimeError("Dataset is empty, cannot sample.") | ||
idx = self._rng.randint(0, self._length - 1) | ||
return self[idx] | ||
hallerite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _setup(self, min_samples: int) -> None: | ||
r"""Set up the dataset by validating and processing raw data. | ||
|
||
This method: | ||
1. Checks if the dataset meets the minimum sample requirement. | ||
2. Creates the cache directory if specified. | ||
3. Processes raw data into DataPoint objects | ||
for validation and consistency. | ||
|
||
In non-strict mode, invalid datapoints are filtered out | ||
rather than raising an error. | ||
|
||
Args: | ||
min_samples (int): Minimum number of samples required. | ||
|
||
Raises: | ||
ValueError: If the dataset size is less than min_samples or | ||
if sample validation fails (in strict mode), | ||
or if the dataset size is smaller than | ||
min_samples after filtering invalid datapoints | ||
(in non-strict mode). | ||
OSError: If cache directory creation fails. | ||
""" | ||
if len(self._raw_data) < min_samples: | ||
apokryphosx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
f"Seed dataset must contain at least {min_samples} samples." | ||
f"Dataset must have at least {min_samples} samples, " | ||
f"got {len(self._raw_data)}" | ||
) | ||
|
||
super().__init__( | ||
data=data, | ||
cache_dir=cache_dir, | ||
**kwargs, | ||
if self._cache_dir: | ||
try: | ||
os.makedirs(self._cache_dir, exist_ok=True) | ||
logger.debug(f"Created cache directory: {self._cache_dir}") | ||
except OSError as e: | ||
logger.error( | ||
f"Failed to create cache directory {self._cache_dir}: {e}" | ||
) | ||
raise | ||
|
||
# Process raw data into DataPoint objects for validation purposes | ||
if not self._raw_data: | ||
if min_samples > 0: | ||
raise ValueError("No data provided, but min_samples > 0") | ||
logger.debug("No raw data to process") | ||
return | ||
|
||
def create_datapoint( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is supposed to validate whether a dict is a valid Note that it should not fail if there are extra fields. |
||
item: Dict[str, Any], idx: int | ||
) -> Optional[DataPoint]: | ||
try: | ||
return DataPoint( | ||
question=item.get('question', ''), | ||
rationale=item.get('rationale', ''), | ||
final_answer=item.get('final_answer', ''), | ||
metadata=item.get('metadata', {}) | ||
if isinstance(item.get('metadata'), dict) | ||
else {}, | ||
difficulty=item.get('difficulty', ''), # Match BaseDataset | ||
# raw_markdown='' if DataPoint supports it | ||
) | ||
|
||
except ValidationError as e: | ||
apokryphosx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._strict: | ||
raise ValueError( | ||
f"Sample at index {idx} validation error: {e}" | ||
) | ||
else: | ||
logger.warning( | ||
f"Skipping invalid sample at index {idx} " | ||
f"due to validation error: {e}" | ||
) | ||
return None | ||
|
||
raw_data = [ | ||
create_datapoint(item, i) for i, item in enumerate(self._raw_data) | ||
] | ||
self.data = [dp for dp in raw_data if dp is not None] | ||
logger.debug( | ||
f"Processed {len(raw_data)} data points, of which " | ||
f"{len(self.data)} were valid." | ||
) | ||
|
||
@property | ||
def metadata(self) -> Dict[str, Any]: | ||
r"""Get dataset metadata.""" | ||
return self._metadata.copy() | ||
|
||
|
||
class SyntheticDataset(BaseDataset): | ||
r"""A dataset for storing synthetically generated data points. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
__init__
is getting really messy. It would be better to create specific methods to handle the conversion.