diff --git a/src/main/java/io/airlift/compress/zlib/InflateDecompressor.java b/src/main/java/io/airlift/compress/zlib/InflateDecompressor.java new file mode 100644 index 00000000..74eeb826 --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/InflateDecompressor.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import io.airlift.compress.Decompressor; +import io.airlift.compress.MalformedInputException; + +import java.nio.Buffer; +import java.nio.ByteBuffer; + +import static io.airlift.compress.zlib.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; + +public class InflateDecompressor + implements Decompressor +{ + @Override + public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) + throws MalformedInputException + { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; + long inputLimit = inputAddress + inputLength; + long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; + long outputLimit = outputAddress + maxOutputLength; + + return InflateRawDecompressor.decompress(input, inputAddress, inputLimit, output, outputAddress, outputLimit); + } + + @Override + public void decompress(ByteBuffer inputBuffer, ByteBuffer outputBuffer) + throws MalformedInputException + { + // Java 9+ added an overload of various methods in ByteBuffer. When compiling with Java 11+ and targeting Java 8 bytecode + // the resulting signatures are invalid for JDK 8, so accesses below result in NoSuchMethodError. Accessing the + // methods through the interface class works around the problem + // Sidenote: we can't target "javac --release 8" because Unsafe is not available in the signature data for that profile + Buffer input = inputBuffer; + Buffer output = outputBuffer; + + Object inputBase; + long inputAddress; + long inputLimit; + if (input.isDirect()) { + inputBase = null; + long address = getAddress(input); + inputAddress = address + input.position(); + inputLimit = address + input.limit(); + } + else if (input.hasArray()) { + inputBase = input.array(); + inputAddress = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.position(); + inputLimit = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.limit(); + } + else { + throw new IllegalArgumentException("Unsupported input ByteBuffer implementation " + input.getClass().getName()); + } + + Object outputBase; + long outputAddress; + long outputLimit; + if (output.isDirect()) { + outputBase = null; + long address = getAddress(output); + outputAddress = address + output.position(); + outputLimit = address + output.limit(); + } + else if (output.hasArray()) { + outputBase = output.array(); + outputAddress = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.position(); + outputLimit = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.limit(); + } + else { + throw new IllegalArgumentException("Unsupported output ByteBuffer implementation " + output.getClass().getName()); + } + + // HACK: Assure JVM does not collect Slice wrappers while decompressing, since the + // collection may trigger freeing of the underlying memory resulting in a segfault + // There is no other known way to signal to the JVM that an object should not be + // collected in a block, and technically, the JVM is allowed to eliminate these locks. + synchronized (input) { + synchronized (output) { + int written = InflateRawDecompressor.decompress(inputBase, inputAddress, inputLimit, outputBase, outputAddress, outputLimit); + output.position(output.position() + written); + } + } + } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } +} diff --git a/src/main/java/io/airlift/compress/zlib/InflateRawDecompressor.java b/src/main/java/io/airlift/compress/zlib/InflateRawDecompressor.java new file mode 100644 index 00000000..63b7d446 --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/InflateRawDecompressor.java @@ -0,0 +1,256 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import io.airlift.compress.MalformedInputException; +import io.airlift.compress.zlib.InflateTables.CodeType; +import io.airlift.compress.zlib.InflateTables.InflateTable; + +import static io.airlift.compress.zlib.InflateTables.END_OF_BLOCK; +import static io.airlift.compress.zlib.InflateTables.ENOUGH_DISTANCES; +import static io.airlift.compress.zlib.InflateTables.ENOUGH_LENGTHS; +import static io.airlift.compress.zlib.InflateTables.INVALID_CODE; +import static io.airlift.compress.zlib.InflateTables.buildCodeTable; +import static io.airlift.compress.zlib.InflateTables.extractBits; +import static io.airlift.compress.zlib.InflateTables.extractOp; +import static io.airlift.compress.zlib.InflateTables.extractValue; +import static java.lang.Math.toIntExact; + +// This implementation is based on zlib by Jean-loup Gailly and Mark Adler +public final class InflateRawDecompressor +{ + private static final int NON_COMPRESSED = 0; + private static final int FIXED_HUFFMAN = 1; + private static final int DYNAMIC_HUFFMAN = 2; + + private static final int MAX_LENGTH_CODES = 286; // max number of literal/length codes + private static final int MAX_DISTANCE_CODES = 30; // max number of distance codes + + private InflateRawDecompressor() {} + + public static int decompress(Object inputBase, long inputAddress, long inputLimit, Object outputBase, long outputAddress, long outputLimit) + throws MalformedInputException + { + InputReader reader = new InputReader(inputBase, inputAddress, inputLimit); + OutputWriter writer = new OutputWriter(outputBase, outputAddress, outputLimit); + + boolean last; + do { + last = reader.bits(1) == 1; + int type = reader.bits(2); + + switch (type) { + case NON_COMPRESSED: + nonCompressed(reader, writer); + break; + case FIXED_HUFFMAN: + fixedHuffman(reader, writer); + break; + case DYNAMIC_HUFFMAN: + dynamicHuffman(reader, writer); + break; + default: + throw new MalformedInputException(reader.offset(), "Invalid block type: " + type); + } + } + while (!last); + + if (reader.available() > 0) { + throw new MalformedInputException(reader.offset(), "Output buffer too small"); + } + + return toIntExact(writer.offset()); + } + + private static void nonCompressed(InputReader reader, OutputWriter writer) + { + reader.clear(); + + int lsb = reader.readByte(); + int msb = reader.readByte(); + + int checkLsb = reader.readByte(); + int checkMsb = reader.readByte(); + + if ((lsb != (~checkLsb & 0xFF)) || (msb != (~checkMsb & 0xFF))) { + throw new MalformedInputException(reader.offset(), "Block length does not match complement"); + } + + int length = (msb << 8) | lsb; + + writer.copyInput(reader, length); + } + + private static void fixedHuffman(InputReader reader, OutputWriter writer) + { + inflate(InflateTables.FIXED_TABLE, reader, writer); + } + + private static final short[] CODE_LENGTHS_ORDER = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15, + }; + + private static void dynamicHuffman(InputReader reader, OutputWriter writer) + { + int lengthSize = reader.bits(5) + 257; + int distanceSize = reader.bits(5) + 1; + int codeSize = reader.bits(4) + 4; + if (lengthSize > MAX_LENGTH_CODES) { + throw new MalformedInputException(reader.offset(), "Length count is too large: " + lengthSize); + } + if (distanceSize > MAX_DISTANCE_CODES) { + throw new MalformedInputException(reader.offset(), "Distance count is too large: " + distanceSize); + } + + short[] codeLengths = new short[19]; + for (int i = 0; i < codeSize; i++) { + codeLengths[CODE_LENGTHS_ORDER[i]] = (short) reader.bits(3); + } + + int[] codeCode = new int[388]; + int codeBits = buildCodeTable(CodeType.CODES, codeLengths, 0, 19, 7, codeCode); + + short[] lengths = new short[MAX_LENGTH_CODES + MAX_DISTANCE_CODES]; + + int index = 0; + while (index < (lengthSize + distanceSize)) { + int code = codeCode[reader.peek(codeBits)]; + reader.skip(extractBits(code)); + short value = extractValue(code); + + if (value < 16) { + lengths[index] = value; + index++; + continue; + } + + short length = 0; + int copy; + if (value == 16) { + if (index == 0) { + throw new MalformedInputException(reader.offset(), "No previous length for repeat"); + } + length = lengths[index - 1]; + copy = reader.bits(2) + 3; + } + else if (value == 17) { + copy = reader.bits(3) + 3; + } + else { + copy = reader.bits(7) + 11; + } + + if ((index + copy) > (lengthSize + distanceSize)) { + throw new MalformedInputException(reader.offset(), "Too many lengths for repeat"); + } + + while (copy > 0) { + lengths[index] = length; + index++; + copy--; + } + } + + if (lengths[256] == 0) { + throw new MalformedInputException(reader.offset(), "Missing end-of-block code"); + } + + int[] lengthCode = new int[ENOUGH_LENGTHS]; + int lengthBits = buildCodeTable(CodeType.LENGTHS, lengths, 0, lengthSize, 9, lengthCode); + + int[] distanceCode = new int[ENOUGH_DISTANCES]; + int distanceBits = buildCodeTable(CodeType.DISTANCES, lengths, lengthSize, distanceSize, 6, distanceCode); + + InflateTable table = new InflateTable(lengthCode, lengthBits, distanceCode, distanceBits); + + inflate(table, reader, writer); + } + + private static void inflate(InflateTable table, InputReader reader, OutputWriter writer) + { + int tableLengthBits = table.lengthBits; + int tableLengthMask = table.lengthMask; + int[] lengths = table.lengthCode; + + int tableDistanceBits = table.distanceBits; + int tableDistanceMask = table.distanceMask; + int[] distances = table.distanceCode; + + // decode literals and length/distances until end-of-block + while (true) { + int lengthIndex = reader.peek(tableLengthBits, tableLengthMask); + + while (true) { + int packedLength = lengths[lengthIndex]; + reader.skip(extractBits(packedLength)); + int lengthOp = extractOp(packedLength); + int length = extractValue(packedLength); + + if (lengthOp == 0) { + // literal + writer.writeByte(reader, (byte) length); + break; + } + + if ((lengthOp & 0b0001_0000) != 0) { + // length base + int lengthBits = lengthOp & 0b1111; + if (lengthBits > 0) { + length += reader.bits(lengthBits); + } + + int distanceIndex = reader.peek(tableDistanceBits, tableDistanceMask); + while (true) { + int packedDistance = distances[distanceIndex]; + reader.skip(extractBits(packedDistance)); + int distanceOp = extractOp(packedDistance); + int distance = extractValue(packedDistance); + + if ((distanceOp & 0b0001_0000) != 0) { + // distance base + int distanceBits = distanceOp & 0b1111; + if (distanceBits > 0) { + distance += reader.bits(distanceBits); + } + writer.copyOutput(reader, distance, length); + break; + } + + if ((distanceOp & INVALID_CODE) == 0) { + // second level distance code + distanceIndex = distance + reader.peek(distanceOp); + continue; + } + + throw new MalformedInputException(reader.offset(), "Invalid distance code"); + } + break; + } + + if ((lengthOp & INVALID_CODE) == 0) { + // second level length code + lengthIndex = length + reader.peek(lengthOp); + continue; + } + + if ((lengthOp & END_OF_BLOCK) != 0) { + // end-of-block + return; + } + + throw new MalformedInputException(reader.offset(), "Invalid length/literal code"); + } + } + } +} diff --git a/src/main/java/io/airlift/compress/zlib/InflateTables.java b/src/main/java/io/airlift/compress/zlib/InflateTables.java new file mode 100644 index 00000000..06d72e3a --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/InflateTables.java @@ -0,0 +1,338 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import static io.airlift.compress.zlib.InputReader.mask; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.Arrays.fill; + +// This implementation is based on zlib by Jean-loup Gailly and Mark Adler +final class InflateTables +{ + // Length codes 257..285 base + private static final short[] LENGTH_CODES_BASE = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0, + }; + + // Length codes 257..285 extra + private static final short[] LENGTH_CODES_EXTRA = { + 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, + 19, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 16, 194, 65, + }; + + // Distance codes 0..29 base + private static final short[] DISTANCE_CODES_BASE = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, + 8193, 12289, 16385, 24577, 0, 0, + }; + + // Distance codes 0..29 extra + private static final short[] DISTANCE_CODES_EXTRA = { + 16, 16, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, + 28, 28, 29, 29, 64, 64, + }; + + public static final int MAX_BITS = 15; // max bits in a code + public static final int MAX_COUNTS = MAX_BITS + 1; // max number of counts + + public static final int ENOUGH_LENGTHS = 852; // max size of lengths dynamic table + public static final int ENOUGH_DISTANCES = 592; // max size of distances dynamic table + + public static final byte END_OF_BLOCK = 0b0010_0000; + public static final byte INVALID_CODE = 0b0100_0000; + + public static final InflateTable FIXED_TABLE; + + static { + int[] lengthFixed = new int[512]; + int[] distanceFixed = new int[32]; + + short[] lengths = new short[288]; + fill(lengths, 0, 144, (short) 8); + fill(lengths, 144, 256, (short) 9); + fill(lengths, 256, 280, (short) 7); + fill(lengths, 280, 288, (short) 8); + buildCodeTable(CodeType.LENGTHS, lengths, 0, 288, 9, lengthFixed); + + lengths = new short[32]; + fill(lengths, (short) 5); + buildCodeTable(CodeType.DISTANCES, lengths, 0, 32, 5, distanceFixed); + + FIXED_TABLE = new InflateTable(lengthFixed, 9, distanceFixed, 5); + } + + private InflateTables() {} + + public static int buildCodeTable(CodeType type, short[] lengths, int offset, int codes, int bits, int[] table) + { + short[] counts = new short[MAX_COUNTS]; + short[] offsets = new short[MAX_COUNTS]; + + // accumulate lengths for codes -- assumes lengths all in [0, MAX_BITS] + for (int i = 0; i < codes; i++) { + counts[lengths[offset + i]]++; + } + + // bound code lengths, force root to be within code lengths + int root = bits; + int max; + for (max = MAX_BITS; max >= 1; max--) { + if (counts[max] != 0) { + break; + } + } + root = min(root, max); + + if (max == 0) { + // no symbols to code at all + // make a table to force an error + table[0] = tableEntry(INVALID_CODE, (byte) 1, (short) 0); + table[1] = table[0]; + + // no symbols, but wait for decoding to report error + return 1; + } + + int min; + for (min = 1; min < max; min++) { + if (counts[min] != 0) { + break; + } + } + root = max(root, min); + + // check for an over-subscribed or incomplete set of lengths + int left = 1; + for (int length = 1; length < MAX_COUNTS; length++) { + left <<= 1; + left -= counts[length]; + if (left < 0) { + throw new IllegalArgumentException("over-subscribed lengths"); + } + } + if ((left > 0) && ((type == CodeType.CODES) || (max != 1))) { + throw new IllegalArgumentException("incomplete set of lengths"); + } + + // generate offsets into symbol table for each length for sorting + offsets[1] = 0; + for (short length = 1; length < MAX_BITS; length++) { + offsets[length + 1] = (short) (offsets[length] + counts[length]); + } + + // sort symbols by length, by symbol order within each length + short[] symbols = new short[codes]; + for (short symbol = 0; symbol < codes; symbol++) { + short length = lengths[offset + symbol]; + if (length != 0) { + symbols[offsets[length]] = symbol; + offsets[length]++; + } + } + + // set up for code type + short[] base; + short[] extra; + int match; + switch (type) { + case CODES: + base = new short[0]; + extra = new short[0]; + match = 20; + break; + case LENGTHS: + base = LENGTH_CODES_BASE; + extra = LENGTH_CODES_EXTRA; + match = 257; + break; + case DISTANCES: + base = DISTANCE_CODES_BASE; + extra = DISTANCE_CODES_EXTRA; + match = 0; + break; + default: + throw new AssertionError(); + } + + // initialize state for loop + int huffman = 0; // starting code + int symbol = 0; // starting code symbol + int length = min; // starting code length + int next = 0; // current table to fill in + int current = root; // current table index bits + int drop = 0; // current bits to drop from code for index + int low = -1; // trigger new sub-table when length > root + int used = 1 << root; // used root table entries + int mask = used - 1; // mask for comparing low + + // check available table space + checkAvailableSpace(type, used); + + // process all codes and make table entries + while (true) { + // create table entry + byte hereBits = (byte) (length - drop); + byte hereOp; + short hereValue; + short value = symbols[symbol]; + if ((value + 1) < match) { + hereOp = 0; + hereValue = value; + } + else if (value >= match) { + hereOp = (byte) extra[value - match]; + hereValue = base[value - match]; + } + else { + hereOp = END_OF_BLOCK | INVALID_CODE; + hereValue = 0; + } + + // replicate for those indices with low length bits equal to huffman + int increment = 1 << (length - drop); + int fill = 1 << current; + min = fill; // save offset to next table + do { + fill -= increment; + int index = next + (huffman >> drop) + fill; + table[index] = tableEntry(hereOp, hereBits, hereValue); + } + while (fill != 0); + + // backwards increment the length-bit code huffman + increment = 1 << (length - 1); + while ((huffman & increment) != 0) { + increment >>= 1; + } + if (increment == 0) { + huffman = 0; + } + else { + huffman &= increment - 1; + huffman += increment; + } + + // go to next symbol, update count, length + symbol++; + counts[length]--; + if (counts[length] == 0) { + if (length == max) { + break; + } + length = lengths[offset + symbols[symbol]]; + } + + // create new sub-table if needed + if ((length > root) && ((huffman & mask) != low)) { + // if first time, transition to sub-tables + if (drop == 0) { + drop = root; + } + + // increment past last table (min is 1 << current) + next += min; + + // determine length of next table + current = length - drop; + left = 1 << current; + while ((current + drop) < max) { + left -= counts[current + drop]; + if (left <= 0) { + break; + } + current++; + left <<= 1; + } + + // check for enough space + used += 1 << current; + checkAvailableSpace(type, used); + + // point entry in root table to sub-table + low = huffman & mask; + table[low] = tableEntry((byte) current, (byte) root, (short) next); + } + } + + // fill in remaining table entry if code is incomplete (guaranteed to have + // at most one remaining entry, since if the code is incomplete, the + // maximum code length that was allowed to get this far is one bit) + if (huffman != 0) { + table[next + huffman] = tableEntry(INVALID_CODE, (byte) (length - drop), (short) 0); + } + + return root; + } + + private static void checkAvailableSpace(CodeType type, int used) + { + if ((type == CodeType.LENGTHS) && (used > ENOUGH_LENGTHS)) { + throw new IllegalArgumentException("too many lengths"); + } + if ((type == CodeType.DISTANCES) && (used > ENOUGH_DISTANCES)) { + throw new IllegalArgumentException("too many distances"); + } + } + + public static byte extractOp(int packed) + { + return (byte) ((packed >> 24) & 0xFF); + } + + public static byte extractBits(int packed) + { + return (byte) ((packed >> 16) & 0xFF); + } + + public static short extractValue(int packed) + { + return (short) (packed & 0xFFFF); + } + + private static int tableEntry(byte op, byte bits, short value) + { + return ((op & 0xFF) << 24) | ((bits & 0xFF) << 16) | (value & 0xFFFF); + } + + @SuppressWarnings("PublicField") + public static class InflateTable + { + public final int[] lengthCode; + public final int lengthBits; + public final int lengthMask; + public final int[] distanceCode; + public final int distanceBits; + public final int distanceMask; + + @SuppressWarnings("AssignmentOrReturnOfFieldWithMutableType") + public InflateTable(int[] lengthCode, int lengthBits, int[] distanceCode, int distanceBits) + { + this.lengthCode = lengthCode; + this.lengthBits = lengthBits; + this.lengthMask = mask(lengthBits); + this.distanceCode = distanceCode; + this.distanceBits = distanceBits; + this.distanceMask = mask(distanceBits); + } + } + + public enum CodeType + { + CODES, LENGTHS, DISTANCES + } +} diff --git a/src/main/java/io/airlift/compress/zlib/InputReader.java b/src/main/java/io/airlift/compress/zlib/InputReader.java new file mode 100644 index 00000000..5b8bc674 --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/InputReader.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import io.airlift.compress.MalformedInputException; + +import java.util.StringJoiner; + +import static io.airlift.compress.zlib.UnsafeUtil.UNSAFE; + +final class InputReader +{ + private final Object inputBase; + private final long inputAddress; + private final long inputLimit; + + private long inputPosition; + private int bitCount; + private int bitBuffer; + + public InputReader(Object inputBase, long inputAddress, long inputLimit) + { + checkArgument(inputAddress >= 0, "inputAddress is negative"); + checkArgument(inputAddress <= inputLimit, "inputAddress exceeds inputLimit"); + if (inputAddress == inputLimit) { + throw new MalformedInputException(0, "Input is empty"); + } + + this.inputBase = inputBase; + this.inputAddress = inputAddress; + this.inputLimit = inputLimit; + this.inputPosition = inputAddress; + } + + public int bits(int need) + { + int accumulator = bitBuffer; + + while (bitCount < need) { + int octet = readByte(); + accumulator |= octet << bitCount; + bitCount += 8; + } + + bitBuffer = accumulator >> need; + bitCount -= need; + + return accumulator & mask(need); + } + + public int peek(int need) + { + return peek(need, mask(need)); + } + + public int peek(int need, int mask) + { + while (bitCount < need) { + if (available() == 0) { + return bitBuffer & mask(bitCount); + } + bitBuffer |= readByte() << bitCount; + bitCount += 8; + } + + return bitBuffer & mask; + } + + public void skip(int need) + { + bits(need); + } + + public int readByte() + { + if (inputPosition >= inputLimit) { + throw new MalformedInputException(offset(), "Input is truncated"); + } + + int octet = UNSAFE.getByte(inputBase, inputPosition) & 0xFF; + inputPosition++; + + return octet; + } + + public long offset() + { + return inputPosition - inputAddress; + } + + public long available() + { + return inputLimit - inputPosition; + } + + public void clear() + { + if (bitCount >= 8) { + throw new MalformedInputException(offset(), "Too many partial bits: " + bitCount); + } + + bitCount = 0; + bitBuffer = 0; + } + + public void copyMemory(Object outputBase, long outputPosition, int length) + { + if (available() < length) { + throw new MalformedInputException(offset(), "Input is truncated"); + } + + UNSAFE.copyMemory(inputBase, inputPosition, outputBase, outputPosition, length); + inputPosition += length; + } + + @Override + public String toString() + { + return new StringJoiner(", ", getClass().getSimpleName() + "{", "}") + .add("offset=" + offset()) + .add("available=" + available()) + .toString(); + } + + public static int mask(int bits) + { + return (1 << bits) - 1; + } + + private static void checkArgument(boolean condition, String message) + { + if (!condition) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/src/main/java/io/airlift/compress/zlib/OutputWriter.java b/src/main/java/io/airlift/compress/zlib/OutputWriter.java new file mode 100644 index 00000000..0190c057 --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/OutputWriter.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import io.airlift.compress.MalformedInputException; + +import java.util.StringJoiner; + +import static io.airlift.compress.zlib.UnsafeUtil.UNSAFE; + +final class OutputWriter +{ + private static final int SIZE_OF_INT = 4; + private static final int SIZE_OF_LONG = 8; + + private static final int[] DEC_32_TABLE = {4, 1, 2, 1, 4, 4, 4, 4}; + private static final int[] DEC_64_TABLE = {0, 0, 0, -1, 0, 1, 2, 3}; + + private final Object outputBase; + private final long outputAddress; + private final long outputLimit; + private final long fastOutputLimit; + + private long outputPosition; + + public OutputWriter(Object outputBase, long outputAddress, long outputLimit) + { + checkArgument(outputAddress >= 0, "outputAddress is negative"); + checkArgument(outputAddress <= outputLimit, "outputAddress exceeds outputLimit"); + + this.outputBase = outputBase; + this.outputAddress = outputAddress; + this.outputLimit = outputLimit; + this.fastOutputLimit = outputLimit - SIZE_OF_LONG; + this.outputPosition = outputAddress; + } + + public long offset() + { + return outputPosition - outputAddress; + } + + public long available() + { + return outputLimit - outputPosition; + } + + public void writeByte(InputReader reader, byte value) + { + if (available() == 0) { + throw new MalformedInputException(reader.offset(), "Output buffer is too small"); + } + + UNSAFE.putByte(outputBase, outputPosition, value); + outputPosition++; + } + + public void copyInput(InputReader reader, int length) + { + if (length > available()) { + throw new MalformedInputException(reader.offset(), "Output buffer too small"); + } + + reader.copyMemory(outputBase, outputPosition, length); + outputPosition += length; + } + + public void copyOutput(InputReader reader, int distance, int length) + { + if (length > available()) { + throw new MalformedInputException(reader.offset(), "Output buffer too small"); + } + if (distance > offset()) { + throw new MalformedInputException(reader.offset(), "Distance is too far back"); + } + + long matchAddress = outputPosition - distance; + + if (distance >= length) { + UNSAFE.copyMemory(outputBase, matchAddress, outputBase, outputPosition, length); + outputPosition += length; + return; + } + + long matchOutputLimit = outputPosition + length; + + if (outputPosition > fastOutputLimit) { + // slow match copy + while (outputPosition < matchOutputLimit) { + UNSAFE.putByte(outputBase, outputPosition, UNSAFE.getByte(outputBase, matchAddress)); + matchAddress++; + outputPosition++; + } + return; + } + + // copy repeated sequence + if (distance < SIZE_OF_LONG) { + // 8 bytes apart so that we can copy long-at-a-time below + int increment32 = DEC_32_TABLE[distance]; + int decrement64 = DEC_64_TABLE[distance]; + + UNSAFE.putByte(outputBase, outputPosition, UNSAFE.getByte(outputBase, matchAddress)); + UNSAFE.putByte(outputBase, outputPosition + 1, UNSAFE.getByte(outputBase, matchAddress + 1)); + UNSAFE.putByte(outputBase, outputPosition + 2, UNSAFE.getByte(outputBase, matchAddress + 2)); + UNSAFE.putByte(outputBase, outputPosition + 3, UNSAFE.getByte(outputBase, matchAddress + 3)); + outputPosition += SIZE_OF_INT; + matchAddress += increment32; + + UNSAFE.putInt(outputBase, outputPosition, UNSAFE.getInt(outputBase, matchAddress)); + outputPosition += SIZE_OF_INT; + matchAddress -= decrement64; + } + else { + UNSAFE.putLong(outputBase, outputPosition, UNSAFE.getLong(outputBase, matchAddress)); + matchAddress += SIZE_OF_LONG; + outputPosition += SIZE_OF_LONG; + } + + if (matchOutputLimit > fastOutputLimit) { + while (outputPosition < fastOutputLimit) { + UNSAFE.putLong(outputBase, outputPosition, UNSAFE.getLong(outputBase, matchAddress)); + matchAddress += SIZE_OF_LONG; + outputPosition += SIZE_OF_LONG; + } + + while (outputPosition < matchOutputLimit) { + UNSAFE.putByte(outputBase, outputPosition, UNSAFE.getByte(outputBase, matchAddress)); + matchAddress++; + outputPosition++; + } + } + else { + while (outputPosition < matchOutputLimit) { + UNSAFE.putLong(outputBase, outputPosition, UNSAFE.getLong(outputBase, matchAddress)); + matchAddress += SIZE_OF_LONG; + outputPosition += SIZE_OF_LONG; + } + } + + // correction in case we over-copied + outputPosition = matchOutputLimit; + } + + @Override + public String toString() + { + return new StringJoiner(", ", getClass().getSimpleName() + "{", "}") + .add("offset=" + offset()) + .add("available=" + available()) + .toString(); + } + + private static void checkArgument(boolean condition, String message) + { + if (!condition) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/src/main/java/io/airlift/compress/zlib/UnsafeUtil.java b/src/main/java/io/airlift/compress/zlib/UnsafeUtil.java new file mode 100644 index 00000000..c3211152 --- /dev/null +++ b/src/main/java/io/airlift/compress/zlib/UnsafeUtil.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.compress.zlib; + +import io.airlift.compress.IncompatibleJvmException; +import sun.misc.Unsafe; + +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteOrder; + +import static java.lang.String.format; + +final class UnsafeUtil +{ + public static final Unsafe UNSAFE; + private static final long ADDRESS_OFFSET; + + private UnsafeUtil() {} + + static { + ByteOrder order = ByteOrder.nativeOrder(); + if (!order.equals(ByteOrder.LITTLE_ENDIAN)) { + throw new IncompatibleJvmException(format("zlib requires a little endian platform (found %s)", order)); + } + + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + UNSAFE = (Unsafe) theUnsafe.get(null); + } + catch (Exception e) { + throw new IncompatibleJvmException("zlib requires access to sun.misc.Unsafe"); + } + + try { + // fetch the address field for direct buffers + ADDRESS_OFFSET = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address")); + } + catch (NoSuchFieldException e) { + throw new IncompatibleJvmException("zlib requires access to java.nio.Buffer raw address field"); + } + } + + public static long getAddress(Buffer buffer) + { + if (!buffer.isDirect()) { + throw new IllegalArgumentException("buffer is not direct"); + } + + return UNSAFE.getLong(buffer, ADDRESS_OFFSET); + } +} diff --git a/src/test/java/io/airlift/compress/benchmark/Algorithm.java b/src/test/java/io/airlift/compress/benchmark/Algorithm.java index d76f48ac..77912458 100644 --- a/src/test/java/io/airlift/compress/benchmark/Algorithm.java +++ b/src/test/java/io/airlift/compress/benchmark/Algorithm.java @@ -38,6 +38,7 @@ import io.airlift.compress.thirdparty.XerialSnappyDecompressor; import io.airlift.compress.thirdparty.ZstdJniCompressor; import io.airlift.compress.thirdparty.ZstdJniDecompressor; +import io.airlift.compress.zlib.InflateDecompressor; import io.airlift.compress.zstd.ZstdCompressor; import io.airlift.compress.zstd.ZstdDecompressor; import net.jpountz.lz4.LZ4Factory; @@ -50,6 +51,7 @@ public enum Algorithm airlift_lz4(new Lz4Decompressor(), new Lz4Compressor()), airlift_snappy(new SnappyDecompressor(), new SnappyCompressor()), airlift_lzo(new LzoDecompressor(), new LzoCompressor()), + airlift_zlib(new InflateDecompressor(), new JdkDeflateCompressor()), airlift_zstd(new ZstdDecompressor(), new ZstdCompressor()), airlift_lz4_stream(new Lz4Codec(), new Lz4Compressor()), diff --git a/src/test/java/io/airlift/compress/zlib/TestZlib.java b/src/test/java/io/airlift/compress/zlib/TestZlib.java index 7430cd6d..c7d8aead 100644 --- a/src/test/java/io/airlift/compress/zlib/TestZlib.java +++ b/src/test/java/io/airlift/compress/zlib/TestZlib.java @@ -31,7 +31,7 @@ protected Compressor getCompressor() @Override protected Decompressor getDecompressor() { - return new JdkInflateDecompressor(); + return new InflateDecompressor(); } @Override