Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback to GET if HEAD doesn't work #9

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
== 0.4.0 ==

* Added fallback behaviour to try GET for content length when HEAD isn't available
* Added `no_head_request` option
* Added `session_args` to pass kwargs to the constructor of `aiohttp.ClientSession`
* Made unit tests slightly faster when generating data

== 0.3.0 ==

* Addition of asyncio compatible interface for use in python versions 3.6 and above
Expand Down
39 changes: 30 additions & 9 deletions httpio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ class HTTPIOError(IOBaseError):


class SyncHTTPIOFile(BufferedIOBase):
def __init__(self, url, block_size=-1, **kwargs):
def __init__(self, url, block_size=-1, no_head_request=False, **kwargs):
super(SyncHTTPIOFile, self).__init__()
self.url = url
self.block_size = block_size
self.no_head_request = no_head_request

self._kwargs = kwargs
self._cursor = 0
Expand All @@ -57,18 +58,38 @@ def __enter__(self):
self.open()
return super(SyncHTTPIOFile, self).__enter__()

def _check_ranges_set_length(self, response):
try:
self.length = int(response.headers['Content-Length'])
except KeyError:
raise HTTPIOError("Server does not report content length")
if response.headers.get('Accept-Ranges', '').lower() != 'bytes':
raise HTTPIOError("Server does not accept 'Range' headers")

def _check_file_headers_set_length(self, getter):
pass

def open(self):
self._assert_not_closed()
if not self._closing and self._session is None:
self._session = requests.Session()
response = self._session.head(self.url, **self._kwargs)
response.raise_for_status()
try:
self.length = int(response.headers['Content-Length'])
except KeyError:
raise HTTPIOError("Server does not report content length")
if response.headers.get('Accept-Ranges', '').lower() != 'bytes':
raise HTTPIOError("Server does not accept 'Range' headers")

if not self.no_head_request:
response = self._session.head(self.url, **self._kwargs)

# In some cases, notably including AWS S3 presigned URLs, it's only possible to GET the URL and HEAD
# isn't supported. In these cases we skip raising an exception and fall through to the `no_head_request`
# behaviour instead
if response.status_code != 405 and response.status_code != 403:
response.raise_for_status()
self._check_ranges_set_length(response)
return

# GET the URL with stream=True to avoid downloading the full response: exiting the context manager will
# close the connection
with self._session.get(self.url, stream=True, **self._kwargs) as response:
response.raise_for_status()
self._check_ranges_set_length(response)

def close(self):
self._closing = True
Expand Down
30 changes: 26 additions & 4 deletions httpio_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@ class AsyncHTTPIOFile(object):
"""An asynchronous equivalent to httpio.HTTPIOFile.
Sadly this class cannot descend from that one for technical reasons.
"""
def __init__(self, url, block_size=-1, **kwargs):
def __init__(self, url, block_size=-1, no_head_request=False, session_args={}, **kwargs):
"""
:param url: The URL of the file to open
:param block_size: The cache block size, or `-1` to disable caching.
:param no_head_request: Don't make a HEAD request to check the file size, use a GET instead
:param session_args: Additional kwargs to pass when creating aiohttp.ClientSession (e.g. trust_env)
:param kwargs: Additional arguments to pass to `session.get`
"""
super(AsyncHTTPIOFile, self).__init__()
self.url = url
self.block_size = block_size
self.no_head_request = no_head_request
self.session_args = session_args

self._kwargs = kwargs
self._cursor = 0
Expand Down Expand Up @@ -70,11 +74,25 @@ async def open(self):
be coroutines this class needs this as a seperate coroutine"""

if self._session is None:
self._session = await aiohttp.ClientSession().__aenter__()
async with self._session.head(self.url, **self._kwargs) as response:
self._session = await aiohttp.ClientSession(**self.session_args).__aenter__()

if not self.no_head_request:
async with self._session.head(self.url, **self._kwargs) as response:
# In some cases, notably including AWS S3 presigned URLs, it's only possible to GET the URL and HEAD
# isn't supported. In these cases we skip raising an exception and fall through to the
# `no_head_request` behaviour instead
if response.status != 405 and response.status != 403:
response.raise_for_status()
self.length = int(response.headers.get('content-length', None))
self.closed = False
return

async with self._session.get(self.url, **self._kwargs) as response:
response.raise_for_status()
self.length = int(response.headers.get('content-length', None))
self.closed = False
# Note that not reading the response body will cause the underlying connection to be closed before the
# server sends the file

async def __aenter__(self):
await self.open()
Expand Down Expand Up @@ -291,7 +309,11 @@ class AsyncHTTPIOFileContextManagerMixin (object):
"""This is a mixin for HTTPIOFile to make it act as an async context manager via the AsyncHTTPIOFile class"""

async def __aenter__(self):
self.__acontextmanager = AsyncHTTPIOFile(self.url, self.block_size, **self._kwargs)
self.__acontextmanager = AsyncHTTPIOFile(self.url,
self.block_size,
no_head_request=self.no_head_request,
**self._kwargs)

return await self.__acontextmanager.__aenter__()

async def __aexit__(self, exc_type, exc, tb):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name='httpio',
version='0.3.0',
version='0.4.0',
author='Barney Gale',
author_email='[email protected]',
url='https://github.com/barneygale/httpio',
Expand Down
15 changes: 15 additions & 0 deletions tests/random_source_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import random
from six import int2byte

# 8 MB of random data for the HTTP requests to return
DATA = b''.join(int2byte(random.getrandbits(8))
for _ in range(0, 8*1024*1024))

OTHER_DATA = b''.join(int2byte(random.getrandbits(8))
for _ in range(0, 8*1024*1024))

ASCII_LINES = ["Line0\n",
"Line the first\n",
"Line Returns\n",
"Line goes forth"]
ASCII_DATA = b''.join(line.encode('ascii') for line in ASCII_LINES)
54 changes: 33 additions & 21 deletions tests/test_async_httpio.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,23 @@
import asyncio
from unittest import TestCase
from unittest import mock

from httpio import HTTPIOFile

import mock
import random
import re
import warnings

from io import SEEK_CUR, SEEK_END

from random_source_data import DATA, OTHER_DATA, ASCII_DATA, ASCII_LINES


def async_func(f):
async def __inner(*args, **kwargs):
return f(*args, **kwargs)
return __inner


# 8 MB of random data for the HTTP requests to return
DATA = bytes(random.randint(0, 0xFF)
for _ in range(0, 8*1024*1024))

OTHER_DATA = bytes(random.randint(0, 0xFF)
for _ in range(0, 8*1024*1024))

ASCII_LINES = ["Line0\n",
"Line the first\n",
"Line Returns\n",
"Line goes forth"]
ASCII_DATA = b''.join(line.encode('ascii') for line in ASCII_LINES)


IOBaseError = OSError


Expand Down Expand Up @@ -94,17 +81,18 @@ def setUp(self):

self.data_source = DATA
self.error_code = None
self.head_error_code = None

def _head(url, **kwargs):
m = AsyncContextManagerMock()
if self.error_code is None:
m.async_context_object.status_code = 204
if self.error_code is None and self.head_error_code is None:
m.async_context_object.status = 204
m.async_context_object.headers = {'content-length':
len(self.data_source),
'Accept-Ranges':
'bytes'}
else:
m.async_context_object.status_code = self.error_code
m.async_context_object.status = self.error_code or self.head_error_code
m.async_context_object.raise_for_status = mock.MagicMock(side_effect=HTTPException)
return m

Expand All @@ -122,14 +110,14 @@ def _get(*args, **kwargs):

if self.error_code is None:
return AsyncContextManagerMock(
async_context_object=mock.MagicMock(status_code=200,
async_context_object=mock.MagicMock(status=200,
read=mock.MagicMock(
side_effect=async_func(
lambda: self.data_source[start:end]))))
else:
return AsyncContextManagerMock(
async_context_object=mock.MagicMock(
status_code=self.error_code,
status=self.error_code,
raise_for_status=mock.MagicMock(side_effect=HTTPException)))
self.session.get.side_effect = _get

Expand Down Expand Up @@ -301,3 +289,27 @@ async def test_random_access(self):
async def test_seekable(self):
async with HTTPIOFile('http://www.example.com/test/', 1024) as io:
self.assertTrue(await io.seekable())

@async_test
async def test_ignores_head_error_when_no_head_request_set(self):
"""If the no_head_request flag is set, an error returned by HEAD should be ignored"""
self.head_error_code = 404
async with HTTPIOFile('http://www.example.com/test/', 1024, no_head_request=True):
pass

@async_test
async def test_throws_exception_when_get_returns_error_when_no_head_request_set(self):
self.error_code = 404
with self.assertRaises(HTTPException):
async with HTTPIOFile('http://www.example.com/test/', 1024, no_head_request=True):
pass

@async_test
async def test_retries_with_get_when_head_returns_403(self):
"""Test data can be read when the GET works but the HEAD request returns 403

This happens when given an S3 pre-signed URL, because they only support one method
"""
self.head_error_code = 403
async with HTTPIOFile('http://www.example.com/test/', 1024):
pass
63 changes: 41 additions & 22 deletions tests/test_sync_httpio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,11 @@
from io import SEEK_CUR, SEEK_END

import mock
import random
import re

from six import int2byte, PY3
from six import PY3


# 8 MB of random data for the HTTP requests to return
DATA = b''.join(int2byte(random.randint(0, 0xFF))
for _ in range(0, 8*1024*1024))

OTHER_DATA = b''.join(int2byte(random.randint(0, 0xFF))
for _ in range(0, 8*1024*1024))

ASCII_LINES = ["Line0\n",
"Line the first\n",
"Line Returns\n",
"Line goes forth"]
ASCII_DATA = b''.join(line.encode('ascii') for line in ASCII_LINES)
from random_source_data import DATA, OTHER_DATA, ASCII_DATA, ASCII_LINES


# The expected exception from unimplemented IOBase operations
Expand All @@ -51,16 +38,17 @@ def setUp(self):

self.data_source = DATA
self.error_code = None
self.head_error_code = None

def _head(url, **kwargs):
if self.error_code is None:
if self.error_code is None and self.head_error_code is None:
return mock.MagicMock(status_code=204,
headers={'Content-Length':
len(self.data_source),
'Accept-Ranges':
'bytes'})
else:
return mock.MagicMock(status_code=self.error_code,
return mock.MagicMock(status_code=self.head_error_code or self.error_code,
raise_for_status=mock.MagicMock(
side_effect=HTTPException))

Expand All @@ -77,12 +65,22 @@ def _get(url, **kwargs):
end = int(m.group(2)) + 1

if self.error_code is not None:
return mock.MagicMock(status_code=self.error_code,
raise_for_status=mock.MagicMock(
side_effect=HTTPException))
response_mock = mock.MagicMock(status_code=self.error_code,
raise_for_status=mock.MagicMock(side_effect=HTTPException))
response_mock.__enter__.return_value = response_mock

return response_mock
else:
return mock.MagicMock(status_code=200,
content=self.data_source[start:end])
content_length = (end or len(self.data_source)) - (start or 0)

response_mock = mock.MagicMock(status_code=200,
headers={
'Content-Length': content_length,
'Accept-Ranges': 'bytes'},
content=self.data_source[start:end])

response_mock.__enter__.return_value = response_mock
return response_mock

self.session.get.side_effect = _get

Expand Down Expand Up @@ -272,6 +270,27 @@ def test_writelines(self):
with self.assertRaises(IOBaseError):
io.writelines([line.encode('ascii') for line in ASCII_LINES])

def test_ignores_head_error_when_no_head_request_set(self):
"""If the no_head_request flag is set, an error returned by HEAD should be ignored"""
self.head_error_code = 404
with HTTPIOFile('http://www.example.com/test/', 1024, no_head_request=True):
pass

def test_throws_exception_when_get_returns_error_when_no_head_request_set(self):
self.error_code = 404
with self.assertRaises(HTTPException):
with HTTPIOFile('http://www.example.com/test/', 1024, no_head_request=True):
pass

def test_retries_with_get_when_head_returns_403(self):
"""Test data can be read when the GET works but the HEAD request returns 403

This happens when given an S3 pre-signed URL, because they only support one method
"""
self.head_error_code = 403
with HTTPIOFile('http://www.example.com/test/', 1024):
pass


if __name__ == "__main__":
unittest.main()