Skip to content

Commit

Permalink
test: Test file_format attribute alignment in dc.sd
Browse files Browse the repository at this point in the history
Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Jan 4, 2025
1 parent 4af1f00 commit 46b7b62
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
47 changes: 46 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from urllib.parse import urlparse
import uuid
import pytest
import mock
from unittest import mock
from dataclasses import dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -26,6 +27,7 @@
from flytekit.remote.remote import FlyteRemote
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

Expand Down Expand Up @@ -833,3 +835,46 @@ def test_open_ff():
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_sd_attr():
"""Test correctness of StructuredDataset attributes.
This test considers only the following condition:
1. Check StructuredDataset (wrapped in a dataclass) file_format attribute
We'll make sure uri aligns with the user-specified one in the future.
"""
from workflows.basic.sd_attr import wf

@dataclass
class DC:
sd: StructuredDataset

FILE_FORMAT = "parquet"

# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type=FILE_FORMAT)

# Create a dataclass as the workflow input because `pyflyte run`
# can't properly handle input arg `dc` as a json str so far
dc = DC(sd=StructuredDataset(uri=remote_file_path, file_format=FILE_FORMAT))

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True)
wf_exec = remote.execute(
wf,
inputs={"dc": dc, "file_format": FILE_FORMAT},
wait=True,
version=VERSION,
image_config=ImageConfig.from_images(IMAGE),
)
assert wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {wf_exec.closure.phase}"
assert wf_exec.outputs["o0"].file_format == FILE_FORMAT, (
f"Workflow output StructuredDataset file_format should align with the user-specified file_format: {FILE_FORMAT}."
)

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)
68 changes: 68 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/sd_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass

import pandas as pd
from flytekit import task, workflow
from flytekit.types.structured import StructuredDataset


@dataclass
class DC:
sd: StructuredDataset


@task
def create_dc(uri: str, file_format: str) -> DC:
"""Create a dataclass with a StructuredDataset attribute.
Args:
uri: File URI.
file_format: File format, e.g., parquet, csv.
Returns:
dc: A dataclass with a StructuredDataset attribute.
"""
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format))

return dc


@task
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset:
"""Check StructuredDataset file_format attribute.
StruturedDataset file_format should align with what users specify.
Args:
sd: Python native StructuredDataset.
true_file_format: User-specified file_format.
"""
assert sd.file_format == true_file_format, (
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}."
)
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, (
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}."
)
print(f">>> SD <<<\n{sd}")
print(f">>> Literal SD <<<\n{sd._literal_sd}")
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}")
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}")

return sd


@workflow
def wf(dc: DC, file_format: str) -> StructuredDataset:
# Fail to use dc.sd.file_format as the input
sd = check_file_format(sd=dc.sd, true_file_format=file_format)

return sd


if __name__ == "__main__":
# Define inputs
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet"
file_format = "parquet"

dc = create_dc(uri=uri, file_format=file_format)
sd = wf(dc=dc, file_format=file_format)
print(sd.file_format)

0 comments on commit 46b7b62

Please sign in to comment.