Skip to content

Commit

Permalink
update testing
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Oct 16, 2024
1 parent 87029c2 commit c17d025
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
2 changes: 1 addition & 1 deletion backend/retrieval_graph/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AgentConfiguration(BaseConfiguration):
)

response_model: str = field(
default="anthropic/claude-3-5-sonnet-20240620",
default="openai/gpt-4o-mini",
metadata={
"description": "The language model used for generating responses. Should be in the form: provider/model-name."
},
Expand Down
53 changes: 33 additions & 20 deletions backend/retrieval_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
conducting research, and formulating responses.
"""

from typing import Any, Literal, TypedDict, cast
from typing import Any, Literal, Type, TypedDict, cast

from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
Expand All @@ -33,6 +33,10 @@ async def analyze_and_route_query(
Returns:
dict[str, Router]: A dictionary containing the 'router' key with the classification result (classification type and logic).
"""
# allow skipping the router for testing
if state.router and state.router["logic"]:
return {"router": state.router}

configuration = AgentConfiguration.from_runnable_config(config)
model = load_chat_model(configuration.query_model)
messages = [
Expand Down Expand Up @@ -207,22 +211,31 @@ async def respond(


# Define the graph
builder = StateGraph(AgentState, input=InputState, config_schema=AgentConfiguration)
builder.add_node(analyze_and_route_query)
builder.add_node(ask_for_more_info)
builder.add_node(respond_to_general_query)
builder.add_node(conduct_research)
builder.add_node(create_research_plan)
builder.add_node(respond)

builder.add_edge(START, "analyze_and_route_query")
builder.add_conditional_edges("analyze_and_route_query", route_query)
builder.add_edge("create_research_plan", "conduct_research")
builder.add_conditional_edges("conduct_research", check_finished)
builder.add_edge("ask_for_more_info", END)
builder.add_edge("respond_to_general_query", END)
builder.add_edge("respond", END)

# Compile into a graph object that you can invoke and deploy.
graph = builder.compile()
graph.name = "RetrievalGraph"


def make_graph(*, input_schema: Type[Any]):
builder = StateGraph(
AgentState, input=input_schema, config_schema=AgentConfiguration
)
builder.add_node(analyze_and_route_query)
builder.add_node(ask_for_more_info)
builder.add_node(respond_to_general_query)
builder.add_node(conduct_research)
builder.add_node(create_research_plan)
builder.add_node(respond)

builder.add_edge(START, "analyze_and_route_query")
builder.add_conditional_edges("analyze_and_route_query", route_query)
builder.add_edge("create_research_plan", "conduct_research")
builder.add_conditional_edges("conduct_research", check_finished)
builder.add_edge("ask_for_more_info", END)
builder.add_edge("respond_to_general_query", END)
builder.add_edge("respond", END)

# Compile into a graph object that you can invoke and deploy.
graph = builder.compile()
graph.name = "RetrievalGraph"
return graph


graph = make_graph(input_schema=InputState)
34 changes: 22 additions & 12 deletions backend/tests/evals/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any

import pandas as pd
Expand All @@ -6,10 +7,11 @@
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langsmith.evaluation import EvaluationResults, evaluate
from langsmith.evaluation import EvaluationResults, aevaluate
from langsmith.schemas import Example, Run

from backend.retrieval_graph.graph import graph
from backend.retrieval_graph.graph import make_graph
from backend.retrieval_graph.state import AgentState, Router
from backend.utils import format_docs

DATASET_NAME = "chat-langchain-qa"
Expand Down Expand Up @@ -141,10 +143,16 @@ def evaluate_qa_context(run: Run, example: Example) -> dict:

# Run evaluation

# TODO: this is a hack to allow for skipping the router for testing. Add testing for individual components.
graph = make_graph(input_schema=AgentState)

def run_graph(inputs: dict[str, Any]) -> dict[str, Any]:
results = graph.invoke(
{"messages": [("human", inputs["question"])]},

async def run_graph(inputs: dict[str, Any]) -> dict[str, Any]:
results = await graph.ainvoke(
{
"messages": [("human", inputs["question"])],
"router": Router(type="langchain", logic="The question is about LangChain"),
}
)
return results

Expand All @@ -162,13 +170,15 @@ def convert_single_example_results(evaluation_results: EvaluationResults):
# NOTE: this is more of a regression test
def test_scores_regression():
# test most commonly used model
experiment_results = evaluate(
lambda inputs: run_graph(inputs),
data=DATASET_NAME,
evaluators=[evaluate_retrieval_recall, evaluate_qa, evaluate_qa_context],
experiment_prefix=EXPERIMENT_PREFIX,
metadata={"judge_model_name": JUDGE_MODEL_NAME},
max_concurrency=4,
experiment_results = asyncio.run(
aevaluate(
run_graph,
data=DATASET_NAME,
evaluators=[evaluate_retrieval_recall, evaluate_qa, evaluate_qa_context],
experiment_prefix=EXPERIMENT_PREFIX,
metadata={"judge_model_name": JUDGE_MODEL_NAME},
max_concurrency=4,
)
)
experiment_result_df = pd.DataFrame(
convert_single_example_results(result["evaluation_results"])
Expand Down

0 comments on commit c17d025

Please sign in to comment.