Skip to content

Commit

Permalink
Begin supporting microbatch models in sample mode
Browse files Browse the repository at this point in the history
  • Loading branch information
QMalcolm committed Jan 28, 2025
1 parent d97bc4a commit 96e3b7a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 11 deletions.
42 changes: 36 additions & 6 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,53 @@ def resolve_limit(self) -> Optional[int]:

def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeFilter]:
event_time_filter = None
sample_mode = bool(self.config.args.sample and self.config.args.sample_window)

# TODO The number of branches here is getting rough. We should consider ways to simplify
# what is going on to make it easier to maintain

# Only do event time filtering if the base node has the necessary event time configs
if (
(isinstance(target.config, NodeConfig) or isinstance(target.config, SourceConfig))
and target.config.event_time
and isinstance(self.model, ModelNode)
):

# Handling of microbatch models
if (
self.model.config.materialized == "incremental"
and self.model.config.incremental_strategy == "microbatch"
and self.manifest.use_microbatch_batches(project_name=self.config.project_name)
and self.model.batch is not None
):
event_time_filter = EventTimeFilter(
field_name=target.config.event_time,
start=self.model.batch.event_time_start,
end=self.model.batch.event_time_end,
)
elif self.config.args.sample and self.config.args.sample_window:
# Sample mode microbatch models
if sample_mode:
start = (
self.config.args.sample_window.start
if self.config.args.sample_window.start > self.model.batch.event_time_start
else self.model.batch.event_time_start
)
end = (
self.config.args.sample_window.end
if self.config.args.sample_window.end < self.model.batch.event_time_end
else self.model.batch.event_time_end
)
event_time_filter = EventTimeFilter(
field_name=target.config.event_time,
start=start,
end=end,
)

# Regular microbatch models
else:
event_time_filter = EventTimeFilter(
field_name=target.config.event_time,
start=self.model.batch.event_time_start,
end=self.model.batch.event_time_end,
)

# Sample mode _non_ microbatch models
elif sample_mode:
event_time_filter = EventTimeFilter(
field_name=target.config.event_time,
start=self.config.args.sample_window.start,
Expand Down
13 changes: 11 additions & 2 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,20 @@ def _execute_microbatch_materialization(
context: Dict[str, Any],
materialization_macro: MacroProtocol,
) -> RunResult:
# TODO: This method has gotten a little large. It may be time to break it up into more manageable parts.
event_time_start = getattr(self.config.args, "EVENT_TIME_START", None)
event_time_end = getattr(self.config.args, "EVENT_TIME_END", None)
if getattr(self.config.args, "SAMPLE", None) and getattr(
self.config.args, "SAMPLE_WINDOW", None
):
event_time_start = self.config.args.sample_window.start
event_time_end = self.config.args.sample_window.end

microbatch_builder = MicrobatchBuilder(
model=model,
is_incremental=self._is_incremental(model),
event_time_start=getattr(self.config.args, "EVENT_TIME_START", None),
event_time_end=getattr(self.config.args, "EVENT_TIME_END", None),
event_time_start=event_time_start,
event_time_end=event_time_end,
default_end_time=self.config.invoked_at,
)

Expand Down
98 changes: 95 additions & 3 deletions tests/functional/sample_mode/test_sample_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import pytest
import pytz

from dbt.artifacts.resources.types import BatchSize
from dbt.event_time.sample_window import SampleWindow
from dbt.events.types import JinjaLogInfo
from dbt.tests.util import relation_from_name, run_dbt
from dbt.materializations.incremental.microbatch import MicrobatchBuilder
from dbt.tests.util import read_file, relation_from_name, run_dbt
from tests.utils import EventCatcher

input_model_sql = """
{{ config(materialized='table', event_time='event_time') }}
select 1 as id, TIMESTAMP '2020-01-01 11:25:00-0' as event_time
select 1 as id, TIMESTAMP '2020-01-01 01:25:00-0' as event_time
UNION ALL
select 2 as id, TIMESTAMP '2025-01-01 13:47:00-0' as event_time
UNION ALL
Expand All @@ -29,9 +31,20 @@
SELECT * FROM {{ ref("input_model") }}
"""

sample_microbatch_model_sql = """
{{ config(materialized='incremental', incremental_strategy='microbatch', event_time='event_time', batch_size='day', lookback=3, begin='2024-12-25', unique_key='id')}}
class TestSampleMode:
{% if execute %}
{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}}
{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}}
{% endif %}
SELECT * FROM {{ ref("input_model") }}
"""


class BaseSampleMode:
# TODO This is now used in 3 test files, it might be worth turning into a full test utility method
def assert_row_count(self, project, relation_name: str, expected_row_count: int):
relation = relation_from_name(project.adapter, relation_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
Expand All @@ -42,6 +55,8 @@ def assert_row_count(self, project, relation_name: str, expected_row_count: int)

assert result[0] == expected_row_count


class TestBasicSampleMode(BaseSampleMode):
@pytest.fixture(scope="class")
def models(self):
return {
Expand Down Expand Up @@ -87,3 +102,80 @@ def test_sample_mode(
relation_name="sample_mode_model",
expected_row_count=expected_row_count,
)


class TestMicrobatchSampleMode(BaseSampleMode):
@pytest.fixture(scope="class")
def models(self):
return {
"input_model.sql": input_model_sql,
"sample_microbatch_model.sql": sample_microbatch_model_sql,
}

@pytest.fixture
def event_time_start_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "batch.event_time_start" in event.info.msg) # type: ignore

@pytest.fixture
def event_time_end_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "batch.event_time_end" in event.info.msg) # type: ignore

@freezegun.freeze_time("2025-01-03T02:03:0Z")
def test_sample_mode(
self,
project,
event_time_end_catcher: EventCatcher,
event_time_start_catcher: EventCatcher,
):
expected_batches = [
("2025-01-01 00:00:00", "2025-01-02 00:00:00"),
("2025-01-02 00:00:00", "2025-01-03 00:00:00"),
("2025-01-03 00:00:00", "2025-01-04 00:00:00"),
]

# These are different from the expected batches because the sample window might only operate on "part" of a given batch
expected_filters = [
(
"event_time >= '2025-01-01 02:03:00+00:00' and event_time < '2025-01-02 00:00:00+00:00'"
),
(
"event_time >= '2025-01-02 00:00:00+00:00' and event_time < '2025-01-03 00:00:00+00:00'"
),
(
"event_time >= '2025-01-03 00:00:00+00:00' and event_time < '2025-01-03 02:03:00+00:00'"
),
]

_ = run_dbt(
["run", "--sample", "--sample-window=2 day"],
callbacks=[event_time_end_catcher.catch, event_time_start_catcher.catch],
)
assert len(event_time_start_catcher.caught_events) == len(expected_batches)
assert len(event_time_end_catcher.caught_events) == len(expected_batches)

for index in range(len(expected_batches)):
assert expected_batches[index][0] in event_time_start_catcher.caught_events[index].info.msg # type: ignore
assert expected_batches[index][1] in event_time_end_catcher.caught_events[index].info.msg # type: ignore

batch_id = MicrobatchBuilder.format_batch_start(
datetime.fromisoformat(expected_batches[index][0]), BatchSize.day
)
batch_file_name = f"sample_microbatch_model_{batch_id}.sql"
compiled_sql = read_file(
project.project_root,
"target",
"compiled",
"test",
"models",
"sample_microbatch_model",
batch_file_name,
)
assert expected_filters[index] in compiled_sql

# The first row of the "input_model" should be excluded from the sample because
# it falls outside of the filter for the first batch (which is only doing a _partial_ batch selection)
self.assert_row_count(
project=project,
relation_name="sample_microbatch_model",
expected_row_count=2,
)

0 comments on commit 96e3b7a

Please sign in to comment.