Skip to content

Commit

Permalink
Add support for deletions (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Jan 31, 2025
1 parent d13151f commit 18b6e46
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"jsonpatch<2.0,>=1.33",
]
name = "trustcall"
version = "0.0.30"
version = "0.0.31"
description = "Tenacious & trustworthy tool calling built on LangGraph."
readme = "README.md"

Expand Down
2 changes: 1 addition & 1 deletion tests/evals/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __iter__(self):
"model_name",
[
"gpt-4o",
"gpt-4o-mini",
# "gpt-4o-mini",
# "gpt-3.5-turbo",
"claude-3-5-sonnet-20240620",
# "accounts/fireworks/models/firefunction-v2",
Expand Down
58 changes: 58 additions & 0 deletions tests/unit_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class FakeExtractionModel(SimpleChatModel):
backup_responses: List[AIMessage] = []
i: int = 0
bound_count: int = 0
bound: Optional["FakeExtractionModel"] = None
tools: list = []

def _call(
Expand Down Expand Up @@ -78,6 +79,7 @@ def bind_tools(self, tools: list, **kwargs: Any) -> "FakeExtractionModel": # ty
backup_responses=backup_responses,
tools=tools,
i=self.i,
bound=self,
**kwargs,
)

Expand Down Expand Up @@ -651,3 +653,59 @@ class MyRecognizedSchema(BaseModel):
assert len(recognized_responses) == 1
recognized_item = recognized_responses[0]
assert recognized_item.notes == "updated notes"


@pytest.mark.asyncio
@pytest.mark.parametrize("enable_inserts", [True, False])
async def test_enable_deletes_flow(enable_inserts: bool) -> None:
class MySchema(BaseModel):
"""Schema for recognized docs."""

data: str

existing_docs = [
("Doc1", "MySchema", {"data": "contents of doc1"}),
("Doc2", "MySchema", {"data": "contents of doc2"}),
]

remove_doc_call_id = str(uuid.uuid4())
remove_message = AIMessage(
content="I want to remove Doc1",
tool_calls=[
{
"id": remove_doc_call_id,
"name": "RemoveDoc", # This is recognized only if enable_deletes=True
"args": {"json_doc_id": "Doc1"},
}
],
)

fake_llm = FakeExtractionModel(
responses=[remove_message], backup_responses=[remove_message] * 3
)

extractor = create_extractor(
llm=fake_llm,
tools=[MySchema],
enable_inserts=enable_inserts,
enable_deletes=True,
)

# Invoke the pipeline with some dummy "system" prompt and existing docs
result = await extractor.ainvoke(
{
"messages": [("system", "System instructions: handle doc removal.")],
"existing": existing_docs,
}
)

# The pipeline always returns final "messages" in result["messages"].
# Because "RemoveDoc" isn't a recognized schema in the final output,
# we won't see it among result["responses"] either way.
assert len(result["messages"]) == 1
final_ai_msg = result["messages"][0]
assert isinstance(final_ai_msg, AIMessage)

assert len(final_ai_msg.tool_calls) == 1
assert len(result["responses"]) == 1
assert result["responses"][0].__repr_name__() == "RemoveDoc" # type: ignore
6 changes: 4 additions & 2 deletions tests/unit_tests/test_strict_existing.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def test_validate_existing_strictness(
if isinstance(coerced, dict):
assert all(k in tools or k == "__any__" for k in coerced)
elif isinstance(coerced, list):
assert all(s.schema_name in tools or s.schema_name == "__any__" for s in coerced)
assert all(
s.schema_name in tools or s.schema_name == "__any__" for s in coerced
)
elif existing_schema_policy is False:
pass
pass
Loading

0 comments on commit 18b6e46

Please sign in to comment.