Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(row-counts-for-saved-queries): Temporal row counts for saved queries #29164

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
33 changes: 32 additions & 1 deletion posthog/temporal/data_modeling/run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from posthog.settings.base_variables import TEST
from posthog.temporal.common.base import PostHogWorkflow
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.warehouse.models import DataWarehouseModelPath, DataWarehouseSavedQuery
from posthog.warehouse.models import DataWarehouseModelPath, DataWarehouseSavedQuery, DataWarehouseTable
from posthog.warehouse.util import database_sync_to_async
from posthog.warehouse.data_load.create_table import create_table_from_saved_query
from posthog.temporal.data_imports.util import prepare_s3_files_for_querying
Expand Down Expand Up @@ -337,9 +337,40 @@ async def materialize_model(model_label: str, team: Team) -> tuple[str, DeltaTab
prepare_s3_files_for_querying(saved_query.folder_path, saved_query.name, file_uris)

key, delta_table = tables.popitem()

row_count = await asyncio.to_thread(count_delta_table_rows, delta_table)
await update_table_row_count(saved_query, row_count)

return (key, delta_table)


def count_delta_table_rows(delta_table: DeltaTable) -> int:
"""
Count the number of rows in a Delta table using metadata.
"""
count = 0
for batch in delta_table.to_pyarrow_dataset().to_batches():
count += len(batch)
return count


async def update_table_row_count(saved_query: DataWarehouseSavedQuery, row_count: int) -> None:
"""Update the row count in the DataWarehouseTable record. `saved_query` name is unique per team."""
try:
table = await database_sync_to_async(
DataWarehouseTable.objects.filter(team_id=saved_query.team_id, name=saved_query.name).first
)()

if table:
table.row_count = row_count
await database_sync_to_async(table.save)()
await logger.ainfo("Updated row count for table %s to %d", saved_query.name, row_count)
else:
await logger.awarning("Could not find DataWarehouseTable record for saved query %s", saved_query.name)
except Exception as e:
await logger.aexception("Failed to update row count for table %s: %s", saved_query.name, str(e))


@dlt.source(max_table_nesting=0)
def hogql_table(query: str, team: Team, table_name: str, table_columns: dlt_typing.TTableSchemaColumns):
"""A dlt source representing a HogQL table given by a HogQL query."""
Expand Down
17 changes: 17 additions & 0 deletions posthog/temporal/tests/data_modeling/test_run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,15 @@ async def test_run_workflow_with_minio_bucket(
expected_events_a = [event for event in all_expected_events if event["distinct_id"] == "a"]
expected_events_b = [event for event in all_expected_events if event["distinct_id"] == "b"]

for query in saved_queries:
await DataWarehouseTable.objects.acreate(
name=query.name,
team=ateam,
format="Delta",
url_pattern=f"s3://{bucket_name}/team_{ateam.pk}_model_{query.id.hex}",
credential=None,
)

workflow_id = str(uuid.uuid4())
inputs = RunWorkflowInputs(team_id=ateam.pk)

Expand Down Expand Up @@ -678,3 +687,11 @@ async def test_run_workflow_with_minio_bucket(
assert sorted(table.to_pylist(), key=lambda d: (d["distinct_id"], d["timestamp"])) == expected_data
assert query.status == DataWarehouseSavedQuery.Status.COMPLETED
assert query.last_run_at == TEST_TIME

# Verify row count was updated in the DataWarehouseTable
warehouse_table = await database_sync_to_async(
DataWarehouseTable.objects.filter(team_id=ateam.pk, name=query.name).first
)()
assert warehouse_table is not None, f"DataWarehouseTable for {query.name} not found"
# Match the 50 page_view events defined above
assert warehouse_table.row_count == len(expected_data), f"Row count for {query.name} not the expected value"
Loading