Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Community): Adding Structured Support for ChatPerplexity #29361

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
116 changes: 112 additions & 4 deletions libs/community/langchain_community/chat_models/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from __future__ import annotations

import logging
from operator import itemgetter
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
Expand All @@ -34,17 +38,78 @@
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.utils import from_env, get_pydantic_field_names
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import (
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self

_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]

logger = logging.getLogger(__name__)


def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)


def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Union[Dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema

if (
isinstance(schema, dict)
and "json_schema" in schema
and schema.get("type") == "json_schema"
):
response_format = schema
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
response_format = {"type": "json_schema", "json_schema": schema}
else:
if strict is None:
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
strict = schema["strict"]
else:
strict = False
function = convert_to_openai_function(schema, strict=strict)
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}

if strict is not None and strict is not response_format["json_schema"].get(
"strict"
):
msg = (
f"Output schema already has 'strict' value set to "
f"{schema['json_schema']['strict']} but 'strict' also passed in to "
f"with_structured_output as {strict}. Please make sure that "
f"'strict' is only specified in one place."
)
raise ValueError(msg)
return response_format


@chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a BaseModel instance getting populated under "parsed" in .additional_kwargs?

else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
f"field. Received message:\n\n{ai_msg}"
)


class ChatPerplexity(BaseChatModel):
"""`Perplexity AI` Chat models API.

Expand Down Expand Up @@ -282,3 +347,46 @@ def _invocation_params(self) -> Mapping[str, Any]:
def _llm_type(self) -> str:
"""Return type of chat model."""
return "perplexitychat"

def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["json_schema"] = "json_schema",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
""" """ # noqa: E501
if method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_schema'. "
"Received None."
)
is_pydantic_schema = _is_pydantic_class(schema)
if is_pydantic_schema:
response_format = schema.model_json_schema() # type: ignore[union-attr]
else:
response_format = _convert_to_openai_response_format(
schema, strict=strict
)
llm = self.bind(response_format=response_format)
output_parser = JsonOutputParser()
else:
raise ValueError(
f"Unrecognized method argument. Expected 'json_schema' Received:\
'{method}'"
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
Loading