Skip to content

Commit

Permalink
Cognee integration update (#17482)
Browse files Browse the repository at this point in the history
  • Loading branch information
dexters1 authored Jan 11, 2025
1 parent ea3daed commit 76a7539
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def example_graph_rag_cognee():
graph_db_provider="networkx",
vector_db_provider="lancedb",
relational_db_provider="sqlite",
db_name="cognee_db",
relational_db_name="cognee_db",
)

# Add data to cognee
Expand All @@ -50,12 +50,22 @@ async def example_graph_rag_cognee():
await cogneeRAG.process_data("test")

# Answer prompt based on knowledge graph
search_results = await cogneeRAG.search("person")
print("\n\nExtracted sentences are:\n")
search_results = await cogneeRAG.search(
"Tell me who are the people mentioned?"
)
print("\n\nAnswer based on knowledge graph:\n")
for result in search_results:
print(f"{result}\n")

# Answer prompt based on RAG
search_results = await cogneeRAG.rag_search(
"Tell me who are the people mentioned?"
)
print("\n\nAnswer based on RAG:\n")
for result in search_results:
print(f"{result}\n")

# Search for related nodes
# Search for related nodes in graph
search_results = await cogneeRAG.get_related_nodes("person")
print("\n\nRelated nodes are:\n")
for result in search_results:
Expand All @@ -65,3 +75,11 @@ async def example_graph_rag_cognee():
if __name__ == "__main__":
asyncio.run(example_graph_rag_cognee())
```

## Supported databases

**Relational databases:** SQLite, PostgreSQL

**Vector databases:** LanceDB, PGVector, QDrant, Weviate

**Graph databases:** Neo4j, NetworkX
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@ class CogneeGraphRAG(GraphRAG):
This enables the system to retrieve more precise and structured information about an entity, its relationships, and its properties.
Attributes:
llm_api_key: str: Api key for desired llm.
llm_provider: str: Provider for desired llm.
llm_model: str: Model for desired llm.
graph_db_provider: str: The graph database provider.
vector_db_provider: str: The vector database provider.
relational_db_provider: str: The relational database provider.
db_name: str: The name of the databases.
llm_api_key: str: API key for desired LLM.
llm_provider: str: Provider for desired LLM (default: "openai").
llm_model: str: Model for desired LLM (default: "gpt-4o-mini").
graph_db_provider: str: The graph database provider (default: "networkx").
Supported providers: "neo4j", "networkx".
graph_database_url: str: URL for the graph database.
graph_database_username: str: Username for accessing the graph database.
graph_database_password: str: Password for accessing the graph database.
vector_db_provider: str: The vector database provider (default: "lancedb").
Supported providers: "lancedb", "pgvector", "qdrant", "weviate".
vector_db_url: str: URL for the vector database.
vector_db_key: str: API key for accessing the vector database.
relational_db_provider: str: The relational database provider (default: "sqlite").
Supported providers: "sqlite", "postgres".
db_name: str: The name of the databases (default: "cognee_db").
db_host: str: Host for the relational database.
db_port: str: Port for the relational database.
db_username: str: Username for the relational database.
db_password: str: Password for the relational database.
"""

def __init__(
Expand All @@ -32,9 +44,18 @@ def __init__(
llm_provider: str = "openai",
llm_model: str = "gpt-4o-mini",
graph_db_provider: str = "networkx",
graph_database_url: str = "",
graph_database_username: str = "",
graph_database_password: str = "",
vector_db_provider: str = "lancedb",
vector_db_url: str = "",
vector_db_key: str = "",
relational_db_provider: str = "sqlite",
db_name: str = "cognee_db",
relational_db_name: str = "cognee_db",
relational_db_host: str = "",
relational_db_port: str = "",
relational_db_username: str = "",
relational_db_password: str = "",
) -> None:
cognee.config.set_llm_config(
{
Expand All @@ -44,11 +65,33 @@ def __init__(
}
)

cognee.config.set_vector_db_config({"vector_db_provider": vector_db_provider})
cognee.config.set_vector_db_config(
{
"vector_db_url": vector_db_url,
"vector_db_key": vector_db_key,
"vector_db_provider": vector_db_provider,
}
)
cognee.config.set_relational_db_config(
{"db_provider": relational_db_provider, "db_name": db_name}
{
"db_path": "",
"db_name": relational_db_name,
"db_host": relational_db_host,
"db_port": relational_db_port,
"db_username": relational_db_username,
"db_password": relational_db_password,
"db_provider": relational_db_provider,
}
)

cognee.config.set_graph_db_config(
{
"graph_database_provider": graph_db_provider,
"graph_database_url": graph_database_url,
"graph_database_username": graph_database_username,
"graph_database_password": graph_database_password,
}
)
cognee.config.set_graph_database_provider(graph_db_provider)

data_directory_path = str(
pathlib.Path(
Expand Down Expand Up @@ -119,6 +162,17 @@ async def get_graph_url(self, graphistry_password, graphistry_username) -> str:
print(graph_url)
return graph_url

async def rag_search(self, query: str) -> list:
"""Answer query based on data chunk most relevant to query.
Args:
query (str): The query string.
"""
user = await cognee.modules.users.methods.get_default_user()
return await cognee.search(
cognee.api.v1.search.SearchType.COMPLETION, query, user
)

async def search(self, query: str) -> list:
"""Search the graph for relevant information based on a query.
Expand All @@ -127,7 +181,7 @@ async def search(self, query: str) -> list:
"""
user = await cognee.modules.users.methods.get_default_user()
return await cognee.search(
cognee.api.v1.search.SearchType.SUMMARIES, query, user
cognee.api.v1.search.SearchType.GRAPH_COMPLETION, query, user
)

async def get_related_nodes(self, node_id: str) -> list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-graph-rag-cognee"
readme = "README.md"
version = "0.1.0"
version = "0.1.1"

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
cognee = "^0.1.20"
cognee = {extras = ["neo4j", "postgres", "qdrant", "weaviate"], version = "^0.1.21"}
httpx = "~=0.27.0"
llama-index-core = "^0.12.5"
pytest-cov = "^6.0.0"
Expand All @@ -40,7 +40,7 @@ pytest-cov = "^6.0.0"
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pre-commit = "^4.0.0"
pylint = "2.15.10"
pytest = "8.2"
pytest-asyncio = "^0.25.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ async def test_add_data(monkeypatch):
graph_db_provider="networkx",
vector_db_provider="lancedb",
relational_db_provider="sqlite",
db_name="cognee_db",
relational_db_name="cognee_db",
)

# Mock logging to graphistry
async def mock_add_return(add, dataset_name):
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def test_get_graph_url(monkeypatch):
graph_db_provider="networkx",
vector_db_provider="lancedb",
relational_db_provider="sqlite",
db_name="cognee_db",
relational_db_name="cognee_db",
)

# Mock logging to graphistry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_graph_rag_cognee():
graph_db_provider="networkx",
vector_db_provider="lancedb",
relational_db_provider="sqlite",
db_name="cognee_db",
relational_db_name="cognee_db",
)

# Add data to cognee
Expand All @@ -39,11 +39,20 @@ async def test_graph_rag_cognee():
await cogneeRAG.process_data("test")

# Answer prompt based on knowledge graph
search_results = await cogneeRAG.search("person")
search_results = await cogneeRAG.search("Tell me who are the people mentioned?")

assert len(search_results) > 0, "No search results found"

print("\n\nExtracted sentences are:\n")
print("\n\nAnswer based on knowledge graph:\n")
for result in search_results:
print(f"{result}\n")

# Answer prompt based on RAG
search_results = await cogneeRAG.rag_search("Tell me who are the people mentioned?")

assert len(search_results) > 0, "No search results found"

print("\n\nAnswer based on RAG:\n")
for result in search_results:
print(f"{result}\n")

Expand Down

0 comments on commit 76a7539

Please sign in to comment.