Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck authored Oct 17, 2024
2 parents 79c834d + 1869232 commit c9c7985
Show file tree
Hide file tree
Showing 89 changed files with 3,406 additions and 1,176 deletions.
4 changes: 2 additions & 2 deletions backend/open_webui/apps/retrieval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,8 @@ def save_docs_to_vector_db(
if overwrite:
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
log.info(f"deleting existing collection {collection_name}")

if add is False:
elif add is False:
log.info(f"collection {collection_name} already exists, overwrite is False and add is False")
return True

log.info(f"adding to collection {collection_name}")
Expand Down
11 changes: 6 additions & 5 deletions backend/open_webui/apps/retrieval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def get_rag_context(
extracted_collections.extend(collection_names)

if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})

contexts = []
Expand All @@ -401,11 +403,8 @@ def get_rag_context(
]
)
)

contexts.append(
(", ".join(file_names) + ":\n\n")
if file_names
else ""
((", ".join(file_names) + ":\n\n") if file_names else "")
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
Expand All @@ -423,7 +422,9 @@ def get_rag_context(
except Exception as e:
log.exception(e)

print(contexts, citations)
print("contexts", contexts)
print("citations", citations)

return contexts, citations


Expand Down
2 changes: 2 additions & 0 deletions backend/open_webui/apps/webui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from open_webui.apps.webui.routers import (
auths,
chats,
folders,
configs,
files,
functions,
Expand Down Expand Up @@ -110,6 +111,7 @@
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])

app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
Expand Down
113 changes: 102 additions & 11 deletions backend/open_webui/apps/webui/models/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Chat(Base):
pinned = Column(Boolean, default=False, nullable=True)

meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)


class ChatModel(BaseModel):
Expand All @@ -51,6 +52,7 @@ class ChatModel(BaseModel):
pinned: Optional[bool] = False

meta: dict = {}
folder_id: Optional[str] = None


####################
Expand All @@ -61,10 +63,12 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel):
chat: dict


class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]


class ChatTitleForm(BaseModel):
title: str

Expand All @@ -80,6 +84,7 @@ class ChatResponse(BaseModel):
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None


class ChatTitleIdResponse(BaseModel):
Expand Down Expand Up @@ -252,14 +257,18 @@ def get_chat_list_by_user_id(
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
if not include_archived:
query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)

query = query.order_by(Chat.updated_at.desc())

if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)

all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]

def get_chat_title_id_list_by_user_id(
Expand All @@ -270,7 +279,9 @@ def get_chat_title_id_list_by_user_id(
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))

if not include_archived:
query = query.filter_by(archived=False)

Expand Down Expand Up @@ -361,7 +372,7 @@ def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
Expand All @@ -387,9 +398,25 @@ def get_chats_by_user_id_and_search_text(
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()

if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)

search_text_words = search_text.split(" ")

# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
word.replace("tag:", "").replace(" ", "_").lower()
for word in search_text_words
if word.startswith("tag:")
]

search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]

search_text = " ".join(search_text_words)

with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)

Expand Down Expand Up @@ -418,6 +445,26 @@ def get_chats_by_user_id_and_search_text(
)
).params(search_text=search_text)
)

# Check if there are any tags to filter, it should have all the tags
if tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
WHERE tag.value = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)

elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
Expand All @@ -436,6 +483,25 @@ def get_chats_by_user_id_and_search_text(
)
).params(search_text=search_text)
)

# Check if there are any tags to filter, it should have all the tags
if tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
WHERE tag = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
Expand All @@ -444,9 +510,34 @@ def get_chats_by_user_id_and_search_text(
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()

print(len(all_chats))

# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]

def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id).all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]

def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None

def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
Expand Down Expand Up @@ -498,7 +589,7 @@ def add_chat_tag_by_id_and_user_id_and_tag_name(
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": chat.meta.get("tags", []) + [tag_id],
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}

db.commit()
Expand All @@ -509,7 +600,7 @@ def add_chat_tag_by_id_and_user_id_and_tag_name(

def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)

# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
Expand Down Expand Up @@ -555,7 +646,7 @@ def delete_tag_by_id_and_user_id_and_tag_name(
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": tags,
"tags": list(set(tags)),
}
db.commit()
return True
Expand Down
Loading

0 comments on commit c9c7985

Please sign in to comment.