-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: Test file_format attribute alignment in dc.sd
Signed-off-by: JiaWei Jiang <[email protected]>
- Loading branch information
1 parent
4af1f00
commit 46b7b62
Showing
2 changed files
with
114 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
tests/flytekit/integration/remote/workflows/basic/sd_attr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from dataclasses import dataclass | ||
|
||
import pandas as pd | ||
from flytekit import task, workflow | ||
from flytekit.types.structured import StructuredDataset | ||
|
||
|
||
@dataclass | ||
class DC: | ||
sd: StructuredDataset | ||
|
||
|
||
@task | ||
def create_dc(uri: str, file_format: str) -> DC: | ||
"""Create a dataclass with a StructuredDataset attribute. | ||
Args: | ||
uri: File URI. | ||
file_format: File format, e.g., parquet, csv. | ||
Returns: | ||
dc: A dataclass with a StructuredDataset attribute. | ||
""" | ||
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format)) | ||
|
||
return dc | ||
|
||
|
||
@task | ||
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset: | ||
"""Check StructuredDataset file_format attribute. | ||
StruturedDataset file_format should align with what users specify. | ||
Args: | ||
sd: Python native StructuredDataset. | ||
true_file_format: User-specified file_format. | ||
""" | ||
assert sd.file_format == true_file_format, ( | ||
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}." | ||
) | ||
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, ( | ||
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}." | ||
) | ||
print(f">>> SD <<<\n{sd}") | ||
print(f">>> Literal SD <<<\n{sd._literal_sd}") | ||
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}") | ||
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}") | ||
|
||
return sd | ||
|
||
|
||
@workflow | ||
def wf(dc: DC, file_format: str) -> StructuredDataset: | ||
# Fail to use dc.sd.file_format as the input | ||
sd = check_file_format(sd=dc.sd, true_file_format=file_format) | ||
|
||
return sd | ||
|
||
|
||
if __name__ == "__main__": | ||
# Define inputs | ||
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet" | ||
file_format = "parquet" | ||
|
||
dc = create_dc(uri=uri, file_format=file_format) | ||
sd = wf(dc=dc, file_format=file_format) | ||
print(sd.file_format) |