Skip to content

Commit

Permalink
python: add method to determine responsible target for an object
Browse files Browse the repository at this point in the history
- Add method in 'Smap' to map objects to targets
- Implement Xoshiro256-inspired hash function for consistent hashing
- Add  method for Snode to handle node states

Signed-off-by: Abhishek Gaikwad <[email protected]>
  • Loading branch information
gaikwadabhishek committed Jan 29, 2025
1 parent 6514aa1 commit b61478e
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 27 deletions.
3 changes: 2 additions & 1 deletion python/aistore/common_requirements
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pytest==7.4.4
PyYAML==6.0.2
requests==2.32.3
typing_extensions==4.12.2
webdataset==0.2.100
webdataset==0.2.100
xxhash==3.5.0
2 changes: 1 addition & 1 deletion python/aistore/sdk/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ def object(self, obj_name: str, props: ObjectProps = None) -> Object:
Returns:
The object created.
"""
details = BucketDetails(self.name, self.provider, self.qparam)
details = BucketDetails(self.name, self.provider, self.qparam, self.get_path())
return Object(
client=self.client, bck_details=details, name=obj_name, props=props
)
Expand Down
5 changes: 4 additions & 1 deletion python/aistore/sdk/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
#

# Standard Header Keys
Expand Down Expand Up @@ -171,3 +171,6 @@

# Ref: https://www.rfc-editor.org/rfc/rfc7233#section-2.1
BYTE_RANGE_PREFIX_LENGTH = 6

# Custom seed (MLCG32)
XX_HASH_SEED = 1103515245
12 changes: 8 additions & 4 deletions python/aistore/sdk/obj/object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
#

import warnings
Expand All @@ -8,7 +8,7 @@
from io import BufferedWriter
from pathlib import Path
from typing import Dict

import os
from requests import Response
from requests.structures import CaseInsensitiveDict

Expand Down Expand Up @@ -55,6 +55,7 @@ class BucketDetails:
name: str
provider: Provider
qparams: Dict[str, str]
path: str


class Object:
Expand Down Expand Up @@ -129,7 +130,7 @@ def head(self) -> CaseInsensitiveDict:
self._props = ObjectProps(headers)
return headers

# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-locals
def get_reader(
self,
archive_config: ArchiveConfig = None,
Expand All @@ -139,6 +140,7 @@ def get_reader(
writer: BufferedWriter = None,
latest: bool = False,
byte_range: str = None,
direct: bool = False,
) -> ObjectReader:
"""
Creates and returns an ObjectReader with access to object contents and optionally writes to a provided writer.
Expand All @@ -153,6 +155,8 @@ def get_reader(
latest (bool, optional): GET the latest object version from the associated remote bucket
byte_range (str, optional): Byte range in RFC 7233 format for single-range requests (e.g., "bytes=0-499",
"bytes=500-", "bytes=-500"). See: https://www.rfc-editor.org/rfc/rfc7233#section-2.1.
direct (bool, optional): If True, the object content is read directly from the target node,
bypassing the proxy
Returns:
An ObjectReader which can be iterated over to stream chunks of object content or used to read all content
Expand Down Expand Up @@ -197,13 +201,13 @@ def get_reader(
int(byte_range_l) if byte_range_l else None,
int(byte_range_r) if byte_range_r else None,
)

obj_client = ObjectClient(
request_client=self._client,
path=self._object_path,
params=params,
headers=headers,
byte_range=byte_range_tuple,
uname=os.path.join(self._bck_details.path, self.name) if direct else None,
)

obj_reader = ObjectReader(
Expand Down
81 changes: 66 additions & 15 deletions python/aistore/sdk/obj/object_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aistore.sdk.const import HTTP_METHOD_GET, HTTP_METHOD_HEAD, HEADER_RANGE
from aistore.sdk.obj.object_attributes import ObjectAttributes
from aistore.sdk.request_client import RequestClient
from aistore.sdk.errors import ErrObjNotFound


class ObjectClient:
Expand All @@ -21,6 +22,7 @@ class ObjectClient:
params (Dict[str, str]): Query parameters for the request
headers (Optional[Dict[str, str]]): HTTP request headers
byte_range (Optional[Tuple[Optional[int], Optional[int]]): Tuple representing the byte range
uname (Optional[str]): Unique (namespaced) name of the object (used for determining the target node)
"""

# pylint: disable=too-many-arguments
Expand All @@ -31,47 +33,96 @@ def __init__(
params: Dict[str, str],
headers: Optional[Dict[str, str]] = None,
byte_range: Optional[Tuple[Optional[int], Optional[int]]] = (None, None),
uname: Optional[str] = None,
):
self._request_client = request_client
self._request_path = path
self._request_params = params
self._request_headers = headers
self._byte_range = byte_range
self._uname = uname
if uname:
self._initialize_target_client()

def _initialize_target_client(self, force: bool = False):
"""
Initialize a new RequestClient pointing to the target node for the object.
"""
smap = self._request_client.get_smap(force)
target_node = smap.get_target_for_object(self._uname)
new_client = self._request_client.clone(
base_url=target_node.public_net.direct_url
)
self._request_client = new_client

def _retry_with_new_smap(self, method: str, **kwargs):
"""
Retry the request with the latest `smap` if a 404 error is encountered.
Args:
method (str): HTTP method (e.g., GET, HEAD).
**kwargs: Additional arguments to pass to the request.
Returns:
requests.Response: The response object from the retried request.
"""
if self._uname:
# Force update the smap
self._initialize_target_client(force=True)

# Retry the request
return self._request_client.request(method, **kwargs)

def get(self, stream: bool, offset: Optional[int] = None) -> requests.Response:
"""
Make a request to AIS to get the object content, applying an optional offset.
Fetch object content from AIS, applying an optional offset.
Args:
stream (bool): If True, stream the response content.
offset (int, optional): The offset in bytes to apply. If not provided, no offset
is applied.
offset (int, optional): Byte offset for reading the object. Defaults to None.
Returns:
requests.Response: The response object from the request.
requests.Response: The response object containing the content.
Raises:
ErrObjNotFound: If the object is not found and cannot be retried.
requests.RequestException: For network-related errors.
Exception: For any unexpected failures.
"""
headers = self._request_headers.copy() if self._request_headers else {}

if offset:
l, r = self._byte_range
if l is not None:
l = l + offset
l += offset
elif r is not None:
r = r - offset
r -= offset
else:
l = offset

headers[HEADER_RANGE] = f"bytes={l or ''}-{r or ''}"

resp = self._request_client.request(
HTTP_METHOD_GET,
path=self._request_path,
params=self._request_params,
stream=stream,
headers=headers,
)
resp.raise_for_status()
return resp
try:
resp = self._request_client.request(
HTTP_METHOD_GET,
path=self._request_path,
params=self._request_params,
stream=stream,
headers=headers,
)
resp.raise_for_status()
return resp

except ErrObjNotFound as _:
if self._uname:
return self._retry_with_new_smap(
HTTP_METHOD_GET,
path=self._request_path,
params=self._request_params,
stream=stream,
headers=headers,
)
raise

def head(self) -> ObjectAttributes:
"""
Expand Down
43 changes: 43 additions & 0 deletions python/aistore/sdk/request_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
HEADER_LOCATION,
STATUS_REDIRECT_PERM,
STATUS_REDIRECT_TMP,
URL_PATH_DAEMON,
WHAT_SMAP,
QPARAM_WHAT,
HTTP_METHOD_GET,
)
from aistore.sdk.session_manager import SessionManager
from aistore.sdk.utils import parse_ais_error, handle_errors, decode_response
from aistore.version import __version__ as sdk_version
from aistore.sdk.types import Smap

T = TypeVar("T")

Expand Down Expand Up @@ -52,6 +57,8 @@ def __init__(
self._token = token
self._timeout = timeout
self._error_handler = error_handler
# smap is used to calculate the target node for a given object
self._smap = None

@property
def base_url(self):
Expand Down Expand Up @@ -101,6 +108,42 @@ def token(self, token: str):
"""
self._token = token

def get_smap(self, force_update: bool = False) -> "Smap":
"""Return the smap."""
if not self._smap or force_update:
self._smap = self.request_deserialize(
HTTP_METHOD_GET,
path=URL_PATH_DAEMON,
res_model=Smap,
params={QPARAM_WHAT: WHAT_SMAP},
)
return self._smap

def clone(self, base_url: Optional[str] = None) -> "RequestClient":
"""
Create a copy of the current RequestClient instance with an optional new base URL.
Args:
base_url (Optional[str]): New base URL for the cloned client. Defaults to the existing base URL.
Returns:
RequestClient: A new instance with the same settings but an optional different base URL.
"""

# Default to the existing base URL if none is provided
base_url = base_url or self._base_url

# Ensure the base URL ends with "/v1"
base_url = base_url if base_url.endswith("/v1") else urljoin(base_url, "v1")

return RequestClient(
endpoint=base_url,
session_manager=self._session_manager,
timeout=self._timeout,
token=self._token,
error_handler=self._error_handler,
)

def request_deserialize(
self, method: str, path: str, res_model: Type[T], **kwargs
) -> T:
Expand Down
42 changes: 41 additions & 1 deletion python/aistore/sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
#

from __future__ import annotations
Expand All @@ -24,6 +24,8 @@
AIS_LOCATION,
AIS_MIRROR_COPIES,
)
from aistore.sdk.utils import get_digest, xoshiro256_hash
from aistore.sdk.errors import AISError


# pylint: disable=too-few-public-methods,unused-variable,missing-function-docstring,too-many-lines
Expand Down Expand Up @@ -58,6 +60,10 @@ class Snode(BaseModel):
intra_control_net: NetInfo = None
intra_data_net: NetInfo = None
flags: int = 0
id_digest: int = 0

def in_maint_or_decomm(self) -> bool:
return (self.flags & (1 << 2 | 1 << 3)) != 0


class Smap(BaseModel):
Expand All @@ -72,6 +78,40 @@ class Smap(BaseModel):
uuid: str = ""
creation_time: str = ""

def get_target_for_object(self, uname: str) -> Snode:
"""
Determine the target node responsible for an object based on its bucket path and name.
Args:
uname (str): Fully qualified (namespaced) object name (e.g., f"{bck.get_path()}{obj.name}").
Returns:
Snode: The assigned target node.
Raises:
AISError: If no suitable target node is found.
"""
digest = get_digest(uname)

selected_node, max_hash = None, -1

for tsi in self.tmap.values():
if tsi.in_maint_or_decomm():
continue # Skip nodes in maintenance or decommissioned mode

# Compute hash using Xoshiro256
cs = xoshiro256_hash(tsi.id_digest ^ digest)

if cs > max_hash:
max_hash, selected_node = cs, tsi

if selected_node is None:
raise AISError(
500, f"No available targets in the map. Total nodes: {len(self.tmap)}"
)

return selected_node


class BucketEntry(msgspec.Struct):
"""
Expand Down
Loading

0 comments on commit b61478e

Please sign in to comment.