From 4dd5ef39e8ca803cc8bcee5277c09d152ee67c56 Mon Sep 17 00:00:00 2001 From: Anash03 Date: Fri, 31 Jan 2025 19:55:13 +0530 Subject: [PATCH] added pagination for more than 50 docs retrieval --- .../retrievers/azure_ai_search.py | 80 ++++++++++++++----- 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/libs/community/langchain_community/retrievers/azure_ai_search.py b/libs/community/langchain_community/retrievers/azure_ai_search.py index 01549ff3bd012..59d901e89cd7c 100644 --- a/libs/community/langchain_community/retrievers/azure_ai_search.py +++ b/libs/community/langchain_community/retrievers/azure_ai_search.py @@ -123,7 +123,7 @@ def validate_environment(cls, values: Dict) -> Any: ) return values - def _build_search_url(self, query: str) -> str: + def _build_search_url(self, query: str, skip: int = 0) -> str: url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX) if url_suffix in self.service_name and "https://" in self.service_name: base_url = f"{self.service_name}/" @@ -139,9 +139,13 @@ def _build_search_url(self, query: str) -> str: # pass to Azure to throw a specific error base_url = self.service_name endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" - top_param = f"&$top={self.top_k}" if self.top_k else "" + batch_size = self.top_k if self.top_k is not None else 1000 + + top_param = f"&$top={batch_size}" filter_param = f"&$filter={self.filter}" if self.filter else "" - return base_url + endpoint_path + f"&search={query}" + top_param + filter_param + skip_param = f"&$skip={skip}" + count_param = "&$count=true" + return base_url + endpoint_path + f"&search={query}" + top_param + skip_param + filter_param + count_param @property def _headers(self) -> Dict[str, str]: @@ -151,26 +155,64 @@ def _headers(self) -> Dict[str, str]: } def _search(self, query: str) -> List[dict]: - search_url = self._build_search_url(query) - response = requests.get(search_url, headers=self._headers) - if response.status_code != 200: - raise Exception(f"Error in search request: {response}") - - return json.loads(response.text)["value"] - + all_results = [] + skip = 0 + + while True: + search_url = self._build_search_url(query, skip) + response = requests.get(search_url, headers=self._headers) + if response.status_code != 200: + raise Exception(f"Error in search request: {response}") + + response_json = json.loads(response.text) + current_results = response_json.get('value', []) + + all_results.extend(current_results) + + total_results = response_json.get('@odata.count', 0) + if len(all_results) >= total_results or not current_results: + break + + skip += len(current_results) + + return all_results + async def _asearch(self, query: str) -> List[dict]: - search_url = self._build_search_url(query) + all_results = [] + skip = 0 + if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(search_url, headers=self._headers) as response: - response_json = await response.json() + while True: + search_url = self._build_search_url(query, skip) + async with session.get(search_url, headers=self._headers) as response: + response_json = await response.json() + + current_results = response_json.get('value', []) + all_results.extend(current_results) + + total_results = response_json.get('@odata.count', 0) + if len(all_results) >= total_results or not current_results: + break + + skip += len(current_results) else: - async with self.aiosession.get( - search_url, headers=self._headers - ) as response: - response_json = await response.json() - - return response_json["value"] + async with self.aiosession: + while True: + search_url = self._build_search_url(query, skip) + async with self.aiosession.get(search_url, headers=self._headers) as response: + response_json = await response.json() + + current_results = response_json.get('value', []) + all_results.extend(current_results) + + total_results = response_json.get('@odata.count', 0) + if len(all_results) >= total_results or not current_results: + break + + skip += len(current_results) + + return all_results def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun