265 lines
7.5 KiB
C++
265 lines
7.5 KiB
C++
|
//
|
||
|
// Copyright (C) 2020 The Android Open Source Project
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
//
|
||
|
|
||
|
#include "cow_decompress.h"
|
||
|
|
||
|
#include <utility>
|
||
|
|
||
|
#include <android-base/logging.h>
|
||
|
#include <brotli/decode.h>
|
||
|
#include <zlib.h>
|
||
|
|
||
|
namespace android {
|
||
|
namespace snapshot {
|
||
|
|
||
|
class NoDecompressor final : public IDecompressor {
|
||
|
public:
|
||
|
bool Decompress(size_t) override;
|
||
|
};
|
||
|
|
||
|
bool NoDecompressor::Decompress(size_t) {
|
||
|
size_t stream_remaining = stream_->Size();
|
||
|
while (stream_remaining) {
|
||
|
size_t buffer_size = stream_remaining;
|
||
|
uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
|
||
|
if (!buffer) {
|
||
|
LOG(ERROR) << "Could not acquire buffer from sink";
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Read until we can fill the buffer.
|
||
|
uint8_t* buffer_pos = buffer;
|
||
|
size_t bytes_to_read = std::min(buffer_size, stream_remaining);
|
||
|
while (bytes_to_read) {
|
||
|
size_t read;
|
||
|
if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
|
||
|
return false;
|
||
|
}
|
||
|
if (!read) {
|
||
|
LOG(ERROR) << "Stream ended prematurely";
|
||
|
return false;
|
||
|
}
|
||
|
if (!sink_->ReturnData(buffer_pos, read)) {
|
||
|
LOG(ERROR) << "Could not return buffer to sink";
|
||
|
return false;
|
||
|
}
|
||
|
buffer_pos += read;
|
||
|
bytes_to_read -= read;
|
||
|
stream_remaining -= read;
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
|
||
|
return std::unique_ptr<IDecompressor>(new NoDecompressor());
|
||
|
}
|
||
|
|
||
|
// Read chunks of the COW and incrementally stream them to the decoder.
|
||
|
class StreamDecompressor : public IDecompressor {
|
||
|
public:
|
||
|
bool Decompress(size_t output_bytes) override;
|
||
|
|
||
|
virtual bool Init() = 0;
|
||
|
virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
|
||
|
virtual bool Done() = 0;
|
||
|
|
||
|
protected:
|
||
|
bool GetFreshBuffer();
|
||
|
|
||
|
size_t output_bytes_;
|
||
|
size_t stream_remaining_;
|
||
|
uint8_t* output_buffer_ = nullptr;
|
||
|
size_t output_buffer_remaining_ = 0;
|
||
|
};
|
||
|
|
||
|
static constexpr size_t kChunkSize = 4096;
|
||
|
|
||
|
bool StreamDecompressor::Decompress(size_t output_bytes) {
|
||
|
if (!Init()) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
stream_remaining_ = stream_->Size();
|
||
|
output_bytes_ = output_bytes;
|
||
|
|
||
|
uint8_t chunk[kChunkSize];
|
||
|
while (stream_remaining_) {
|
||
|
size_t read = std::min(stream_remaining_, sizeof(chunk));
|
||
|
if (!stream_->Read(chunk, read, &read)) {
|
||
|
return false;
|
||
|
}
|
||
|
if (!read) {
|
||
|
LOG(ERROR) << "Stream ended prematurely";
|
||
|
return false;
|
||
|
}
|
||
|
if (!DecompressInput(chunk, read)) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
stream_remaining_ -= read;
|
||
|
|
||
|
if (stream_remaining_ && Done()) {
|
||
|
LOG(ERROR) << "Decompressor terminated early";
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
if (!Done()) {
|
||
|
LOG(ERROR) << "Decompressor expected more bytes";
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool StreamDecompressor::GetFreshBuffer() {
|
||
|
size_t request_size = std::min(output_bytes_, kChunkSize);
|
||
|
output_buffer_ =
|
||
|
reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
|
||
|
if (!output_buffer_) {
|
||
|
LOG(ERROR) << "Could not acquire buffer from sink";
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
class GzDecompressor final : public StreamDecompressor {
|
||
|
public:
|
||
|
~GzDecompressor();
|
||
|
|
||
|
bool Init() override;
|
||
|
bool DecompressInput(const uint8_t* data, size_t length) override;
|
||
|
bool Done() override { return ended_; }
|
||
|
|
||
|
private:
|
||
|
z_stream z_ = {};
|
||
|
bool ended_ = false;
|
||
|
};
|
||
|
|
||
|
bool GzDecompressor::Init() {
|
||
|
if (int rv = inflateInit(&z_); rv != Z_OK) {
|
||
|
LOG(ERROR) << "inflateInit returned error code " << rv;
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
GzDecompressor::~GzDecompressor() {
|
||
|
inflateEnd(&z_);
|
||
|
}
|
||
|
|
||
|
bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
|
||
|
z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
|
||
|
z_.avail_in = length;
|
||
|
|
||
|
while (z_.avail_in) {
|
||
|
// If no more output buffer, grab a new buffer.
|
||
|
if (z_.avail_out == 0) {
|
||
|
if (!GetFreshBuffer()) {
|
||
|
return false;
|
||
|
}
|
||
|
z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
|
||
|
z_.avail_out = output_buffer_remaining_;
|
||
|
}
|
||
|
|
||
|
// Remember the position of the output buffer so we can call ReturnData.
|
||
|
auto avail_out = z_.avail_out;
|
||
|
|
||
|
// Decompress.
|
||
|
int rv = inflate(&z_, Z_NO_FLUSH);
|
||
|
if (rv != Z_OK && rv != Z_STREAM_END) {
|
||
|
LOG(ERROR) << "inflate returned error code " << rv;
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
size_t returned = avail_out - z_.avail_out;
|
||
|
if (!sink_->ReturnData(output_buffer_, returned)) {
|
||
|
LOG(ERROR) << "Could not return buffer to sink";
|
||
|
return false;
|
||
|
}
|
||
|
output_buffer_ += returned;
|
||
|
output_buffer_remaining_ -= returned;
|
||
|
|
||
|
if (rv == Z_STREAM_END) {
|
||
|
if (z_.avail_in) {
|
||
|
LOG(ERROR) << "Gz stream ended prematurely";
|
||
|
return false;
|
||
|
}
|
||
|
ended_ = true;
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<IDecompressor> IDecompressor::Gz() {
|
||
|
return std::unique_ptr<IDecompressor>(new GzDecompressor());
|
||
|
}
|
||
|
|
||
|
class BrotliDecompressor final : public StreamDecompressor {
|
||
|
public:
|
||
|
~BrotliDecompressor();
|
||
|
|
||
|
bool Init() override;
|
||
|
bool DecompressInput(const uint8_t* data, size_t length) override;
|
||
|
bool Done() override { return BrotliDecoderIsFinished(decoder_); }
|
||
|
|
||
|
private:
|
||
|
BrotliDecoderState* decoder_ = nullptr;
|
||
|
};
|
||
|
|
||
|
bool BrotliDecompressor::Init() {
|
||
|
decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
BrotliDecompressor::~BrotliDecompressor() {
|
||
|
if (decoder_) {
|
||
|
BrotliDecoderDestroyInstance(decoder_);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) {
|
||
|
size_t available_in = length;
|
||
|
const uint8_t* next_in = data;
|
||
|
|
||
|
bool needs_more_output = false;
|
||
|
while (available_in || needs_more_output) {
|
||
|
if (!output_buffer_remaining_ && !GetFreshBuffer()) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
auto output_buffer = output_buffer_;
|
||
|
auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
|
||
|
&output_buffer_remaining_, &output_buffer_, nullptr);
|
||
|
if (r == BROTLI_DECODER_RESULT_ERROR) {
|
||
|
LOG(ERROR) << "brotli decode failed";
|
||
|
return false;
|
||
|
}
|
||
|
if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) {
|
||
|
return false;
|
||
|
}
|
||
|
needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT);
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
|
||
|
return std::unique_ptr<IDecompressor>(new BrotliDecompressor());
|
||
|
}
|
||
|
|
||
|
} // namespace snapshot
|
||
|
} // namespace android
|