Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select aggregation functions based on column types #28

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/pandas_query_generator/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Arguments:
disable_multi_processing: bool
filter: QueryFilter
groupby_aggregation_probability: float
max_aggregation_columns: int
max_groupby_columns: int
max_merges: int
max_projection_columns: int
Expand Down Expand Up @@ -127,6 +128,14 @@ def from_args() -> 'Arguments':
help='Probability of including groupby aggregation operations',
)

parser.add_argument(
'--max-aggregation-columns',
type=int,
required=False,
default=3,
help='Maximum number of columns to aggregate in GROUP BY operations',
)

parser.add_argument(
'--max-groupby-columns',
type=int,
Expand Down
69 changes: 64 additions & 5 deletions src/pandas_query_generator/group_by_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,75 @@
import typing as t
from dataclasses import dataclass
from enum import Enum

from .entity import Property, PropertyDate, PropertyEnum, PropertyFloat, PropertyInt, PropertyString
from .operation import Operation


class AggregationType(str, Enum):
"""Defines supported aggregation types and their compatibility."""

MEAN = 'mean'
SUM = 'sum'
MIN = 'min'
MAX = 'max'
COUNT = 'count'
MODE = 'mode'
NUNIQUE = 'nunique'
FIRST = 'first'
LAST = 'last'

@staticmethod
def compatible_aggregations(property: Property) -> t.List[str]:
"""Get compatible aggregation types for a given property type."""
match property:
case PropertyInt() | PropertyFloat():
return [
AggregationType.MEAN.value,
AggregationType.SUM.value,
AggregationType.MIN.value,
AggregationType.MAX.value,
AggregationType.COUNT.value,
AggregationType.NUNIQUE.value,
AggregationType.FIRST.value,
AggregationType.LAST.value,
]
case PropertyString() | PropertyEnum():
return [
AggregationType.COUNT.value,
AggregationType.MODE.value,
AggregationType.NUNIQUE.value,
AggregationType.FIRST.value,
AggregationType.LAST.value,
]
case PropertyDate():
return [
AggregationType.MIN.value,
AggregationType.MAX.value,
AggregationType.COUNT.value,
AggregationType.NUNIQUE.value,
AggregationType.FIRST.value,
AggregationType.LAST.value,
]


@dataclass
class GroupByAggregation(Operation):
"""
Represents a group by aggregation operation in a query.

Attributes:
group_by_columns: List of columns to group by
agg_columns: Dictionary mapping column names to their aggregation functions
"""

group_by_columns: t.List[str]
agg_function: str
aggregation_columns: t.Dict[str, str]

def apply(self, entity: str) -> str:
group_cols = ', '.join(f"'{col}'" for col in self.group_by_columns)
numeric_only = 'numeric_only=True' if self.agg_function != 'count' else ''
formatted_option = f', {numeric_only}' if numeric_only else ''
return f".groupby(by=[{group_cols}]).agg('{self.agg_function}'{formatted_option})"
"""Generate the pandas groupby operation string."""
group_by_columns = ', '.join(f"'{col}'" for col in self.group_by_columns)

aggregations = [f'{col!r}: {func!r}' for col, func in self.aggregation_columns.items()]

return f'.groupby(by=[{group_by_columns}]).agg({'{' + ', '.join(aggregations) + '}'})'
5 changes: 2 additions & 3 deletions src/pandas_query_generator/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class Query:
entity (str): The name of the target entity.
operations (List[Operation]): List of operations to apply.
multi_line (bool): Whether to format output across multiple lines.
available_columns (Set[str]): Columns available for operations.
complexity (int): Measure of query complexity based on merge operations.
columns (Set[str]): Columns available for operations.
"""

def __init__(
Expand Down Expand Up @@ -73,7 +72,7 @@ def complexity(self) -> int:
5. GroupBy complexity: Number of grouping columns plus weight of aggregation

Returns:
int: Complexity score for the query
int: Complexity score for the query
"""

def get_merge_complexity(op: Operation) -> int:
Expand Down
67 changes: 51 additions & 16 deletions src/pandas_query_generator/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import random
import typing as t

from .entity import Entity, PropertyDate, PropertyEnum, PropertyFloat, PropertyInt, PropertyString
from .group_by_aggregation import GroupByAggregation
from .entity import (
Entity,
Property,
PropertyDate,
PropertyEnum,
PropertyFloat,
PropertyInt,
PropertyString,
)
from .group_by_aggregation import AggregationType, GroupByAggregation
from .merge import Merge
from .operation import Operation
from .projection import Projection
Expand Down Expand Up @@ -212,6 +220,7 @@ def _generate_merge(self, num_merges: int) -> Operation:

right_query_structure = QueryStructure(
groupby_aggregation_probability=0,
max_aggregation_columns=0,
max_groupby_columns=0,
max_merges=self.max_merges - num_merges,
max_projection_columns=self.query_structure.max_projection_columns,
Expand Down Expand Up @@ -246,25 +255,51 @@ def format_join_columns(columns: str | t.List[str]) -> str:

def _generate_group_by_aggregation(self) -> Operation:
"""
Generate a GROUP BY clause with aggregation.
Generate a type-aware GROUP BY clause for aggregating data.

Creates a grouping operation that:
This method:
1. Randomly selects columns to group by
2. Chooses an aggregation function (mean, sum, min, max, count)
3. Ensures numeric_only parameter is set appropriately

The number of grouping columns is bounded by max_groupby_columns
configuration and available columns.
2. From remaining columns, selects columns to aggregate
3. For each aggregation column, chooses an appropriate aggregation
function based on the column's data type

Returns:
Operation: A GroupByAggregation operation with the grouping
configuration.
Operation: A GroupByAggregation operation with type-appropriate aggregations
"""
group_columns = random.sample(
list(self.columns),
random.randint(1, min(self.query_structure.max_groupby_columns, len(self.columns))),
num_group_by_columns = random.randint(
1, min(self.query_structure.max_groupby_columns, len(self.columns))
)

group_by_columns = random.sample(list(self.columns), num_group_by_columns)
aggregation_candidates = list(self.columns - set(group_by_columns))

if not aggregation_candidates:
return GroupByAggregation(
group_by_columns=group_by_columns, aggregation_columns={group_by_columns[0]: 'count'}
)

num_aggregation_columns = random.randint(
1, min(self.query_structure.max_aggregation_columns, len(aggregation_candidates))
)

aggregations, aggregation_columns = (
{},
random.sample(aggregation_candidates, num_aggregation_columns),
)

property_for_column = lambda column: next(
(
entity.properties[column]
for entity in (self.schema[e] for e in self.merge_entities)
if column in entity.properties
),
None,
)

agg_function = random.choice(['mean', 'sum', 'min', 'max', 'count'])
for column in aggregation_columns:
property = property_for_column(column)
assert property is not None
compatible_aggregations = AggregationType.compatible_aggregations(property)
aggregations[column] = random.choice(compatible_aggregations)

return GroupByAggregation(group_columns, agg_function)
return GroupByAggregation(group_by_columns=group_by_columns, aggregation_columns=aggregations)
2 changes: 2 additions & 0 deletions src/pandas_query_generator/query_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class QueryStructure:
"""

groupby_aggregation_probability: float
max_aggregation_columns: int
max_groupby_columns: int
max_merges: int
max_projection_columns: int
Expand All @@ -31,6 +32,7 @@ def from_args(arguments: Arguments) -> 'QueryStructure':
"""
return QueryStructure(
groupby_aggregation_probability=arguments.groupby_aggregation_probability,
max_aggregation_columns=arguments.max_aggregation_columns,
max_groupby_columns=arguments.max_groupby_columns,
max_merges=arguments.max_merges,
max_projection_columns=arguments.max_projection_columns,
Expand Down
Loading