Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add ruff rules PYI #29335

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SkipValidation,
model_validator,
)
from typing_extensions import Self

from langchain_core._api import deprecated
from langchain_core.load import Serializable
Expand Down Expand Up @@ -455,11 +456,6 @@ async def aformat(self, **kwargs: Any) -> BaseMessage:
)


_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)


class _TextTemplateParam(TypedDict, total=False):
text: Union[str, dict]

Expand Down Expand Up @@ -487,13 +483,13 @@ def get_lc_namespace(cls) -> list[str]:

@classmethod
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a string template.

Args:
Expand Down Expand Up @@ -581,11 +577,11 @@ def from_template(

@classmethod
def from_template_file(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template_file: Union[str, Path],
input_variables: list[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a template file.

Args:
Expand Down
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4139,7 +4139,7 @@ def get_output_schema(
module_name=module,
)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
Expand Down Expand Up @@ -4514,7 +4514,7 @@ def get_graph(self, config: RunnableConfig | None = None) -> Graph:

return graph

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
Expand Down Expand Up @@ -5779,22 +5779,24 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:


class _RunnableCallableSync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ...
def __call__(self, _in: Input, /, *, config: RunnableConfig) -> Output: ...


class _RunnableCallableAsync(Protocol[Input, Output]):
def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ...
def __call__(
self, _in: Input, /, *, config: RunnableConfig
) -> Awaitable[Output]: ...


class _RunnableCallableIterator(Protocol[Input, Output]):
def __call__(
self, __in: Iterator[Input], *, config: RunnableConfig
self, _in: Iterator[Input], /, *, config: RunnableConfig
) -> Iterator[Output]: ...


class _RunnableCallableAsyncIterator(Protocol[Input, Output]):
def __call__(
self, __in: AsyncIterator[Input], *, config: RunnableConfig
self, _in: AsyncIterator[Input], /, *, config: RunnableConfig
) -> AsyncIterator[Output]: ...


Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def __radd__(self, other: AddableDict) -> AddableDict:
class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition."""

def __add__(self, __x: _T_contra) -> _T_co: ...
def __add__(self, x: _T_contra, /) -> _T_co: ...


Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
Expand Down
14 changes: 12 additions & 2 deletions libs/core/langchain_core/utils/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class NoLock:
async def __aenter__(self) -> None:
pass

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
return False


Expand Down Expand Up @@ -220,7 +225,12 @@ def __iter__(self) -> Iterator[AsyncIterator[T]]:
async def __aenter__(self) -> "Tee[T]":
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
await self.aclose()
return False

Expand Down
15 changes: 13 additions & 2 deletions libs/core/langchain_core/utils/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice
from types import TracebackType
from typing import (
Any,
Generic,
Expand All @@ -22,7 +23,12 @@ class NoLock:
def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
return False


Expand Down Expand Up @@ -166,7 +172,12 @@ def __iter__(self) -> Iterator[Iterator[T]]:
def __enter__(self) -> "Tee[T]":
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
self.close()
return False

Expand Down
25 changes: 11 additions & 14 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,7 @@ def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...

def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
Expand Down Expand Up @@ -488,28 +483,30 @@ def _create_root_model_cached(

@lru_cache(maxsize=256)
def _create_model_cached(
__model_name: str,
model_name: str,
/,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
__model_name,
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)


def create_model(
__model_name: str,
__module_name: Optional[str] = None,
model_name: str,
module_name: Optional[str] = None,
/,
**field_definitions: Any,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.

Please use create_model_v2 instead of this function.

Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.

Expand All @@ -521,8 +518,8 @@ def create_model(
kwargs["root"] = field_definitions.pop("__root__")

return create_model_v2(
__model_name,
module_name=__module_name,
model_name,
module_name=module_name,
field_definitions=field_definitions,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PYI", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
10 changes: 5 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5109,22 +5109,22 @@ async def test_ainvoke_on_returned_runnable() -> None:
be runthroughaasync path (issue #13407).
"""

def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

async def idchain_async(__input: dict) -> bool:
async def idchain_async(_input: dict, /) -> bool:
return True

idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)

def func(__input: dict) -> Runnable:
def func(_input: dict, /) -> Runnable:
return idchain

assert await RunnableLambda(func).ainvoke({})


def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)
Expand All @@ -5144,7 +5144,7 @@ def idchain_sync(__input: dict) -> bool:


async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
def idchain_sync(_input: dict, /) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class AnyStr(str):
__slots__ = ()

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, str)


Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,7 @@ class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x

@tool
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,12 +969,12 @@ class Tool(typed_dict):
)
def test_convert_union_type_py_39() -> None:
@tool
def magic_function(input: int | float) -> str:
def magic_function(input: int | str) -> str:
"""Compute a magic function."""

result = convert_to_openai_function(magic_function)
assert result["parameters"]["properties"]["input"] == {
"anyOf": [{"type": "integer"}, {"type": "number"}]
"anyOf": [{"type": "integer"}, {"type": "string"}]
}


Expand Down
Loading