Skip to content

Commit

Permalink
Use a more general name (#31)
Browse files Browse the repository at this point in the history
Fixed tool call to work with vLLM and OCI DataScience Model Deployment API

Co-authored-by: Anup Ojah <[email protected]>
  • Loading branch information
hinthornw and aojah1 authored Jan 31, 2025
1 parent 0c27fe5 commit 27c4115
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
8 changes: 4 additions & 4 deletions tests/unit_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ async def test_e2e_existing_schema_policy_behavior(strict_mode):
class MyRecognizedSchema(BaseModel):
"""A recognized schema that the pipeline can handle."""

user_id: str
notes: str
user_id: str # type: ignore
notes: str # type: ignore

# Our existing data includes 2 top-level keys: recognized, unknown
existing_schemas = {
Expand Down Expand Up @@ -543,8 +543,8 @@ async def test_e2e_existing_schema_policy_tuple_behavior(strict_mode):
class MyRecognizedSchema(BaseModel):
"""A recognized schema that the pipeline can handle."""

user_id: str
notes: str
user_id: str # type: ignore
notes: str # type: ignore

existing_schemas = [
(
Expand Down
12 changes: 9 additions & 3 deletions trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,11 @@ def __init__(
existing_schema_policy: bool | Literal["ignore"] = True,
):
new_tools: list = [PatchDoc]
tool_choice = PatchDoc.__name__ if not enable_deletes else "any"
tool_choice = (
{"type": "function", "function": {"name": "PatchDoc"}}
if not enable_deletes
else "any"
)
if enable_inserts: # Also let the LLM know that we can extract NEW schemas.
tools_ = [
schema
Expand Down Expand Up @@ -1052,7 +1056,7 @@ def _tear_down(

async def ainvoke(
self, state: ExtendedExtractState, config: RunnableConfig
) -> dict:
) -> Command[Literal["sync", "__end__"]]:
"""Generate a JSONPatch to correct the validation error and heal the tool call.
Assumptions:
Expand All @@ -1075,7 +1079,9 @@ async def ainvoke(
goto=("sync",),
)

def invoke(self, state: ExtendedExtractState, config: RunnableConfig) -> dict:
def invoke(
self, state: ExtendedExtractState, config: RunnableConfig
) -> Command[Literal["sync", "__end__"]]:
try:
msg = self.bound.invoke(state.messages, config)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 27c4115

Please sign in to comment.