From c6d22846e2d1a63036378634938543f796ca28a7 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 25 Jan 2025 19:10:40 +0100 Subject: [PATCH] core: Add ruff rules RET --- libs/core/langchain_core/language_models/llms.py | 12 ++++-------- libs/core/langchain_core/output_parsers/json.py | 1 + libs/core/langchain_core/prompts/chat.py | 3 +-- libs/core/langchain_core/prompts/few_shot.py | 6 ++---- libs/core/langchain_core/prompts/string.py | 3 +-- libs/core/langchain_core/retrievers.py | 3 +-- libs/core/langchain_core/runnables/base.py | 3 +-- libs/core/langchain_core/runnables/history.py | 3 +-- libs/core/langchain_core/tracers/base.py | 3 +-- libs/core/langchain_core/tracers/stdout.py | 3 +-- libs/core/langchain_core/utils/input.py | 3 +-- libs/core/langchain_core/utils/json.py | 4 +--- libs/core/langchain_core/vectorstores/in_memory.py | 6 ++---- libs/core/langchain_core/vectorstores/utils.py | 3 +-- libs/core/pyproject.toml | 2 +- .../test_length_based_example_selector.py | 3 +-- .../tests/unit_tests/tracers/test_base_tracer.py | 3 +-- 17 files changed, 22 insertions(+), 42 deletions(-) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 7db54409356561..9222f5e8d6f3af 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -249,8 +249,7 @@ def update_cache( prompt = prompts[missing_prompt_idxs[i]] if llm_cache is not None: llm_cache.update(prompt, llm_string, result) - llm_output = new_results.llm_output - return llm_output + return new_results.llm_output async def aupdate_cache( @@ -283,8 +282,7 @@ async def aupdate_cache( prompt = prompts[missing_prompt_idxs[i]] if llm_cache: await llm_cache.aupdate(prompt, llm_string, result) - llm_output = new_results.llm_output - return llm_output + return new_results.llm_output class BaseLLM(BaseLanguageModel[str], ABC): @@ -953,10 +951,9 @@ def generate( callback_managers, prompts, run_name_list, run_ids_list ) ] - output = self._generate_helper( + return self._generate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) - return output if len(missing_prompts) > 0: run_managers = [ callback_managers[idx].on_llm_start( @@ -1201,14 +1198,13 @@ async def agenerate( ] ) run_managers = [r[0] for r in run_managers] # type: ignore[misc] - output = await self._agenerate_helper( + return await self._agenerate_helper( prompts, stop, run_managers, # type: ignore[arg-type] bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) - return output if len(missing_prompts) > 0: run_managers = await asyncio.gather( *[ diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index bc38d90a264f46..e426fcd5006890 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -53,6 +53,7 @@ def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: return pydantic_object.model_json_schema() if issubclass(pydantic_object, pydantic.v1.BaseModel): return pydantic_object.schema() + return None def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index b9657044fa83f0..0a3066353d6a0f 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -629,8 +629,7 @@ def input_variables(self) -> list[str]: List of input variable names. """ prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt] - input_variables = [iv for prompt in prompts for iv in prompt.input_variables] - return input_variables + return [iv for prompt in prompts for iv in prompt.input_variables] def format(self, **kwargs: Any) -> BaseMessage: """Format the prompt template. diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 921b0b065567ce..06589b57cb7848 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -389,12 +389,11 @@ def format_messages(self, **kwargs: Any) -> list[BaseMessage]: {k: e[k] for k in self.example_prompt.input_variables} for e in examples ] # Format the examples. - messages = [ + return [ message for example in examples for message in self.example_prompt.format_messages(**example) ] - return messages async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format kwargs into a list of messages. @@ -411,12 +410,11 @@ async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: {k: e[k] for k in self.example_prompt.input_variables} for e in examples ] # Format the examples. - messages = [ + return [ message for example in examples for message in await self.example_prompt.aformat_messages(**example) ] - return messages def format(self, **kwargs: Any) -> str: """Format the prompt with inputs generating a string. diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 61fc1d35d4f0b4..b0de417ab1ff81 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -99,8 +99,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]: # noqa for insecure warning elsewhere env = Environment() # noqa: S701 ast = env.parse(template) - variables = meta.find_undeclared_variables(ast) - return variables + return meta.find_undeclared_variables(ast) def mustache_formatter(template: str, /, **kwargs: Any) -> str: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 0f70c6e492718b..e929208875be35 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -206,8 +206,7 @@ def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: default_retriever_name = default_retriever_name[:-9] default_retriever_name = default_retriever_name.lower() - ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) - return ls_params + return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 291f0a270b951c..bad6101e710127 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -509,10 +509,9 @@ def config_schema( if field_name in [i for i in include if i != "configurable"] }, } - model = create_model_v2( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] self.get_name("Config"), field_definitions=all_fields ) - return model def get_config_jsonschema( self, *, include: Optional[Sequence[str]] = None diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 6022d9c1a32128..bd13022c0d4099 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -399,8 +399,7 @@ def get_input_schema( @property @override def OutputType(self) -> type[Output]: - output_type = self._history_chain.OutputType - return output_type + return self._history_chain.OutputType def get_output_schema( self, config: Optional[RunnableConfig] = None diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index f3ae965f6025cc..5b9fea090adeb9 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -180,11 +180,10 @@ def on_retry( Returns: The run. """ - llm_run = self._llm_run_with_retry_event( + return self._llm_run_with_retry_event( retry_state=retry_state, run_id=run_id, ) - return llm_run def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for an LLM run. diff --git a/libs/core/langchain_core/tracers/stdout.py b/libs/core/langchain_core/tracers/stdout.py index 3643724f9e2b46..98364da3bb9497 100644 --- a/libs/core/langchain_core/tracers/stdout.py +++ b/libs/core/langchain_core/tracers/stdout.py @@ -84,13 +84,12 @@ def get_breadcrumbs(self, run: Run) -> str: A string with the breadcrumbs of the run. """ parents = self.get_parents(run)[::-1] - string = " > ".join( + return " > ".join( f"{parent.run_type}:{parent.name}" if i != len(parents) - 1 else f"{parent.run_type}:{parent.name}" for i, parent in enumerate(parents + [run]) ) - return string # logging methods def _on_chain_start(self, run: Run) -> None: diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py index 05e9bbcfcedc36..afa3bf758c722e 100644 --- a/libs/core/langchain_core/utils/input.py +++ b/libs/core/langchain_core/utils/input.py @@ -26,8 +26,7 @@ def get_color_mapping( colors = list(_TEXT_COLOR_MAPPING.keys()) if excluded_colors is not None: colors = [c for c in colors if c not in excluded_colors] - color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} - return color_mapping + return {item: colors[i % len(colors)] for i, item in enumerate(items)} def get_colored_text(text: str, color: str) -> str: diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 8aedfaf339b711..7eb7ec9e7602a5 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -26,15 +26,13 @@ def _custom_parser(multiline_string: str) -> str: if isinstance(multiline_string, (bytes, bytearray)): multiline_string = multiline_string.decode() - multiline_string = re.sub( + return re.sub( r'("action_input"\:\s*")(.*?)(")', _replace_new_line, multiline_string, flags=re.DOTALL, ) - return multiline_string - # Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py # MIT License diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index ab32c7cdacbe94..3cb6f310895d79 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -381,23 +381,21 @@ def similarity_search_with_score( **kwargs: Any, ) -> list[tuple[Document, float]]: embedding = self.embedding.embed_query(query) - docs = self.similarity_search_with_score_by_vector( + return self.similarity_search_with_score_by_vector( embedding, k, **kwargs, ) - return docs async def asimilarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any ) -> list[tuple[Document, float]]: embedding = await self.embedding.aembed_query(query) - docs = self.similarity_search_with_score_by_vector( + return self.similarity_search_with_score_by_vector( embedding, k, **kwargs, ) - return docs def similarity_search_by_vector( self, diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index add5e2ef4ab4f3..242e3e4a06aeed 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -57,8 +57,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: x = np.array(x, dtype=np.float32) y = np.array(y, dtype=np.float32) - z = 1 - np.array(simd.cdist(x, y, metric="cosine")) - return z + return 1 - np.array(simd.cdist(x, y, metric="cosine")) except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index a14282b8476dc7..c8ef33eedeb5d5 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,7 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",] +select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RET", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",] ignore = [ "COM812", "UP007", "S110", "S112",] [tool.coverage.run] diff --git a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py index db3a88160dbbdf..63f406d1b27fb5 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py +++ b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py @@ -17,12 +17,11 @@ def selector() -> LengthBasedExampleSelector: """Get length based selector to use in tests.""" prompts = PromptTemplate(input_variables=["question"], template="{question}") - selector = LengthBasedExampleSelector( + return LengthBasedExampleSelector( examples=EXAMPLES, example_prompt=prompts, max_length=30, ) - return selector def test_selector_valid(selector: LengthBasedExampleSelector) -> None: diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index 80d7a9297492c4..aaa34a662f2c8d 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -599,8 +599,7 @@ def test_tracer_nested_runs_on_error() -> None: def _get_mock_client() -> Client: mock_session = MagicMock() - client = Client(session=mock_session, api_key="test") - return client + return Client(session=mock_session, api_key="test") def test_traceable_to_tracing() -> None: