Skip to content

Commit

Permalink
Extend is_traceable to support class call (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Dec 21, 2023
1 parent 407ebef commit 43fd528
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
10 changes: 9 additions & 1 deletion python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,17 @@ def get_run_tree_context() -> Optional[run_trees.RunTree]:
return _PARENT_RUN_TREE.get()


def _is_traceable_function(func: Callable) -> bool:
return getattr(func, "__langsmith_traceable__", False)


def is_traceable_function(func: Callable) -> bool:
"""Check if a function is @traceable decorated."""
return getattr(func, "__langsmith_traceable__", False)
return (
_is_traceable_function(func)
or (isinstance(func, functools.partial) and _is_traceable_function(func.func))
or (hasattr(func, "__call__") and _is_traceable_function(func.__call__))
)


def _get_inputs(
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.0.72"
version = "0.0.73"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
50 changes: 49 additions & 1 deletion python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import functools
import inspect
from typing import Any

from langsmith.run_helpers import _get_inputs, as_runnable, traceable
from langsmith.run_helpers import (
_get_inputs,
as_runnable,
is_traceable_function,
traceable,
)


def test__get_inputs_with_no_args() -> None:
Expand Down Expand Up @@ -207,3 +213,45 @@ async def my_function(a, b, d):
]
)
assert result == [6, 7]


def test_is_traceable_function() -> None:
@traceable()
def my_function(a: int, b: int, d: int) -> int:
return a + b + d

assert is_traceable_function(my_function)


def test_is_traceable_partial_function() -> None:
@traceable()
def my_function(a: int, b: int, d: int) -> int:
return a + b + d

partial_function = functools.partial(my_function, 1, 2)

assert is_traceable_function(partial_function)


def test_is_not_traceable_function() -> None:
def my_function(a: int, b: int, d: int) -> int:
return a + b + d

assert not is_traceable_function(my_function)


def test_is_traceable_class_call() -> None:
class Foo:
@traceable()
def __call__(self, a: int, b: int) -> None:
pass

assert is_traceable_function(Foo())


def test_is_not_traceable_class_call() -> None:
class Foo:
def __call__(self, a: int, b: int) -> None:
pass

assert not is_traceable_function(Foo())

0 comments on commit 43fd528

Please sign in to comment.