-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtoolkit.py
175 lines (144 loc) · 6.4 KB
/
toolkit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import List, Optional, Type, Sequence, Dict, Any, Union, Tuple
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, BaseModel
from langchain_core.tools import BaseToolkit
from langchain_community.tools import BaseTool
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from snowflake.connector import SnowflakeConnection
from sqlalchemy.engine import Result
class _InfoSQLDatabaseToolInput(BaseModel):
table_names: str = Field(
...,
description=(
"A comma-separated list of the table names for which to return the schema. "
"Example input: 'table1, table2, table3'"
),
)
class InfoSnowflakeTableTool(BaseTool):
"""Tool for getting metadata about a SQL database."""
name: str = "sql_db_schema"
description: str = "Get the schema and sample rows for the specified SQL tables."
args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput
db: SQLDatabase = Field(exclude=True)
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
output_schema = ""
_table_names = table_names.split(",")
for t in _table_names:
schema = self.db.run(f"DESCRIBE TABLE {t}")
output_schema += f"Schema for table {t}: {schema}\n"
return output_schema
class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")
class QuerySQLCheckerTool(BaseTool):
"""Uses Snowflake Artic model to check if a query is correct."""
template: str = """
{query}
Double check the {dialect} query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
Output the final SQL query only.
SQL Query: """
name: str = "sql_db_query_checker"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
db: SQLDatabase = Field(exclude=True)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
escaped_query = query.replace('"', '\\"').replace("'", "\\'")
prompt=self.template.format(query=escaped_query, dialect=self.db.dialect)
query = f"SELECT SNOWFLAKE.CORTEX.COMPLETE('snowflake-arctic', '{prompt}');"
return self.db.run(query)
class _QuerySQLDataBaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct SQL query.")
class QuerySQLDataBaseTool(BaseTool):
"""Tool for querying a SQL database."""
name: str = "sql_db_query"
description: str = """
Execute a SQL query against the database and get back the result and query_id.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
con: SnowflakeConnection = Field(exclude=True)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Union[str, Sequence[Dict[str, Any]], Result], Optional[str]]:
"""Execute the query, return the results and query_id; or an error message."""
try:
cursor = self.con.cursor()
results = cursor.execute(query).fetchall()
query_id = cursor.sfqid
cursor.close()
return results, query_id
except Exception as e:
return f"Error: {e}", None
class AgentToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases."""
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
con: SnowflakeConnection = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSnowflakeTableTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result and query_id from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDataBaseTool(
con=self.con, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
query_sql_checker_tool,
]