From 07ecfd056cd2f1f8d1ba0a2a729b57b3852cddfa Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 25 Jan 2025 19:16:56 +0100 Subject: [PATCH] core: Add ruff rules RET --- libs/core/langchain_core/_api/deprecation.py | 3 +- libs/core/langchain_core/agents.py | 18 +- .../langchain_core/beta/runnables/context.py | 16 +- .../document_loaders/langsmith.py | 9 +- libs/core/langchain_core/documents/base.py | 10 +- .../example_selectors/length_based.py | 5 +- .../example_selectors/semantic_similarity.py | 3 +- libs/core/langchain_core/indexing/api.py | 15 +- .../langchain_core/language_models/base.py | 6 +- .../language_models/chat_models.py | 45 +++-- .../langchain_core/language_models/llms.py | 33 ++-- libs/core/langchain_core/load/dump.py | 9 +- libs/core/langchain_core/load/load.py | 11 +- libs/core/langchain_core/messages/ai.py | 2 +- libs/core/langchain_core/messages/base.py | 18 +- libs/core/langchain_core/messages/chat.py | 5 +- libs/core/langchain_core/messages/tool.py | 33 ++-- libs/core/langchain_core/messages/utils.py | 75 ++++----- .../langchain_core/output_parsers/base.py | 52 +++--- .../langchain_core/output_parsers/json.py | 28 +-- .../output_parsers/openai_functions.py | 18 +- .../output_parsers/openai_tools.py | 8 +- .../langchain_core/output_parsers/pydantic.py | 7 +- .../core/langchain_core/output_parsers/xml.py | 3 +- .../langchain_core/outputs/chat_generation.py | 13 +- .../core/langchain_core/outputs/generation.py | 7 +- libs/core/langchain_core/prompts/chat.py | 74 ++++---- libs/core/langchain_core/prompts/few_shot.py | 20 +-- .../prompts/few_shot_with_templates.py | 10 +- libs/core/langchain_core/prompts/image.py | 13 +- libs/core/langchain_core/prompts/prompt.py | 10 +- libs/core/langchain_core/prompts/string.py | 3 +- .../core/langchain_core/prompts/structured.py | 5 +- libs/core/langchain_core/retrievers.py | 3 +- libs/core/langchain_core/runnables/base.py | 159 +++++++----------- .../langchain_core/runnables/configurable.py | 33 ++-- .../langchain_core/runnables/fallbacks.py | 5 +- libs/core/langchain_core/runnables/graph.py | 8 +- .../langchain_core/runnables/graph_mermaid.py | 11 +- libs/core/langchain_core/runnables/history.py | 33 ++-- .../langchain_core/runnables/passthrough.py | 12 +- libs/core/langchain_core/runnables/utils.py | 9 +- libs/core/langchain_core/tools/base.py | 117 ++++++------- libs/core/langchain_core/tools/convert.py | 86 +++++----- libs/core/langchain_core/tracers/base.py | 3 +- libs/core/langchain_core/tracers/core.py | 14 +- .../langchain_core/tracers/event_stream.py | 2 +- libs/core/langchain_core/tracers/stdout.py | 3 +- libs/core/langchain_core/utils/_merge.py | 23 ++- libs/core/langchain_core/utils/env.py | 15 +- .../langchain_core/utils/function_calling.py | 44 +++-- libs/core/langchain_core/utils/input.py | 3 +- libs/core/langchain_core/utils/json.py | 4 +- libs/core/langchain_core/utils/json_schema.py | 5 +- libs/core/langchain_core/utils/mustache.py | 6 +- libs/core/langchain_core/utils/pydantic.py | 31 ++-- libs/core/langchain_core/utils/strings.py | 7 +- libs/core/langchain_core/utils/utils.py | 38 ++--- libs/core/langchain_core/vectorstores/base.py | 32 ++-- .../langchain_core/vectorstores/in_memory.py | 6 +- libs/core/pyproject.toml | 2 +- .../test_length_based_example_selector.py | 3 +- .../unit_tests/prompts/test_structured.py | 5 +- .../tests/unit_tests/runnables/test_graph.py | 3 +- .../unit_tests/runnables/test_runnable.py | 27 ++- .../runnables/test_runnable_events_v1.py | 6 +- .../runnables/test_runnable_events_v2.py | 6 +- .../runnables/test_tracing_interops.py | 33 ++-- .../unit_tests/tracers/test_base_tracer.py | 3 +- 69 files changed, 603 insertions(+), 784 deletions(-) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index a2cccaa2be3a3..8c7b0940cb5d1 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -443,8 +443,7 @@ def warn_deprecated( f"{removal}" ) raise NotImplementedError(msg) - else: - removal = f"in {removal}" + removal = f"in {removal}" if not message: message = "" diff --git a/libs/core/langchain_core/agents.py b/libs/core/langchain_core/agents.py index 3a7fea6d328c5..c86b2dc3a99e8 100644 --- a/libs/core/langchain_core/agents.py +++ b/libs/core/langchain_core/agents.py @@ -172,8 +172,7 @@ def _convert_agent_action_to_messages( """ if isinstance(agent_action, AgentActionMessageLog): return agent_action.message_log - else: - return [AIMessage(content=agent_action.log)] + return [AIMessage(content=agent_action.log)] def _convert_agent_observation_to_messages( @@ -192,14 +191,13 @@ def _convert_agent_observation_to_messages( """ if isinstance(agent_action, AgentActionMessageLog): return [_create_function_message(agent_action, observation)] - else: - content = observation - if not isinstance(observation, str): - try: - content = json.dumps(observation, ensure_ascii=False) - except Exception: - content = str(observation) - return [HumanMessage(content=content)] + content = observation + if not isinstance(observation, str): + try: + content = json.dumps(observation, ensure_ascii=False) + except Exception: + content = str(observation) + return [HumanMessage(content=content)] def _create_function_message( diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 2be721387cbcd..41b4c67401030 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -56,11 +56,10 @@ def _key_from_id(id_: str) -> str: wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1] if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET): return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)] - elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET): + if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET): return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)] - else: - msg = f"Invalid context config id {id_}" - raise ValueError(msg) + msg = f"Invalid context config id {id_}" + raise ValueError(msg) def _config_with_context( @@ -190,8 +189,7 @@ def invoke( configurable = config.get("configurable", {}) if isinstance(self.key, list): return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)} - else: - return configurable[self.ids[0]]() + return configurable[self.ids[0]]() async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -201,8 +199,7 @@ async def ainvoke( if isinstance(self.key, list): values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) return dict(zip(self.key, values)) - else: - return await configurable[self.ids[0]]() + return await configurable[self.ids[0]]() SetValue = Union[ @@ -397,5 +394,4 @@ def setter( def _print_keys(keys: Union[str, Sequence[str]]) -> str: if isinstance(keys, str): return f"'{keys}'" - else: - return ", ".join(f"'{k}'" for k in keys) + return ", ".join(f"'{k}'" for k in keys) diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 39fda02af5792..107194b0ee6f7 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -122,8 +122,7 @@ def lazy_load(self) -> Iterator[Document]: def _stringify(x: Union[str, dict]) -> str: if isinstance(x, str): return x - else: - try: - return json.dumps(x, indent=2) - except Exception: - return str(x) + try: + return json.dumps(x, indent=2) + except Exception: + return str(x) diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 2adfe1a718397..63a8cfeb3d658 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -45,8 +45,7 @@ class BaseMedia(Serializable): def cast_id_to_str(cls, id_value: Any) -> Optional[str]: if id_value is not None: return str(id_value) - else: - return id_value + return id_value class Blob(BaseMedia): @@ -163,9 +162,9 @@ def as_bytes(self) -> bytes: """Read data as bytes.""" if isinstance(self.data, bytes): return self.data - elif isinstance(self.data, str): + if isinstance(self.data, str): return self.data.encode(self.encoding) - elif self.data is None and self.path: + if self.data is None and self.path: with open(str(self.path), "rb") as f: return f.read() else: @@ -306,5 +305,4 @@ def __str__(self) -> str: # a more general solution of formatting content directly inside the prompts. if self.metadata: return f"page_content='{self.page_content}' metadata={self.metadata}" - else: - return f"page_content='{self.page_content}'" + return f"page_content='{self.page_content}'" diff --git a/libs/core/langchain_core/example_selectors/length_based.py b/libs/core/langchain_core/example_selectors/length_based.py index 792e6317cfd8c..ec9566d75aca2 100644 --- a/libs/core/langchain_core/example_selectors/length_based.py +++ b/libs/core/langchain_core/example_selectors/length_based.py @@ -79,9 +79,8 @@ def select_examples(self, input_variables: dict[str, str]) -> list[dict]: new_length = remaining_length - self.example_text_lengths[i] if new_length < 0: break - else: - examples.append(self.examples[i]) - remaining_length = new_length + examples.append(self.examples[i]) + remaining_length = new_length i += 1 return examples diff --git a/libs/core/langchain_core/example_selectors/semantic_similarity.py b/libs/core/langchain_core/example_selectors/semantic_similarity.py index b27122ec36d4e..30cbba5c2de8b 100644 --- a/libs/core/langchain_core/example_selectors/semantic_similarity.py +++ b/libs/core/langchain_core/example_selectors/semantic_similarity.py @@ -54,8 +54,7 @@ def _example_to_text( ) -> str: if input_keys: return " ".join(sorted_values({key: example[key] for key in input_keys})) - else: - return " ".join(sorted_values(example)) + return " ".join(sorted_values(example)) def _documents_to_examples(self, documents: list[Document]) -> list[dict]: # Get the examples from the metadata. diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 11343d17f7184..3ee2027003225 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -152,16 +152,15 @@ def _get_source_id_assigner( """Get the source id from the document.""" if source_id_key is None: return lambda doc: None - elif isinstance(source_id_key, str): + if isinstance(source_id_key, str): return lambda doc: doc.metadata[source_id_key] - elif callable(source_id_key): + if callable(source_id_key): return source_id_key - else: - msg = ( - f"source_id_key should be either None, a string or a callable. " - f"Got {source_id_key} of type {type(source_id_key)}." - ) - raise ValueError(msg) + msg = ( + f"source_id_key should be either None, a string or a callable. " + f"Got {source_id_key} of type {type(source_id_key)}." + ) + raise ValueError(msg) def _deduplicate_in_order( diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 12445f2560f53..bf8ee6e3f97bc 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -141,8 +141,7 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool: """ if verbose is None: return _get_verbosity() - else: - return verbose + return verbose @property @override @@ -349,8 +348,7 @@ def get_token_ids(self, text: str) -> list[int]: """ if self.custom_get_token_ids is not None: return self.custom_get_token_ids(text) - else: - return _get_token_ids_default_method(text) + return _get_token_ids_default_method(text) def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6aaaf7d4ca80a..20618120789d3 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -259,16 +259,15 @@ def OutputType(self) -> Any: def _convert_input(self, input: LanguageModelInput) -> PromptValue: if isinstance(input, PromptValue): return input - elif isinstance(input, str): + if isinstance(input, str): return StringPromptValue(text=input) - elif isinstance(input, Sequence): + if isinstance(input, Sequence): return ChatPromptValue(messages=convert_to_messages(input)) - else: - msg = ( - f"Invalid input type {type(input)}. " - "Must be a PromptValue, str, or list of BaseMessages." - ) - raise ValueError(msg) # noqa: TRY004 + msg = ( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + raise ValueError(msg) # noqa: TRY004 def invoke( self, @@ -565,10 +564,9 @@ def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> st _cleanup_llm_representation(serialized_repr, 1) llm_string = json.dumps(serialized_repr, sort_keys=True) return llm_string + "---" + param_string - else: - params = self._get_invocation_params(stop=stop, **kwargs) - params = {**params, **kwargs} - return str(sorted(params.items())) + params = self._get_invocation_params(stop=stop, **kwargs) + params = {**params, **kwargs} + return str(sorted(params.items())) def generate( self, @@ -1024,9 +1022,8 @@ def __call__( ).generations[0][0] if isinstance(generation, ChatGeneration): return generation.message - else: - msg = "Unexpected generation type" - raise ValueError(msg) # noqa: TRY004 + msg = "Unexpected generation type" + raise ValueError(msg) # noqa: TRY004 async def _call_async( self, @@ -1041,9 +1038,8 @@ async def _call_async( generation = result.generations[0][0] if isinstance(generation, ChatGeneration): return generation.message - else: - msg = "Unexpected generation type" - raise ValueError(msg) # noqa: TRY004 + msg = "Unexpected generation type" + raise ValueError(msg) # noqa: TRY004 @deprecated("0.1.7", alternative="invoke", removal="1.0") def call_as_llm( @@ -1059,9 +1055,8 @@ def predict( result = self([HumanMessage(content=text)], stop=_stop, **kwargs) if isinstance(result.content, str): return result.content - else: - msg = "Cannot use predict when output is not a string." - raise ValueError(msg) # noqa: TRY004 + msg = "Cannot use predict when output is not a string." + raise ValueError(msg) # noqa: TRY004 @deprecated("0.1.7", alternative="invoke", removal="1.0") def predict_messages( @@ -1084,9 +1079,8 @@ async def apredict( ) if isinstance(result.content, str): return result.content - else: - msg = "Cannot use predict when output is not a string." - raise ValueError(msg) # noqa: TRY004 + msg = "Cannot use predict when output is not a string." + raise ValueError(msg) # noqa: TRY004 @deprecated("0.1.7", alternative="ainvoke", removal="1.0") async def apredict_messages( @@ -1259,8 +1253,7 @@ class AnswerWithJustification(BaseModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser + return llm | output_parser class SimpleChatModel(BaseChatModel): diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 4ba16f516965d..fb6863238a294 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -249,8 +249,7 @@ def update_cache( prompt = prompts[missing_prompt_idxs[i]] if llm_cache is not None: llm_cache.update(prompt, llm_string, result) - llm_output = new_results.llm_output - return llm_output + return new_results.llm_output async def aupdate_cache( @@ -283,8 +282,7 @@ async def aupdate_cache( prompt = prompts[missing_prompt_idxs[i]] if llm_cache: await llm_cache.aupdate(prompt, llm_string, result) - llm_output = new_results.llm_output - return llm_output + return new_results.llm_output class BaseLLM(BaseLanguageModel[str], ABC): @@ -328,16 +326,15 @@ def OutputType(self) -> type[str]: def _convert_input(self, input: LanguageModelInput) -> PromptValue: if isinstance(input, PromptValue): return input - elif isinstance(input, str): + if isinstance(input, str): return StringPromptValue(text=input) - elif isinstance(input, Sequence): + if isinstance(input, Sequence): return ChatPromptValue(messages=convert_to_messages(input)) - else: - msg = ( - f"Invalid input type {type(input)}. " - "Must be a PromptValue, str, or list of BaseMessages." - ) - raise ValueError(msg) # noqa: TRY004 + msg = ( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + raise ValueError(msg) # noqa: TRY004 def _get_ls_params( self, @@ -447,8 +444,7 @@ def batch( except Exception as e: if return_exceptions: return cast(list[str], [e for _ in inputs]) - else: - raise + raise else: batches = [ inputs[i : i + max_concurrency] @@ -493,8 +489,7 @@ async def abatch( except Exception as e: if return_exceptions: return cast(list[str], [e for _ in inputs]) - else: - raise + raise else: batches = [ inputs[i : i + max_concurrency] @@ -960,10 +955,9 @@ def generate( callback_managers, prompts, run_name_list, run_ids_list ) ] - output = self._generate_helper( + return self._generate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) - return output if len(missing_prompts) > 0: run_managers = [ callback_managers[idx].on_llm_start( @@ -1208,14 +1202,13 @@ async def agenerate( ] ) run_managers = [r[0] for r in run_managers] # type: ignore[misc] - output = await self._agenerate_helper( + return await self._agenerate_helper( prompts, stop, run_managers, # type: ignore[arg-type] bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) - return output if len(missing_prompts) > 0: run_managers = await asyncio.gather( *[ diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py index 00fae99d5287f..1ae28918fe5f9 100644 --- a/libs/core/langchain_core/load/dump.py +++ b/libs/core/langchain_core/load/dump.py @@ -16,8 +16,7 @@ def default(obj: Any) -> Any: """ if isinstance(obj, Serializable): return obj.to_json() - else: - return to_json_not_implemented(obj) + return to_json_not_implemented(obj) def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: @@ -43,14 +42,12 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: if pretty: indent = kwargs.pop("indent", 2) return json.dumps(obj, default=default, indent=indent, **kwargs) - else: - return json.dumps(obj, default=default, **kwargs) + return json.dumps(obj, default=default, **kwargs) except TypeError: if pretty: indent = kwargs.pop("indent", 2) return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs) - else: - return json.dumps(to_json_not_implemented(obj), **kwargs) + return json.dumps(to_json_not_implemented(obj), **kwargs) def dumpd(obj: Any) -> Any: diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index ff991789f4528..b1b6a911993e9 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -94,11 +94,10 @@ def __call__(self, value: dict[str, Any]) -> Any: [key] = value["id"] if key in self.secrets_map: return self.secrets_map[key] - else: - if self.secrets_from_env and key in os.environ and os.environ[key]: - return os.environ[key] - msg = f'Missing key "{key}" in load(secrets_map)' - raise KeyError(msg) + if self.secrets_from_env and key in os.environ and os.environ[key]: + return os.environ[key] + msg = f'Missing key "{key}" in load(secrets_map)' + raise KeyError(msg) if ( value.get("lc") == 1 @@ -127,7 +126,7 @@ def __call__(self, value: dict[str, Any]) -> Any: msg = f"Invalid namespace: {value}" raise ValueError(msg) # Has explicit import path. - elif mapping_key in self.import_mappings: + if mapping_key in self.import_mappings: import_path = self.import_mappings[mapping_key] # Split into module and name import_dir, name = import_path[:-1], import_path[-1] diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index c317bf099c782..680c02a83c7fa 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -396,7 +396,7 @@ def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, AIMessageChunk): return add_ai_message_chunks(self, other) - elif isinstance(other, (list, tuple)) and all( + if isinstance(other, (list, tuple)) and all( isinstance(o, AIMessageChunk) for o in other ): return add_ai_message_chunks(self, *other) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 9eab1ed431af2..b7e60546abb1e 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -61,8 +61,7 @@ class BaseMessage(Serializable): def cast_id_to_str(cls, id_value: Any) -> Optional[str]: if id_value is not None: return str(id_value) - else: - return id_value + return id_value def __init__( self, content: Union[str, list[Union[str, dict]]], **kwargs: Any @@ -206,7 +205,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore self.response_metadata, other.response_metadata ), ) - elif isinstance(other, list) and all( + if isinstance(other, list) and all( isinstance(o, BaseMessageChunk) for o in other ): content = merge_content(self.content, *(o.content for o in other)) @@ -222,13 +221,12 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore additional_kwargs=additional_kwargs, response_metadata=response_metadata, ) - else: - msg = ( - 'unsupported operand type(s) for +: "' - f"{self.__class__.__name__}" - f'" and "{other.__class__.__name__}"' - ) - raise TypeError(msg) + msg = ( + 'unsupported operand type(s) for +: "' + f"{self.__class__.__name__}" + f'" and "{other.__class__.__name__}"' + ) + raise TypeError(msg) def message_to_dict(message: BaseMessage) -> dict: diff --git a/libs/core/langchain_core/messages/chat.py b/libs/core/langchain_core/messages/chat.py index 73aafd8834ee6..cb15617979f8a 100644 --- a/libs/core/langchain_core/messages/chat.py +++ b/libs/core/langchain_core/messages/chat.py @@ -62,7 +62,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore ), id=self.id, ) - elif isinstance(other, BaseMessageChunk): + if isinstance(other, BaseMessageChunk): return self.__class__( role=self.role, content=merge_content(self.content, other.content), @@ -74,5 +74,4 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore ), id=self.id, ) - else: - return super().__add__(other) + return super().__add__(other) diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 5c14ae045af3e..51831d462ccf0 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -295,25 +295,24 @@ def default_tool_parser( for raw_tool_call in raw_tool_calls: if "function" not in raw_tool_call: continue - else: - function_name = raw_tool_call["function"]["name"] - try: - function_args = json.loads(raw_tool_call["function"]["arguments"]) - parsed = tool_call( - name=function_name or "", - args=function_args or {}, + function_name = raw_tool_call["function"]["name"] + try: + function_args = json.loads(raw_tool_call["function"]["arguments"]) + parsed = tool_call( + name=function_name or "", + args=function_args or {}, + id=raw_tool_call.get("id"), + ) + tool_calls.append(parsed) + except json.JSONDecodeError: + invalid_tool_calls.append( + invalid_tool_call( + name=function_name, + args=raw_tool_call["function"]["arguments"], id=raw_tool_call.get("id"), + error=None, ) - tool_calls.append(parsed) - except json.JSONDecodeError: - invalid_tool_calls.append( - invalid_tool_call( - name=function_name, - args=raw_tool_call["function"]["arguments"], - id=raw_tool_call.get("id"), - error=None, - ) - ) + ) return tool_calls, invalid_tool_calls diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index f7d823bf80a11..e3e862f768e29 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -50,14 +50,13 @@ def _get_type(v: Any) -> str: """Get the type associated with the object for serialization purposes.""" if isinstance(v, dict) and "type" in v: return v["type"] - elif hasattr(v, "type"): + if hasattr(v, "type"): return v.type - else: - msg = ( - f"Expected either a dictionary with a 'type' key or an object " - f"with a 'type' attribute. Instead got type {type(v)}." - ) - raise TypeError(msg) + msg = ( + f"Expected either a dictionary with a 'type' key or an object " + f"with a 'type' attribute. Instead got type {type(v)}." + ) + raise TypeError(msg) AnyMessage = Annotated[ @@ -137,33 +136,32 @@ def _message_from_dict(message: dict) -> BaseMessage: _type = message["type"] if _type == "human": return HumanMessage(**message["data"]) - elif _type == "ai": + if _type == "ai": return AIMessage(**message["data"]) - elif _type == "system": + if _type == "system": return SystemMessage(**message["data"]) - elif _type == "chat": + if _type == "chat": return ChatMessage(**message["data"]) - elif _type == "function": + if _type == "function": return FunctionMessage(**message["data"]) - elif _type == "tool": + if _type == "tool": return ToolMessage(**message["data"]) - elif _type == "remove": + if _type == "remove": return RemoveMessage(**message["data"]) - elif _type == "AIMessageChunk": + if _type == "AIMessageChunk": return AIMessageChunk(**message["data"]) - elif _type == "HumanMessageChunk": + if _type == "HumanMessageChunk": return HumanMessageChunk(**message["data"]) - elif _type == "FunctionMessageChunk": + if _type == "FunctionMessageChunk": return FunctionMessageChunk(**message["data"]) - elif _type == "ToolMessageChunk": + if _type == "ToolMessageChunk": return ToolMessageChunk(**message["data"]) - elif _type == "SystemMessageChunk": + if _type == "SystemMessageChunk": return SystemMessageChunk(**message["data"]) - elif _type == "ChatMessageChunk": + if _type == "ChatMessageChunk": return ChatMessageChunk(**message["data"]) - else: - msg = f"Got unexpected message type: {_type}" - raise ValueError(msg) + msg = f"Got unexpected message type: {_type}" + raise ValueError(msg) def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]: @@ -379,8 +377,7 @@ def wrapped( if messages is not None: return func(messages, **kwargs) - else: - return RunnableLambda(partial(func, **kwargs), name=func.__name__) + return RunnableLambda(partial(func, **kwargs), name=func.__name__) wrapped.__doc__ = func.__doc__ return wrapped @@ -456,8 +453,6 @@ def filter_messages( or (exclude_ids and msg.id in exclude_ids) ): continue - else: - pass # default to inclusion when no inclusion criteria given. if ( @@ -864,7 +859,7 @@ def list_token_counter(messages: Sequence[BaseMessage]) -> int: partial_strategy="first" if allow_partial else None, end_on=end_on, ) - elif strategy == "last": + if strategy == "last": return _last_max_tokens( messages, max_tokens=max_tokens, @@ -875,9 +870,8 @@ def list_token_counter(messages: Sequence[BaseMessage]) -> int: end_on=end_on, text_splitter=text_splitter_fn, ) - else: - msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'." - raise ValueError(msg) + msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'." + raise ValueError(msg) def convert_to_openai_messages( @@ -1204,8 +1198,7 @@ def convert_to_openai_messages( if is_single: return oai_messages[0] - else: - return oai_messages + return oai_messages def _first_max_tokens( @@ -1317,8 +1310,7 @@ def _last_max_tokens( ) if swapped_system: return reversed_[:1] + reversed_[1:][::-1] - else: - return reversed_[::-1] + return reversed_[::-1] _MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = { @@ -1388,19 +1380,18 @@ def _bytes_to_b64_str(bytes_: bytes) -> str: def _get_message_openai_role(message: BaseMessage) -> str: if isinstance(message, AIMessage): return "assistant" - elif isinstance(message, HumanMessage): + if isinstance(message, HumanMessage): return "user" - elif isinstance(message, ToolMessage): + if isinstance(message, ToolMessage): return "tool" - elif isinstance(message, SystemMessage): + if isinstance(message, SystemMessage): return message.additional_kwargs.get("__openai_role__", "system") - elif isinstance(message, FunctionMessage): + if isinstance(message, FunctionMessage): return "function" - elif isinstance(message, ChatMessage): + if isinstance(message, ChatMessage): return message.role - else: - msg = f"Unknown BaseMessage type {message.__class__}." - raise ValueError(msg) # noqa: TRY004 + msg = f"Unknown BaseMessage type {message.__class__}." + raise ValueError(msg) # noqa: TRY004 def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]: diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 9d080cef300bc..33d104d035491 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -94,13 +94,12 @@ def invoke( config, run_type="parser", ) - else: - return self._call_with_config( - lambda inner_input: self.parse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) + return self._call_with_config( + lambda inner_input: self.parse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) async def ainvoke( self, @@ -117,13 +116,12 @@ async def ainvoke( config, run_type="parser", ) - else: - return await self._acall_with_config( - lambda inner_input: self.aparse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) class BaseOutputParser( @@ -198,13 +196,12 @@ def invoke( config, run_type="parser", ) - else: - return self._call_with_config( - lambda inner_input: self.parse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) + return self._call_with_config( + lambda inner_input: self.parse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) async def ainvoke( self, @@ -221,13 +218,12 @@ async def ainvoke( config, run_type="parser", ) - else: - return await self._acall_with_config( - lambda inner_input: self.aparse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 18c1257a9adf3..e426fcd500689 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -51,8 +51,9 @@ def _diff(self, prev: Optional[Any], next: Any) -> Any: def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: if issubclass(pydantic_object, pydantic.BaseModel): return pydantic_object.model_json_schema() - elif issubclass(pydantic_object, pydantic.v1.BaseModel): + if issubclass(pydantic_object, pydantic.v1.BaseModel): return pydantic_object.schema() + return None def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. @@ -104,19 +105,18 @@ def get_format_instructions(self) -> str: """ if self.pydantic_object is None: return "Return a JSON object." - else: - # Copy schema to avoid altering original Pydantic schema. - schema = dict(self._get_schema(self.pydantic_object).items()) - - # Remove extraneous fields. - reduced_schema = schema - if "title" in reduced_schema: - del reduced_schema["title"] - if "type" in reduced_schema: - del reduced_schema["type"] - # Ensure json in context is well-formed with double quotes. - schema_str = json.dumps(reduced_schema, ensure_ascii=False) - return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str) + # Copy schema to avoid altering original Pydantic schema. + schema = dict(self._get_schema(self.pydantic_object).items()) + + # Remove extraneous fields. + reduced_schema = schema + if "title" in reduced_schema: + del reduced_schema["title"] + if "type" in reduced_schema: + del reduced_schema["type"] + # Ensure json in context is well-formed with double quotes. + schema_str = json.dumps(reduced_schema, ensure_ascii=False) + return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str) @property def _type(self) -> str: diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 118090bc69bb2..e979ba6e0cc10 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -97,9 +97,8 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An except KeyError as exc: if partial: return None - else: - msg = f"Could not parse function call: {exc}" - raise OutputParserException(msg) from exc + msg = f"Could not parse function call: {exc}" + raise OutputParserException(msg) from exc try: if partial: try: @@ -107,13 +106,12 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An return parse_partial_json( function_call["arguments"], strict=self.strict ) - else: - return { - **function_call, - "arguments": parse_partial_json( - function_call["arguments"], strict=self.strict - ), - } + return { + **function_call, + "arguments": parse_partial_json( + function_call["arguments"], strict=self.strict + ), + } except json.JSONDecodeError: return None else: diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index da1d8609fe565..6c3d66966c3d3 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -239,10 +239,9 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An ) if self.return_id: return single_result - elif single_result: + if single_result: return single_result["args"] - else: - return None + return None parsed_result = [res for res in parsed_result if res["type"] == self.key_name] if not self.return_id: parsed_result = [res["args"] for res in parsed_result] @@ -298,5 +297,4 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An raise if self.first_tool_only: return pydantic_objects[0] if pydantic_objects else None - else: - return pydantic_objects + return pydantic_objects diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index a23bdc24aa1de..9bbd4e70150f1 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -26,12 +26,11 @@ def _parse_obj(self, obj: dict) -> TBaseModel: try: if issubclass(self.pydantic_object, pydantic.BaseModel): return self.pydantic_object.model_validate(obj) - elif issubclass(self.pydantic_object, pydantic.v1.BaseModel): + if issubclass(self.pydantic_object, pydantic.v1.BaseModel): return self.pydantic_object.parse_obj(obj) - else: - msg = f"Unsupported model version for PydanticOutputParser: \ + msg = f"Unsupported model version for PydanticOutputParser: \ {self.pydantic_object.__class__}" - raise OutputParserException(msg) + raise OutputParserException(msg) except (pydantic.ValidationError, pydantic.v1.ValidationError) as e: raise self._parser_exception(e, obj) from e else: # pydantic v1 diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 4fef349c498fc..4e41e5ed995d9 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -280,5 +280,4 @@ def nested_element(path: list[str], elem: ET.Element) -> Any: """ if len(path) == 0: return AddableDict({elem.tag: elem.text}) - else: - return AddableDict({path[0]: [nested_element(path[1:], elem)]}) + return AddableDict({path[0]: [nested_element(path[1:], elem)]}) diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index d40e1fd5362b4..2766d12abf3b5 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -56,11 +56,9 @@ def set_text(self) -> Self: if isinstance(block, str): text = block break - elif isinstance(block, dict) and "text" in block: + if isinstance(block, dict) and "text" in block: text = block["text"] break - else: - pass else: pass self.text = text @@ -103,7 +101,7 @@ def __add__( message=self.message + other.message, generation_info=generation_info or None, ) - elif isinstance(other, list) and all( + if isinstance(other, list) and all( isinstance(x, ChatGenerationChunk) for x in other ): generation_info = merge_dicts( @@ -114,8 +112,5 @@ def __add__( message=self.message + [chunk.message for chunk in other], generation_info=generation_info or None, ) - else: - msg = ( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - raise TypeError(msg) + msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + raise TypeError(msg) diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index ce2dd36390281..8c6148f61b0fa 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -63,8 +63,5 @@ def __add__(self, other: GenerationChunk) -> GenerationChunk: text=self.text + other.text, generation_info=generation_info or None, ) - else: - msg = ( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - raise TypeError(msg) + msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + raise TypeError(msg) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 11924e3328d68..93540d79e535f 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -517,7 +517,7 @@ def from_template( partial_variables=partial_variables, ) return cls(prompt=prompt, **kwargs) - elif isinstance(template, list): + if isinstance(template, list): if (partial_variables is not None) and len(partial_variables) > 0: msg = "Partial variables are not supported for list of templates." raise ValueError(msg) @@ -575,9 +575,8 @@ def from_template( msg = f"Invalid template: {tmpl}" raise ValueError(msg) return cls(prompt=prompt, **kwargs) - else: - msg = f"Invalid template: {template}" - raise ValueError(msg) # noqa: TRY004 + msg = f"Invalid template: {template}" + raise ValueError(msg) # noqa: TRY004 @classmethod def from_template_file( @@ -630,8 +629,7 @@ def input_variables(self) -> list[str]: List of input variable names. """ prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt] - input_variables = [iv for prompt in prompts for iv in prompt.input_variables] - return input_variables + return [iv for prompt in prompts for iv in prompt.input_variables] def format(self, **kwargs: Any) -> BaseMessage: """Format the prompt template. @@ -647,19 +645,18 @@ def format(self, **kwargs: Any) -> BaseMessage: return self._msg_class( content=text, additional_kwargs=self.additional_kwargs ) - else: - content: list = [] - for prompt in self.prompt: - inputs = {var: kwargs[var] for var in prompt.input_variables} - if isinstance(prompt, StringPromptTemplate): - formatted: Union[str, ImageURL] = prompt.format(**inputs) - content.append({"type": "text", "text": formatted}) - elif isinstance(prompt, ImagePromptTemplate): - formatted = prompt.format(**inputs) - content.append({"type": "image_url", "image_url": formatted}) - return self._msg_class( - content=content, additional_kwargs=self.additional_kwargs - ) + content: list = [] + for prompt in self.prompt: + inputs = {var: kwargs[var] for var in prompt.input_variables} + if isinstance(prompt, StringPromptTemplate): + formatted: Union[str, ImageURL] = prompt.format(**inputs) + content.append({"type": "text", "text": formatted}) + elif isinstance(prompt, ImagePromptTemplate): + formatted = prompt.format(**inputs) + content.append({"type": "image_url", "image_url": formatted}) + return self._msg_class( + content=content, additional_kwargs=self.additional_kwargs + ) async def aformat(self, **kwargs: Any) -> BaseMessage: """Async format the prompt template. @@ -675,19 +672,18 @@ async def aformat(self, **kwargs: Any) -> BaseMessage: return self._msg_class( content=text, additional_kwargs=self.additional_kwargs ) - else: - content: list = [] - for prompt in self.prompt: - inputs = {var: kwargs[var] for var in prompt.input_variables} - if isinstance(prompt, StringPromptTemplate): - formatted: Union[str, ImageURL] = await prompt.aformat(**inputs) - content.append({"type": "text", "text": formatted}) - elif isinstance(prompt, ImagePromptTemplate): - formatted = await prompt.aformat(**inputs) - content.append({"type": "image_url", "image_url": formatted}) - return self._msg_class( - content=content, additional_kwargs=self.additional_kwargs - ) + content: list = [] + for prompt in self.prompt: + inputs = {var: kwargs[var] for var in prompt.input_variables} + if isinstance(prompt, StringPromptTemplate): + formatted: Union[str, ImageURL] = await prompt.aformat(**inputs) + content.append({"type": "text", "text": formatted}) + elif isinstance(prompt, ImagePromptTemplate): + formatted = await prompt.aformat(**inputs) + content.append({"type": "image_url", "image_url": formatted}) + return self._msg_class( + content=content, additional_kwargs=self.additional_kwargs + ) def pretty_repr(self, html: bool = False) -> str: """Human-readable representation. @@ -1040,19 +1036,18 @@ def __add__(self, other: Any) -> ChatPromptTemplate: # Allow for easy combining if isinstance(other, ChatPromptTemplate): return ChatPromptTemplate(messages=self.messages + other.messages) # type: ignore[call-arg] - elif isinstance( + if isinstance( other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate) ): return ChatPromptTemplate(messages=self.messages + [other]) # type: ignore[call-arg] - elif isinstance(other, (list, tuple)): + if isinstance(other, (list, tuple)): _other = ChatPromptTemplate.from_messages(other) return ChatPromptTemplate(messages=self.messages + _other.messages) # type: ignore[call-arg] - elif isinstance(other, str): + if isinstance(other, str): prompt = HumanMessagePromptTemplate.from_template(other) return ChatPromptTemplate(messages=self.messages + [prompt]) # type: ignore[call-arg] - else: - msg = f"Unsupported operand type for +: {type(other)}" - raise NotImplementedError(msg) + msg = f"Unsupported operand type for +: {type(other)}" + raise NotImplementedError(msg) @model_validator(mode="before") @classmethod @@ -1322,8 +1317,7 @@ def __getitem__( start, stop, step = index.indices(len(self.messages)) messages = self.messages[start:stop:step] return ChatPromptTemplate.from_messages(messages) - else: - return self.messages[index] + return self.messages[index] def __len__(self) -> int: """Get the length of the chat template.""" diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 5794a91332531..06589b57cb784 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -85,11 +85,10 @@ def _get_examples(self, **kwargs: Any) -> list[dict]: """ if self.examples is not None: return self.examples - elif self.example_selector is not None: + if self.example_selector is not None: return self.example_selector.select_examples(kwargs) - else: - msg = "One of 'examples' and 'example_selector' should be provided" - raise ValueError(msg) + msg = "One of 'examples' and 'example_selector' should be provided" + raise ValueError(msg) async def _aget_examples(self, **kwargs: Any) -> list[dict]: """Async get the examples to use for formatting the prompt. @@ -105,11 +104,10 @@ async def _aget_examples(self, **kwargs: Any) -> list[dict]: """ if self.examples is not None: return self.examples - elif self.example_selector is not None: + if self.example_selector is not None: return await self.example_selector.aselect_examples(kwargs) - else: - msg = "One of 'examples' and 'example_selector' should be provided" - raise ValueError(msg) + msg = "One of 'examples' and 'example_selector' should be provided" + raise ValueError(msg) class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): @@ -391,12 +389,11 @@ def format_messages(self, **kwargs: Any) -> list[BaseMessage]: {k: e[k] for k in self.example_prompt.input_variables} for e in examples ] # Format the examples. - messages = [ + return [ message for example in examples for message in self.example_prompt.format_messages(**example) ] - return messages async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format kwargs into a list of messages. @@ -413,12 +410,11 @@ async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: {k: e[k] for k in self.example_prompt.input_variables} for e in examples ] # Format the examples. - messages = [ + return [ message for example in examples for message in await self.example_prompt.aformat_messages(**example) ] - return messages def format(self, **kwargs: Any) -> str: """Format the prompt with inputs generating a string. diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index da53d5b8c59d1..7a32146f4bdd5 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -97,18 +97,16 @@ def template_is_valid(self) -> Self: def _get_examples(self, **kwargs: Any) -> list[dict]: if self.examples is not None: return self.examples - elif self.example_selector is not None: + if self.example_selector is not None: return self.example_selector.select_examples(kwargs) - else: - raise ValueError + raise ValueError async def _aget_examples(self, **kwargs: Any) -> list[dict]: if self.examples is not None: return self.examples - elif self.example_selector is not None: + if self.example_selector is not None: return await self.example_selector.aselect_examples(kwargs) - else: - raise ValueError + raise ValueError def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs. diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index 334e2b85aee1c..cdcdc3eab503f 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -107,14 +107,13 @@ def format( if not url: msg = "Must provide url." raise ValueError(msg) - elif not isinstance(url, str): + if not isinstance(url, str): msg = "url must be a string." - raise ValueError(msg) - else: - output: ImageURL = {"url": url} - if detail: - # Don't check literal values here: let the API check them - output["detail"] = detail # type: ignore[typeddict-item] + raise ValueError(msg) # noqa: TRY004 + output: ImageURL = {"url": url} + if detail: + # Don't check literal values here: let the API check them + output["detail"] = detail # type: ignore[typeddict-item] return output async def aformat(self, **kwargs: Any) -> ImageURL: diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 37f7eda64acff..e7ec5d157ca2f 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -149,8 +149,7 @@ def __add__(self, other: Any) -> PromptTemplate: if k in partial_variables: msg = "Cannot have same variable partialed twice." raise ValueError(msg) - else: - partial_variables[k] = v + partial_variables[k] = v return PromptTemplate( template=template, input_variables=input_variables, @@ -158,12 +157,11 @@ def __add__(self, other: Any) -> PromptTemplate: template_format="f-string", validate_template=validate_template, ) - elif isinstance(other, str): + if isinstance(other, str): prompt = PromptTemplate.from_template(other) return self + prompt - else: - msg = f"Unsupported operand type for +: {type(other)}" - raise NotImplementedError(msg) + msg = f"Unsupported operand type for +: {type(other)}" + raise NotImplementedError(msg) @property def _prompt_type(self) -> str: diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 61fc1d35d4f0b..b0de417ab1ff8 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -99,8 +99,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]: # noqa for insecure warning elsewhere env = Environment() # noqa: S701 ast = env.parse(template) - variables = meta.find_undeclared_variables(ast) - return variables + return meta.find_undeclared_variables(ast) def mustache_formatter(template: str, /, **kwargs: Any) -> str: diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 1f38695760be6..4464ac93f3ff9 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -154,6 +154,5 @@ def pipe( *others[1:], name=name, ) - else: - msg = "Structured prompts need to be piped to a language model." - raise NotImplementedError(msg) + msg = "Structured prompts need to be piped to a language model." + raise NotImplementedError(msg) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index d1a5cb8cf93f4..77d34dc082b84 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -206,8 +206,7 @@ def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: default_retriever_name = default_retriever_name[:-9] default_retriever_name = default_retriever_name.lower() - ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) - return ls_params + return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index bf805c3a4f07c..075a22e8fb424 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -267,10 +267,8 @@ def get_name( if suffix: if name_[0].isupper(): return name_ + suffix.title() - else: - return name_ + "_" + suffix.lower() - else: - return name_ + return name_ + "_" + suffix.lower() + return name_ @property def InputType(self) -> type[Input]: # noqa: N802 @@ -511,10 +509,9 @@ def config_schema( if field_name in [i for i in include if i != "configurable"] }, } - model = create_model_v2( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] self.get_name("Config"), field_definitions=all_fields ) - return model def get_config_jsonschema( self, *, include: Optional[Sequence[str]] = None @@ -2027,8 +2024,7 @@ def _batch_with_config( run_manager.on_chain_error(e) if return_exceptions: return cast(list[Output], [e for _ in input]) - else: - raise + raise else: first_exception: Optional[Exception] = None for run_manager, out in zip(run_managers, output): @@ -2039,8 +2035,7 @@ def _batch_with_config( run_manager.on_chain_end(out) if return_exceptions or first_exception is None: return cast(list[Output], output) - else: - raise first_exception + raise first_exception async def _abatch_with_config( self, @@ -2103,8 +2098,7 @@ async def _abatch_with_config( ) if return_exceptions: return cast(list[Output], [e for _ in input]) - else: - raise + raise else: first_exception: Optional[Exception] = None coros: list[Awaitable[None]] = [] @@ -2117,8 +2111,7 @@ async def _abatch_with_config( await asyncio.gather(*coros) if return_exceptions or first_exception is None: return cast(list[Output], output) - else: - raise first_exception + raise first_exception def _transform_stream_with_config( self, @@ -2582,7 +2575,7 @@ def _seq_input_schema( first = steps[0] if len(steps) == 1: return first.get_input_schema(config) - elif isinstance(first, RunnableAssign): + if isinstance(first, RunnableAssign): next_input_schema = _seq_input_schema(steps[1:], config) if not issubclass(next_input_schema, RootModel): # it's a dict as expected @@ -2608,7 +2601,7 @@ def _seq_output_schema( last = steps[-1] if len(steps) == 1: return last.get_input_schema(config) - elif isinstance(last, RunnableAssign): + if isinstance(last, RunnableAssign): mapper_output_schema = last.mapper.get_output_schema(config) prev_output_schema = _seq_output_schema(steps[:-1], config) if not issubclass(prev_output_schema, RootModel): @@ -2639,11 +2632,10 @@ def _seq_output_schema( if k in last.keys }, ) - else: - field = prev_output_schema.model_fields[last.keys] - return create_model_v2( # type: ignore[call-overload] - "RunnableSequenceOutput", root=(field.annotation, field.default) - ) + field = prev_output_schema.model_fields[last.keys] + return create_model_v2( # type: ignore[call-overload] + "RunnableSequenceOutput", root=(field.annotation, field.default) + ) return last.get_output_schema(config) @@ -2948,14 +2940,13 @@ def __or__( other.last, name=self.name or other.name, ) - else: - return RunnableSequence( - self.first, - *self.middle, - self.last, - coerce_to_runnable(other), - name=self.name, - ) + return RunnableSequence( + self.first, + *self.middle, + self.last, + coerce_to_runnable(other), + name=self.name, + ) def __ror__( self, @@ -2976,14 +2967,13 @@ def __ror__( self.last, name=other.name or self.name, ) - else: - return RunnableSequence( - coerce_to_runnable(other), - self.first, - *self.middle, - self.last, - name=self.name, - ) + return RunnableSequence( + coerce_to_runnable(other), + self.first, + *self.middle, + self.last, + name=self.name, + ) def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -3178,8 +3168,7 @@ def batch( rm.on_chain_error(e) if return_exceptions: return cast(list[Output], [e for _ in inputs]) - else: - raise + raise else: first_exception: Optional[Exception] = None for run_manager, out in zip(run_managers, inputs): @@ -3190,8 +3179,7 @@ def batch( run_manager.on_chain_end(out) if return_exceptions or first_exception is None: return cast(list[Output], inputs) - else: - raise first_exception + raise first_exception async def abatch( self, @@ -3306,8 +3294,7 @@ async def abatch( await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) if return_exceptions: return cast(list[Output], [e for _ in inputs]) - else: - raise + raise else: first_exception: Optional[Exception] = None coros: list[Awaitable[None]] = [] @@ -3320,8 +3307,7 @@ async def abatch( await asyncio.gather(*coros) if return_exceptions or first_exception is None: return cast(list[Output], inputs) - else: - raise first_exception + raise first_exception def _transform( self, @@ -3757,8 +3743,7 @@ async def _ainvoke_step( return await asyncio.create_task( # type: ignore step.ainvoke(input, child_config), context=context ) - else: - return await asyncio.create_task(step.ainvoke(input, child_config)) + return await asyncio.create_task(step.ainvoke(input, child_config)) # gather results from all steps try: @@ -4067,10 +4052,9 @@ def InputType(self) -> Any: first_param = next(iter(params.values()), None) if first_param and first_param.annotation != inspect.Parameter.empty: return getattr(first_param.annotation, "__args__", (Any,))[0] - else: - return Any except ValueError: return Any + return Any def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -4143,12 +4127,10 @@ def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableGenerator): if hasattr(self, "_transform") and hasattr(other, "_transform"): return self._transform == other._transform - elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): + if hasattr(self, "_atransform") and hasattr(other, "_atransform"): return self._atransform == other._atransform - else: - return False - else: return False + return False def __repr__(self) -> str: return f"RunnableGenerator({self.name})" @@ -4359,10 +4341,9 @@ def InputType(self) -> Any: first_param = next(iter(params.values()), None) if first_param and first_param.annotation != inspect.Parameter.empty: return first_param.annotation - else: - return Any except ValueError: return Any + return Any def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -4387,16 +4368,15 @@ def get_input_schema( fields = {item[1:-1]: (Any, ...) for item in items} # It's a dict, lol return create_model_v2(self.get_name("Input"), field_definitions=fields) - else: - module = getattr(func, "__module__", None) - return create_model_v2( - self.get_name("Input"), - root=list[Any], - # To create the schema, we need to provide the module - # where the underlying function is defined. - # This allows pydantic to resolve type annotations appropriately. - module_name=module, - ) + module = getattr(func, "__module__", None) + return create_model_v2( + self.get_name("Input"), + root=list[Any], + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + module_name=module, + ) if self.InputType != Any: return super().get_input_schema(config) @@ -4428,10 +4408,9 @@ def OutputType(self) -> Any: ): return getattr(sig.return_annotation, "__args__", (Any,))[0] return sig.return_annotation - else: - return Any except ValueError: return Any + return Any def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -4518,12 +4497,10 @@ def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableLambda): if hasattr(self, "func") and hasattr(other, "func"): return self.func == other.func - elif hasattr(self, "afunc") and hasattr(other, "afunc"): + if hasattr(self, "afunc") and hasattr(other, "afunc"): return self.afunc == other.afunc - else: - return False - else: return False + return False def __repr__(self) -> str: """A string representation of this Runnable.""" @@ -4716,12 +4693,8 @@ def invoke( self._config(config, self.func), **kwargs, ) - else: - msg = ( - "Cannot invoke a coroutine function synchronously." - "Use `ainvoke` instead." - ) - raise TypeError(msg) + msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead." + raise TypeError(msg) async def ainvoke( self, @@ -5755,7 +5728,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: ) return wrapper - elif config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + if config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: idx = list(inspect.signature(attr).parameters).index("config") @wraps(attr) @@ -5764,14 +5737,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: argsl = list(args) argsl[idx] = merge_configs(self.config, argsl[idx]) return attr(*argsl, **kwargs) - else: - return attr( - *args, - config=merge_configs( - self.config, kwargs.pop("config", None) - ), - **kwargs, - ) + return attr( + *args, + config=merge_configs(self.config, kwargs.pop("config", None)), + **kwargs, + ) return wrapper @@ -5826,18 +5796,17 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: """ if isinstance(thing, Runnable): return thing - elif is_async_generator(thing) or inspect.isgeneratorfunction(thing): + if is_async_generator(thing) or inspect.isgeneratorfunction(thing): return RunnableGenerator(thing) - elif callable(thing): + if callable(thing): return RunnableLambda(cast(Callable[[Input], Output], thing)) - elif isinstance(thing, dict): + if isinstance(thing, dict): return cast(Runnable[Input, Output], RunnableParallel(thing)) - else: - msg = ( - f"Expected a Runnable, callable or dict." - f"Instead got an unsupported type: {type(thing)}" - ) - raise TypeError(msg) + msg = ( + f"Expected a Runnable, callable or dict." + f"Instead got an unsupported type: {type(thing)}" + ) + raise TypeError(msg) @overload diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index b59d0239fb1f3..c86cefc36138b 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -291,8 +291,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper - else: - return attr + return attr class RunnableConfigurableFields(DynamicRunnable[Input, Output]): @@ -447,8 +446,7 @@ def _prepare( self.default.__class__(**{**init_params, **configurable}), config, ) - else: - return (self.default, config) + return (self.default, config) RunnableConfigurableFields.model_rebuild() @@ -626,15 +624,13 @@ def _prepare( # return the chosen alternative if which == self.default_key: return (self.default, config) - elif which in self.alternatives: + if which in self.alternatives: alt = self.alternatives[which] if isinstance(alt, Runnable): return (alt, config) - else: - return (alt(), config) - else: - msg = f"Unknown alternative: {which}" - raise ValueError(msg) + return (alt(), config) + msg = f"Unknown alternative: {which}" + raise ValueError(msg) def _strremoveprefix(s: str, prefix: str) -> str: @@ -703,12 +699,11 @@ def make_options_spec( default=spec.default, is_shared=spec.is_shared, ) - else: - return ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description or description, - annotation=Sequence[enum], # type: ignore[valid-type] - default=spec.default, - is_shared=spec.is_shared, - ) + return ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description or description, + annotation=Sequence[enum], # type: ignore[valid-type] + default=spec.default, + is_shared=spec.is_shared, + ) diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index f932ce3589e00..b05f6560107c6 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -650,7 +650,6 @@ def _is_runnable_type(type_: Any) -> bool: origin = getattr(type_, "__origin__", None) if inspect.isclass(origin): return issubclass(origin, Runnable) - elif origin is typing.Union: + if origin is typing.Union: return all(_is_runnable_type(t) for t in type_.__args__) - else: - return False + return False diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 84b86994dbf0f..25bd3908ed2f9 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -189,10 +189,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str: if not is_uuid(id): return id - elif isinstance(data, Runnable): - data_str = data.get_name() - else: - data_str = data.__name__ + data_str = data.get_name() if isinstance(data, Runnable) else data.__name__ return data_str if not data_str.startswith("Runnable") else data_str[8:] @@ -439,8 +436,7 @@ def _get_node_id(node_id: str) -> str: label = unique_labels[node_id] if is_uuid(node_id): return label - else: - return node_id + return node_id return Graph( nodes={ diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index c9b3025bc857d..2173da50637c5 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -341,9 +341,8 @@ def _render_mermaid_using_api( file.write(response.content) return img_bytes - else: - msg = ( - f"Failed to render the graph using the Mermaid.INK API. " - f"Status code: {response.status_code}." - ) - raise ValueError(msg) + msg = ( + f"Failed to render the graph using the Mermaid.INK API. " + f"Status code: {response.status_code}." + ) + raise ValueError(msg) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index d2040a3f3fb6f..6485967d69d11 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -399,8 +399,7 @@ def get_input_schema( @property @override def OutputType(self) -> type[Output]: - output_type = self._history_chain.OutputType - return output_type + return self._history_chain.OutputType def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -461,10 +460,10 @@ def _get_input_messages( return [HumanMessage(content=input_val)] # If value is a single message, convert to a list - elif isinstance(input_val, BaseMessage): + if isinstance(input_val, BaseMessage): return [input_val] # If value is a list or tuple... - elif isinstance(input_val, (list, tuple)): + if isinstance(input_val, (list, tuple)): # Handle empty case if len(input_val) == 0: return list(input_val) @@ -476,12 +475,11 @@ def _get_input_messages( raise ValueError(msg) return input_val[0] return list(input_val) - else: - msg = ( - f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " - f"Got {input_val}." - ) - raise ValueError(msg) # noqa: TRY004 + msg = ( + f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " + f"Got {input_val}." + ) + raise ValueError(msg) # noqa: TRY004 def _get_output_messages( self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] @@ -508,16 +506,15 @@ def _get_output_messages( return [AIMessage(content=output_val)] # If value is a single message, convert to a list - elif isinstance(output_val, BaseMessage): + if isinstance(output_val, BaseMessage): return [output_val] - elif isinstance(output_val, (list, tuple)): + if isinstance(output_val, (list, tuple)): return list(output_val) - else: - msg = ( - f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " - f"Got {output_val}." - ) - raise ValueError(msg) # noqa: TRY004 + msg = ( + f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " + f"Got {output_val}." + ) + raise ValueError(msg) # noqa: TRY004 def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index b0da175ae3291..a46a87382589e 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -442,7 +442,7 @@ def get_output_schema( return create_model_v2( # type: ignore[call-overload] "RunnableAssignOutput", field_definitions=fields ) - elif not issubclass(map_output_schema, RootModel): + if not issubclass(map_output_schema, RootModel): # ie. only map output is a dict # ie. input type is either unknown or inferred incorrectly return map_output_schema @@ -712,12 +712,10 @@ def _pick(self, input: dict[str, Any]) -> Any: if isinstance(self.keys, str): return input.get(self.keys) - else: - picked = {k: input.get(k) for k in self.keys if k in input} - if picked: - return AddableDict(picked) - else: - return None + picked = {k: input.get(k) for k in self.keys if k in input} + if picked: + return AddableDict(picked) + return None def _invoke( self, diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 75063c7db58bb..5cbfc39944e1b 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -433,11 +433,10 @@ def get_function_nonlocals(func: Callable) -> list[Any]: for part in kk.split(".")[1:]: if vv is None: break - else: - try: - vv = getattr(vv, part) - except AttributeError: - break + try: + vv = getattr(vv, part) + except AttributeError: + break else: values.append(vv) except (SyntaxError, TypeError, OSError, SystemError): diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f0833cdc24a14..99b29178c4dfb 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -471,8 +471,7 @@ def get_input_schema( """ if self.args_schema is not None: return self.args_schema - else: - return create_schema_from_function(self.name, self._run) + return create_schema_from_function(self.name, self._run) def invoke( self, @@ -511,56 +510,52 @@ def _parse_input( else: input_args.parse_obj({key_: tool_input}) return tool_input - else: - if input_args is not None: - if issubclass(input_args, BaseModel): - for k, v in get_all_basemodel_annotations(input_args).items(): - if ( - _is_injected_arg_type(v, injected_type=InjectedToolCallId) - and k not in tool_input - ): - if tool_call_id is None: - msg = ( - "When tool includes an InjectedToolCallId " - "argument, tool must always be invoked with a full " - "model ToolCall of the form: {'args': {...}, " - "'name': '...', 'type': 'tool_call', " - "'tool_call_id': '...'}" - ) - raise ValueError(msg) - tool_input[k] = tool_call_id - result = input_args.model_validate(tool_input) - result_dict = result.model_dump() - elif issubclass(input_args, BaseModelV1): - for k, v in get_all_basemodel_annotations(input_args).items(): - if ( - _is_injected_arg_type(v, injected_type=InjectedToolCallId) - and k not in tool_input - ): - if tool_call_id is None: - msg = ( - "When tool includes an InjectedToolCallId " - "argument, tool must always be invoked with a full " - "model ToolCall of the form: {'args': {...}, " - "'name': '...', 'type': 'tool_call', " - "'tool_call_id': '...'}" - ) - raise ValueError(msg) - tool_input[k] = tool_call_id - result = input_args.parse_obj(tool_input) - result_dict = result.dict() - else: - msg = ( - "args_schema must be a Pydantic BaseModel, " - f"got {self.args_schema}" - ) - raise NotImplementedError(msg) - return { - k: getattr(result, k) - for k, v in result_dict.items() - if k in tool_input - } - return tool_input + if input_args is not None: + if issubclass(input_args, BaseModel): + for k, v in get_all_basemodel_annotations(input_args).items(): + if ( + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[k] = tool_call_id + result = input_args.model_validate(tool_input) + result_dict = result.model_dump() + elif issubclass(input_args, BaseModelV1): + for k, v in get_all_basemodel_annotations(input_args).items(): + if ( + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[k] = tool_call_id + result = input_args.parse_obj(tool_input) + result_dict = result.dict() + else: + msg = ( + f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" + ) + raise NotImplementedError(msg) + return { + k: getattr(result, k) for k, v in result_dict.items() if k in tool_input + } + return tool_input @model_validator(mode="before") @classmethod @@ -613,8 +608,7 @@ def _to_args_and_kwargs( # pass as a positional argument. if isinstance(tool_input, str): return (tool_input,), {} - else: - return (), tool_input + return (), tool_input def run( self, @@ -957,10 +951,9 @@ def _is_message_content_block(obj: Any) -> bool: """Check for OpenAI or Anthropic format tool message content blocks.""" if isinstance(obj, str): return True - elif isinstance(obj, dict): + if isinstance(obj, dict): return obj.get("type", None) in ("text", "image_url", "image", "json") - else: - return False + return False def _stringify(content: Any) -> str: @@ -1091,17 +1084,15 @@ def _replace_type_vars( if isinstance(type_, TypeVar): if type_ in generic_map: return generic_map[type_] - elif default_to_bound: + if default_to_bound: return type_.__bound__ or Any - else: - return type_ - elif (origin := get_origin(type_)) and (args := get_args(type_)): + return type_ + if (origin := get_origin(type_)) and (args := get_args(type_)): new_args = tuple( _replace_type_vars(arg, generic_map, default_to_bound) for arg in args ) return _py_38_safe_origin(origin)[new_args] # type: ignore[index] - else: - return type_ + return type_ class BaseToolkit(BaseModel, ABC): diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index bb8b85f5558cc..439fcf046b236 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -295,14 +295,14 @@ def invoke_wrapper( msg = "Name must be a string for tool constructor" raise ValueError(msg) return _create_tool_factory(name_or_callable)(runnable) - elif name_or_callable is not None: + if name_or_callable is not None: if callable(name_or_callable) and hasattr(name_or_callable, "__name__"): # Used as a decorator without parameters # @tool # def my_tool(): # pass return _create_tool_factory(name_or_callable.__name__)(name_or_callable) - elif isinstance(name_or_callable, str): + if isinstance(name_or_callable, str): # Used with a new name for the tool # @tool("search") # def my_tool(): @@ -314,24 +314,23 @@ def invoke_wrapper( # def my_tool(): # pass return _create_tool_factory(name_or_callable) - else: - msg = ( - f"The first argument must be a string or a callable with a __name__ " - f"for tool decorator. Got {type(name_or_callable)}" - ) - raise ValueError(msg) - else: - # Tool is used as a decorator with parameters specified - # @tool(parse_docstring=True) - # def my_tool(): - # pass - def _partial(func: Union[Callable, Runnable]) -> BaseTool: - """Partial function that takes a callable and returns a tool.""" - name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ - tool_factory = _create_tool_factory(name_) - return tool_factory(func) + msg = ( + f"The first argument must be a string or a callable with a __name__ " + f"for tool decorator. Got {type(name_or_callable)}" + ) + raise ValueError(msg) + + # Tool is used as a decorator with parameters specified + # @tool(parse_docstring=True) + # def my_tool(): + # pass + def _partial(func: Union[Callable, Runnable]) -> BaseTool: + """Partial function that takes a callable and returns a tool.""" + name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ + tool_factory = _create_tool_factory(name_) + return tool_factory(func) - return _partial + return _partial def _get_description_from_runnable(runnable: Runnable) -> str: @@ -393,31 +392,30 @@ def convert_runnable_to_tool( coroutine=runnable.ainvoke, description=description, ) - else: - async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: - return await runnable.ainvoke(kwargs, config={"callbacks": callbacks}) - - def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any: - return runnable.invoke(kwargs, config={"callbacks": callbacks}) - - if ( - arg_types is None - and schema.get("type") == "object" - and schema.get("properties") - ): - args_schema = runnable.input_schema - else: - args_schema = _get_schema_from_runnable_and_arg_types( - runnable, name, arg_types=arg_types - ) + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await runnable.ainvoke(kwargs, config={"callbacks": callbacks}) - return StructuredTool.from_function( - name=name, - func=invoke_wrapper, - coroutine=ainvoke_wrapper, - description=description, - args_schema=args_schema, + def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any: + return runnable.invoke(kwargs, config={"callbacks": callbacks}) + + if ( + arg_types is None + and schema.get("type") == "object" + and schema.get("properties") + ): + args_schema = runnable.input_schema + else: + args_schema = _get_schema_from_runnable_and_arg_types( + runnable, name, arg_types=arg_types ) + + return StructuredTool.from_function( + name=name, + func=invoke_wrapper, + coroutine=ainvoke_wrapper, + description=description, + args_schema=args_schema, + ) diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index f3ae965f6025c..5b9fea090adeb 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -180,11 +180,10 @@ def on_retry( Returns: The run. """ - llm_run = self._llm_run_with_retry_event( + return self._llm_run_with_retry_event( retry_state=retry_state, run_id=run_id, ) - return llm_run def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for an LLM run. diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index d3544df04e3dc..6a706188f489b 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -333,25 +333,23 @@ def _get_chain_inputs(self, inputs: Any) -> Any: """Get the inputs for a chain run.""" if self._schema_format in ("original", "original+chat"): return inputs if isinstance(inputs, dict) else {"input": inputs} - elif self._schema_format == "streaming_events": + if self._schema_format == "streaming_events": return { "input": inputs, } - else: - msg = f"Invalid format: {self._schema_format}" - raise ValueError(msg) + msg = f"Invalid format: {self._schema_format}" + raise ValueError(msg) def _get_chain_outputs(self, outputs: Any) -> Any: """Get the outputs for a chain run.""" if self._schema_format in ("original", "original+chat"): return outputs if isinstance(outputs, dict) else {"output": outputs} - elif self._schema_format == "streaming_events": + if self._schema_format == "streaming_events": return { "output": outputs, } - else: - msg = f"Invalid format: {self._schema_format}" - raise ValueError(msg) + msg = f"Invalid format: {self._schema_format}" + raise ValueError(msg) def _complete_chain_run( self, diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index b7f3db8595d61..fe3c83c4a62fe 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -77,7 +77,7 @@ def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> s if serialized is not None: if "name" in serialized: return serialized["name"] - elif "id" in serialized: + if "id" in serialized: return serialized["id"][-1] return "Unnamed" diff --git a/libs/core/langchain_core/tracers/stdout.py b/libs/core/langchain_core/tracers/stdout.py index 3643724f9e2b4..98364da3bb949 100644 --- a/libs/core/langchain_core/tracers/stdout.py +++ b/libs/core/langchain_core/tracers/stdout.py @@ -84,13 +84,12 @@ def get_breadcrumbs(self, run: Run) -> str: A string with the breadcrumbs of the run. """ parents = self.get_parents(run)[::-1] - string = " > ".join( + return " > ".join( f"{parent.run_type}:{parent.name}" if i != len(parents) - 1 else f"{parent.run_type}:{parent.name}" for i, parent in enumerate(parents + [run]) ) - return string # logging methods def _on_chain_start(self, run: Run) -> None: diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index be8f829387454..18be54d62ded7 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -83,7 +83,7 @@ def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list] for other in others: if other is None: continue - elif merged is None: + if merged is None: merged = other.copy() else: for e in other: @@ -126,23 +126,22 @@ def merge_obj(left: Any, right: Any) -> Any: """ if left is None or right is None: return left if left is not None else right - elif type(left) is not type(right): + if type(left) is not type(right): msg = ( f"left and right are of different types. Left type: {type(left)}. Right " f"type: {type(right)}." ) raise TypeError(msg) - elif isinstance(left, str): + if isinstance(left, str): return left + right - elif isinstance(left, dict): + if isinstance(left, dict): return merge_dicts(left, right) - elif isinstance(left, list): + if isinstance(left, list): return merge_lists(left, right) - elif left == right: + if left == right: return left - else: - msg = ( - f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or " - f"list, or else be two equal objects." - ) - raise ValueError(msg) + msg = ( + f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or " + f"list, or else be two equal objects." + ) + raise ValueError(msg) diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index ac018b0208a70..88b68389d8d8e 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -70,12 +70,11 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: """ if env_key in os.environ and os.environ[env_key]: return os.environ[env_key] - elif default is not None: + if default is not None: return default - else: - msg = ( - f"Did not find {key}, please add an environment variable" - f" `{env_key}` which contains it, or pass" - f" `{key}` as a named parameter." - ) - raise ValueError(msg) + msg = ( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + raise ValueError(msg) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index cdcd3dc864964..ff312cf74f1a1 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -222,9 +222,9 @@ def _convert_any_typed_dicts_to_pydantic( if type_ in visited: return visited[type_] - elif depth >= _MAX_TYPED_DICT_RECURSION: + if depth >= _MAX_TYPED_DICT_RECURSION: return type_ - elif is_typeddict(type_): + if is_typeddict(type_): typed_dict = type_ docstring = inspect.getdoc(typed_dict) annotations_ = typed_dict.__annotations__ @@ -248,7 +248,7 @@ def _convert_any_typed_dicts_to_pydantic( f"type {type(field_desc)}." ) raise ValueError(msg) - elif arg_desc := arg_descriptions.get(arg): + if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc else: pass @@ -265,15 +265,14 @@ def _convert_any_typed_dicts_to_pydantic( model.__doc__ = description visited[typed_dict] = model return model - elif (origin := get_origin(type_)) and (type_args := get_args(type_)): + if (origin := get_origin(type_)) and (type_args := get_args(type_)): subscriptable_origin = _py_38_safe_origin(origin) type_args = tuple( _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) for arg in type_args # type: ignore[index] ) return subscriptable_origin[type_args] # type: ignore[index] - else: - return type_ + return type_ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: @@ -292,23 +291,22 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: return _convert_pydantic_to_openai_function( tool.tool_call_schema, name=tool.name, description=tool.description ) - else: - return { - "name": tool.name, - "description": tool.description, - "parameters": { - # This is a hack to get around the fact that some tools - # do not expose an args_schema, and expect an argument - # which is a string. - # And Open AI does not support an array type for the - # parameters. - "properties": { - "__arg1": {"title": "__arg1", "type": "string"}, - }, - "required": ["__arg1"], - "type": "object", + return { + "name": tool.name, + "description": tool.description, + "parameters": { + # This is a hack to get around the fact that some tools + # do not expose an args_schema, and expect an argument + # which is a string. + # And Open AI does not support an array type for the + # parameters. + "properties": { + "__arg1": {"title": "__arg1", "type": "string"}, }, - } + "required": ["__arg1"], + "type": "object", + }, + } format_tool_to_openai_function = deprecated( @@ -634,7 +632,7 @@ def _parse_google_docstring( if block.startswith("Args:"): args_block = block break - elif block.startswith(("Returns:", "Example:")): + if block.startswith(("Returns:", "Example:")): # Don't break in case Args come after past_descriptors = True elif not past_descriptors: diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py index 05e9bbcfcedc3..afa3bf758c722 100644 --- a/libs/core/langchain_core/utils/input.py +++ b/libs/core/langchain_core/utils/input.py @@ -26,8 +26,7 @@ def get_color_mapping( colors = list(_TEXT_COLOR_MAPPING.keys()) if excluded_colors is not None: colors = [c for c in colors if c not in excluded_colors] - color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} - return color_mapping + return {item: colors[i % len(colors)] for i, item in enumerate(items)} def get_colored_text(text: str, color: str) -> str: diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 8aedfaf339b71..7eb7ec9e7602a 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -26,15 +26,13 @@ def _custom_parser(multiline_string: str) -> str: if isinstance(multiline_string, (bytes, bytearray)): multiline_string = multiline_string.decode() - multiline_string = re.sub( + return re.sub( r'("action_input"\:\s*")(.*?)(")', _replace_new_line, multiline_string, flags=re.DOTALL, ) - return multiline_string - # Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py # MIT License diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index 38fab589909b9..0d525ab9c2e93 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -56,13 +56,12 @@ def _dereference_refs_helper( else: obj_out[k] = v return obj_out - elif isinstance(obj, list): + if isinstance(obj, list): return [ _dereference_refs_helper(el, full_schema, skip_keys, processed_refs) for el in obj ] - else: - return obj + return obj def _infer_skip_keys( diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index ee2ed8f2528f8..b1ca78c3c510c 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -81,8 +81,7 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: # Then the next tag could be a standalone # Otherwise it can't be return padding.isspace() or padding == "" - else: - return False + return False def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: @@ -104,8 +103,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: return on_newline[0].isspace() or not on_newline[0] # If we're a tag can't be a standalone - else: - return False + return False def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]: diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 6c4b4cb8c2fd7..418d7da120d50 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -85,7 +85,7 @@ def is_pydantic_v1_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" if PYDANTIC_MAJOR_VERSION == 1: return True - elif PYDANTIC_MAJOR_VERSION == 2: + if PYDANTIC_MAJOR_VERSION == 2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(cls, BaseModelV1): @@ -331,7 +331,7 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - elif PYDANTIC_MAJOR_VERSION == 2: + if PYDANTIC_MAJOR_VERSION == 2: from pydantic.v1 import BaseModel as BaseModelV1 if issubclass(model, BaseModelV1): @@ -342,17 +342,15 @@ def _create_subset_model( descriptions=descriptions, fn_description=fn_description, ) - else: - return _create_subset_model_v2( - name, - model, - field_names, - descriptions=descriptions, - fn_description=fn_description, - ) - else: - msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" - raise NotImplementedError(msg) + return _create_subset_model_v2( + name, + model, + field_names, + descriptions=descriptions, + fn_description=fn_description, + ) + msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}" + raise NotImplementedError(msg) if PYDANTIC_MAJOR_VERSION == 2: @@ -384,11 +382,10 @@ def get_fields( if hasattr(model, "model_fields"): return model.model_fields # type: ignore - elif hasattr(model, "__fields__"): + if hasattr(model, "__fields__"): return model.__fields__ # type: ignore - else: - msg = f"Expected a Pydantic model. Got {type(model)}" - raise TypeError(msg) + msg = f"Expected a Pydantic model. Got {type(model)}" + raise TypeError(msg) elif PYDANTIC_MAJOR_VERSION == 1: from pydantic import BaseModel as BaseModelV1_ diff --git a/libs/core/langchain_core/utils/strings.py b/libs/core/langchain_core/utils/strings.py index e7a79761f9a01..4be7697dfad70 100644 --- a/libs/core/langchain_core/utils/strings.py +++ b/libs/core/langchain_core/utils/strings.py @@ -12,12 +12,11 @@ def stringify_value(val: Any) -> str: """ if isinstance(val, str): return val - elif isinstance(val, dict): + if isinstance(val, dict): return "\n" + stringify_dict(val) - elif isinstance(val, list): + if isinstance(val, list): return "\n".join(stringify_value(v) for v in val) - else: - return str(val) + return str(val) def stringify_dict(data: dict) -> str: diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 6cfbc9ceaf8bb..640745ff117b1 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -391,16 +391,14 @@ def get_from_env_fn() -> Optional[str]: if isinstance(default, (str, type(None))): return default - else: - if error_message: - raise ValueError(error_message) - else: - msg = ( - f"Did not find {key}, please add an environment variable" - f" `{key}` which contains it, or pass" - f" `{key}` as a named parameter." - ) - raise ValueError(msg) + if error_message: + raise ValueError(error_message) + msg = ( + f"Did not find {key}, please add an environment variable" + f" `{key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + raise ValueError(msg) return get_from_env_fn @@ -453,17 +451,15 @@ def get_secret_from_env() -> Optional[SecretStr]: return SecretStr(os.environ[key]) if isinstance(default, str): return SecretStr(default) - elif default is None: + if default is None: return None - else: - if error_message: - raise ValueError(error_message) - else: - msg = ( - f"Did not find {key}, please add an environment variable" - f" `{key}` which contains it, or pass" - f" `{key}` as a named parameter." - ) - raise ValueError(msg) + if error_message: + raise ValueError(error_message) + msg = ( + f"Did not find {key}, please add an environment variable" + f" `{key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + raise ValueError(msg) return get_secret_from_env diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index b154a14b98191..617912b762d82 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -340,20 +340,19 @@ def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]: """ if search_type == "similarity": return self.similarity_search(query, **kwargs) - elif search_type == "similarity_score_threshold": + if search_type == "similarity_score_threshold": docs_and_similarities = self.similarity_search_with_relevance_scores( query, **kwargs ) return [doc for doc, _ in docs_and_similarities] - elif search_type == "mmr": + if search_type == "mmr": return self.max_marginal_relevance_search(query, **kwargs) - else: - msg = ( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity', 'similarity_score_threshold'" - " or 'mmr'." - ) - raise ValueError(msg) + msg = ( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity', 'similarity_score_threshold'" + " or 'mmr'." + ) + raise ValueError(msg) async def asearch( self, query: str, search_type: str, **kwargs: Any @@ -375,19 +374,18 @@ async def asearch( """ if search_type == "similarity": return await self.asimilarity_search(query, **kwargs) - elif search_type == "similarity_score_threshold": + if search_type == "similarity_score_threshold": docs_and_similarities = await self.asimilarity_search_with_relevance_scores( query, **kwargs ) return [doc for doc, _ in docs_and_similarities] - elif search_type == "mmr": + if search_type == "mmr": return await self.amax_marginal_relevance_search(query, **kwargs) - else: - msg = ( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'." - ) - raise ValueError(msg) + msg = ( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'." + ) + raise ValueError(msg) @abstractmethod def similarity_search( diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index ab32c7cdacbe9..3cb6f310895d7 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -381,23 +381,21 @@ def similarity_search_with_score( **kwargs: Any, ) -> list[tuple[Document, float]]: embedding = self.embedding.embed_query(query) - docs = self.similarity_search_with_score_by_vector( + return self.similarity_search_with_score_by_vector( embedding, k, **kwargs, ) - return docs async def asimilarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any ) -> list[tuple[Document, float]]: embedding = await self.embedding.aembed_query(query) - docs = self.similarity_search_with_score_by_vector( + return self.similarity_search_with_score_by_vector( embedding, k, **kwargs, ) - return docs def similarity_search_by_vector( self, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index f9b731c730163..3e18e17f1089d 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,7 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] +select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RET", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] ignore = [ "COM812", "UP007", "S110", "S112",] [tool.coverage.run] diff --git a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py index db3a88160dbbd..63f406d1b27fb 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py +++ b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py @@ -17,12 +17,11 @@ def selector() -> LengthBasedExampleSelector: """Get length based selector to use in tests.""" prompts = PromptTemplate(input_variables=["question"], template="{question}") - selector = LengthBasedExampleSelector( + return LengthBasedExampleSelector( examples=EXAMPLES, example_prompt=prompts, max_length=30, ) - return selector def test_selector_valid(selector: LengthBasedExampleSelector) -> None: diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 921ff68a0fed8..ff7dc8e8bdc37 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -18,9 +18,8 @@ def _fake_runnable( ) -> Union[BaseModel, dict]: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) - else: - params = cast(dict, schema)["parameters"] - return {k: 1 if k != "value" else value for k, v in params.items()} + params = cast(dict, schema)["parameters"] + return {k: 1 if k != "value" else value for k, v in params.items()} class FakeStructuredChatModel(FakeListChatModel): diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 58e5749a0d557..38c65fecdd848 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -219,8 +219,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: def conditional_str_parser(input: str) -> Runnable: if input == "a": return str_parser - else: - return xml_parser + return xml_parser sequence: Runnable = ( prompt diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 1bd21b9d5605a..37a5c35e460f2 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -2845,11 +2845,10 @@ async def test_higher_order_lambda_runnable( def router(input: dict[str, Any]) -> Runnable: if input["key"] == "math": return itemgetter("input") | math_chain - elif input["key"] == "english": + if input["key"] == "english": return itemgetter("input") | english_chain - else: - msg = f"Unknown key: {input['key']}" - raise ValueError(msg) + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) chain: Runnable = input_map | router assert dumps(chain, pretty=True) == snapshot @@ -2901,11 +2900,10 @@ def router(input: dict[str, Any]) -> Runnable: async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": return itemgetter("input") | math_chain - elif input["key"] == "english": + if input["key"] == "english": return itemgetter("input") | english_chain - else: - msg = f"Unknown key: {input['key']}" - raise ValueError(msg) + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) achain: Runnable = input_map | arouter math_spy = mocker.spy(math_chain.__class__, "ainvoke") @@ -3737,8 +3735,7 @@ def test_recursive_lambda() -> None: def _simple_recursion(x: int) -> Union[int, Runnable]: if x < 10: return RunnableLambda(lambda *args: _simple_recursion(x + 1)) - else: - return x + return x runnable = RunnableLambda(_simple_recursion) assert runnable.invoke(5) == 10 @@ -3752,11 +3749,10 @@ def _lambda(x: int) -> Union[int, Runnable]: if x == 1: msg = "x is 1" raise ValueError(msg) - elif x == 2: + if x == 2: msg = "x is 2" raise RuntimeError(msg) - else: - return x + return x _lambda_mock = mocker.Mock(side_effect=_lambda) runnable = RunnableLambda(_lambda_mock) @@ -3817,11 +3813,10 @@ def _lambda(x: int) -> Union[int, Runnable]: if x == 1: msg = "x is 1" raise ValueError(msg) - elif x == 2: + if x == 2: msg = "x is 2" raise RuntimeError(msg) - else: - return x + return x _lambda_mock = mocker.Mock(side_effect=_lambda) runnable = RunnableLambda(_lambda_mock) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 59c5e765e2361..8cf9f96bc9e4b 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -545,8 +545,7 @@ async def test_astream_events_from_model() -> None: def i_dont_stream(input: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): return model.invoke(input) - else: - return model.invoke(input, config) + return model.invoke(input, config) events = await _collect_events(i_dont_stream.astream_events("hello", version="v1")) _assert_events_equal_allow_superset_metadata( @@ -670,8 +669,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): return await model.ainvoke(input) - else: - return await model.ainvoke(input, config) + return await model.ainvoke(input, config) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1")) _assert_events_equal_allow_superset_metadata( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 698a4c4ddab85..dff30ee125f92 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -610,8 +610,7 @@ async def test_astream_with_model_in_chain() -> None: def i_dont_stream(input: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): return model.invoke(input) - else: - return model.invoke(input, config) + return model.invoke(input, config) events = await _collect_events(i_dont_stream.astream_events("hello", version="v2")) _assert_events_equal_allow_superset_metadata( @@ -719,8 +718,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): return await model.ainvoke(input) - else: - return await model.ainvoke(input, config) + return await model.ainvoke(input, config) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2")) _assert_events_equal_allow_superset_metadata( diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 3409d04f23401..b96bc936e856a 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -354,23 +354,22 @@ def parent(a: int) -> int: parent_id_map[n] = matching_post.get("parent_run_id") i += len(name) continue - else: - assert posts[i]["name"] == name - dotted_order = posts[i]["dotted_order"] - if prev_dotted_order is not None and not str( - expected_parents[name] - ).startswith("RunnableParallel"): - assert dotted_order > prev_dotted_order, ( - f"{name} not after {name_order[i - 1]}" - ) - prev_dotted_order = dotted_order - if name in dotted_order_map: - msg = f"Duplicate name {name}" - raise ValueError(msg) - dotted_order_map[name] = dotted_order - id_map[name] = posts[i]["id"] - parent_id_map[name] = posts[i].get("parent_run_id") - i += 1 + assert posts[i]["name"] == name + dotted_order = posts[i]["dotted_order"] + if prev_dotted_order is not None and not str(expected_parents[name]).startswith( + "RunnableParallel" + ): + assert dotted_order > prev_dotted_order, ( + f"{name} not after {name_order[i - 1]}" + ) + prev_dotted_order = dotted_order + if name in dotted_order_map: + msg = f"Duplicate name {name}" + raise ValueError(msg) + dotted_order_map[name] = dotted_order + id_map[name] = posts[i]["id"] + parent_id_map[name] = posts[i].get("parent_run_id") + i += 1 # Now check the dotted orders for name, parent_ in expected_parents.items(): diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index 80d7a9297492c..aaa34a662f2c8 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -599,8 +599,7 @@ def test_tracer_nested_runs_on_error() -> None: def _get_mock_client() -> Client: mock_session = MagicMock() - client = Client(session=mock_session, api_key="test") - return client + return Client(session=mock_session, api_key="test") def test_traceable_to_tracing() -> None: