Skip to content

Commit

Permalink
core: Add ruff rules RET
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 25, 2025
1 parent dbb6b7b commit 07ecfd0
Show file tree
Hide file tree
Showing 69 changed files with 603 additions and 784 deletions.
3 changes: 1 addition & 2 deletions libs/core/langchain_core/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ def warn_deprecated(
f"{removal}"
)
raise NotImplementedError(msg)
else:
removal = f"in {removal}"
removal = f"in {removal}"

if not message:
message = ""
Expand Down
18 changes: 8 additions & 10 deletions libs/core/langchain_core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def _convert_agent_action_to_messages(
"""
if isinstance(agent_action, AgentActionMessageLog):
return agent_action.message_log
else:
return [AIMessage(content=agent_action.log)]
return [AIMessage(content=agent_action.log)]


def _convert_agent_observation_to_messages(
Expand All @@ -192,14 +191,13 @@ def _convert_agent_observation_to_messages(
"""
if isinstance(agent_action, AgentActionMessageLog):
return [_create_function_message(agent_action, observation)]
else:
content = observation
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
return [HumanMessage(content=content)]
content = observation
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
return [HumanMessage(content=content)]


def _create_function_message(
Expand Down
16 changes: 6 additions & 10 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
msg = f"Invalid context config id {id_}"
raise ValueError(msg)
msg = f"Invalid context config id {id_}"
raise ValueError(msg)


def _config_with_context(
Expand Down Expand Up @@ -190,8 +189,7 @@ def invoke(
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
return configurable[self.ids[0]]()

async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -201,8 +199,7 @@ async def ainvoke(
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()
return await configurable[self.ids[0]]()


SetValue = Union[
Expand Down Expand Up @@ -397,5 +394,4 @@ def setter(
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
if isinstance(keys, str):
return f"'{keys}'"
else:
return ", ".join(f"'{k}'" for k in keys)
return ", ".join(f"'{k}'" for k in keys)
9 changes: 4 additions & 5 deletions libs/core/langchain_core/document_loaders/langsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def lazy_load(self) -> Iterator[Document]:
def _stringify(x: Union[str, dict]) -> str:
if isinstance(x, str):
return x
else:
try:
return json.dumps(x, indent=2)
except Exception:
return str(x)
try:
return json.dumps(x, indent=2)
except Exception:
return str(x)
10 changes: 4 additions & 6 deletions libs/core/langchain_core/documents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class BaseMedia(Serializable):
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
if id_value is not None:
return str(id_value)
else:
return id_value
return id_value


class Blob(BaseMedia):
Expand Down Expand Up @@ -163,9 +162,9 @@ def as_bytes(self) -> bytes:
"""Read data as bytes."""
if isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
if isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
if self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
else:
Expand Down Expand Up @@ -306,5 +305,4 @@ def __str__(self) -> str:
# a more general solution of formatting content directly inside the prompts.
if self.metadata:
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"
return f"page_content='{self.page_content}'"
5 changes: 2 additions & 3 deletions libs/core/langchain_core/example_selectors/length_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
new_length = remaining_length - self.example_text_lengths[i]
if new_length < 0:
break
else:
examples.append(self.examples[i])
remaining_length = new_length
examples.append(self.examples[i])
remaining_length = new_length
i += 1
return examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def _example_to_text(
) -> str:
if input_keys:
return " ".join(sorted_values({key: example[key] for key in input_keys}))
else:
return " ".join(sorted_values(example))
return " ".join(sorted_values(example))

def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
# Get the examples from the metadata.
Expand Down
15 changes: 7 additions & 8 deletions libs/core/langchain_core/indexing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,15 @@ def _get_source_id_assigner(
"""Get the source id from the document."""
if source_id_key is None:
return lambda doc: None
elif isinstance(source_id_key, str):
if isinstance(source_id_key, str):
return lambda doc: doc.metadata[source_id_key]
elif callable(source_id_key):
if callable(source_id_key):
return source_id_key
else:
msg = (
f"source_id_key should be either None, a string or a callable. "
f"Got {source_id_key} of type {type(source_id_key)}."
)
raise ValueError(msg)
msg = (
f"source_id_key should be either None, a string or a callable. "
f"Got {source_id_key} of type {type(source_id_key)}."
)
raise ValueError(msg)


def _deduplicate_in_order(
Expand Down
6 changes: 2 additions & 4 deletions libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
return verbose

@property
@override
Expand Down Expand Up @@ -349,8 +348,7 @@ def get_token_ids(self, text: str) -> list[int]:
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
else:
return _get_token_ids_default_method(text)
return _get_token_ids_default_method(text)

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Expand Down
45 changes: 19 additions & 26 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,15 @@ def OutputType(self) -> Any:
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
if isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004

def invoke(
self,
Expand Down Expand Up @@ -565,10 +564,9 @@ def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> st
_cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted(params.items()))
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted(params.items()))

def generate(
self,
Expand Down Expand Up @@ -1024,9 +1022,8 @@ def __call__(
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004

async def _call_async(
self,
Expand All @@ -1041,9 +1038,8 @@ async def _call_async(
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
Expand All @@ -1059,9 +1055,8 @@ def predict(
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
Expand All @@ -1084,9 +1079,8 @@ async def apredict(
)
if isinstance(result.content, str):
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
Expand Down Expand Up @@ -1259,8 +1253,7 @@ class AnswerWithJustification(BaseModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser


class SimpleChatModel(BaseChatModel):
Expand Down
33 changes: 13 additions & 20 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ def update_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output


async def aupdate_cache(
Expand Down Expand Up @@ -283,8 +282,7 @@ async def aupdate_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output


class BaseLLM(BaseLanguageModel[str], ABC):
Expand Down Expand Up @@ -328,16 +326,15 @@ def OutputType(self) -> type[str]:
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
if isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004

def _get_ls_params(
self,
Expand Down Expand Up @@ -447,8 +444,7 @@ def batch(
except Exception as e:
if return_exceptions:
return cast(list[str], [e for _ in inputs])
else:
raise
raise
else:
batches = [
inputs[i : i + max_concurrency]
Expand Down Expand Up @@ -493,8 +489,7 @@ async def abatch(
except Exception as e:
if return_exceptions:
return cast(list[str], [e for _ in inputs])
else:
raise
raise
else:
batches = [
inputs[i : i + max_concurrency]
Expand Down Expand Up @@ -960,10 +955,9 @@ def generate(
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
return self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = [
callback_managers[idx].on_llm_start(
Expand Down Expand Up @@ -1208,14 +1202,13 @@ async def agenerate(
]
)
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
output = await self._agenerate_helper(
return await self._agenerate_helper(
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
if len(missing_prompts) > 0:
run_managers = await asyncio.gather(
*[
Expand Down
Loading

0 comments on commit 07ecfd0

Please sign in to comment.