Skip to content

Commit

Permalink
added pagination for more than 50 docs retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Anash3 committed Jan 31, 2025
1 parent dbb6b7b commit 4dd5ef3
Showing 1 changed file with 61 additions and 19 deletions.
80 changes: 61 additions & 19 deletions libs/community/langchain_community/retrievers/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def validate_environment(cls, values: Dict) -> Any:
)
return values

def _build_search_url(self, query: str) -> str:
def _build_search_url(self, query: str, skip: int = 0) -> str:
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
if url_suffix in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}/"
Expand All @@ -139,9 +139,13 @@ def _build_search_url(self, query: str) -> str:
# pass to Azure to throw a specific error
base_url = self.service_name
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
top_param = f"&$top={self.top_k}" if self.top_k else ""
batch_size = self.top_k if self.top_k is not None else 1000

top_param = f"&$top={batch_size}"
filter_param = f"&$filter={self.filter}" if self.filter else ""
return base_url + endpoint_path + f"&search={query}" + top_param + filter_param
skip_param = f"&$skip={skip}"
count_param = "&$count=true"
return base_url + endpoint_path + f"&search={query}" + top_param + skip_param + filter_param + count_param

@property
def _headers(self) -> Dict[str, str]:
Expand All @@ -151,26 +155,64 @@ def _headers(self) -> Dict[str, str]:
}

def _search(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")

return json.loads(response.text)["value"]

all_results = []
skip = 0

while True:
search_url = self._build_search_url(query, skip)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")

response_json = json.loads(response.text)
current_results = response_json.get('value', [])

all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)

return all_results

async def _asearch(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
all_results = []
skip = 0

if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(search_url, headers=self._headers) as response:
response_json = await response.json()
while True:
search_url = self._build_search_url(query, skip)
async with session.get(search_url, headers=self._headers) as response:
response_json = await response.json()

current_results = response_json.get('value', [])
all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)
else:
async with self.aiosession.get(
search_url, headers=self._headers
) as response:
response_json = await response.json()

return response_json["value"]
async with self.aiosession:
while True:
search_url = self._build_search_url(query, skip)
async with self.aiosession.get(search_url, headers=self._headers) as response:
response_json = await response.json()

current_results = response_json.get('value', [])
all_results.extend(current_results)

total_results = response_json.get('@odata.count', 0)
if len(all_results) >= total_results or not current_results:
break

skip += len(current_results)

return all_results

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
Expand Down

0 comments on commit 4dd5ef3

Please sign in to comment.