Skip to content

Commit

Permalink
core: Add ruff rules RET
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 25, 2025
1 parent 24c1586 commit c6d2284
Show file tree
Hide file tree
Showing 17 changed files with 22 additions and 42 deletions.
12 changes: 4 additions & 8 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ def update_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output


async def aupdate_cache(
Expand Down Expand Up @@ -283,8 +282,7 @@ async def aupdate_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output


class BaseLLM(BaseLanguageModel[str], ABC):
Expand Down Expand Up @@ -953,10 +951,9 @@ def generate(
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
return self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = [
callback_managers[idx].on_llm_start(
Expand Down Expand Up @@ -1201,14 +1198,13 @@ async def agenerate(
]
)
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
output = await self._agenerate_helper(
return await self._agenerate_helper(
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
if len(missing_prompts) > 0:
run_managers = await asyncio.gather(
*[
Expand Down
1 change: 1 addition & 0 deletions libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
return pydantic_object.model_json_schema()
if issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()
return None

def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ def input_variables(self) -> list[str]:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables
return [iv for prompt in prompts for iv in prompt.input_variables]

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Expand Down
6 changes: 2 additions & 4 deletions libs/core/langchain_core/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,11 @@ def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
return [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
return messages

async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages.
Expand All @@ -411,12 +410,11 @@ async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
return [
message
for example in examples
for message in await self.example_prompt.aformat_messages(**example)
]
return messages

def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/prompts/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
# noqa for insecure warning elsewhere
env = Environment() # noqa: S701
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
return meta.find_undeclared_variables(ast)


def mustache_formatter(template: str, /, **kwargs: Any) -> str:
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
default_retriever_name = default_retriever_name[:-9]
default_retriever_name = default_retriever_name.lower()

ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
return ls_params
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)

def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,9 @@ def config_schema(
if field_name in [i for i in include if i != "configurable"]
},
}
model = create_model_v2( # type: ignore[call-overload]
return create_model_v2( # type: ignore[call-overload]
self.get_name("Config"), field_definitions=all_fields
)
return model

def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ def get_input_schema(
@property
@override
def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType
return output_type
return self._history_chain.OutputType

def get_output_schema(
self, config: Optional[RunnableConfig] = None
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,10 @@ def on_retry(
Returns:
The run.
"""
llm_run = self._llm_run_with_retry_event(
return self._llm_run_with_retry_event(
retry_state=retry_state,
run_id=run_id,
)
return llm_run

def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run.
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/tracers/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ def get_breadcrumbs(self, run: Run) -> str:
A string with the breadcrumbs of the run.
"""
parents = self.get_parents(run)[::-1]
string = " > ".join(
return " > ".join(
f"{parent.run_type}:{parent.name}"
if i != len(parents) - 1
else f"{parent.run_type}:{parent.name}"
for i, parent in enumerate(parents + [run])
)
return string

# logging methods
def _on_chain_start(self, run: Run) -> None:
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/utils/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def get_color_mapping(
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
return {item: colors[i % len(colors)] for i, item in enumerate(items)}


def get_colored_text(text: str, color: str) -> str:
Expand Down
4 changes: 1 addition & 3 deletions libs/core/langchain_core/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ def _custom_parser(multiline_string: str) -> str:
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()

multiline_string = re.sub(
return re.sub(
r'("action_input"\:\s*")(.*?)(")',
_replace_new_line,
multiline_string,
flags=re.DOTALL,
)

return multiline_string


# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
# MIT License
Expand Down
6 changes: 2 additions & 4 deletions libs/core/langchain_core/vectorstores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,23 +381,21 @@ def similarity_search_with_score(
**kwargs: Any,
) -> list[tuple[Document, float]]:
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
return self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs

async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> list[tuple[Document, float]]:
embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector(
return self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs

def similarity_search_by_vector(
self,
Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/vectorstores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:

x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
return z
return 1 - np.array(simd.cdist(x, y, metric="cosine"))
except ImportError:
logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RET", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
def selector() -> LengthBasedExampleSelector:
"""Get length based selector to use in tests."""
prompts = PromptTemplate(input_variables=["question"], template="{question}")
selector = LengthBasedExampleSelector(
return LengthBasedExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
max_length=30,
)
return selector


def test_selector_valid(selector: LengthBasedExampleSelector) -> None:
Expand Down
3 changes: 1 addition & 2 deletions libs/core/tests/unit_tests/tracers/test_base_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,7 @@ def test_tracer_nested_runs_on_error() -> None:

def _get_mock_client() -> Client:
mock_session = MagicMock()
client = Client(session=mock_session, api_key="test")
return client
return Client(session=mock_session, api_key="test")


def test_traceable_to_tracing() -> None:
Expand Down

0 comments on commit c6d2284

Please sign in to comment.