This is an automated email from the ASF dual-hosted git repository. martinzink pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 97011df5a63e739b922212651672d0de9e709076 Author: Gabor Gyimesi <[email protected]> AuthorDate: Wed Oct 15 14:04:49 2025 +0200 MINIFICPP-2565 Add compression support for site to site communication Depends on #1966 Closes #1974 Signed-off-by: Martin Zink <[email protected]> --- SITE_TO_SITE.md | 8 +- core-framework/include/io/CRCStream.h | 4 + core-framework/src/io/ZlibStream.cpp | 2 +- .../features/MiNiFi_integration_test_driver.py | 6 +- docker/test/integration/features/s2s.feature | 85 ++++++++ docker/test/integration/features/steps/steps.py | 14 ++ docker/test/integration/minifi/core/InputPort.py | 4 + docker/test/integration/minifi/core/OutputPort.py | 4 + .../Minifi_flow_json_serializer.py | 1 + .../Minifi_flow_yaml_serializer.py | 1 + libminifi/include/sitetosite/CompressionConsts.h | 28 +++ .../include/sitetosite/CompressionInputStream.h | 53 +++++ .../include/sitetosite/CompressionOutputStream.h | 48 +++++ libminifi/include/sitetosite/SiteToSiteClient.h | 6 +- .../src/sitetosite/CompressionInputStream.cpp | 154 +++++++++++++++ .../src/sitetosite/CompressionOutputStream.cpp | 139 +++++++++++++ libminifi/src/sitetosite/HttpSiteToSiteClient.cpp | 5 +- libminifi/src/sitetosite/RawSiteToSiteClient.cpp | 3 +- libminifi/src/sitetosite/SiteToSiteClient.cpp | 115 ++++++++--- .../test/unit/SiteToSiteCompressionStreamTests.cpp | 215 +++++++++++++++++++++ libminifi/test/unit/SiteToSiteTests.cpp | 174 ++++++++++++----- 21 files changed, 976 insertions(+), 93 deletions(-) diff --git a/SITE_TO_SITE.md b/SITE_TO_SITE.md index 0869a3acc..830aa5ff7 100644 --- a/SITE_TO_SITE.md +++ b/SITE_TO_SITE.md @@ -72,7 +72,7 @@ Remote Process Groups: - id: de7cc09a-0196-1000-2c63-ee6b4319ffb6 # this is the instance id of the input port created in NiFi name: nifi-inputport max concurrent tasks: 1 - use compression: false # currently not supported and ignored in MiNiFi C++ + use compression: true batch size: size: 10 MB count: 10 @@ -112,6 +112,11 @@ Remote Processing Groups: id: 22d38f35-4d25-4e68-878c-f46f46d5781c max concurrent tasks: 1 name: from_nifi + use compression: true + batch size: + size: 10 MB + count: 10 + duration: 30 sec id: 20ed42b0-d41e-4add-9e6d-8777223370b8 name: RemoteProcessGroup timeout: 30 sec @@ -123,7 +128,6 @@ Notes on the configuration: - In the MiNiFi C++ configuration, in yaml configuration the remote input and output ports' `id` field, and in json configuration the ports' `identifier`, `instanceIdentifier`, and `targetId` fields should be set to the instance id of the input and output ports created in NiFi (`de7cc09a-0196-1000-2c63-ee6b4319ffb6` in the examples). - Connections from the remote output port to the processor should use the `undefined` relationship -- `useCompression` can be set, but it is currently not supported in MiNiFi C++ so it will be set to false in the site-to-site messages - the `url` field (`targetUri` or `targetUris` in JSON) field in the remote process group should be set to the NiFi instance's URL, this can also use comma separated list of URLs if the remote process group is configured to use multiple NiFi nodes ## Additional examples diff --git a/core-framework/include/io/CRCStream.h b/core-framework/include/io/CRCStream.h index 3cf1d844f..50fc877be 100644 --- a/core-framework/include/io/CRCStream.h +++ b/core-framework/include/io/CRCStream.h @@ -66,6 +66,10 @@ class CRCStreamBase : public virtual StreamImpl { crc_ = crc32(0L, Z_NULL, 0); } + void setCrc(uint64_t crc) { + crc_ = gsl::narrow<uLong>(crc); + } + protected: uLong crc_ = 0; StreamType* child_stream_ = nullptr; diff --git a/core-framework/src/io/ZlibStream.cpp b/core-framework/src/io/ZlibStream.cpp index 843c92140..9d8912c7a 100644 --- a/core-framework/src/io/ZlibStream.cpp +++ b/core-framework/src/io/ZlibStream.cpp @@ -181,7 +181,7 @@ size_t ZlibDecompressStream::write(const uint8_t* value, size_t size) { return STREAM_ERROR; } const auto output_size = outputBuffer_.size() - strm_.avail_out; - logger_->log_trace("deflate produced {} B of output data", output_size); + logger_->log_trace("inflate produced {} B of output data", output_size); if (output_->write(gsl::make_span(outputBuffer_).subspan(0, output_size)) != output_size) { logger_->log_error("Failed to write to underlying stream"); state_ = ZlibStreamState::ERRORED; diff --git a/docker/test/integration/features/MiNiFi_integration_test_driver.py b/docker/test/integration/features/MiNiFi_integration_test_driver.py index 32f4f709f..455a8b5ce 100644 --- a/docker/test/integration/features/MiNiFi_integration_test_driver.py +++ b/docker/test/integration/features/MiNiFi_integration_test_driver.py @@ -155,17 +155,19 @@ class MiNiFi_integration_test: raise Exception("Trying to fetch unknown node: \"%s\"" % name) @staticmethod - def generate_input_port_for_remote_process_group(remote_process_group, name): + def generate_input_port_for_remote_process_group(remote_process_group, name, use_compression=False): input_port_node = InputPort(name, remote_process_group) # Generate an MD5 hash unique to the remote process group id input_port_node.set_uuid(uuid.uuid3(remote_process_group.get_uuid(), "input_port")) + input_port_node.set_use_compression(use_compression) return input_port_node @staticmethod - def generate_output_port_for_remote_process_group(remote_process_group, name): + def generate_output_port_for_remote_process_group(remote_process_group, name, use_compression=False): output_port_node = OutputPort(name, remote_process_group) # Generate an MD5 hash unique to the remote process group id output_port_node.set_uuid(uuid.uuid3(remote_process_group.get_uuid(), "output_port")) + output_port_node.set_use_compression(use_compression) return output_port_node def add_test_data(self, path, test_data, file_name=None): diff --git a/docker/test/integration/features/s2s.feature b/docker/test/integration/features/s2s.feature index cf3a38988..9b3321d4f 100644 --- a/docker/test/integration/features/s2s.feature +++ b/docker/test/integration/features/s2s.feature @@ -285,3 +285,88 @@ Feature: Sending data from MiNiFi-C++ to NiFi using S2S protocol Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds + + Scenario: A MiNiFi instance produces and transfers data to a NiFi instance via s2s using compression + Given a GetFile processor with the "Input Directory" property set to "/tmp/input" + And a file with the content "test" is present in "/tmp/input" + And a RemoteProcessGroup node with name "RemoteProcessGroup" is opened on "http://nifi-${feature_id}:8080/nifi" + And an input port using compression with name "to_nifi" is created on the RemoteProcessGroup named "RemoteProcessGroup" + And the "success" relationship of the GetFile processor is connected to the to_nifi + + And a NiFi flow is receiving data from the RemoteProcessGroup named "RemoteProcessGroup" in an input port named "from-minifi" which has the same id as the port named "to_nifi" + And a PutFile processor with the "Directory" property set to "/tmp/output" in the "nifi" flow + And the "success" relationship of the from-minifi is connected to the PutFile + + When both instances start up + Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds + And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds + + Scenario: A MiNiFi instance produces and transfers data to a NiFi instance via s2s using compression in YAML config + Given a MiNiFi CPP server with yaml config + And a GetFile processor with the "Input Directory" property set to "/tmp/input" + And a file with the content "test" is present in "/tmp/input" + And a RemoteProcessGroup node with name "RemoteProcessGroup" is opened on "http://nifi-${feature_id}:8080/nifi" + And an input port using compression with name "to_nifi" is created on the RemoteProcessGroup named "RemoteProcessGroup" + And the "success" relationship of the GetFile processor is connected to the to_nifi + + And a NiFi flow is receiving data from the RemoteProcessGroup named "RemoteProcessGroup" in an input port named "from-minifi" which has the same id as the port named "to_nifi" + And a PutFile processor with the "Directory" property set to "/tmp/output" in the "nifi" flow + And the "success" relationship of the from-minifi is connected to the PutFile + + When both instances start up + Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds + And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds + + Scenario: A NiFi instance produces and transfers data to a MiNiFi instance via s2s using compression + Given a file with the content "test" is present in "/tmp/input" + And a RemoteProcessGroup node with name "RemoteProcessGroup" is opened on "http://nifi-${feature_id}:8080/nifi" + And an output port using compression with name "from_nifi" is created on the RemoteProcessGroup named "RemoteProcessGroup" + And "from_nifi" port is a start node + And a PutFile processor with the "Directory" property set to "/tmp/output" + And the output port "from_nifi" is connected to the PutFile processor + + And a GetFile processor with the "Input Directory" property set to "/tmp/input" in the "nifi" flow using the "nifi" engine + And a NiFi flow is sending data to an output port named "to-minifi-in-nifi" with the id of the port named "from_nifi" from the RemoteProcessGroup named "RemoteProcessGroup" + And the "success" relationship of the GetFile is connected to the to-minifi-in-nifi + + When both instances start up + + Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds + And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds + + Scenario: A NiFi instance produces and transfers data to a MiNiFi instance via s2s using compression and HTTP protocol + Given a file with the content "test" is present in "/tmp/input" + And a RemoteProcessGroup node with name "RemoteProcessGroup" is opened on "http://nifi-${feature_id}:8080/nifi" with transport protocol set to "HTTP" + And an output port using compression with name "from_nifi" is created on the RemoteProcessGroup named "RemoteProcessGroup" + And "from_nifi" port is a start node + And a PutFile processor with the "Directory" property set to "/tmp/output" + And the output port "from_nifi" is connected to the PutFile processor + + And a GetFile processor with the "Input Directory" property set to "/tmp/input" in the "nifi" flow using the "nifi" engine + And a NiFi flow is sending data to an output port named "to-minifi-in-nifi" with the id of the port named "from_nifi" from the RemoteProcessGroup named "RemoteProcessGroup" + And the "success" relationship of the GetFile is connected to the to-minifi-in-nifi + + When both instances start up + + Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds + And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds + + Scenario: A NiFi instance produces and transfers data to a MiNiFi instance via s2s using compression with YAML config and SSL config defined in minifi.properties + Given a MiNiFi CPP server with yaml config + And a file with the content "test" is present in "/tmp/input" + And a RemoteProcessGroup node with name "RemoteProcessGroup" is opened on "https://nifi-${feature_id}:8443/nifi" + And an output port using compression with name "from_nifi" is created on the RemoteProcessGroup named "RemoteProcessGroup" + And "from_nifi" port is a start node + And a PutFile processor with the "Directory" property set to "/tmp/output" + And the output port "from_nifi" is connected to the PutFile processor + And SSL properties are set in MiNiFi + + And SSL is enabled in NiFi flow + And a GetFile processor with the "Input Directory" property set to "/tmp/input" in the "nifi" flow using the "nifi" engine + And a NiFi flow is sending data to an output port named "to-minifi-in-nifi" with the id of the port named "from_nifi" from the RemoteProcessGroup named "RemoteProcessGroup" + And the "success" relationship of the GetFile is connected to the to-minifi-in-nifi + + When both instances start up + + Then a flowfile with the content "test" is placed in the monitored directory in less than 90 seconds + And the Minifi logs do not contain the following message: "ProcessSession rollback" after 1 seconds diff --git a/docker/test/integration/features/steps/steps.py b/docker/test/integration/features/steps/steps.py index c7bb827d0..a0b1b532f 100644 --- a/docker/test/integration/features/steps/steps.py +++ b/docker/test/integration/features/steps/steps.py @@ -229,6 +229,13 @@ def step_impl(context, port_name, rpg_name): context.test.add_node(input_port_node) +@given("an input port using compression with name \"{port_name}\" is created on the RemoteProcessGroup named \"{rpg_name}\"") +def step_impl(context, port_name, rpg_name): + remote_process_group = context.test.get_remote_process_group_by_name(rpg_name) + input_port_node = context.test.generate_input_port_for_remote_process_group(remote_process_group, port_name, True) + context.test.add_node(input_port_node) + + @given("an output port with name \"{port_name}\" is created on the RemoteProcessGroup named \"{rpg_name}\"") def step_impl(context, port_name, rpg_name): remote_process_group = context.test.get_remote_process_group_by_name(rpg_name) @@ -236,6 +243,13 @@ def step_impl(context, port_name, rpg_name): context.test.add_node(input_port_node) +@given("an output port using compression with name \"{port_name}\" is created on the RemoteProcessGroup named \"{rpg_name}\"") +def step_impl(context, port_name, rpg_name): + remote_process_group = context.test.get_remote_process_group_by_name(rpg_name) + input_port_node = context.test.generate_output_port_for_remote_process_group(remote_process_group, port_name, True) + context.test.add_node(input_port_node) + + @given("the output port \"{port_name}\" is connected to the {destination_name} processor") def step_impl(context, port_name, destination_name): destination = context.test.get_node_by_name(destination_name) diff --git a/docker/test/integration/minifi/core/InputPort.py b/docker/test/integration/minifi/core/InputPort.py index e37320a5f..968ef76dc 100644 --- a/docker/test/integration/minifi/core/InputPort.py +++ b/docker/test/integration/minifi/core/InputPort.py @@ -22,9 +22,13 @@ class InputPort(Connectable): super(InputPort, self).__init__(name=name) self.remote_process_group = remote_process_group + self.use_compression = False self.properties = {} if self.remote_process_group: self.properties = self.remote_process_group.properties def id_for_connection(self): return self.instance_id + + def set_use_compression(self, use_compression: bool): + self.use_compression = use_compression diff --git a/docker/test/integration/minifi/core/OutputPort.py b/docker/test/integration/minifi/core/OutputPort.py index c42255444..5186ac330 100644 --- a/docker/test/integration/minifi/core/OutputPort.py +++ b/docker/test/integration/minifi/core/OutputPort.py @@ -22,9 +22,13 @@ class OutputPort(Connectable): super(OutputPort, self).__init__(name=name) self.remote_process_group = remote_process_group + self.use_compression = False self.properties = {} if self.remote_process_group: self.properties = self.remote_process_group.properties def id_for_connection(self): return self.instance_id + + def set_use_compression(self, use_compression: bool): + self.use_compression = use_compression diff --git a/docker/test/integration/minifi/flow_serialization/Minifi_flow_json_serializer.py b/docker/test/integration/minifi/flow_serialization/Minifi_flow_json_serializer.py index ed260cef3..4a64a982e 100644 --- a/docker/test/integration/minifi/flow_serialization/Minifi_flow_json_serializer.py +++ b/docker/test/integration/minifi/flow_serialization/Minifi_flow_json_serializer.py @@ -106,6 +106,7 @@ class Minifi_flow_json_serializer: res_group['inputPorts'].append({ 'identifier': str(connectable.instance_id), 'name': connectable.name, + 'useCompression': connectable.use_compression, 'properties': connectable.properties }) diff --git a/docker/test/integration/minifi/flow_serialization/Minifi_flow_yaml_serializer.py b/docker/test/integration/minifi/flow_serialization/Minifi_flow_yaml_serializer.py index a2e1b6997..4c951a051 100644 --- a/docker/test/integration/minifi/flow_serialization/Minifi_flow_yaml_serializer.py +++ b/docker/test/integration/minifi/flow_serialization/Minifi_flow_yaml_serializer.py @@ -107,6 +107,7 @@ class Minifi_flow_yaml_serializer: res_group['Input Ports'].append({ 'id': str(connectable.instance_id), 'name': connectable.name, + 'use compression': connectable.use_compression, 'max concurrent tasks': 1, 'Properties': connectable.properties }) diff --git a/libminifi/include/sitetosite/CompressionConsts.h b/libminifi/include/sitetosite/CompressionConsts.h new file mode 100644 index 000000000..a962e80a6 --- /dev/null +++ b/libminifi/include/sitetosite/CompressionConsts.h @@ -0,0 +1,28 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "io/InputStream.h" +#include "io/BufferStream.h" + +namespace org::apache::nifi::minifi::sitetosite { + +inline constexpr size_t COMPRESSION_BUFFER_SIZE = 65536; +inline constexpr std::array<char, 4> SYNC_BYTES = { 'S', 'Y', 'N', 'C' }; + +} // namespace org::apache::nifi::minifi::sitetosite diff --git a/libminifi/include/sitetosite/CompressionInputStream.h b/libminifi/include/sitetosite/CompressionInputStream.h new file mode 100644 index 000000000..7985a6d58 --- /dev/null +++ b/libminifi/include/sitetosite/CompressionInputStream.h @@ -0,0 +1,53 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "io/InputStream.h" +#include "io/BufferStream.h" +#include "CompressionConsts.h" +#include "core/logging/LoggerFactory.h" + +namespace org::apache::nifi::minifi::sitetosite { + +class CompressionInputStream : public io::InputStreamImpl { + public: + explicit CompressionInputStream(io::InputStream& internal_stream) + : internal_stream_(internal_stream) { + } + + using io::InputStream::read; + size_t read(std::span<std::byte> out_buffer) override; + void close() override; + void resetBuffer() { + buffer_offset_ = 0; + buffered_data_length_ = 0; + eof_ = false; + } + + private: + size_t decompressData(); + + bool eof_{false}; + io::InputStream& internal_stream_; + std::vector<std::byte> buffer_{COMPRESSION_BUFFER_SIZE}; + size_t buffer_offset_{0}; + size_t buffered_data_length_{0}; + std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<CompressionInputStream>::getLogger(); +}; + +} // namespace org::apache::nifi::minifi::sitetosite diff --git a/libminifi/include/sitetosite/CompressionOutputStream.h b/libminifi/include/sitetosite/CompressionOutputStream.h new file mode 100644 index 000000000..225b83728 --- /dev/null +++ b/libminifi/include/sitetosite/CompressionOutputStream.h @@ -0,0 +1,48 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "io/OutputStream.h" +#include "io/BaseStream.h" +#include "CompressionConsts.h" +#include "core/logging/LoggerFactory.h" + +namespace org::apache::nifi::minifi::sitetosite { + +class CompressionOutputStream : public io::StreamImpl, public virtual io::OutputStreamImpl { + public: + explicit CompressionOutputStream(io::OutputStream& internal_stream) + : internal_stream_(internal_stream) { + } + + using io::OutputStream::write; + size_t write(const uint8_t *value, size_t len) override; + void close() override; + void flush(); + + private: + size_t compressAndWrite(); + + bool was_data_written_{false}; + size_t buffer_offset_{0}; + io::OutputStream& internal_stream_; + std::vector<std::byte> buffer_{COMPRESSION_BUFFER_SIZE}; + std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<CompressionOutputStream>::getLogger(); +}; + +} // namespace org::apache::nifi::minifi::sitetosite diff --git a/libminifi/include/sitetosite/SiteToSiteClient.h b/libminifi/include/sitetosite/SiteToSiteClient.h index cfbb59a26..a947671c5 100644 --- a/libminifi/include/sitetosite/SiteToSiteClient.h +++ b/libminifi/include/sitetosite/SiteToSiteClient.h @@ -137,7 +137,7 @@ class SiteToSiteClient { virtual bool writeResponse(const std::shared_ptr<Transaction> &transaction, const SiteToSiteResponse& response); bool initializeSend(const std::shared_ptr<Transaction>& transaction); - bool writeAttributesInSendTransaction(const std::shared_ptr<Transaction>& transaction, const std::map<std::string, std::string>& attributes); + bool writeAttributesInSendTransaction(io::OutputStream& stream, const std::string& transaction_id_str, const std::map<std::string, std::string>& attributes); void finalizeSendTransaction(const std::shared_ptr<Transaction>& transaction, uint64_t sent_bytes); bool sendPacket(const DataPacket& packet); bool sendFlowFile(const std::shared_ptr<Transaction>& transaction, core::FlowFile& flow_file, core::ProcessSession& session); @@ -187,8 +187,8 @@ class SiteToSiteClient { bool completeReceive(const std::shared_ptr<Transaction>& transaction, const utils::Identifier& transaction_id); bool completeSend(const std::shared_ptr<Transaction>& transaction, const utils::Identifier& transaction_id, core::ProcessContext& context); - bool readFlowFileHeaderData(const std::shared_ptr<Transaction>& transaction, SiteToSiteClient::ReceiveFlowFileHeaderResult& result); - std::optional<ReceiveFlowFileHeaderResult> receiveFlowFileHeader(const std::shared_ptr<Transaction>& transaction); + bool readFlowFileHeaderData(io::InputStream& stream, const std::string& transaction_id, SiteToSiteClient::ReceiveFlowFileHeaderResult& result); + std::optional<ReceiveFlowFileHeaderResult> receiveFlowFileHeader(io::InputStream& stream, const std::shared_ptr<Transaction>& transaction); std::pair<uint64_t, uint64_t> readFlowFiles(const std::shared_ptr<Transaction>& transaction, core::ProcessSession& session); std::shared_ptr<core::logging::Logger> logger_{core::logging::LoggerFactory<SiteToSiteClient>::getLogger()}; diff --git a/libminifi/src/sitetosite/CompressionInputStream.cpp b/libminifi/src/sitetosite/CompressionInputStream.cpp new file mode 100644 index 000000000..5bc32f1c4 --- /dev/null +++ b/libminifi/src/sitetosite/CompressionInputStream.cpp @@ -0,0 +1,154 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sitetosite/CompressionInputStream.h" + +#include <algorithm> +#include "io/ZlibStream.h" + +namespace org::apache::nifi::minifi::sitetosite { + +size_t CompressionInputStream::decompressData() { + if (eof_) { + return 0; + } + + std::vector<std::byte> local_buffer(COMPRESSION_BUFFER_SIZE); + auto ret = internal_stream_.read(std::span(local_buffer).subspan(0, SYNC_BYTES.size())); + if (ret != SYNC_BYTES.size() || + !std::equal(SYNC_BYTES.begin(), SYNC_BYTES.end(), local_buffer.begin(), [](char sync_char, std::byte read_byte) { return static_cast<std::byte>(sync_char) == read_byte;})) { + logger_->log_error("Failed to read sync bytes or sync bytes do not match"); + return io::STREAM_ERROR; + } + + uint32_t original_size = 0; + ret = internal_stream_.read(original_size); + if (io::isError(ret) || ret != 4) { + logger_->log_error("Failed to read original size, ret: {}", ret); + return io::STREAM_ERROR; + } + + uint32_t compressed_size = 0; + ret = internal_stream_.read(compressed_size); + if (io::isError(ret) || ret != 4) { + logger_->log_error("Failed to read compressed size, ret: {}", ret); + return io::STREAM_ERROR; + } + + if (compressed_size == 0 && original_size != 0) { + logger_->log_error("Compressed size is 0 but original size is not"); + return io::STREAM_ERROR; + } + + if (compressed_size > COMPRESSION_BUFFER_SIZE) { + logger_->log_error("Compressed size exceeds buffer size"); + return io::STREAM_ERROR; + } + + if (original_size > COMPRESSION_BUFFER_SIZE) { + logger_->log_error("Original size exceeds buffer size"); + return io::STREAM_ERROR; + } + + ret = internal_stream_.read(std::span(local_buffer).subspan(0, compressed_size)); + if (io::isError(ret) || ret != compressed_size) { + logger_->log_error("Failed to read compressed data, ret: {}", ret); + return io::STREAM_ERROR; + } + + if (compressed_size != 0) { + io::BufferStream decompressed_data_stream; + io::ZlibDecompressStream zlib_stream{gsl::make_not_null(&decompressed_data_stream), io::ZlibCompressionFormat::ZLIB}; + ret = zlib_stream.write(std::span(local_buffer).subspan(0, compressed_size)); + if (io::isError(ret)) { + logger_->log_error("Failed to write compressed data to zlib stream, ret: {}", ret); + return ret; + } + zlib_stream.close(); + gsl_Assert(zlib_stream.isFinished()); + + ret = decompressed_data_stream.read(std::span(buffer_).subspan(0, original_size)); + if (io::isError(ret) || ret != original_size) { + logger_->log_error("Failed to read decompressed data, ret: {}", ret); + return io::STREAM_ERROR; + } + } + + uint8_t end_byte = 0; + ret = internal_stream_.read(end_byte); + if (io::isError(ret) || ret != 1) { + logger_->log_error("Failed to read end byte, ret: {}", ret); + return io::STREAM_ERROR; + } + + // If end_byte is 0, it indicates EOF, if it is 1, it indicates more data will follow + if (end_byte == 0) { + eof_ = true; + } else if (end_byte != 1) { + logger_->log_error("End byte is not 0 or 1, received: {}", end_byte); + return io::STREAM_ERROR; + } + + buffered_data_length_ = original_size; + buffer_offset_ = 0; + return original_size; +} + +size_t CompressionInputStream::read(std::span<std::byte> out_buffer) { + if (eof_ && buffered_data_length_ == 0) { + return 0; + } + + std::span<std::byte> remaining_output = out_buffer; + size_t total_bytes_read = 0; + + while (!remaining_output.empty()) { + if (buffered_data_length_ == 0 || buffered_data_length_ == buffer_offset_) { + auto ret = decompressData(); + if (io::isError(ret)) { + return io::STREAM_ERROR; + } + } + + const auto bytes_available = buffered_data_length_ - buffer_offset_; + if (bytes_available == 0) { + break; + } + + const auto bytes_to_copy = std::min(bytes_available, remaining_output.size()); + const auto source_data = std::span<const std::byte>(buffer_).subspan(buffer_offset_, bytes_to_copy); + + std::ranges::copy(source_data, remaining_output.begin()); + + buffer_offset_ += bytes_to_copy; + total_bytes_read += bytes_to_copy; + remaining_output = remaining_output.subspan(bytes_to_copy); + + if (buffer_offset_ == buffered_data_length_) { + buffer_offset_ = 0; + buffered_data_length_ = 0; + } + } + + return total_bytes_read; +} + +void CompressionInputStream::close() { + internal_stream_.close(); +} + +} // namespace org::apache::nifi::minifi::sitetosite diff --git a/libminifi/src/sitetosite/CompressionOutputStream.cpp b/libminifi/src/sitetosite/CompressionOutputStream.cpp new file mode 100644 index 000000000..31e1e4388 --- /dev/null +++ b/libminifi/src/sitetosite/CompressionOutputStream.cpp @@ -0,0 +1,139 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sitetosite/CompressionOutputStream.h" + +#include <algorithm> + +#include "io/ZlibStream.h" +#include "io/StreamPipe.h" +#include "io/BufferStream.h" +#include "core/logging/LoggerFactory.h" + +namespace org::apache::nifi::minifi::sitetosite { + +size_t CompressionOutputStream::write(const uint8_t *value, size_t len) { + if (value == nullptr || len == 0) { + return 0; + } + + const std::span<const std::byte> input_data{reinterpret_cast<const std::byte*>(value), len}; + auto remaining_data = input_data; + size_t total_bytes_written = 0; + + while (!remaining_data.empty()) { + const auto free_spaces_left_in_buffer = buffer_.size() - buffer_offset_; + const auto bytes_to_write = std::min(remaining_data.size(), free_spaces_left_in_buffer); + + const auto chunk_to_write = remaining_data.subspan(0, bytes_to_write); + const auto buffer_destination = std::span<std::byte>(buffer_).subspan(buffer_offset_, bytes_to_write); + + std::ranges::copy(chunk_to_write, buffer_destination.begin()); + + total_bytes_written += bytes_to_write; + remaining_data = remaining_data.subspan(bytes_to_write); + buffer_offset_ += bytes_to_write; + + gsl_Assert(buffer_offset_ <= buffer_.size()); + + if (buffer_offset_ == buffer_.size()) { + auto ret = compressAndWrite(); + if (io::isError(ret)) { + return ret; + } + } + } + + return total_bytes_written; +} + +size_t CompressionOutputStream::compressAndWrite() { + if (was_data_written_) { + // Write a continue byte to indicate that there is more data to follow + auto ret = internal_stream_.write(static_cast<uint8_t>(1)); + if (io::isError(ret)) { + logger_->log_error("Failed to write continue byte before compression: {}", ret); + return ret; + } + } + auto ret = internal_stream_.write(reinterpret_cast<const uint8_t *>(SYNC_BYTES.data()), SYNC_BYTES.size()); + if (io::isError(ret)) { + logger_->log_error("Failed to write sync bytes before compression: {}", ret); + return ret; + } + + io::BufferStream buffer_stream; + { + io::ZlibCompressStream zlib_stream{gsl::make_not_null(&buffer_stream), io::ZlibCompressionFormat::ZLIB, Z_BEST_SPEED}; + ret = zlib_stream.write(gsl::make_span(buffer_).subspan(0, buffer_offset_)); + if (io::isError(ret)) { + logger_->log_error("Failed to write data to zlib stream: {}", ret); + return ret; + } + gsl_Assert(buffer_offset_ == ret); + zlib_stream.close(); + gsl_Assert(zlib_stream.isFinished()); + } + + // Write the original size of the data before compression + ret = internal_stream_.write(gsl::narrow<uint32_t>(buffer_offset_)); + if (io::isError(ret)) { + logger_->log_error("Failed to write original size before compression: {}", ret); + return ret; + } + + // Write the compressed size of the data + ret = internal_stream_.write(gsl::narrow<uint32_t>(buffer_stream.size())); + if (io::isError(ret)) { + return ret; + } + + // Write the compressed data + ret = internal::pipe(buffer_stream, internal_stream_); + if (io::isError(ret)) { + return ret; + } + + buffer_offset_ = 0; + was_data_written_ = true; + return ret; +} + +void CompressionOutputStream::flush() { + if (buffer_offset_ > 0) { + auto ret = compressAndWrite(); + if (io::isError(ret)) { + logger_->log_error("Flush failed when compressing data: {}", ret); + return; + } + } + if (was_data_written_) { + was_data_written_ = false; + auto ret = internal_stream_.write(static_cast<uint8_t>(0)); + if (io::isError(ret)) { + logger_->log_error("Flush failed when writing on internal stream: {}", ret); + return; + } + } +} + +void CompressionOutputStream::close() { + flush(); + internal_stream_.close(); +} + +} // namespace org::apache::nifi::minifi::sitetosite diff --git a/libminifi/src/sitetosite/HttpSiteToSiteClient.cpp b/libminifi/src/sitetosite/HttpSiteToSiteClient.cpp index 8cd99377f..b8fcbaf86 100644 --- a/libminifi/src/sitetosite/HttpSiteToSiteClient.cpp +++ b/libminifi/src/sitetosite/HttpSiteToSiteClient.cpp @@ -149,7 +149,7 @@ std::shared_ptr<Transaction> HttpSiteToSiteClient::createTransaction(TransferDir setSiteToSiteHeaders(*transaction_client); peer_->setStream(std::make_unique<http::HttpStream>(transaction_client)); - logger_->log_debug("Created transaction id -{}-", transaction->getUUID().to_string()); + logger_->log_debug("Created transaction id -{}-", transaction->getUUIDStr()); known_transactions_[transaction->getUUID()] = transaction; return transaction; } @@ -375,8 +375,7 @@ void HttpSiteToSiteClient::deleteTransaction(const utils::Identifier& transactio void HttpSiteToSiteClient::setSiteToSiteHeaders(minifi::http::HTTPClient& client) { client.setRequestHeader(PROTOCOL_VERSION_HEADER, "1"); - // TODO(lordgamez): send use_compression_ boolean value when compression support is added - client.setRequestHeader(HANDSHAKE_PROPERTY_USE_COMPRESSION, "false"); + client.setRequestHeader(HANDSHAKE_PROPERTY_USE_COMPRESSION, use_compression_ ? "true" : "false"); if (timeout_.load() > 0ms) { client.setRequestHeader(HANDSHAKE_PROPERTY_REQUEST_EXPIRATION, std::to_string(timeout_.load().count())); } diff --git a/libminifi/src/sitetosite/RawSiteToSiteClient.cpp b/libminifi/src/sitetosite/RawSiteToSiteClient.cpp index 41ed72cdc..1f87b26dc 100644 --- a/libminifi/src/sitetosite/RawSiteToSiteClient.cpp +++ b/libminifi/src/sitetosite/RawSiteToSiteClient.cpp @@ -148,8 +148,7 @@ bool RawSiteToSiteClient::handShake() { } std::map<std::string, std::string> properties; - // TODO(lordgamez): send use_compression_ boolean value when compression support is added - properties[std::string(magic_enum::enum_name(HandshakeProperty::GZIP))] = "false"; + properties[std::string(magic_enum::enum_name(HandshakeProperty::GZIP))] = use_compression_ ? "true" : "false"; properties[std::string(magic_enum::enum_name(HandshakeProperty::PORT_IDENTIFIER))] = port_id_.to_string(); properties[std::string(magic_enum::enum_name(HandshakeProperty::REQUEST_EXPIRATION_MILLIS))] = std::to_string(timeout_.load().count()); if (current_version_ >= 5) { diff --git a/libminifi/src/sitetosite/SiteToSiteClient.cpp b/libminifi/src/sitetosite/SiteToSiteClient.cpp index eb68b9636..79bc0221b 100644 --- a/libminifi/src/sitetosite/SiteToSiteClient.cpp +++ b/libminifi/src/sitetosite/SiteToSiteClient.cpp @@ -24,6 +24,8 @@ #include "utils/gsl.h" #include "utils/Enum.h" +#include "sitetosite/CompressionOutputStream.h" +#include "sitetosite/CompressionInputStream.h" namespace org::apache::nifi::minifi::sitetosite { @@ -93,12 +95,14 @@ void SiteToSiteClient::deleteTransaction(const utils::Identifier& transaction_id bool SiteToSiteClient::writeResponse(const std::shared_ptr<Transaction>& /*transaction*/, const SiteToSiteResponse& response) { const ResponseCodeContext* response_code_context = getResponseCodeContext(response.code); if (!response_code_context) { + logger_->log_error("Site2Site write response failed: invalid response code context for code {}", magic_enum::enum_underlying(response.code)); return false; } const std::array<uint8_t, 3> code_sequence = { CODE_SEQUENCE_VALUE_1, CODE_SEQUENCE_VALUE_2, magic_enum::enum_underlying(response.code) }; const auto ret = peer_->write(code_sequence.data(), 3); if (ret != 3) { + logger_->log_error("Site2Site write response failed: failed to write code sequence, expected 3 bytes, got {}", ret); return false; } @@ -420,23 +424,22 @@ bool SiteToSiteClient::initializeSend(const std::shared_ptr<Transaction>& transa return true; } -bool SiteToSiteClient::writeAttributesInSendTransaction(const std::shared_ptr<Transaction>& transaction, const std::map<std::string, std::string>& attributes) { - auto transaction_id = transaction->getUUID(); - if (const auto ret = transaction->getStream().write(gsl::narrow<uint32_t>(attributes.size())); ret != 4) { +bool SiteToSiteClient::writeAttributesInSendTransaction(io::OutputStream& stream, const std::string& transaction_id_str, const std::map<std::string, std::string>& attributes) { + if (const auto ret = stream.write(gsl::narrow<uint32_t>(attributes.size())); ret != 4) { logger_->log_error("Failed to write number of attributes!"); return false; } return std::ranges::all_of(attributes, [&](const auto& attribute) { - if (const auto ret = transaction->getStream().write(attribute.first, true); ret == 0 || io::isError(ret)) { + if (const auto ret = stream.write(attribute.first, true); ret == 0 || io::isError(ret)) { logger_->log_error("Failed to write attribute key {}!", attribute.first); return false; } - if (const auto ret = transaction->getStream().write(attribute.second, true); ret == 0 || io::isError(ret)) { + if (const auto ret = stream.write(attribute.second, true); ret == 0 || io::isError(ret)) { logger_->log_error("Failed to write attribute value {}!", attribute.second); return false; } - logger_->log_debug("Site2Site transaction {} send attribute key {} value {}", transaction_id.to_string(), attribute.first, attribute.second); + logger_->log_debug("Site2Site transaction {} send attribute key {} value {}", transaction_id_str, attribute.first, attribute.second); return true; }); } @@ -447,7 +450,7 @@ void SiteToSiteClient::finalizeSendTransaction(const std::shared_ptr<Transaction transaction->setState(TransactionState::DATA_EXCHANGED); transaction->addBytes(sent_bytes); - logger_->log_info("Site to Site transaction {} sent flow {} flow records, with total size {}", transaction->getUUID().to_string(), transaction->getTotalTransfers(), transaction->getBytes()); + logger_->log_info("Site to Site transaction {} sent flow {} flow records, with total size {}", transaction->getUUIDStr(), transaction->getTotalTransfers(), transaction->getBytes()); } bool SiteToSiteClient::sendFlowFile(const std::shared_ptr<Transaction>& transaction, core::FlowFile& flow_file, core::ProcessSession& session) { @@ -455,8 +458,16 @@ bool SiteToSiteClient::sendFlowFile(const std::shared_ptr<Transaction>& transact return false; } + std::unique_ptr<CompressionOutputStream> compression_stream; + std::unique_ptr<io::CRCStream<io::OutputStream>> compression_wrapper_crc_stream; + if (use_compression_) { + compression_stream = std::make_unique<CompressionOutputStream>(transaction->getStream()); + compression_wrapper_crc_stream = std::make_unique<io::CRCStream<io::OutputStream>>(gsl::make_not_null(compression_stream.get())); + } + io::OutputStream& stream = use_compression_ ? static_cast<io::OutputStream&>(*compression_wrapper_crc_stream) : static_cast<io::OutputStream&>(transaction->getStream()); + auto attributes = flow_file.getAttributes(); - if (!writeAttributesInSendTransaction(transaction, attributes)) { + if (!writeAttributesInSendTransaction(stream, transaction->getUUIDStr(), attributes)) { return false; } @@ -472,14 +483,14 @@ bool SiteToSiteClient::sendFlowFile(const std::shared_ptr<Transaction>& transact uint64_t len = 0; if (flowfile_has_content) { len = flow_file.getSize(); - const auto ret = transaction->getStream().write(len); + const auto ret = stream.write(len); if (ret != 8) { logger_->log_debug("Failed to write content size!"); return false; } if (flow_file.getSize() > 0) { - auto read_result = session.read(flow_file, [&transaction](const std::shared_ptr<io::InputStream>& input_stream) -> int64_t { - return internal::pipe(*input_stream, transaction->getStream()); + auto read_result = session.read(flow_file, [&stream](const std::shared_ptr<io::InputStream>& input_stream) -> int64_t { + return internal::pipe(*input_stream, stream); }); if (flow_file.getSize() != gsl::narrow<uint64_t>(read_result)) { logger_->log_debug("Mismatched sizes {} {}", flow_file.getSize(), read_result); @@ -489,13 +500,19 @@ bool SiteToSiteClient::sendFlowFile(const std::shared_ptr<Transaction>& transact logger_->log_trace("Flowfile empty {}", flow_file.getResourceClaim()->getContentFullPath()); } } else { - const auto ret = transaction->getStream().write(len); // Indicate zero length + const auto ret = stream.write(len); // Indicate zero length if (ret != 8) { logger_->log_debug("Failed to write content size (0)!"); return false; } } + if (compression_stream) { + // Update the CRC value to use the uncompressed stream CRC + compression_stream->flush(); + transaction->getStream().setCrc(compression_wrapper_crc_stream->getCRC()); + } + finalizeSendTransaction(transaction, len); return true; } @@ -504,45 +521,61 @@ bool SiteToSiteClient::sendPacket(const DataPacket& packet) { if (!initializeSend(packet.transaction)) { return false; } - auto transaction = packet.transaction; - if (!writeAttributesInSendTransaction(transaction, packet.attributes)) { + + std::unique_ptr<CompressionOutputStream> compression_stream; + std::unique_ptr<io::CRCStream<io::OutputStream>> compression_wrapper_crc_stream; + if (use_compression_) { + compression_stream = std::make_unique<CompressionOutputStream>(packet.transaction->getStream()); + compression_wrapper_crc_stream = std::make_unique<io::CRCStream<io::OutputStream>>(gsl::make_not_null(compression_stream.get())); + } + io::OutputStream& stream = use_compression_ ? static_cast<io::OutputStream&>(*compression_wrapper_crc_stream) : static_cast<io::OutputStream&>(transaction->getStream()); + + if (!writeAttributesInSendTransaction(stream, transaction->getUUIDStr(), packet.attributes)) { return false; } uint64_t len = 0; if (!packet.payload.empty()) { len = packet.payload.length(); - if (const auto ret = transaction->getStream().write(len); ret != 8) { + if (const auto ret = stream.write(len); ret != 8) { logger_->log_debug("Failed to write payload size!"); return false; } - if (const auto ret = transaction->getStream().write(reinterpret_cast<const uint8_t*>(packet.payload.c_str()), gsl::narrow<size_t>(len)); ret != gsl::narrow<size_t>(len)) { + if (const auto ret = stream.write(reinterpret_cast<const uint8_t*>(packet.payload.c_str()), gsl::narrow<size_t>(len)); ret != gsl::narrow<size_t>(len)) { logger_->log_debug("Failed to write payload!"); return false; } } + if (compression_stream) { + // Update the CRC value to use the uncompressed stream CRC + compression_stream->flush(); + transaction->getStream().setCrc(compression_wrapper_crc_stream->getCRC()); + } + finalizeSendTransaction(transaction, len); return true; } -bool SiteToSiteClient::readFlowFileHeaderData(const std::shared_ptr<Transaction>& transaction, SiteToSiteClient::ReceiveFlowFileHeaderResult& result) { +bool SiteToSiteClient::readFlowFileHeaderData(io::InputStream& stream, const std::string& transaction_id_str, SiteToSiteClient::ReceiveFlowFileHeaderResult& result) { uint32_t num_attributes = 0; - if (const auto ret = transaction->getStream().read(num_attributes); ret == 0 || io::isError(ret) || num_attributes > MAX_NUM_ATTRIBUTES) { + if (const auto ret = stream.read(num_attributes); ret == 0 || io::isError(ret) || num_attributes > MAX_NUM_ATTRIBUTES) { + logger_->log_error("Site2Site failed to read number of attributes with return code {}, or number of attributes is invalid: {}", ret, num_attributes); return false; } - const auto transaction_id_str = transaction->getUUID().to_string(); logger_->log_debug("Site2Site transaction {} receives {} attributes", transaction_id_str, num_attributes); for (uint64_t i = 0; i < num_attributes; i++) { std::string key; std::string value; - if (const auto ret = transaction->getStream().read(key, true); ret == 0 || io::isError(ret)) { + if (const auto ret = stream.read(key, true); ret == 0 || io::isError(ret)) { + logger_->log_error("Site2Site transaction {} failed to read attribute key", transaction_id_str); return false; } - if (const auto ret = transaction->getStream().read(value, true); ret == 0 || io::isError(ret)) { + if (const auto ret = stream.read(value, true); ret == 0 || io::isError(ret)) { + logger_->log_error("Site2Site transaction {} failed to read attribute value for key {}", transaction_id_str, key); return false; } @@ -551,7 +584,8 @@ bool SiteToSiteClient::readFlowFileHeaderData(const std::shared_ptr<Transaction> } uint64_t len = 0; - if (const auto ret = transaction->getStream().read(len); ret == 0 || io::isError(ret)) { + if (const auto ret = stream.read(len); ret == 0 || io::isError(ret)) { + logger_->log_error("Site2Site transaction {} failed to read flow file data size", transaction_id_str); return false; } @@ -559,7 +593,7 @@ bool SiteToSiteClient::readFlowFileHeaderData(const std::shared_ptr<Transaction> return true; } -std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::receiveFlowFileHeader(const std::shared_ptr<Transaction>& transaction) { +std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::receiveFlowFileHeader(io::InputStream& stream, const std::shared_ptr<Transaction>& transaction) { if (peer_state_ != PeerState::READY) { bootstrap(); } @@ -568,7 +602,7 @@ std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::r return std::nullopt; } - const auto transaction_id_str = transaction->getUUID().to_string(); + const auto transaction_id_str = transaction->getUUIDStr(); if (transaction->getState() != TransactionState::TRANSACTION_STARTED && transaction->getState() != TransactionState::DATA_EXCHANGED) { logger_->log_warn("Site2Site transaction {} is not at started or exchanged state", transaction_id_str); return std::nullopt; @@ -586,7 +620,7 @@ std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::r } if (transaction->getCurrentTransfers() > 0) { - // if we already has transfer before, check to see whether another one is available + // if we already have transferred a flow file before, check to see whether another one is available auto response = readResponse(transaction); if (!response) { return std::nullopt; @@ -611,7 +645,7 @@ std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::r return result; } - if (!readFlowFileHeaderData(transaction, result)) { + if (!readFlowFileHeaderData(stream, transaction_id_str, result)) { logger_->log_error("Site2Site transaction {} failed to read flow file header data", transaction_id_str); return std::nullopt; } @@ -637,12 +671,21 @@ std::optional<SiteToSiteClient::ReceiveFlowFileHeaderResult> SiteToSiteClient::r std::pair<uint64_t, uint64_t> SiteToSiteClient::readFlowFiles(const std::shared_ptr<Transaction>& transaction, core::ProcessSession& session) { uint64_t transfers = 0; uint64_t bytes = 0; + + std::unique_ptr<CompressionInputStream> compression_stream; + std::unique_ptr<io::CRCStream<io::InputStream>> compression_wrapper_crc_stream; + if (use_compression_) { + compression_stream = std::make_unique<CompressionInputStream>(transaction->getStream()); + compression_wrapper_crc_stream = std::make_unique<io::CRCStream<io::InputStream>>(gsl::make_not_null(compression_stream.get())); + } + io::InputStream& stream = use_compression_ ? static_cast<io::InputStream&>(*compression_wrapper_crc_stream) : static_cast<io::InputStream&>(transaction->getStream()); + while (true) { auto start_time = std::chrono::steady_clock::now(); - auto receive_header_result = receiveFlowFileHeader(transaction); + auto receive_header_result = receiveFlowFileHeader(stream, transaction); if (!receive_header_result) { - throw Exception(SITE2SITE_EXCEPTION, "Receive Failed " + transaction->getUUID().to_string()); + throw Exception(SITE2SITE_EXCEPTION, "Receive Failed " + transaction->getUUIDStr()); } if (receive_header_result->eof) { @@ -664,12 +707,12 @@ std::pair<uint64_t, uint64_t> SiteToSiteClient::readFlowFiles(const std::shared_ } if (receive_header_result->flow_file_data_size > 0) { - session.write(flow_file, [&receive_header_result, &transaction](const std::shared_ptr<io::OutputStream>& output_stream) -> int64_t { + session.write(flow_file, [&receive_header_result, &stream](const std::shared_ptr<io::OutputStream>& output_stream) -> int64_t { uint64_t len = receive_header_result->flow_file_data_size; std::array<std::byte, utils::configuration::DEFAULT_BUFFER_SIZE> buffer{}; while (len > 0) { const auto size = std::min(len, uint64_t{utils::configuration::DEFAULT_BUFFER_SIZE}); - const auto ret = transaction->getStream().read(std::as_writable_bytes(std::span(buffer).subspan(0, size))); + const auto ret = stream.read(std::as_writable_bytes(std::span(buffer).subspan(0, size))); if (ret != size) { return -1; } @@ -690,9 +733,19 @@ std::pair<uint64_t, uint64_t> SiteToSiteClient::readFlowFiles(const std::shared_ std::string details = "urn:nifi:" + source_identifier + "Remote Host=" + peer_->getHostName(); session.getProvenanceReporter()->receive(*flow_file, transitUri, source_identifier, details, std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time)); session.transfer(flow_file, relation); - // receive the transfer for the flow record + bytes += receive_header_result->flow_file_data_size; transfers++; + + if (compression_stream) { + // Non-compressed response codes are written between flow files, so we need to reset the compression stream buffer for the next flow file + compression_stream->resetBuffer(); + } + } + + if (compression_stream) { + // Update the CRC value to use the uncompressed stream CRC + transaction->getStream().setCrc(compression_wrapper_crc_stream->getCRC()); } return {transfers, bytes}; diff --git a/libminifi/test/unit/SiteToSiteCompressionStreamTests.cpp b/libminifi/test/unit/SiteToSiteCompressionStreamTests.cpp new file mode 100644 index 000000000..cb269068b --- /dev/null +++ b/libminifi/test/unit/SiteToSiteCompressionStreamTests.cpp @@ -0,0 +1,215 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "unit/Catch.h" +#include "unit/TestBase.h" +#include "sitetosite/CompressionOutputStream.h" +#include "sitetosite/CompressionInputStream.h" +#include "io/BufferStream.h" +#include "io/ZlibStream.h" + +namespace org::apache::nifi::minifi::test { + +void verifySyncBytes(io::BufferStream& buffer_stream) { + std::vector<std::byte> data_buffer; + data_buffer.resize(4); + buffer_stream.read(std::span(data_buffer)); + REQUIRE(std::to_integer<char>(data_buffer[0]) == sitetosite::SYNC_BYTES[0]); + REQUIRE(std::to_integer<char>(data_buffer[1]) == sitetosite::SYNC_BYTES[1]); + REQUIRE(std::to_integer<char>(data_buffer[2]) == sitetosite::SYNC_BYTES[2]); + REQUIRE(std::to_integer<char>(data_buffer[3]) == sitetosite::SYNC_BYTES[3]); +} + +void verifyOriginalSize(io::BufferStream& buffer_stream, uint32_t expected_size) { + uint32_t original_size = 0; + buffer_stream.read(original_size); + REQUIRE(original_size == expected_size); +} + +void verifyCompressedData(io::BufferStream& compressed_buffer_stream, uint32_t expected_size) { + uint32_t compressed_size = 0; + compressed_buffer_stream.read(compressed_size); + + std::vector<std::byte> compressed_data_buffer; + compressed_data_buffer.resize(compressed_size); + compressed_buffer_stream.read(std::span(compressed_data_buffer)); + + io::BufferStream decompressed_data_stream; + io::ZlibDecompressStream decompressor(gsl::make_not_null(&decompressed_data_stream), io::ZlibCompressionFormat::ZLIB); + decompressor.write(compressed_data_buffer); + REQUIRE(decompressor.isFinished()); + + std::vector<std::byte> decompressed_data_buffer; + decompressed_data_buffer.resize(expected_size); + decompressed_data_stream.read(std::span(decompressed_data_buffer)); + + for (size_t i = 0; i < decompressed_data_buffer.size(); i += 4) { + uint32_t value = (static_cast<uint32_t>(std::to_integer<uint8_t>(decompressed_data_buffer[i])) << 24) + | (static_cast<uint32_t>(std::to_integer<uint8_t>(decompressed_data_buffer[i + 1])) << 16) + | (static_cast<uint32_t>(std::to_integer<uint8_t>(decompressed_data_buffer[i + 2])) << 8) + | static_cast<uint32_t>(std::to_integer<uint8_t>(decompressed_data_buffer[i + 3])); + + REQUIRE(value == 42); +} +} + +void verifyContinueByte(io::BufferStream& buffer_stream) { + uint8_t closing_byte = 2; + buffer_stream.read(closing_byte); + REQUIRE(closing_byte == 1); +} + +void verifyClosingByte(io::BufferStream& buffer_stream) { + uint8_t closing_byte = 2; + buffer_stream.read(closing_byte); + REQUIRE(closing_byte == 0); +} + +void verifyCompressedChunks(io::BufferStream& buffer_stream, uint32_t expected_size) { + bool first_chunk = true; + uint32_t size_processed = 0; + while (size_processed < expected_size) { + uint32_t current_size_to_read = 0; + if (expected_size - size_processed > sitetosite::COMPRESSION_BUFFER_SIZE) { + current_size_to_read = sitetosite::COMPRESSION_BUFFER_SIZE; + } else { + current_size_to_read = expected_size - size_processed; + } + if (first_chunk) { + first_chunk = false; + } else { + verifyContinueByte(buffer_stream); + } + verifySyncBytes(buffer_stream); + verifyOriginalSize(buffer_stream, current_size_to_read); + verifyCompressedData(buffer_stream, current_size_to_read); + size_processed += current_size_to_read; + } + verifyClosingByte(buffer_stream); +} + +TEST_CASE("Write empty output stream", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + output_stream.close(); + REQUIRE(buffer_stream.size() == 0); +} + +TEST_CASE("Write a 4 byte integer and flush", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + output_stream.flush(); + verifyCompressedChunks(buffer_stream, 4); +} + +TEST_CASE("Write a single chunk of compressed data and flush on close", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + REQUIRE(buffer_stream.size() == 0); + output_stream.close(); + REQUIRE(buffer_stream.size() > 0); + + verifyCompressedChunks(buffer_stream, 40000); +} + +TEST_CASE("Write 2 chunks of compressed data and flush on demand", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + REQUIRE(buffer_stream.size() == 0); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + + // Automatically compress data when buffer is full + REQUIRE(buffer_stream.size() > 0); + output_stream.close(); + + verifyCompressedChunks(buffer_stream, 80000); +} + +TEST_CASE("Write 3 chunks of compressed data and flush on demand", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + REQUIRE(buffer_stream.size() == 0); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + + // Automatically compress data when buffer is full + REQUIRE(buffer_stream.size() > 0); + for (size_t i = 0; i < 20000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + output_stream.close(); + + verifyCompressedChunks(buffer_stream, 160000); +} + +TEST_CASE("Read single 4 byte integer compressed", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + output_stream.flush(); + sitetosite::CompressionInputStream input_stream(buffer_stream); + uint32_t read_byte{}; + CHECK(input_stream.read(read_byte) == 4); + CHECK(read_byte == 42); +} + +TEST_CASE("Read large number of bytes compressed", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + for (size_t i = 0; i < 10000; ++i) { + CHECK(output_stream.write(static_cast<uint32_t>(42)) == 4); + } + output_stream.flush(); + sitetosite::CompressionInputStream input_stream(buffer_stream); + for (size_t i = 0; i < 10000; ++i) { + uint32_t read_byte{}; + CHECK(input_stream.read(read_byte) == 4); + CHECK(read_byte == 42); + } +} + +TEST_CASE("Read large number of bytes that uses multiple buffers", "[CompressionOutputStream]") { + io::BufferStream buffer_stream; + sitetosite::CompressionOutputStream output_stream(buffer_stream); + uint32_t count = 0; + while (buffer_stream.size() + 100 < sitetosite::COMPRESSION_BUFFER_SIZE) { + ++count; + CHECK(output_stream.write(count) == 4); + } + output_stream.flush(); + + sitetosite::CompressionInputStream input_stream(buffer_stream); + for (size_t i = 1; i <= count; ++i) { + uint32_t read_byte{}; + CHECK(input_stream.read(read_byte) == 4); + CHECK(read_byte == i); + } +} + +} // namespace org::apache::nifi::minifi::test diff --git a/libminifi/test/unit/SiteToSiteTests.cpp b/libminifi/test/unit/SiteToSiteTests.cpp index e9b6c948d..7480e2504 100644 --- a/libminifi/test/unit/SiteToSiteTests.cpp +++ b/libminifi/test/unit/SiteToSiteTests.cpp @@ -30,6 +30,8 @@ #include "unit/SiteToSiteHelper.h" #include "unit/DummyProcessor.h" #include "unit/ProvenanceTestHelper.h" +#include "catch2/generators/catch_generators.hpp" +#include "io/ZlibStream.h" namespace org::apache::nifi::minifi::test { @@ -56,6 +58,11 @@ class SiteToSiteClientTestAccessor { } }; +void initializeLogging() { + LogTestController::getInstance().setTrace<sitetosite::RawSiteToSiteClient>(); + LogTestController::getInstance().setTrace<sitetosite::SiteToSitePeer>(); +} + void initializeMockBootstrapResponses(SiteToSiteResponder& collector) { const char resource_ok_code = magic_enum::enum_underlying(sitetosite::ResourceNegotiationStatusCode::RESOURCE_OK); std::string resp_code; @@ -74,11 +81,12 @@ void initializeMockBootstrapResponses(SiteToSiteResponder& collector) { collector.push_response(resp_code); } -void verifyBootstrapMessages(sitetosite::RawSiteToSiteClient& protocol, SiteToSiteResponder& collector) { +void verifyBootstrapMessages(sitetosite::RawSiteToSiteClient& protocol, SiteToSiteResponder& collector, bool use_compression) { protocol.setUseCompression(false); protocol.setBatchDuration(std::chrono::milliseconds(100)); protocol.setBatchCount(5); protocol.setTimeout(std::chrono::milliseconds(20000)); + protocol.setUseCompression(use_compression); minifi::utils::Identifier fake_uuid = minifi::utils::Identifier::parse("C56A4180-65AA-42EC-A945-5FD21DEC0538").value(); protocol.setPortId(fake_uuid); @@ -105,7 +113,7 @@ void verifyBootstrapMessages(sitetosite::RawSiteToSiteClient& protocol, SiteToSi collector.get_next_client_response(); REQUIRE(collector.get_next_client_response() == "GZIP"); collector.get_next_client_response(); - REQUIRE(collector.get_next_client_response() == "false"); + REQUIRE(collector.get_next_client_response() == (use_compression ? "true" : "false")); collector.get_next_client_response(); REQUIRE(collector.get_next_client_response() == "PORT_IDENTIFIER"); collector.get_next_client_response(); @@ -132,6 +140,7 @@ void verifySendResponses(SiteToSiteResponder& collector, const std::vector<std:: } TEST_CASE("TestSetPortId", "[S2S]") { + initializeLogging(); auto peer = gsl::make_not_null(std::make_unique<sitetosite::SiteToSitePeer>(std::make_unique<org::apache::nifi::minifi::io::BufferStream>(), "fake_host", 65433, "")); sitetosite::RawSiteToSiteClient protocol(std::move(peer)); auto fake_uuid = minifi::utils::Identifier::parse("c56a4180-65aa-42ec-a945-5fd21dec0538").value(); @@ -140,6 +149,7 @@ TEST_CASE("TestSetPortId", "[S2S]") { } TEST_CASE("TestSiteToSiteVerifySend using data packet", "[S2S]") { + initializeLogging(); auto collector = std::make_unique<SiteToSiteResponder>(); auto collector_ptr = collector.get(); @@ -148,14 +158,23 @@ TEST_CASE("TestSiteToSiteVerifySend using data packet", "[S2S]") { auto peer = gsl::make_not_null(std::make_unique<sitetosite::SiteToSitePeer>(std::move(collector), "fake_host", 65433, "")); sitetosite::RawSiteToSiteClient protocol(std::move(peer)); + bool use_compression = false; std::vector<std::string> expected_responses; std::string payload = "Test MiNiFi payload"; - expected_responses.push_back(""); // attribute count 0 - expected_responses.push_back(""); // payload length - expected_responses.push_back(payload); + SECTION("Do not use compression") { + use_compression = false; + expected_responses.push_back(""); // attribute count 0 + expected_responses.push_back(""); // payload length + expected_responses.push_back(payload); + } + + SECTION("Use compression") { + use_compression = true; + expected_responses.push_back("SYNC"); + } - verifyBootstrapMessages(protocol, *collector_ptr); + verifyBootstrapMessages(protocol, *collector_ptr, use_compression); // start to send the stuff auto transaction = SiteToSiteClientTestAccessor::createTransaction(protocol, sitetosite::TransferDirection::SEND); @@ -170,6 +189,7 @@ TEST_CASE("TestSiteToSiteVerifySend using data packet", "[S2S]") { } TEST_CASE("TestSiteToSiteVerifySend using flowfile data", "[S2S]") { + initializeLogging(); auto collector = std::make_unique<SiteToSiteResponder>(); auto collector_ptr = collector.get(); @@ -178,20 +198,29 @@ TEST_CASE("TestSiteToSiteVerifySend using flowfile data", "[S2S]") { auto peer = gsl::make_not_null(std::make_unique<sitetosite::SiteToSitePeer>(std::move(collector), "fake_host", 65433, "")); sitetosite::RawSiteToSiteClient protocol(std::move(peer)); + bool use_compression = false; std::vector<std::string> expected_responses; std::string payload = "Test MiNiFi payload"; - expected_responses.push_back(""); // attribute count - expected_responses.push_back(""); // attribute key length - expected_responses.push_back("filename"); - expected_responses.push_back(""); // attribute value length - expected_responses.push_back("myfile"); - expected_responses.push_back(""); // attribute key length - expected_responses.push_back("flow.id"); - expected_responses.push_back(""); // attribute value length - expected_responses.push_back("test"); - expected_responses.push_back(""); // payload length - expected_responses.push_back(payload); + SECTION("Do not use compression") { + use_compression = false; + expected_responses.push_back(""); // attribute count + expected_responses.push_back(""); // attribute key length + expected_responses.push_back("filename"); + expected_responses.push_back(""); // attribute value length + expected_responses.push_back("myfile"); + expected_responses.push_back(""); // attribute key length + expected_responses.push_back("flow.id"); + expected_responses.push_back(""); // attribute value length + expected_responses.push_back("test"); + expected_responses.push_back(""); // payload length + expected_responses.push_back(payload); + } + + SECTION("Use compression") { + use_compression = true; + expected_responses.push_back("SYNC"); + } protocol.setBatchDuration(std::chrono::milliseconds(100)); protocol.setBatchCount(5); @@ -200,7 +229,7 @@ TEST_CASE("TestSiteToSiteVerifySend using flowfile data", "[S2S]") { auto fake_uuid = minifi::utils::Identifier::parse("C56A4180-65AA-42EC-A945-5FD21DEC0538").value(); protocol.setPortId(fake_uuid); - verifyBootstrapMessages(protocol, *collector_ptr); + verifyBootstrapMessages(protocol, *collector_ptr, use_compression); // start to send the stuff auto transaction = SiteToSiteClientTestAccessor::createTransaction(protocol, sitetosite::TransferDirection::SEND); @@ -233,6 +262,7 @@ TEST_CASE("TestSiteToSiteVerifySend using flowfile data", "[S2S]") { } TEST_CASE("TestSiteToSiteVerifyNegotiationFail", "[S2S]") { + initializeLogging(); auto collector = std::make_unique<SiteToSiteResponder>(); const char negotiated_abort_code = magic_enum::enum_underlying(sitetosite::ResourceNegotiationStatusCode::NEGOTIATED_ABORT); @@ -252,10 +282,14 @@ TEST_CASE("TestSiteToSiteVerifyNegotiationFail", "[S2S]") { REQUIRE_FALSE(SiteToSiteClientTestAccessor::bootstrap(protocol)); } -void initializeMockRemoteClientReceiveDataResponses(SiteToSiteResponder& collector) { - collector.push_response("R"); - collector.push_response("C"); - collector.push_response(std::string{static_cast<char>(magic_enum::enum_underlying(sitetosite::ResponseCode::MORE_DATA))}); +void initializeMockRemoteClientReceiveDataResponses(SiteToSiteResponder& collector, bool use_compression) { + auto addResponseCode = [&collector](sitetosite::ResponseCode code) { + collector.push_response("R"); + collector.push_response("C"); + collector.push_response(std::string{static_cast<char>(magic_enum::enum_underlying(code))}); + }; + + addResponseCode(sitetosite::ResponseCode::MORE_DATA); auto addUInt32 = [&collector](uint32_t number) { std::string result(4, '\0'); @@ -264,14 +298,6 @@ void initializeMockRemoteClientReceiveDataResponses(SiteToSiteResponder& collect } collector.push_response(result); }; - const uint32_t number_of_attributes = 1; - addUInt32(number_of_attributes); - std::string attribute_key = "attribute_key"; - addUInt32(gsl::narrow<uint32_t>(attribute_key.size())); - collector.push_response(attribute_key); - std::string attribute_value = "attribute_value"; - addUInt32(gsl::narrow<uint32_t>(attribute_value.size())); - collector.push_response(attribute_value); auto addUInt64 = [&collector](uint64_t number) { std::string result(8, '\0'); for (std::size_t i = 0; i < 8; ++i) { @@ -279,29 +305,77 @@ void initializeMockRemoteClientReceiveDataResponses(SiteToSiteResponder& collect } collector.push_response(result); }; - std::string payload = "data"; - addUInt64(payload.size()); - collector.push_response("data"); - collector.push_response("R"); - collector.push_response("C"); - const char resource_code_finish_transaction = magic_enum::enum_underlying(sitetosite::ResponseCode::FINISH_TRANSACTION); - std::string resp_code; - resp_code.insert(resp_code.begin(), resource_code_finish_transaction); - collector.push_response(resp_code); + const uint32_t number_of_attributes = 1; + const std::string attribute_key = "attribute_key"; + const std::string attribute_value = "attribute_value"; + const std::string payload = "data"; + + for (size_t i = 0; i < 2; ++i) { + if (!use_compression) { + addUInt32(number_of_attributes); + + addUInt32(gsl::narrow<uint32_t>(attribute_key.size())); + collector.push_response(attribute_key); + + addUInt32(gsl::narrow<uint32_t>(attribute_value.size())); + collector.push_response(attribute_value); + + addUInt64(payload.size()); + collector.push_response("data"); + } else { + collector.push_response("SYNC"); + + io::BufferStream buffer_stream; + buffer_stream.write(number_of_attributes); + buffer_stream.write(gsl::narrow<uint32_t>(attribute_key.size())); + buffer_stream.write(reinterpret_cast<const uint8_t*>(attribute_key.data()), attribute_key.size()); + buffer_stream.write(gsl::narrow<uint32_t>(attribute_value.size())); + buffer_stream.write(reinterpret_cast<const uint8_t*>(attribute_value.data()), attribute_value.size()); + buffer_stream.write(gsl::narrow<uint64_t>(payload.size())); + buffer_stream.write(reinterpret_cast<const uint8_t*>(payload.data()), payload.size()); + auto original_size = buffer_stream.size(); + + io::BufferStream compressed_stream; + io::ZlibCompressStream compression_stream(gsl::make_not_null(&compressed_stream), io::ZlibCompressionFormat::ZLIB, Z_BEST_SPEED); + internal::pipe(buffer_stream, compression_stream); + compression_stream.close(); + std::vector<std::byte> compressed_data; + auto compressed_size = compressed_stream.size(); + compressed_data.resize(compressed_size); + compressed_stream.read(compressed_data); + std::string compressed_data_str(reinterpret_cast<const char*>(compressed_data.data()), compressed_data.size()); + + addUInt32(gsl::narrow<uint32_t>(original_size)); + addUInt32(gsl::narrow<uint32_t>(compressed_size)); + collector.push_response(compressed_data_str); + + std::string compression_ending_byte; + compression_ending_byte.insert(compression_ending_byte.begin(), 0); + collector.push_response(compression_ending_byte); + } + + if (i == 0) { + addResponseCode(sitetosite::ResponseCode::CONTINUE_TRANSACTION); + } else { + addResponseCode(sitetosite::ResponseCode::FINISH_TRANSACTION); + } + } } -TEST_CASE("Test receiving flow file through site to site", "[S2S]") { +TEST_CASE("Test receiving multiple flow files through site to site", "[S2S]") { + initializeLogging(); auto collector = std::make_unique<SiteToSiteResponder>(); auto collector_ptr = collector.get(); + const auto use_compression = GENERATE(false, true); initializeMockBootstrapResponses(*collector); - initializeMockRemoteClientReceiveDataResponses(*collector); + initializeMockRemoteClientReceiveDataResponses(*collector, use_compression); auto peer = gsl::make_not_null(std::make_unique<sitetosite::SiteToSitePeer>(std::move(collector), "fake_host", 65433, "")); sitetosite::RawSiteToSiteClient protocol(std::move(peer)); - verifyBootstrapMessages(protocol, *collector_ptr); + verifyBootstrapMessages(protocol, *collector_ptr, use_compression); std::shared_ptr<sitetosite::Transaction> transaction; transaction = SiteToSiteClientTestAccessor::createTransaction(protocol, sitetosite::TransferDirection::RECEIVE); @@ -326,13 +400,15 @@ TEST_CASE("Test receiving flow file through site to site", "[S2S]") { SiteToSiteClientTestAccessor::readFlowFiles(protocol, transaction, *session); session->commit(); std::set<std::shared_ptr<core::FlowFile>> expired; - auto flow_file = outgoing_connection->poll(expired); - auto attributes = flow_file->getAttributes(); - REQUIRE(attributes.size() == 3); - CHECK(attributes["attribute_key"] == "attribute_value"); - CHECK(attributes.contains("filename")); - CHECK(attributes["flow.id"] == "test"); - CHECK(test_plan->getContent(flow_file) == "data"); + for (size_t i = 0; i < 2; ++i) { + auto flow_file = outgoing_connection->poll(expired); + auto attributes = flow_file->getAttributes(); + REQUIRE(attributes.size() == 3); + CHECK(attributes["attribute_key"] == "attribute_value"); + CHECK(attributes.contains("filename")); + CHECK(attributes["flow.id"] == "test"); + CHECK(test_plan->getContent(flow_file) == "data"); + } } } // namespace org::apache::nifi::minifi::test
