Skip to content

Commit

Permalink
wrap agent outputs in a literalmap (#3143)
Browse files Browse the repository at this point in the history
* enclose agent outputs in a literalmap

Signed-off-by: Samhita Alla <[email protected]>

* add type check

Signed-off-by: Samhita Alla <[email protected]>

* fix assert check

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored Feb 19, 2025
1 parent e19fedc commit 806ff20
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,22 @@ async def do(
)
)
with context_manager.FlyteContextManager.with_context(builder) as new_ctx:
outputs = {
"result": TypeEngine.to_literal(
new_ctx,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
"idempotence_token": TypeEngine.to_literal(
new_ctx,
idempotence_token,
str,
TypeEngine.to_literal_type(str),
),
}
outputs = LiteralMap(
{
"result": TypeEngine.to_literal(
new_ctx,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
"idempotence_token": TypeEngine.to_literal(
new_ctx,
idempotence_token,
str,
TypeEngine.to_literal_type(str),
),
}
)

return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs)

Expand Down
5 changes: 4 additions & 1 deletion plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
)
async def test_agent(mock_boto_call, mock_return_value):
async def test_agent(mock_boto_call, mock_return_value, request):
mock_boto_call.return_value = mock_return_value[0]

agent = AgentRegistry.get_agent("boto")
Expand Down Expand Up @@ -159,6 +159,9 @@ async def test_agent(mock_boto_call, mock_return_value):

assert resource.phase == TaskExecution.SUCCEEDED

if request.node.callspec.indices["mock_return_value"] in (0, 1):
assert isinstance(resource.outputs, literals.LiteralMap)

if mock_return_value[0][0]:
outputs = literal_map_string_repr(resource.outputs)
if "pickle_check" in mock_return_value[0][0]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def get(
result = retrieved_result.to_dict()

ctx = FlyteContextManager.current_context()
outputs = {"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}
outputs = LiteralMap({"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))})

return Resource(phase=flyte_phase, outputs=outputs, message=message)

Expand Down
6 changes: 3 additions & 3 deletions plugins/flytekit-openai/tests/openai_batch/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
from datetime import timedelta
from unittest import mock
from unittest.mock import AsyncMock
import json
import msgpack
import base64

import pytest
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.openai.batch.agent import BatchEndpointMetadata
Expand Down Expand Up @@ -157,6 +156,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context):
mock_retrieve.return_value = batch_retrieve_result
resource = await agent.get(metadata)
assert resource.phase == TaskExecution.SUCCEEDED
assert isinstance(resource.outputs, literals.LiteralMap)

outputs = literal_map_string_repr(resource.outputs)

Expand Down

0 comments on commit 806ff20

Please sign in to comment.