Skip to content

Commit

Permalink
core: Add ruff rules PYI
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 25, 2025
1 parent dbb6b7b commit 1c92168
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 44 deletions.
14 changes: 5 additions & 9 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SkipValidation,
model_validator,
)
from typing_extensions import Self

from langchain_core._api import deprecated
from langchain_core.load import Serializable
Expand Down Expand Up @@ -455,11 +456,6 @@ async def aformat(self, **kwargs: Any) -> BaseMessage:
)


_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)


class _TextTemplateParam(TypedDict, total=False):
text: Union[str, dict]

Expand Down Expand Up @@ -487,13 +483,13 @@ def get_lc_namespace(cls) -> list[str]:

@classmethod
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a string template.
Args:
Expand Down Expand Up @@ -581,11 +577,11 @@ def from_template(

@classmethod
def from_template_file(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template_file: Union[str, Path],
input_variables: list[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a template file.
Args:
Expand Down
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4139,7 +4139,7 @@ def get_output_schema(
module_name=module,
)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
Expand Down Expand Up @@ -4514,7 +4514,7 @@ def get_graph(self, config: RunnableConfig | None = None) -> Graph:

return graph

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
Expand Down Expand Up @@ -5779,22 +5779,24 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:


class _RunnableCallableSync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ...
def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...


class _RunnableCallableAsync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ...
def __call__(
self, _in: Input, /, *, config: RunnableConfig
) -> Awaitable[Output]: ...


class _RunnableCallableIterator(Protocol[Input, Output]):
def __call__(
self, __in: Iterator[Input], *, config: RunnableConfig
self, _in: Iterator[Input], /, *, config: RunnableConfig
) -> Iterator[Output]: ...


class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
def __call__(
self, __in: AsyncIterator[Input], *, config: RunnableConfig
self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
) -> AsyncIterator[Output]: ...


Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def __radd__(self, other: AddableDict) -> AddableDict:
class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition."""

def __add__(self, __x: _T_contra) -> _T_co: ...
def __add__(self, x: _T_contra, /) -> _T_co: ...


Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
Expand Down
14 changes: 12 additions & 2 deletions libs/core/langchain_core/utils/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class NoLock:
async def __aenter__(self) -> None:
pass

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
return False


Expand Down Expand Up @@ -220,7 +225,12 @@ def __iter__(self) -> Iterator[AsyncIterator[T]]:
async def __aenter__(self) -> "Tee[T]":
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
await self.aclose()
return False

Expand Down
15 changes: 13 additions & 2 deletions libs/core/langchain_core/utils/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice
from types import TracebackType
from typing import (
Any,
Generic,
Expand All @@ -22,7 +23,12 @@ class NoLock:
def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
return False


Expand Down Expand Up @@ -166,7 +172,12 @@ def __iter__(self) -> Iterator[Iterator[T]]:
def __enter__(self) -> "Tee[T]":
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
self.close()
return False

Expand Down
25 changes: 11 additions & 14 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,7 @@ def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...

def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
Expand Down Expand Up @@ -488,28 +483,30 @@ def _create_root_model_cached(

@lru_cache(maxsize=256)
def _create_model_cached(
__model_name: str,
model_name: str,
/,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
__model_name,
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)


def create_model(
__model_name: str,
__module_name: Optional[str] = None,
model_name: str,
module_name: Optional[str] = None,
/,
**field_definitions: Any,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Please use create_model_v2 instead of this function.
Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.
Expand All @@ -521,8 +518,8 @@ def create_model(
kwargs["root"] = field_definitions.pop("__root__")

return create_model_v2(
__model_name,
module_name=__module_name,
model_name,
module_name=module_name,
field_definitions=field_definitions,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PYI", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
10 changes: 5 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5109,22 +5109,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
be runthroughaasync path (issue #13407).
"""

def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

async def idchain_async(__input: dict) -> bool:
async def idchain_async(_input: dict, /) -> bool:
return True

idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)

def func(__input: dict) -> Runnable:
def func(_input: dict, /) -> Runnable:
return idchain

assert await RunnableLambda(func).ainvoke({})


def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)
Expand All @@ -5144,7 +5144,7 @@ def idchain_sync(__input: dict) -> bool:


async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class AnyStr(str):
__slots__ = ()

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, str)


Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,7 @@ class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x

@tool
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,12 +969,12 @@ class Tool(typed_dict):
)
def test_convert_union_type_py_39() -> None:
@tool
def magic_function(input: int | float) -> str:
def magic_function(input: int | str) -> str:
"""Compute a magic function."""

result = convert_to_openai_function(magic_function)
assert result["parameters"]["properties"]["input"] == {
"anyOf": [{"type": "integer"}, {"type": "number"}]
"anyOf": [{"type": "integer"}, {"type": "string"}]
}


Expand Down

0 comments on commit 1c92168

Please sign in to comment.