From 07a225646255e6473bb8d3845660c6e580d9d807 Mon Sep 17 00:00:00 2001 From: Piotr Rudnik Date: Sun, 22 Dec 2024 15:06:11 +0100 Subject: [PATCH] genai: Fix handling of optional arrays in tool input --- .../langchain_google_genai/_function_utils.py | 10 ++++++++++ libs/genai/tests/unit_tests/test_function_utils.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index d7227268..8d0699c3 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -314,6 +314,16 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): properties_item["items"] = _get_items_from_schema_any(v.get("items")) + elif properties_item.get("type_") == glm.Type.ARRAY and v.get("anyOf"): + types_with_items = [t for t in v.get("anyOf") if t.get("items")] + if len(types_with_items) > 1: + len_types = len(types_with_items) + logger.warning( + "Only first value for 'anyOf' key is supported in array types." + f"Got {len_types} types, using first one: {types_with_items[0]}" + ) + items = types_with_items[0]['items'] + properties_item["items"] = _get_items_from_schema_any(items) if properties_item.get("type_") == glm.Type.OBJECT and v.get("properties"): properties_item["properties"] = _get_properties_from_schema_any( diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 536c2a99..07d8609a 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, Optional, Union +from typing import Any, Generator, List, Optional, Union from unittest.mock import MagicMock, patch import google.ai.generativelanguage as glm @@ -6,7 +6,7 @@ from langchain_core.documents import Document from langchain_core.tools import InjectedToolArg, tool from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from langchain_google_genai._function_utils import ( @@ -309,3 +309,13 @@ class MyModel(BaseModel): gapic_tool = convert_to_genai_function_declarations([MyModel]) tool_dict = tool_to_dict(gapic_tool) assert gapic_tool == convert_to_genai_function_declarations([tool_dict]) + + +def test_tool_input_can_have_optional_arrays() -> None: + class ExampleToolInput(BaseModel): + numbers: Optional[List[str]] = Field() + + gapic_tool = convert_to_genai_function_declarations([ExampleToolInput]) + properties = gapic_tool.function_declarations[0].parameters.properties + assert properties.get('numbers').items.type_ == 1 +