diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index b66c48443c..4ba06c2008 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -46,6 +46,7 @@ def noop(): ... @dataclass class FlyteFile(SerializableType, os.PathLike, typing.Generic[T], DataClassJSONMixin): path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + metadata: typing.Optional[dict[str, str]] = None """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -158,18 +159,24 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: return "/tmp/local_file.csv" """ - def _serialize(self) -> typing.Dict[str, str]: + def _serialize(self) -> typing.Dict[str, typing.Any]: lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) - return {"path": lv.scalar.blob.uri} + out = {"path": lv.scalar.blob.uri} + if lv.metadata: + out["metadata"] = lv.metadata + return out @classmethod def _deserialize(cls, value) -> "FlyteFile": return FlyteFilePathTransformer().dict_to_flyte_file(dict_obj=value, expected_python_type=cls) @model_serializer - def serialize_flyte_file(self) -> Dict[str, str]: + def serialize_flyte_file(self) -> Dict[str, typing.Any]: lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) - return {"path": lv.scalar.blob.uri} + out = {"path": lv.scalar.blob.uri} + if lv.metadata: + out["metadata"] = lv.metadata + return out @model_validator(mode="after") def deserialize_flyte_file(self, info) -> "FlyteFile": @@ -188,7 +195,8 @@ def deserialize_flyte_file(self, info) -> "FlyteFile": ), uri=self.path, ) - ) + ), + metadata=self.metadata, ), type(self), ) @@ -281,6 +289,7 @@ def __init__( path: typing.Union[str, os.PathLike], downloader: typing.Callable = noop, remote_path: typing.Optional[typing.Union[os.PathLike, str, bool]] = None, + metadata: typing.Optional[dict[str, str]] = None, ): """ FlyteFile's init method. @@ -295,6 +304,7 @@ def __init__( # Make this field public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 self.path = path + self.metadata = metadata self._downloader = downloader self._downloaded = False self._remote_path = remote_path @@ -538,7 +548,9 @@ async def async_to_literal( # If the object has a remote source, then we just convert it back. This means that if someone is just # going back and forth between a FlyteFile Python value and a Blob Flyte IDL value, we don't do anything. if python_val._remote_source is not None: - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source))) + return Literal( + scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)), metadata=python_val.metadata + ) # If the user specified the remote_path to be False, that means no matter what, do not upload. Also if the # path given is already a remote path, say https://www.google.com, the concept of uploading to the Flyte @@ -593,10 +605,15 @@ async def async_to_literal( else: remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers) # If the source path is a local file, the remote path will be a remote storage path. - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path))))) + return Literal( + scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))), + metadata=getattr(python_val, "metadata", None), + ) # If not uploading, then we can only take the original source path as the uri. else: - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path))) + return Literal( + scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)), metadata=getattr(python_val, "metadata", None) + ) @staticmethod def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, str]: @@ -608,6 +625,7 @@ def dict_to_flyte_file( self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: path = dict_obj.get("path", None) + metadata = dict_obj.get("metadata", None) if path is None: raise ValueError("FlyteFile's path should not be None") @@ -624,7 +642,8 @@ def dict_to_flyte_file( ), uri=path, ) - ) + ), + metadata=metadata, ), expected_python_type, ) @@ -704,6 +723,7 @@ async def async_to_python_value( try: uri = lv.scalar.blob.uri + metadata = lv.metadata except AttributeError: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") @@ -718,7 +738,7 @@ async def async_to_python_value( # In this condition, we still return a FlyteFile instance, but it's a simple one that has no downloading tricks # Using is instead of issubclass because FlyteFile does actually subclass it if expected_python_type is os.PathLike: - return FlyteFile(uri) + return FlyteFile(path=uri, metadata=metadata) # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type expected_python_type = get_underlying_type(expected_python_type) @@ -730,7 +750,7 @@ async def async_to_python_value( # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. if not ctx.file_access.is_remote(uri): - return expected_python_type(uri) # type: ignore + return expected_python_type(path=uri, metadata=metadata) # type: ignore # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) @@ -738,7 +758,7 @@ async def async_to_python_value( _downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False) expected_format = FlyteFilePathTransformer.get_format(expected_python_type) - ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader) + ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata) ff._remote_source = uri return ff diff --git a/tests/flytekit/integration/remote/workflows/basic/flytefile.py b/tests/flytekit/integration/remote/workflows/basic/flytefile.py index f25b77d907..16ff237c66 100644 --- a/tests/flytekit/integration/remote/workflows/basic/flytefile.py +++ b/tests/flytekit/integration/remote/workflows/basic/flytefile.py @@ -1,15 +1,16 @@ +from typing import Optional from flytekit import task, workflow from flytekit.types.file import FlyteFile @task -def create_ff(file_path: str) -> FlyteFile: +def create_ff(file_path: str, info: str) -> FlyteFile: """Create a FlyteFile.""" - return FlyteFile(path=file_path) + return FlyteFile(path=file_path, metadata={"info": info}) @task -def read_ff(ff: FlyteFile) -> None: +def read_ff(ff: FlyteFile, info: Optional[str] = None) -> None: """Read input FlyteFile. This can be used in the case in which a FlyteFile is created @@ -19,6 +20,11 @@ def read_ff(ff: FlyteFile) -> None: content = f.read() print(f"FILE CONTENT | {content}") + if info: + assert ff.metadata["info"] == info + else: + assert ff.metadata is None + @task def create_and_read_ff(file_path: str) -> FlyteFile: @@ -41,9 +47,9 @@ def create_and_read_ff(file_path: str) -> FlyteFile: @workflow -def wf(remote_file_path: str) -> None: - ff_1 = create_ff(file_path=remote_file_path) - read_ff(ff=ff_1) +def wf(remote_file_path: str, info: str = "abc") -> None: + ff_1 = create_ff(file_path=remote_file_path, info=info) + read_ff(ff=ff_1, info=info) ff_2 = create_and_read_ff(file_path=remote_file_path) read_ff(ff=ff_2) diff --git a/tests/flytekit/unit/types/file/test_file.py b/tests/flytekit/unit/types/file/test_file.py index 5187aa061a..14e14a8813 100644 --- a/tests/flytekit/unit/types/file/test_file.py +++ b/tests/flytekit/unit/types/file/test_file.py @@ -1,9 +1,10 @@ import tempfile from pathlib import Path from typing import Optional +from dataclasses import dataclass import pytest -from flytekit import task, workflow +from flytekit import task, workflow, current_context from flytekit.types.file import FlyteFile @@ -70,3 +71,59 @@ def _verify_msg(ff: FlyteFile) -> None: ff_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=remote_path) _verify_msg(ff_4) + + +def test_metadata(): + + @task + def create_file() -> FlyteFile: + ctx = current_context() + wd = Path(ctx.working_directory) + new_file = wd / "my_file.txt" + + content = "hello there" + new_file.write_text(content) + return FlyteFile(path=new_file, metadata={"length": str(len(content))}) + + @task + def read_metadata(file: FlyteFile) -> Optional[dict]: + return file.metadata + + @workflow + def wf() -> Optional[dict]: + file = create_file() + return read_metadata(file=file) + + output = wf() + assert output["length"] == "11" + + +@dataclass +class SimpleDC: + file: FlyteFile + + +def test_metadata_with_dataclass(): + @task + def create_dc() -> SimpleDC: + ctx = current_context() + wd = Path(ctx.working_directory) + my_file = wd / "file.txt" + my_file.write_text("hello there!") + return SimpleDC(file=FlyteFile(path=my_file, metadata={"HELLO": "WORLD"})) + + + @task + def get_metadata(dc: SimpleDC) -> dict: + if dc.file.metadata: + return dc.file.metadata + else: + return {} + + @workflow + def wf() -> dict: + dc = create_dc() + return get_metadata(dc=dc) + + output = wf() + assert output["HELLO"] == "WORLD"