Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case #707

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 128 additions & 5 deletions numcodecs/zstd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from cpython.buffer cimport PyBUF_ANY_CONTIGUOUS, PyBUF_WRITEABLE
from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING


from .compat_ext cimport Buffer
from .compat_ext import Buffer
from .compat import ensure_contiguous_ndarray
from .abc import Codec

from libc.stdlib cimport malloc, realloc, free

cdef extern from "zstd.h":

Expand All @@ -22,6 +22,23 @@ cdef extern from "zstd.h":
struct ZSTD_CCtx_s:
pass
ctypedef ZSTD_CCtx_s ZSTD_CCtx

struct ZSTD_DStream_s:
pass
ctypedef ZSTD_DStream_s ZSTD_DStream

struct ZSTD_inBuffer_s:
const void* src
size_t size
size_t pos
ctypedef ZSTD_inBuffer_s ZSTD_inBuffer

struct ZSTD_outBuffer_s:
void* dst
size_t size
size_t pos
ctypedef ZSTD_outBuffer_s ZSTD_outBuffer

cdef enum ZSTD_cParameter:
ZSTD_c_compressionLevel=100
ZSTD_c_checksumFlag=201
Expand All @@ -37,12 +54,20 @@ cdef extern from "zstd.h":
size_t dstCapacity,
const void* src,
size_t srcSize) nogil

size_t ZSTD_decompress(void* dst,
size_t dstCapacity,
const void* src,
size_t compressedSize) nogil

size_t ZSTD_decompressStream(ZSTD_DStream* zds,
ZSTD_outBuffer* output,
ZSTD_inBuffer* input) nogil

size_t ZSTD_DStreamOutSize() nogil
ZSTD_DStream* ZSTD_createDStream() nogil
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil

cdef long ZSTD_CONTENTSIZE_UNKNOWN
cdef long ZSTD_CONTENTSIZE_ERROR
unsigned long long ZSTD_getFrameContentSize(const void* src,
Expand All @@ -56,7 +81,7 @@ cdef extern from "zstd.h":

unsigned ZSTD_isError(size_t code) nogil

const char* ZSTD_getErrorName(size_t code)
const char* ZSTD_getErrorName(size_t code) nogil


VERSION_NUMBER = ZSTD_versionNumber()
Expand Down Expand Up @@ -156,7 +181,8 @@ def decompress(source, dest=None):
source : bytes-like
Compressed data. Can be any object supporting the buffer protocol.
dest : array-like, optional
Object to decompress into.
Object to decompress into. If the content size is unknown, the
length of dest must match the decompressed size.

Returns
-------
Expand All @@ -180,9 +206,12 @@ def decompress(source, dest=None):

# determine uncompressed size
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR:
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
raise RuntimeError('Zstd decompression error: invalid input data')

if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
return stream_decompress(source_buffer)

# setup destination buffer
if dest is None:
# allocate memory
Expand All @@ -192,6 +221,8 @@ def decompress(source, dest=None):
arr = ensure_contiguous_ndarray(dest)
dest_buffer = Buffer(arr, PyBUF_ANY_CONTIGUOUS | PyBUF_WRITEABLE)
dest_ptr = dest_buffer.ptr
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
dest_size = dest_buffer.nbytes
if dest_buffer.nbytes < dest_size:
raise ValueError('destination buffer too small; expected at least %s, '
'got %s' % (dest_size, dest_buffer.nbytes))
Expand All @@ -217,6 +248,98 @@ def decompress(source, dest=None):

return dest

cdef stream_decompress(Buffer source_buffer):
"""Decompress data of unknown size

Parameters
----------
source : Buffer
Compressed data buffer

Returns
-------
dest : bytes
Object containing decompressed data.
"""

cdef:
char *source_ptr
void *dest_ptr
void *new_dst
Buffer dest_buffer = None
size_t source_size, dest_size, decompressed_size
size_t DEST_GROWTH_SIZE, status
ZSTD_inBuffer input
ZSTD_outBuffer output
ZSTD_DStream *zds

# Recommended size for output buffer, guaranteed to flush at least
# one completely block in all circumstances
DEST_GROWTH_SIZE = ZSTD_DStreamOutSize();

source_ptr = source_buffer.ptr
source_size = source_buffer.nbytes

# unknown content size, guess it is twice the size as the source
dest_size = source_size * 2

if dest_size < DEST_GROWTH_SIZE:
# minimum dest_size is DEST_GROWTH_SIZE
dest_size = DEST_GROWTH_SIZE

dest_ptr = malloc(dest_size)
zds = ZSTD_createDStream()

try:

with nogil:

status = ZSTD_initDStream(zds)
if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
ZSTD_freeDStream(zds);
raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error)

input = ZSTD_inBuffer(source_ptr, source_size, 0)
output = ZSTD_outBuffer(dest_ptr, dest_size, 0)

# Initialize to 1 to force a loop iteration
status = 1
while(status > 0 or input.pos < input.size):
# Possible returned values of ZSTD_decompressStream:
# 0: frame is completely decoded and fully flushed
# error (<0)
# >0: suggested next input size
status = ZSTD_decompressStream(zds, &output, &input)

if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error)

# There is more to decompress, grow the buffer
if status > 0 and output.pos == output.size:
new_size = output.size + DEST_GROWTH_SIZE

if new_size < output.size or new_size < DEST_GROWTH_SIZE:
raise RuntimeError('Zstd stream decompression error: output buffer overflow')

new_dst = realloc(output.dst, new_size)

if new_dst == NULL:
# output.dst freed in finally block
raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer')

output.dst = new_dst
output.size = new_size

# Copy the output to a bytes object
dest = PyBytes_FromStringAndSize(<char *>output.dst, output.pos)

finally:
ZSTD_freeDStream(zds)
free(output.dst)

return dest

class Zstd(Codec):
"""Codec providing compression using Zstandard.
Expand Down
Loading