Skip to content

Commit

Permalink
Update query params to pass in index name instead of page_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Morse committed Dec 22, 2023
1 parent c329255 commit f1fa5e8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
71 changes: 40 additions & 31 deletions src/wagtail_vector_index/consumers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import asyncio
import logging
from typing import Type
from typing import Any

from channels.generic.http import AsyncHttpConsumer
from django import forms
from django.apps import apps
from django.core.exceptions import ValidationError
from django.http import QueryDict

# Define type instead of importing directly to prevent AppRegistryNotReady errors
VectorIndexType = Type["wagtail_vector_index.index.VectorIndex"] # noqa

logger = logging.Logger(__name__)


class WagtailVectorIndexQueryParamsForm(forms.Form):
"""Provides a form for validating query parameters."""

query = forms.CharField(max_length=255, required=True)
page_type = forms.CharField(max_length=255, required=True)
index = forms.CharField(max_length=255, required=True)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

from wagtail_vector_index.index import get_vector_indexes

self.indexes = get_vector_indexes()

def clean_index(self):
index = self.cleaned_data["index"]
if index not in self.indexes:
raise forms.ValidationError("Invalid index. Please choose a valid index.")
return index


class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer):
Expand All @@ -31,10 +41,10 @@ class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer):
Note:
This consumer expects the following query parameters in the URL:
- 'query': The search query.
- 'page_type': The type of Wagtail page to search.
- 'index': The vector index to perform the query with.
Example URL:
"/chat-query-sse/?query=example&page_type=news.NewsPage"
"/chat-query-sse/?query=example&index=news.NewsPage"
"""

async def handle(self, body: bytes) -> None:
Expand All @@ -56,34 +66,33 @@ async def handle(self, body: bytes) -> None:

# Validate query parameters
form = WagtailVectorIndexQueryParamsForm(query_dict)
if form.is_valid():
query = form.cleaned_data["query"]
page_type = form.cleaned_data["page_type"]

# Get a model class by its name
page_model = apps.get_model(page_type)
vector_index = page_model.get_vector_index()

try:
# Process and reply to prompt
await self.process_prompt(query, vector_index)
except Exception:
logging.exception(
"Unexpected error in WagtailVectorIndexSSEConsumer"
)
payload = (
"data: Error processing request, Please try again later. \n\n"
)
await self.send_body(payload.encode("utf-8"), more_body=True)
if not form.is_valid():
# Ignore "TRY301 Abstract `raise` to an inner function"
# So we can insure the event-stream is closed and no other code is executed
raise ValidationError("Invalid query parameters.") # noqa: TRY301
query = form.cleaned_data["query"]
index = form.cleaned_data["index"]

vector_index = form.indexes.get(index)

if vector_index:
await self.process_prompt(query, vector_index)

except (ValueError, UnicodeDecodeError, KeyError, LookupError, AttributeError):
payload = "data: Error processing request. \n\n"
await self.send_body(payload.encode("utf-8"), more_body=True)
except ValidationError:
await self.error_response()

except Exception:
logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer")
await self.error_response()

# Finish the response
await self.send_body(b"")

async def process_prompt(self, query: str, vector_index: VectorIndexType) -> None:
async def error_response(self) -> None:
payload = "data: Error processing request, Please try again later. \n\n"
await self.send_body(payload.encode("utf-8"), more_body=True)

async def process_prompt(self, query: str, vector_index: Any) -> None:
"""
Processes the incoming prompt and sends SSE updates.
Expand Down
7 changes: 4 additions & 3 deletions src/wagtail_vector_index/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ async def aquery(
similar_documents = await sync_to_async(self.backend_index.similarity_search)(
query_embedding
)
sources = await sync_to_async(self.object_type.bulk_from_documents)(
similar_documents
)
sources = []
# sources = await sync_to_async(self.object_type.bulk_from_documents)(
# similar_documents
# )
merged_context = await get_metadata_from_documents_async(similar_documents)

prompt = (
Expand Down

0 comments on commit f1fa5e8

Please sign in to comment.