Skip to content

Commit

Permalink
Enable using positional args for workflow with default value (#3149)
Browse files Browse the repository at this point in the history
* fix: take out value in input_kwargs if in args

so that in flyte_entity_call_handler the kwargs will not contain args

Signed-off-by: machichima <[email protected]>

* test: for positional args with default value

Signed-off-by: machichima <[email protected]>

* test: positional args mix with default value

Signed-off-by: machichima <[email protected]>

* test: passing extra args

Signed-off-by: machichima <[email protected]>

* test: move test wf with defualt val under standard test positional args

Signed-off-by: machichima <[email protected]>

---------

Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima authored Feb 22, 2025
1 parent acc09a5 commit 93c87c3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
14 changes: 13 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,19 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
Workflow needs to fill in default arguments before invoking the call handler.
"""
# Get default arguments and override with kwargs passed in
input_kwargs = self.python_interface.default_inputs_as_kwargs
interface = self.python_interface
input_kwargs = interface.default_inputs_as_kwargs

if len(args) > len(interface.inputs):
raise AssertionError(
f"Received more arguments than expected in function '{self.name}'. Expected {len(interface.inputs)} but got {len(args)}"
)
if len(input_kwargs) != 0:
for _, input_name in zip(args, interface.inputs.keys()):
if input_name in input_kwargs:
# delete the default value if provide args
del input_kwargs[input_name]

input_kwargs.update(kwargs)
ctx = FlyteContext.current_context()
# todo: remove this conditional once context manager is thread safe
Expand Down
81 changes: 81 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,87 @@ def wf_mixed_positional_and_keyword_args() -> int:
assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_workflow_with_default_value():
arg1 = 5
arg2 = 6
default_arg1 = 1
default_arg2 = 2
ret = 17
ret_arg2_default = 9

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int = default_arg1, y: int = default_arg2) -> int:
return t1(x=x, y=y)

@workflow
def wf_pure_positional_args() -> int:
return sub_wf(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

@workflow
def wf_mixed_positional_and_default_value() -> int:
return sub_wf(arg1)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)
wf_mixed_positional_and_default_value_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_default_value)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
default_arg2_binding = Scalar(primitive=Primitive(integer=default_arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[1].binding.value == default_arg2_binding
assert wf_mixed_positional_and_default_value_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret
assert wf_mixed_positional_and_default_value() == ret_arg2_default


def test_positional_args_workflow_extra_args_or_kwargs():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int, y: int) -> int:
return t1(x=x, y=y)

@workflow
def wf_pure_args_extra_args() -> int:
return sub_wf(arg1, arg2, arg2)

@workflow
def wf_mixed_positional_and_keyword_args_extra_args() -> int:
return sub_wf(arg1, arg2, y=arg2)

with pytest.raises(AssertionError, match="Received more arguments than expected in function 'tests.flytekit.unit.core.test_serialization.sub_wf'. Expected 2 but got 3"):
wf_pure_args_extra_args()

with pytest.raises(AssertionError, match="Got multiple values for argument 'y' in function 'tests.flytekit.unit.core.test_serialization.sub_wf'"):
wf_mixed_positional_and_keyword_args_extra_args()

def test_positional_args_chained_tasks():
@task
def t1(x: int, y: int) -> int:
Expand Down

0 comments on commit 93c87c3

Please sign in to comment.