Commit e1d2ddd5 authored by Nick Terrell's avatar Nick Terrell Committed by Facebook Github Bot

Add zstd streaming interface

Summary:
* Add streaming interface to the `ZstdCodec`
* Implement `ZstdCodec::doCompress()` and `ZstdCodec::doUncompress()` using the streaming interface.
  [fbgs CodecType::ZSTD](https://fburl.com/pr8chg64) and check that no caller requires thread-safety.

Reviewed By: yfeldblum

Differential Revision: D5026558

fbshipit-source-id: 61faa25c71f5aef06ca2d7e0700f43214353c650
parent 74560278
......@@ -40,6 +40,7 @@
#endif
#if FOLLY_HAVE_LIBZSTD
#define ZSTD_STATIC_LINKING_ONLY
#include <zstd.h>
#endif
......@@ -1584,13 +1585,24 @@ std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(
#ifdef FOLLY_HAVE_LIBZSTD
namespace {
void zstdFreeCStream(ZSTD_CStream* zcs) {
ZSTD_freeCStream(zcs);
}
void zstdFreeDStream(ZSTD_DStream* zds) {
ZSTD_freeDStream(zds);
}
}
/**
* ZSTD compression
*/
class ZSTDCodec final : public Codec {
class ZSTDStreamCodec final : public StreamCodec {
public:
static std::unique_ptr<Codec> create(int level, CodecType);
explicit ZSTDCodec(int level, CodecType type);
static std::unique_ptr<Codec> createCodec(int level, CodecType);
static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
explicit ZSTDStreamCodec(int level, CodecType type);
std::vector<std::string> validPrefixes() const override;
bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
......@@ -1599,29 +1611,61 @@ class ZSTDCodec final : public Codec {
private:
bool doNeedsUncompressedLength() const override;
uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
std::unique_ptr<IOBuf> doUncompress(
const IOBuf* data,
Optional<uint64_t> uncompressedLength) override;
Optional<uint64_t> doGetUncompressedLength(
IOBuf const* data,
Optional<uint64_t> uncompressedLength) const override;
void doResetStream() override;
bool doCompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) override;
bool doUncompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) override;
void resetCStream();
void resetDStream();
bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
int level_;
bool needReset_{true};
std::unique_ptr<
ZSTD_CStream,
folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
cstream_{nullptr};
std::unique_ptr<
ZSTD_DStream,
folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
dstream_{nullptr};
};
static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
std::vector<std::string> ZSTDCodec::validPrefixes() const {
std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
return {prefixToStringLE(kZSTDMagicLE)};
}
bool ZSTDCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
const {
return dataStartsWithLE(data, kZSTDMagicLE);
}
std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
return std::make_unique<ZSTDCodec>(level, type);
std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
return make_unique<ZSTDStreamCodec>(level, type);
}
std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
int level,
CodecType type) {
return make_unique<ZSTDStreamCodec>(level, type);
}
ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
: StreamCodec(type) {
DCHECK(type == CodecType::ZSTD);
switch (level) {
case COMPRESSION_LEVEL_FASTEST:
......@@ -1641,11 +1685,12 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
level_ = level;
}
bool ZSTDCodec::doNeedsUncompressedLength() const {
bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
return false;
}
uint64_t ZSTDCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
uint64_t ZSTDStreamCodec::doMaxCompressedLength(
uint64_t uncompressedLength) const {
return ZSTD_compressBound(uncompressedLength);
}
......@@ -1657,163 +1702,158 @@ void zstdThrowIfError(size_t rc) {
to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
}
std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
// Support earlier versions of the codec (working with a single IOBuf,
// and using ZSTD_decompress which requires ZSTD frame to contain size,
// which isn't populated by streaming API).
if (!data->isChained()) {
auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
const auto rc = ZSTD_compress(
out->writableData(),
out->capacity(),
data->data(),
data->length(),
level_);
zstdThrowIfError(rc);
out->append(rc);
return out;
}
auto zcs = ZSTD_createCStream();
SCOPE_EXIT {
ZSTD_freeCStream(zcs);
};
auto rc = ZSTD_initCStream(zcs, level_);
zstdThrowIfError(rc);
Cursor cursor(data);
auto result =
IOBuf::createCombined(maxCompressedLength(cursor.totalLength()));
ZSTD_outBuffer out;
out.dst = result->writableTail();
out.size = result->capacity();
out.pos = 0;
for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
ZSTD_inBuffer in;
in.src = buffer.data();
in.size = buffer.size();
for (in.pos = 0; in.pos != in.size;) {
rc = ZSTD_compressStream(zcs, &out, &in);
zstdThrowIfError(rc);
Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
IOBuf const* data,
Optional<uint64_t> uncompressedLength) const {
// Read decompressed size from frame if available in first IOBuf.
auto const decompressedSize =
ZSTD_getDecompressedSize(data->data(), data->length());
if (decompressedSize != 0) {
if (uncompressedLength && *uncompressedLength != decompressedSize) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
}
cursor.skip(in.size);
buffer = cursor.peekBytes();
uncompressedLength = decompressedSize;
}
return uncompressedLength;
}
rc = ZSTD_endStream(zcs, &out);
zstdThrowIfError(rc);
CHECK_EQ(rc, 0);
void ZSTDStreamCodec::doResetStream() {
needReset_ = true;
}
result->append(out.pos);
return result;
bool ZSTDStreamCodec::tryBlockCompress(
ByteRange& input,
MutableByteRange& output) const {
DCHECK(needReset_);
// We need to know that we have enough output space to use block compression
if (output.size() < ZSTD_compressBound(input.size())) {
return false;
}
size_t const length = ZSTD_compress(
output.data(), output.size(), input.data(), input.size(), level_);
zstdThrowIfError(length);
input.uncheckedAdvance(input.size());
output.uncheckedAdvance(length);
return true;
}
static std::unique_ptr<IOBuf> zstdUncompressBuffer(
const IOBuf* data,
Optional<uint64_t> uncompressedLength) {
// Check preconditions
DCHECK(!data->isChained());
DCHECK(uncompressedLength.hasValue());
auto uncompressed = IOBuf::create(*uncompressedLength);
const auto decompressedSize = ZSTD_decompress(
uncompressed->writableTail(),
uncompressed->tailroom(),
data->data(),
data->length());
zstdThrowIfError(decompressedSize);
if (decompressedSize != uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
void ZSTDStreamCodec::resetCStream() {
if (!cstream_) {
cstream_.reset(ZSTD_createCStream());
if (!cstream_) {
throw std::bad_alloc{};
}
}
uncompressed->append(decompressedSize);
return uncompressed;
// Advanced API usage works for all supported versions of zstd.
// Required to set contentSizeFlag.
auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
params.fParams.contentSizeFlag = uncompressedLength().hasValue();
zstdThrowIfError(ZSTD_initCStream_advanced(
cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0)));
}
static std::unique_ptr<IOBuf> zstdUncompressStream(
const IOBuf* data,
Optional<uint64_t> uncompressedLength) {
auto zds = ZSTD_createDStream();
bool ZSTDStreamCodec::doCompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) {
if (needReset_) {
// If we are given all the input in one chunk try to use block compression
if (flushOp == StreamCodec::FlushOp::END &&
tryBlockCompress(input, output)) {
return true;
}
resetCStream();
needReset_ = false;
}
ZSTD_inBuffer in = {input.data(), input.size(), 0};
ZSTD_outBuffer out = {output.data(), output.size(), 0};
SCOPE_EXIT {
ZSTD_freeDStream(zds);
input.uncheckedAdvance(in.pos);
output.uncheckedAdvance(out.pos);
};
auto rc = ZSTD_initDStream(zds);
zstdThrowIfError(rc);
ZSTD_outBuffer out{};
ZSTD_inBuffer in{};
auto outputSize = uncompressedLength.value_or(ZSTD_DStreamOutSize());
IOBufQueue queue(IOBufQueue::cacheChainLength());
Cursor cursor(data);
for (rc = 0;;) {
if (in.pos == in.size) {
auto buffer = cursor.peekBytes();
in.src = buffer.data();
in.size = buffer.size();
in.pos = 0;
cursor.skip(in.size);
if (rc > 1 && in.size == 0) {
throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
}
}
if (out.pos == out.size) {
if (out.pos != 0) {
queue.postallocate(out.pos);
}
auto buffer = queue.preallocate(outputSize, outputSize);
out.dst = buffer.first;
out.size = buffer.second;
out.pos = 0;
outputSize = ZSTD_DStreamOutSize();
if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
}
if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
size_t rc;
switch (flushOp) {
case StreamCodec::FlushOp::FLUSH:
rc = ZSTD_flushStream(cstream_.get(), &out);
break;
case StreamCodec::FlushOp::END:
rc = ZSTD_endStream(cstream_.get(), &out);
break;
default:
throw std::invalid_argument("ZSTD: invalid FlushOp");
}
rc = ZSTD_decompressStream(zds, &out, &in);
zstdThrowIfError(rc);
if (rc == 0) {
break;
return true;
}
}
if (out.pos != 0) {
queue.postallocate(out.pos);
}
if (in.pos != in.size || !cursor.isAtEnd()) {
throw std::runtime_error("ZSTD: junk after end of data");
}
if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
}
return false;
}
return queue.move();
bool ZSTDStreamCodec::tryBlockUncompress(
ByteRange& input,
MutableByteRange& output) const {
DCHECK(needReset_);
#if ZSTD_VERSION_NUMBER < 10104
// We require ZSTD_findFrameCompressedSize() to perform this optimization.
return false;
#else
// We need to know the uncompressed length and have enough output space.
if (!uncompressedLength() || output.size() < *uncompressedLength()) {
return false;
}
size_t const compressedLength =
ZSTD_findFrameCompressedSize(input.data(), input.size());
zstdThrowIfError(compressedLength);
size_t const length = ZSTD_decompress(
output.data(), *uncompressedLength(), input.data(), compressedLength);
zstdThrowIfError(length);
DCHECK_EQ(length, *uncompressedLength());
input.uncheckedAdvance(compressedLength);
output.uncheckedAdvance(length);
return true;
#endif
}
std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
const IOBuf* data,
Optional<uint64_t> uncompressedLength) {
{
// Read decompressed size from frame if available in first IOBuf.
const auto decompressedSize =
ZSTD_getDecompressedSize(data->data(), data->length());
if (decompressedSize != 0) {
if (uncompressedLength && *uncompressedLength != decompressedSize) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
}
uncompressedLength = decompressedSize;
void ZSTDStreamCodec::resetDStream() {
if (!dstream_) {
dstream_.reset(ZSTD_createDStream());
if (!dstream_) {
throw std::bad_alloc{};
}
}
// Faster to decompress using ZSTD_decompress() if we can.
if (uncompressedLength && !data->isChained()) {
return zstdUncompressBuffer(data, uncompressedLength);
zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
}
bool ZSTDStreamCodec::doUncompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) {
if (needReset_) {
// If we are given all the input in one chunk try to use block uncompression
if (flushOp == StreamCodec::FlushOp::END &&
tryBlockUncompress(input, output)) {
return true;
}
resetDStream();
needReset_ = false;
}
// Fall back to slower streaming decompression.
return zstdUncompressStream(data, uncompressedLength);
ZSTD_inBuffer in = {input.data(), input.size(), 0};
ZSTD_outBuffer out = {output.data(), output.size(), 0};
SCOPE_EXIT {
input.uncheckedAdvance(in.pos);
output.uncheckedAdvance(out.pos);
};
size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
zstdThrowIfError(rc);
return rc == 0;
}
#endif // FOLLY_HAVE_LIBZSTD
#endif // FOLLY_HAVE_LIBZSTD
#if FOLLY_HAVE_LIBBZ2
......@@ -2229,7 +2269,7 @@ constexpr Factory
#endif
#if FOLLY_HAVE_LIBZSTD
{ZSTDCodec::create, nullptr},
{ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
#else
{},
#endif
......
......@@ -34,6 +34,10 @@
#include <folly/io/IOBufQueue.h>
#include <folly/portability/GTest.h>
#if FOLLY_HAVE_LIBZSTD
#include <zstd.h>
#endif
namespace folly { namespace io { namespace test {
class DataHolder : private boost::noncopyable {
......@@ -1084,6 +1088,31 @@ TEST(CheckCompatibleTest, ZlibIsPrefix) {
EXPECT_THROW_IF_DEBUG(
getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
}
#if FOLLY_HAVE_LIBZSTD
TEST(ZstdTest, BackwardCompatible) {
auto codec = getCodec(CodecType::ZSTD);
{
auto const data = IOBuf::wrapBuffer(randomDataHolder.data(size_t(1) << 20));
auto compressed = codec->compress(data.get());
compressed->coalesce();
EXPECT_EQ(
data->length(),
ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
}
{
auto const data =
IOBuf::wrapBuffer(randomDataHolder.data(size_t(100) << 20));
auto compressed = codec->compress(data.get());
compressed->coalesce();
EXPECT_EQ(
data->length(),
ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
}
}
#endif
}}} // namespaces
int main(int argc, char *argv[]) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment