Skip to content

Commit

Permalink
_dt_in_by and _add_dt_to_df
Browse files Browse the repository at this point in the history
  • Loading branch information
jsmariegaard committed Dec 19, 2023
1 parent 3e83518 commit f4b26d0
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions modelskill/comparison/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def calc_metrics(x):
return pd.Series(row)

# .drop(columns=["x", "y"])
if _dt_in_by(by):
df, by = _add_dt_to_df(df, by)

res = df.groupby(by=by, observed=False).apply(calc_metrics)

Expand All @@ -84,6 +86,52 @@ def calc_metrics(x):
return res


def _dt_in_by(by):
by = [by] if isinstance(by, str) else by
if any(str(by).startswith("dt:") for by in by):
return True
return False


ALLOWED_DT = [
"year",
"quarter",
"month",
"month_name",
"day",
"day_of_year",
"dayofyear",
"day_of_week",
"dayofweek",
"hour",
"minute",
"second",
"weekday",
]


def _add_dt_to_df(df, by):
ser = df.index.to_series()
by = [by] if isinstance(by, str) else by

for j, b in enumerate(by):
if str(b).startswith("dt:"):
dt_str = b.split(":")[1].lower()
if dt_str not in ALLOWED_DT:
raise ValueError(
f"Invalid Pandas dt accessor: {dt_str}. Allowed values are: {ALLOWED_DT}"
)
ser = ser.dt.__getattribute__(dt_str)
if dt_str in df.columns:
raise ValueError(
f"Cannot use datetime attribute {dt_str} as it already exists in the dataframe."
)
df[dt_str] = ser
by[j] = dt_str # remove 'dt:' prefix
by = by[0] if len(by) == 1 else by
return df, by


def _parse_groupby(by, n_models, n_obs, n_var=1):
if by is None:
by = []
Expand Down

0 comments on commit f4b26d0

Please sign in to comment.