This is an automated email from the ASF dual-hosted git repository. lordgamez pushed a commit to branch MINIFICPP-2152_BASE in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 22b1ef4c42d2d3dbd3d98d390441a5c2d72bdc9e Author: Martin Zink <martinz...@apache.org> AuthorDate: Wed Jun 7 11:16:19 2023 +0200 MINIFICPP-2131 Refactored GetTCP --- PROCESSORS.md | 22 +- .../standard-processors/processors/GetTCP.cpp | 430 ++++++++-------- extensions/standard-processors/processors/GetTCP.h | 155 +++--- .../standard-processors/processors/PutTCP.cpp | 22 +- extensions/standard-processors/processors/PutTCP.h | 31 +- .../standard-processors/processors/TailFile.cpp | 30 +- .../standard-processors/processors/TailFile.h | 2 +- .../tests/integration/SecureSocketGetTCPTest.cpp | 2 +- .../standard-processors/tests/unit/GetTCPTests.cpp | 548 +++++++++------------ libminifi/include/utils/StringUtils.h | 5 + libminifi/include/utils/net/AsioCoro.h | 10 +- libminifi/include/utils/net/AsioSocketUtils.h | 50 +- .../utils/net/{AsioSocketUtils.h => Message.h} | 25 +- libminifi/include/utils/net/Server.h | 19 +- libminifi/src/utils/StringUtils.cpp | 20 + libminifi/src/utils/net/AsioSocketUtils.cpp | 12 +- libminifi/src/utils/net/TcpServer.cpp | 6 +- libminifi/test/resources/TestC2Metrics.yml | 9 +- libminifi/test/resources/TestGetTCPSecure.yml | 6 +- .../test/resources/TestGetTCPSecureEmptyPass.yml | 9 +- .../resources/TestGetTCPSecureWithFilePass.yml | 6 +- .../test/resources/TestGetTCPSecureWithPass.yml | 8 +- .../test/resources/TestSameProcessorMetrics.yml | 17 +- libminifi/test/resources/encrypted.cn.pass | 2 +- libminifi/test/unit/StringUtilsTests.cpp | 15 + 25 files changed, 694 insertions(+), 767 deletions(-) diff --git a/PROCESSORS.md b/PROCESSORS.md index c3ebccb10..506444db9 100644 --- a/PROCESSORS.md +++ b/PROCESSORS.md @@ -1136,16 +1136,16 @@ Establishes a TCP Server that defines and retrieves one or more byte messages fr In the list below, the names of required properties appear in bold. Any other properties (not in bold) are considered optional. The table also indicates any default values, and whether a property supports the NiFi Expression Language. -| Name | Default Value | Allowable Values | Description | -|----------------------------|---------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| **endpoint-list** | | | A comma delimited list of the endpoints to connect to. The format should be <server_address>:<port>. | -| concurrent-handler-count | 1 | | Number of concurrent handlers for this session | -| reconnect-interval | 5 s | | The number of seconds to wait before attempting to reconnect to the endpoint. | -| Stay Connected | true | | Determines if we keep the same socket despite having no data | -| receive-buffer-size | 16 MB | | The size of the buffer to receive data in. Default 16384 (16MB). | -| SSL Context Service | | | SSL Context Service Name | -| connection-attempt-timeout | 3 | | Maximum number of connection attempts before attempting backup hosts, if configured | -| end-of-message-byte | 13 | | Byte value which denotes end of message. Must be specified as integer within the valid byte range (-128 thru 127). For example, '13' = Carriage return and '10' = New line. Default '13'. | +| Name | Default Value | Allowable Values | Description | +|-------------------------------|---------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Endpoint List** | | | A comma delimited list of the endpoints to connect to. The format should be <server_address>:<port>. | +| SSL Context Service | | | SSL Context Service Name | +| Message Delimiter | \n | | Character that denotes the end of the message. | +| **Max Size of Message Queue** | 10000 | | Maximum number of messages allowed to be buffered before processing them when the processor is triggered. If the buffer is full, the message is ignored. If set to zero the buffer is unlimited. | +| Maximum Message Size | | | Optional size of the buffer to receive data in. | +| **Max Batch Size** | 500 | | The maximum number of messages to process at a time. | +| **Timeout** | 1s | | The timeout for connecting to and communicating with the destination.<br/>**Supports Expression Language: true** | +| **Reconnection Interval** | 1 min | | The duration to wait before attempting to reconnect to the endpoints.<br/>**Supports Expression Language: true** | ### Relationships @@ -2779,5 +2779,3 @@ In the list below, the names of required properties appear in bold. Any other pr |---------|-----------------------------------------| | success | All files are routed to success | | failure | Failed files are transferred to failure | - - diff --git a/extensions/standard-processors/processors/GetTCP.cpp b/extensions/standard-processors/processors/GetTCP.cpp index 4eed2d13d..fee12dc3b 100644 --- a/extensions/standard-processors/processors/GetTCP.cpp +++ b/extensions/standard-processors/processors/GetTCP.cpp @@ -17,275 +17,287 @@ */ #include "GetTCP.h" -#ifndef WIN32 -#include <dirent.h> -#endif #include <cinttypes> -#include <future> #include <memory> -#include <mutex> #include <thread> -#include <utility> -#include <vector> #include <string> -#include "io/ClientSocket.h" +#include <asio/read_until.hpp> +#include <asio/detached.hpp> +#include "utils/net/AsioCoro.h" #include "io/StreamFactory.h" #include "utils/gsl.h" #include "utils/StringUtils.h" -#include "utils/TimeUtil.h" #include "core/ProcessContext.h" #include "core/ProcessSession.h" #include "core/ProcessSessionFactory.h" #include "core/PropertyBuilder.h" #include "core/Resource.h" -namespace org::apache::nifi::minifi::processors { +using namespace std::literals::chrono_literals; -const char *DataHandler::SOURCE_ENDPOINT_ATTRIBUTE = "source.endpoint"; +namespace org::apache::nifi::minifi::processors { const core::Property GetTCP::EndpointList( - core::PropertyBuilder::createProperty("endpoint-list")->withDescription("A comma delimited list of the endpoints to connect to. The format should be <server_address>:<port>.")->isRequired(true) - ->build()); - -const core::Property GetTCP::ConcurrentHandlers( - core::PropertyBuilder::createProperty("concurrent-handler-count")->withDescription("Number of concurrent handlers for this session")->withDefaultValue<int>(1)->build()); + core::PropertyBuilder::createProperty("Endpoint List") + ->withDescription("A comma delimited list of the endpoints to connect to. The format should be <server_address>:<port>.") + ->isRequired(true)->build()); -const core::Property GetTCP::ReconnectInterval( - core::PropertyBuilder::createProperty("reconnect-interval")->withDescription("The number of seconds to wait before attempting to reconnect to the endpoint.") - ->withDefaultValue<core::TimePeriodValue>("5 s")->build()); - -const core::Property GetTCP::ReceiveBufferSize( - core::PropertyBuilder::createProperty("receive-buffer-size")->withDescription("The size of the buffer to receive data in. Default 16384 (16MB).")->withDefaultValue<core::DataSizeValue>("16 MB") +const core::Property GetTCP::SSLContextService( + core::PropertyBuilder::createProperty("SSL Context Service") + ->withDescription("SSL Context Service Name") + ->asType<minifi::controllers::SSLContextService>()->build()); + +const core::Property GetTCP::MessageDelimiter( + core::PropertyBuilder::createProperty("Message Delimiter")->withDescription( + "Character that denotes the end of the message.") + ->withDefaultValue("\\n")->build()); + +const core::Property GetTCP::MaxQueueSize( + core::PropertyBuilder::createProperty("Max Size of Message Queue") + ->withDescription("Maximum number of messages allowed to be buffered before processing them when the processor is triggered. " + "If the buffer is full, the message is ignored. If set to zero the buffer is unlimited.") + ->withDefaultValue<uint64_t>(10000) + ->isRequired(true) ->build()); -const core::Property GetTCP::SSLContextService( - core::PropertyBuilder::createProperty("SSL Context Service")->withDescription("SSL Context Service Name")->asType<minifi::controllers::SSLContextService>()->build()); +const core::Property GetTCP::MaxBatchSize( + core::PropertyBuilder::createProperty("Max Batch Size") + ->withDescription("The maximum number of messages to process at a time.") + ->withDefaultValue<uint64_t>(500) + ->isRequired(true) + ->build()); -const core::Property GetTCP::StayConnected( - core::PropertyBuilder::createProperty("Stay Connected")->withDescription("Determines if we keep the same socket despite having no data")->withDefaultValue<bool>(true)->build()); +const core::Property GetTCP::MaxMessageSize( + core::PropertyBuilder::createProperty("Maximum Message Size") + ->withDescription("Optional size of the buffer to receive data in.")->build()); -const core::Property GetTCP::ConnectionAttemptLimit( - core::PropertyBuilder::createProperty("connection-attempt-timeout")->withDescription("Maximum number of connection attempts before attempting backup hosts, if configured")->withDefaultValue<int>( - 3)->build()); +const core::Property GetTCP::Timeout = core::PropertyBuilder::createProperty("Timeout") + ->withDescription("The timeout for connecting to and communicating with the destination.") + ->withDefaultValue<core::TimePeriodValue>("1s") + ->isRequired(true) + ->supportsExpressionLanguage(true) + ->build(); -const core::Property GetTCP::EndOfMessageByte( - core::PropertyBuilder::createProperty("end-of-message-byte")->withDescription( - "Byte value which denotes end of message. Must be specified as integer within the valid byte range (-128 thru 127). For example, '13' = Carriage return and '10' = New line. Default '13'.") - ->withDefaultValue("13")->build()); +const core::Property GetTCP::ReconnectInterval = core::PropertyBuilder::createProperty("Reconnection Interval") + ->withDescription("The duration to wait before attempting to reconnect to the endpoints.") + ->withDefaultValue<core::TimePeriodValue>("1 min") + ->isRequired(true) + ->supportsExpressionLanguage(true) + ->build(); const core::Relationship GetTCP::Success("success", "All files are routed to success"); const core::Relationship GetTCP::Partial("partial", "Indicates an incomplete message as a result of encountering the end of message byte trigger"); -int16_t DataHandler::handle(const std::string& source, uint8_t *message, size_t size, bool partial) { - std::shared_ptr<core::ProcessSession> my_session = sessionFactory_->createSession(); - std::shared_ptr<core::FlowFile> flowFile = my_session->create(); - - my_session->writeBuffer(flowFile, gsl::make_span(reinterpret_cast<const std::byte*>(message), size)); - - my_session->putAttribute(flowFile, SOURCE_ENDPOINT_ATTRIBUTE, source); - - if (partial) { - my_session->transfer(flowFile, GetTCP::Partial); - } else { - my_session->transfer(flowFile, GetTCP::Success); - } - - my_session->commit(); - - return 0; -} void GetTCP::initialize() { setSupportedProperties(properties()); setSupportedRelationships(relationships()); } -void GetTCP::onSchedule(const std::shared_ptr<core::ProcessContext> &context, const std::shared_ptr<core::ProcessSessionFactory> &sessionFactory) { - std::string value; - if (context->getProperty(EndpointList.getName(), value)) { - endpoints = utils::StringUtils::split(value, ","); +void GetTCP::onSchedule(const std::shared_ptr<core::ProcessContext>& context, const std::shared_ptr<core::ProcessSessionFactory>&) { + std::vector<utils::net::ConnectionId> connections_to_make; + if (auto endpoint_list_str = context->getProperty(EndpointList)) { + for (const auto& endpoint_str : utils::StringUtils::splitAndTrim(*endpoint_list_str, ",")) { + auto hostname_service_pair = utils::StringUtils::splitAndTrim(endpoint_str, ":"); + if (hostname_service_pair.size() != 2) { + logger_->log_error("%s endpoint is invalid, expected {hostname}:{service} format", endpoint_str); + continue; + } + connections_to_make.emplace_back(hostname_service_pair[0], hostname_service_pair[1]); + } } - int handlers = 0; - if (context->getProperty(ConcurrentHandlers.getName(), handlers)) { - concurrent_handlers_ = handlers; - } + if (connections_to_make.empty()) + throw Exception(PROCESS_SCHEDULE_EXCEPTION, "No valid endpoint in Endpoint List property"); - stay_connected_ = true; - if (context->getProperty(StayConnected.getName(), value)) { - stay_connected_ = utils::StringUtils::toBool(value).value_or(true); + char delimiter = '\n'; + if (auto delimiter_str = context->getProperty(MessageDelimiter)) { + auto parsed_delimiter = utils::StringUtils::parseCharacter(*delimiter_str); + if (!parsed_delimiter || !parsed_delimiter->has_value()) + throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Invalid delimiter: {} (it must be a single (escaped or not) character", *delimiter_str)); + delimiter = **parsed_delimiter; } - int connects = 0; - if (context->getProperty(ConnectionAttemptLimit.getName(), connects)) { - connection_attempt_limit_ = connects; + std::optional<asio::ssl::context> ssl_context_; + if (auto context_name = context->getProperty(SSLContextService)) { + if (auto controller_service = context->getControllerService(*context_name)) { + if (auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(context->getControllerService(*context_name))) { + ssl_context_ = utils::net::getSslContext(*ssl_context_service); + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, *context_name + " is not an SSL Context Service"); + } + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, "Invalid controller service: " + *context_name); + } } - context->getProperty(ReceiveBufferSize.getName(), receive_buffer_size_); - if (context->getProperty(EndOfMessageByte.getName(), value)) { - logger_->log_trace("EOM is passed in as %s", value); - int64_t byteValue = 0; - core::Property::StringToInt(value, byteValue); - endOfMessageByte = static_cast<std::byte>(byteValue & 0xFF); - } + std::optional<size_t> max_queue_size = context->getProperty<uint64_t>(MaxQueueSize); + std::optional<size_t> max_message_size = context->getProperty<uint64_t>(MaxMessageSize); - logger_->log_trace("EOM is defined as %i", static_cast<int>(endOfMessageByte)); + if (auto max_batch_size = context->getProperty<uint64_t>(MaxBatchSize)) { + max_batch_size_ = *max_batch_size; + } - if (auto reconnect_interval = context->getProperty<core::TimePeriodValue>(ReconnectInterval)) { - reconnect_interval_ = reconnect_interval->getMilliseconds(); - logger_->log_debug("Reconnect interval is %" PRId64 " ms", reconnect_interval_.count()); - } else { - logger_->log_debug("Reconnect interval using default value of %" PRId64 " ms", reconnect_interval_.count()); + asio::steady_timer::duration timeout_duration = 1s; + if (auto timeout_value = context->getProperty<core::TimePeriodValue>(Timeout)) { + timeout_duration = timeout_value->getMilliseconds(); } - handler_ = std::make_unique<DataHandler>(sessionFactory); - - f_ex = [&] { - std::unique_ptr<io::Socket> socket_ptr; - // reuse the byte buffer. - std::vector<std::byte> buffer; - int reconnects = 0; - do { - if ( socket_ring_buffer_.try_dequeue(socket_ptr) ) { - buffer.resize(receive_buffer_size_); - const auto size_read = socket_ptr->read(buffer, false); - if (!io::isError(size_read)) { - if (size_read != 0) { - // determine cut location - size_t startLoc = 0; - for (size_t i = 0; i < size_read; i++) { - if (buffer.at(i) == endOfMessageByte && i > 0) { - if (i-startLoc > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast<uint8_t*>(buffer.data())+startLoc, (i-startLoc), true); - } - startLoc = i; - } - } - if (startLoc > 0) { - logger_->log_trace("Starting at %i, ending at %i", startLoc, size_read); - if (size_read-startLoc > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast<uint8_t*>(buffer.data())+startLoc, (size_read-startLoc), true); - } - } else { - logger_->log_trace("Handling at %i, ending at %i", startLoc, size_read); - if (size_read > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast<uint8_t*>(buffer.data()), size_read, false); - } - } - reconnects = 0; - } - socket_ring_buffer_.enqueue(std::move(socket_ptr)); - } else if (size_read == static_cast<size_t>(-2) && stay_connected_) { - if (++reconnects > connection_attempt_limit_) { - logger_->log_info("Too many reconnects, exiting thread"); - socket_ptr->close(); - return -1; - } - logger_->log_info("Sleeping for %" PRId64 " msec before attempting to reconnect", int64_t{reconnect_interval_.count()}); - std::this_thread::sleep_for(reconnect_interval_); - socket_ring_buffer_.enqueue(std::move(socket_ptr)); - } else { - socket_ptr->close(); - std::this_thread::sleep_for(reconnect_interval_); - logger_->log_info("Read response returned a -1 from socket, exiting thread"); - return -1; - } - } else { - std::this_thread::sleep_for(reconnect_interval_); - logger_->log_info("Could not use socket, exiting thread"); - return -1; - } - }while (running_); - logger_->log_debug("Ending private thread"); - return 0; - }; - - if (context->getProperty(SSLContextService.getName(), value)) { - std::shared_ptr<core::controller::ControllerService> service = context->getControllerService(value); - if (nullptr != service) { - ssl_service_ = std::static_pointer_cast<minifi::controllers::SSLContextService>(service); - } + asio::steady_timer::duration reconnection_interval = 1min; + if (auto reconnect_interval_value = context->getProperty<core::TimePeriodValue>(ReconnectInterval)) { + reconnection_interval = reconnect_interval_value->getMilliseconds(); } - client_thread_pool_.setMaxConcurrentTasks(concurrent_handlers_); - client_thread_pool_.start(); - running_ = true; + client_.emplace(delimiter, timeout_duration, reconnection_interval, std::move(ssl_context_), max_queue_size, max_message_size, std::move(connections_to_make), logger_); + client_thread_ = std::thread([this]() { client_->run(); }); // NOLINT } void GetTCP::notifyStop() { - running_ = false; - // await threads to shutdown. - client_thread_pool_.shutdown(); - std::unique_ptr<io::Socket> socket_ptr; - while (socket_ring_buffer_.size_approx() > 0) { - socket_ring_buffer_.try_dequeue(socket_ptr); + if (client_) + client_->stop(); +} + +void GetTCP::transferAsFlowFile(const utils::net::Message& message, core::ProcessSession& session) { + auto flow_file = session.create(); + session.writeBuffer(flow_file, message.message_data); + flow_file->setAttribute("tcp.port", std::to_string(message.server_port)); + flow_file->setAttribute("tcp.sender", message.sender_address.to_string()); + if (message.is_partial) + session.transfer(flow_file, Partial); + else + session.transfer(flow_file, Success); +} + +void GetTCP::onTrigger(const std::shared_ptr<core::ProcessContext>&, const std::shared_ptr<core::ProcessSession>& session) { + gsl_Expects(session && max_batch_size_ > 0); + size_t logs_processed = 0; + while (!client_->queueEmpty() && logs_processed < max_batch_size_) { + utils::net::Message received_message; + if (!client_->tryDequeue(received_message)) + break; + transferAsFlowFile(received_message, *session); + ++logs_processed; } } -void GetTCP::onTrigger(const std::shared_ptr<core::ProcessContext> &context, const std::shared_ptr<core::ProcessSession>& /*session*/) { - // Perform directory list - std::lock_guard<std::mutex> lock(mutex_); - // check if the futures are valid. If they've terminated remove it from the map. - - for (auto &initEndpoint : endpoints) { - std::vector<std::string> hostAndPort = utils::StringUtils::split(initEndpoint, ":"); - auto realizedHost = hostAndPort.at(0); -#ifdef WIN32 - if ("localhost" == realizedHost) { - realizedHost = org::apache::nifi::minifi::io::Socket::getMyHostName(); + +GetTCP::TcpClient::TcpClient(char delimiter, + asio::steady_timer::duration timeout_duration, + asio::steady_timer::duration reconnection_interval, + std::optional<asio::ssl::context> ssl_context, + std::optional<size_t> max_queue_size, + std::optional<size_t> max_message_size, + std::vector<utils::net::ConnectionId> connections, + std::shared_ptr<core::logging::Logger> logger) + : delimiter_(delimiter), + timeout_duration_(timeout_duration), + reconnection_interval_(reconnection_interval), + ssl_context_(std::move(ssl_context)), + max_queue_size_(max_queue_size), + max_message_size_(max_message_size), + connections_(std::move(connections)), + logger_(std::move(logger)) { +} + +GetTCP::TcpClient::~TcpClient() { + stop(); +} + + +void GetTCP::TcpClient::run() { + gsl_Expects(!connections_.empty()); + for (const auto& connection_id : connections_) { + asio::co_spawn(io_context_, doReceiveFrom(connection_id), asio::detached); // NOLINT + } + io_context_.run(); +} + +void GetTCP::TcpClient::stop() { + io_context_.stop(); +} + +bool GetTCP::TcpClient::queueEmpty() const { + return concurrent_queue_.empty(); +} + +bool GetTCP::TcpClient::tryDequeue(utils::net::Message& received_message) { + return concurrent_queue_.tryDequeue(received_message); +} + +asio::awaitable<std::error_code> GetTCP::TcpClient::readLoop(auto& socket) { + std::string read_message; + bool last_was_partial = false; + bool current_is_partial = false; + while (true) { + { + last_was_partial = current_is_partial; + current_is_partial = false; + } + auto dynamic_buffer = max_message_size_ ? asio::dynamic_buffer(read_message, *max_message_size_) : asio::dynamic_buffer(read_message); + auto [read_error, bytes_read] = co_await asio::async_read_until(socket, dynamic_buffer, delimiter_, utils::net::use_nothrow_awaitable); // NOLINT + + if (*max_message_size_ && read_error == asio::error::not_found) { + current_is_partial = true; + bytes_read = *max_message_size_; + } else if (read_error) { + logger_->log_error("Error during read %s", read_error.message()); + co_return read_error; } -#endif - if (hostAndPort.size() != 2) { + + if (bytes_read == 0) continue; + + if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) { + utils::net::Message message{read_message.substr(0, bytes_read), utils::net::IpProtocol::TCP, socket.lowest_layer().remote_endpoint().address(), socket.lowest_layer().remote_endpoint().port()}; + if (last_was_partial || current_is_partial) + message.is_partial = true; + concurrent_queue_.enqueue(std::move(message)); + } else { + logger_->log_warn("Queue is full. TCP message ignored."); } + read_message.erase(0, bytes_read); + } +} - auto portStr = hostAndPort.at(1); - auto endpoint = utils::StringUtils::join_pack(realizedHost, ":", portStr); - - auto endPointFuture = live_clients_.find(endpoint); - // does not exist - if (endPointFuture == live_clients_.end()) { - logger_->log_info("creating endpoint for %s", endpoint); - if (hostAndPort.size() == 2) { - logger_->log_debug("Opening another socket to %s:%s is secure %d", realizedHost, portStr, (ssl_service_ != nullptr)); - std::unique_ptr<io::Socket> socket = - ssl_service_ != nullptr ? stream_factory_->createSecureSocket(realizedHost, std::stoi(portStr), ssl_service_) : stream_factory_->createSocket(realizedHost, std::stoi(portStr)); - if (!socket) { - logger_->log_error("Could not create socket during initialization for %s", endpoint); - continue; - } - socket->setNonBlocking(); - if (socket->initialize() != -1) { - logger_->log_debug("Enqueueing socket into ring buffer %s:%s", realizedHost, portStr); - socket_ring_buffer_.enqueue(std::move(socket)); - } else { - logger_->log_error("Could not create socket during initialization for %s", endpoint); +template<class SocketType> +asio::awaitable<std::error_code> GetTCP::TcpClient::doReceiveFromEndpoint(const asio::ip::tcp::endpoint& endpoint, SocketType& socket) { + auto [connection_error] = co_await utils::net::asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, utils::net::use_nothrow_awaitable), timeout_duration_); // NOLINT + if (connection_error) + co_return connection_error; + auto [handshake_error] = co_await utils::net::handshake<SocketType>(socket, timeout_duration_); + if (handshake_error) + co_return handshake_error; + co_return co_await readLoop(socket); +} + +asio::awaitable<void> GetTCP::TcpClient::doReceiveFrom(const utils::net::ConnectionId& connection_id) { + while (true) { + asio::ip::tcp::resolver resolver(io_context_); + auto [resolve_error, resolve_result] = co_await utils::net::asyncOperationWithTimeout( // NOLINT + resolver.async_resolve(connection_id.getHostname(), connection_id.getService(), utils::net::use_nothrow_awaitable), timeout_duration_); + if (resolve_error) { + logger_->log_error("Error during resolution: %s", resolve_error.message()); + co_await utils::net::async_wait(reconnection_interval_); + continue; + } + + std::error_code last_error; + for (const auto& endpoint : resolve_result) { + if (ssl_context_) { + utils::net::SslSocket ssl_socket{io_context_, *ssl_context_}; + last_error = co_await doReceiveFromEndpoint<utils::net::SslSocket>(endpoint, ssl_socket); + if (last_error) continue; - } } else { - logger_->log_error("Could not create socket for %s", endpoint); - } - auto* future = new std::future<int>(); - std::unique_ptr<utils::AfterExecute<int>> after_execute = std::unique_ptr<utils::AfterExecute<int>>(new SocketAfterExecute(running_, endpoint, &live_clients_, &mutex_)); - utils::Worker<int> functor(f_ex, "workers", std::move(after_execute)); - client_thread_pool_.execute(std::move(functor), *future); - live_clients_[endpoint] = future; - } else { - if (!endPointFuture->second->valid()) { - delete endPointFuture->second; - auto* future = new std::future<int>(); - std::unique_ptr<utils::AfterExecute<int>> after_execute = std::unique_ptr<utils::AfterExecute<int>>(new SocketAfterExecute(running_, endpoint, &live_clients_, &mutex_)); - utils::Worker<int> functor(f_ex, "workers", std::move(after_execute)); - client_thread_pool_.execute(std::move(functor), *future); - live_clients_[endpoint] = future; - } else { - logger_->log_debug("Thread still running for %s", endPointFuture->first); - // we have a thread corresponding to this. + utils::net::TcpSocket tcp_socket(io_context_); + last_error = co_await doReceiveFromEndpoint<utils::net::TcpSocket>(endpoint, tcp_socket); + if (last_error) + continue; } } + logger_->log_error("Error connecting to %s:%s due to %s", connection_id.getHostname().data(), connection_id.getService().data(), last_error.message()); + co_await utils::net::async_wait(reconnection_interval_); } - logger_->log_debug("Updating endpoint"); - context->yield(); } REGISTER_RESOURCE(GetTCP, Processor); diff --git a/extensions/standard-processors/processors/GetTCP.h b/extensions/standard-processors/processors/GetTCP.h index 7e0dd03fd..771e8af91 100644 --- a/extensions/standard-processors/processors/GetTCP.h +++ b/extensions/standard-processors/processors/GetTCP.h @@ -23,6 +23,8 @@ #include <utility> #include <vector> #include <atomic> +#include <asio/io_context.hpp> +#include "utils/Literals.h" #include "../core/state/nodes/MetricsBase.h" #include "FlowFileRecord.h" @@ -36,64 +38,11 @@ #include "controllers/SSLContextService.h" #include "utils/gsl.h" #include "utils/Export.h" +#include "utils/net/AsioSocketUtils.h" +#include "utils/net/Message.h" namespace org::apache::nifi::minifi::processors { -class SocketAfterExecute : public utils::AfterExecute<int> { - public: - explicit SocketAfterExecute(std::atomic<bool> &running, std::string endpoint, std::map<std::string, std::future<int>*> *list, std::mutex *mutex) - : running_(running.load()), - endpoint_(std::move(endpoint)), - mutex_(mutex), - list_(list) { - } - - SocketAfterExecute(const SocketAfterExecute&) = delete; - SocketAfterExecute(SocketAfterExecute&&) = delete; - - SocketAfterExecute& operator=(const SocketAfterExecute&) = delete; - SocketAfterExecute& operator=(SocketAfterExecute&&) = delete; - - ~SocketAfterExecute() override = default; - - bool isFinished(const int &result) override { - if (result == -1 || result == 0 || !running_) { - std::lock_guard<std::mutex> lock(*mutex_); - list_->erase(endpoint_); - return true; - } else { - return false; - } - } - bool isCancelled(const int& /*result*/) override { - return !running_; - } - - std::chrono::steady_clock::duration wait_time() override { - // wait 500ms - return std::chrono::milliseconds(500); - } - - protected: - std::atomic<bool> running_; - std::string endpoint_; - std::mutex *mutex_; - std::map<std::string, std::future<int>*> *list_; -}; - -class DataHandler { - public: - DataHandler(std::shared_ptr<core::ProcessSessionFactory> sessionFactory) // NOLINT - : sessionFactory_(std::move(sessionFactory)) { - } - static const char *SOURCE_ENDPOINT_ATTRIBUTE; - - int16_t handle(const std::string& source, uint8_t *message, size_t size, bool partial); - - private: - std::shared_ptr<core::ProcessSessionFactory> sessionFactory_; -}; - class GetTCP : public core::Processor { public: explicit GetTCP(std::string name, const utils::Identifier& uuid = {}) @@ -101,30 +50,35 @@ class GetTCP : public core::Processor { } ~GetTCP() override { - // thread pool must be shut down first before members it is using are destructed, otherwise segfault is possible - client_thread_pool_.shutdown(); + if (client_) { + client_->stop(); + } + if (client_thread_.joinable()) { + client_thread_.join(); + } + client_.reset(); } EXTENSIONAPI static constexpr const char* Description = "Establishes a TCP Server that defines and retrieves one or more byte messages from clients"; EXTENSIONAPI static const core::Property EndpointList; - EXTENSIONAPI static const core::Property ConcurrentHandlers; - EXTENSIONAPI static const core::Property ReconnectInterval; - EXTENSIONAPI static const core::Property StayConnected; - EXTENSIONAPI static const core::Property ReceiveBufferSize; EXTENSIONAPI static const core::Property SSLContextService; - EXTENSIONAPI static const core::Property ConnectionAttemptLimit; - EXTENSIONAPI static const core::Property EndOfMessageByte; + EXTENSIONAPI static const core::Property MessageDelimiter; + EXTENSIONAPI static const core::Property MaxQueueSize; + EXTENSIONAPI static const core::Property MaxMessageSize; + EXTENSIONAPI static const core::Property MaxBatchSize; + EXTENSIONAPI static const core::Property Timeout; + EXTENSIONAPI static const core::Property ReconnectInterval; static auto properties() { return std::array{ EndpointList, - ConcurrentHandlers, - ReconnectInterval, - StayConnected, - ReceiveBufferSize, SSLContextService, - ConnectionAttemptLimit, - EndOfMessageByte + MessageDelimiter, + MaxQueueSize, + MaxMessageSize, + MaxBatchSize, + Timeout, + ReconnectInterval }; } @@ -148,28 +102,55 @@ class GetTCP : public core::Processor { throw std::logic_error{"GetTCP::onTrigger(ProcessContext*, ProcessSession*) is unimplemented"}; } void initialize() override; - - protected: void notifyStop() override; private: - std::function<int()> f_ex; - std::atomic<bool> running_{false}; - std::unique_ptr<DataHandler> handler_; - std::vector<std::string> endpoints; - std::map<std::string, std::future<int>*> live_clients_; - moodycamel::ConcurrentQueue<std::unique_ptr<io::Socket>> socket_ring_buffer_; - bool stay_connected_{true}; - uint16_t concurrent_handlers_{2}; - std::byte endOfMessageByte{13}; - std::chrono::milliseconds reconnect_interval_{5000}; - uint64_t receive_buffer_size_{16 * 1024 * 1024}; - uint16_t connection_attempt_limit_{3}; - // Mutex for ensuring clients are running - std::mutex mutex_; - std::shared_ptr<minifi::controllers::SSLContextService> ssl_service_; + static void transferAsFlowFile(const utils::net::Message& message, core::ProcessSession& session); + + class TcpClient { + public: + TcpClient(char delimiter, + asio::steady_timer::duration timeout_duration, + asio::steady_timer::duration reconnection_interval, + std::optional<asio::ssl::context> ssl_context, + std::optional<size_t> max_queue_size, + std::optional<size_t> max_message_size, + std::vector<utils::net::ConnectionId> connections, + std::shared_ptr<core::logging::Logger> logger); + + ~TcpClient(); + + void run(); + void stop(); + + bool queueEmpty() const; + bool tryDequeue(utils::net::Message& received_message); + + private: + asio::awaitable<void> doReceiveFrom(const utils::net::ConnectionId& connection_id); + + template<class SocketType> + asio::awaitable<std::error_code> doReceiveFromEndpoint(const asio::ip::tcp::endpoint& endpoint, SocketType& socket); + + asio::awaitable<std::error_code> readLoop(auto& socket); + + utils::ConcurrentQueue<utils::net::Message> concurrent_queue_; + asio::io_context io_context_; + + char delimiter_; + asio::steady_timer::duration timeout_duration_; + asio::steady_timer::duration reconnection_interval_; + std::optional<asio::ssl::context> ssl_context_; + std::optional<size_t> max_queue_size_; + std::optional<size_t> max_message_size_; + std::vector<utils::net::ConnectionId> connections_; + std::shared_ptr<core::logging::Logger> logger_; + }; + + std::optional<TcpClient> client_; + size_t max_batch_size_{500}; + std::thread client_thread_; std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<GetTCP>::getLogger(uuid_); - utils::ThreadPool<int> client_thread_pool_; }; } // namespace org::apache::nifi::minifi::processors diff --git a/extensions/standard-processors/processors/PutTCP.cpp b/extensions/standard-processors/processors/PutTCP.cpp index 31cb4eeec..495ec1618 100644 --- a/extensions/standard-processors/processors/PutTCP.cpp +++ b/extensions/standard-processors/processors/PutTCP.cpp @@ -162,20 +162,11 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio } namespace { -template<class SocketType> -asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, asio::steady_timer::duration) { - co_return std::error_code(); -} - -template<> -asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) { - co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT -} template<class SocketType> class ConnectionHandler : public ConnectionHandlerBase { public: - ConnectionHandler(detail::ConnectionId connection_id, + ConnectionHandler(utils::net::ConnectionId connection_id, std::chrono::milliseconds timeout, std::shared_ptr<core::logging::Logger> logger, std::optional<size_t> max_size_of_socket_send_buffer, @@ -212,11 +203,11 @@ class ConnectionHandler : public ConnectionHandlerBase { SocketType createNewSocket(asio::io_context& io_context_); - detail::ConnectionId connection_id_; + utils::net::ConnectionId connection_id_; std::optional<SocketType> socket_; std::optional<steady_clock::time_point> last_used_; - std::chrono::milliseconds timeout_duration_; + asio::steady_timer::duration timeout_duration_; std::shared_ptr<core::logging::Logger> logger_; std::optional<size_t> max_size_of_socket_send_buffer_; @@ -247,7 +238,7 @@ asio::awaitable<std::error_code> ConnectionHandler<SocketType>::establishNewConn last_error = connection_error; continue; } - auto [handshake_error] = co_await handshake(socket, timeout_duration_); + auto [handshake_error] = co_await utils::net::handshake(socket, timeout_duration_); if (handshake_error) { core::logging::LOG_DEBUG(logger_) << "Handshake with " << endpoint.endpoint() << " failed due to " << handshake_error.message(); last_error = handshake_error; @@ -266,7 +257,8 @@ template<class SocketType> if (hasUsableSocket()) co_return std::error_code(); tcp::resolver resolver(io_context); - auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout(resolver.async_resolve(connection_id_.getHostname(), connection_id_.getPort(), use_nothrow_awaitable), timeout_duration_); + auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout( + resolver.async_resolve(connection_id_.getHostname(), connection_id_.getService(), use_nothrow_awaitable), timeout_duration_); if (resolve_error) co_return resolve_error; co_return co_await establishNewConnection(resolve_result, io_context); @@ -328,7 +320,7 @@ void PutTCP::onTrigger(core::ProcessContext* context, core::ProcessSession* cons return; } - auto connection_id = detail::ConnectionId(std::move(hostname), std::move(port)); + auto connection_id = utils::net::ConnectionId(std::move(hostname), std::move(port)); std::shared_ptr<ConnectionHandlerBase> handler; if (!connections_ || !connections_->contains(connection_id)) { if (ssl_context_) diff --git a/extensions/standard-processors/processors/PutTCP.h b/extensions/standard-processors/processors/PutTCP.h index 58ae28f94..113a38d2f 100644 --- a/extensions/standard-processors/processors/PutTCP.h +++ b/extensions/standard-processors/processors/PutTCP.h @@ -31,39 +31,12 @@ #include "utils/expected.h" #include "utils/StringUtils.h" // for string <=> on libc++ +#include "utils/net/AsioSocketUtils.h" #include <asio/io_context.hpp> #include <asio/awaitable.hpp> #include <asio/ssl/context.hpp> -namespace org::apache::nifi::minifi::processors::detail { - -class ConnectionId { - public: - ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), port_(std::move(port)) {} - - auto operator<=>(const ConnectionId&) const = default; - - [[nodiscard]] std::string_view getHostname() const { return hostname_; } - [[nodiscard]] std::string_view getPort() const { return port_; } - - private: - std::string hostname_; - std::string port_; -}; -} // namespace org::apache::nifi::minifi::processors::detail - -namespace std { -template<> -struct hash<org::apache::nifi::minifi::processors::detail::ConnectionId> { - size_t operator()(const org::apache::nifi::minifi::processors::detail::ConnectionId& connection_id) const { - return org::apache::nifi::minifi::utils::hash_combine( - std::hash<std::string_view>{}(connection_id.getHostname()), - std::hash<std::string_view>{}(connection_id.getPort())); - } -}; -} // namespace std - namespace org::apache::nifi::minifi::processors { class ConnectionHandlerBase { public: @@ -128,7 +101,7 @@ class PutTCP final : public core::Processor { std::vector<std::byte> delimiter_; asio::io_context io_context_; - std::optional<std::unordered_map<detail::ConnectionId, std::shared_ptr<ConnectionHandlerBase>>> connections_; + std::optional<std::unordered_map<utils::net::ConnectionId, std::shared_ptr<ConnectionHandlerBase>>> connections_; std::optional<std::chrono::milliseconds> idle_connection_expiration_; std::optional<size_t> max_size_of_socket_send_buffer_; std::chrono::milliseconds timeout_duration_ = std::chrono::seconds(15); diff --git a/extensions/standard-processors/processors/TailFile.cpp b/extensions/standard-processors/processors/TailFile.cpp index 97bcfa04a..147766592 100644 --- a/extensions/standard-processors/processors/TailFile.cpp +++ b/extensions/standard-processors/processors/TailFile.cpp @@ -49,6 +49,7 @@ #include "core/PropertyBuilder.h" #include "core/Resource.h" #include "utils/RegexUtils.h" +#include "utils/expected.h" namespace org::apache::nifi::minifi::processors { @@ -173,19 +174,6 @@ uint64_t readOptionalUint64(const Container &container, const Key &key) { } } -// the delimiter is the first character of the input, allowing some escape sequences -std::string parseDelimiter(const std::string &input) { - if (input.empty()) return ""; - if (input[0] != '\\') return std::string{ input[0] }; - if (input.size() == std::size_t{1}) return "\\"; - switch (input[1]) { - case 'r': return "\r"; - case 't': return "\t"; - case 'n': return "\n"; - default: return std::string{ input[1] }; - } -} - std::map<std::filesystem::path, TailState> update_keys_in_legacy_states(const std::map<std::filesystem::path, TailState> &legacy_tail_states) { std::map<std::filesystem::path, TailState> new_tail_states; for (const auto &key_value_pair : legacy_tail_states) { @@ -343,10 +331,11 @@ void TailFile::onSchedule(const std::shared_ptr<core::ProcessContext> &context, throw Exception(PROCESSOR_EXCEPTION, "Failed to get StateManager"); } - std::string value; - - if (context->getProperty(Delimiter.getName(), value)) { - delimiter_ = parseDelimiter(value); + if (auto delimiter_str = context->getProperty(Delimiter)) { + auto parsed_delimiter = utils::StringUtils::parseCharacter(*delimiter_str); + if (!parsed_delimiter) + throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Invalid delimiter: {} (it must be a single character, whether escaped or not)", *delimiter_str)); + delimiter_ = *parsed_delimiter; } std::string file_name_str; @@ -788,12 +777,11 @@ void TailFile::processSingleFile(const std::shared_ptr<core::ProcessSession> &se if (extension.starts_with('.')) extension.erase(extension.begin()); - if (!delimiter_.empty()) { - char delim = delimiter_[0]; - logger_->log_trace("Looking for delimiter 0x%X", delim); + if (delimiter_) { + logger_->log_trace("Looking for delimiter 0x%X", *delimiter_); std::size_t num_flow_files = 0; - FileReaderCallback file_reader{full_file_name, state.position_, delim, state.checksum_}; + FileReaderCallback file_reader{full_file_name, state.position_, *delimiter_, state.checksum_}; TailState state_copy{state}; while (file_reader.hasMoreToRead() && (!batch_size_ || *batch_size_ > num_flow_files)) { diff --git a/extensions/standard-processors/processors/TailFile.h b/extensions/standard-processors/processors/TailFile.h index 286ec6b8b..705717dd2 100644 --- a/extensions/standard-processors/processors/TailFile.h +++ b/extensions/standard-processors/processors/TailFile.h @@ -200,7 +200,7 @@ class TailFile : public core::Processor { static const char *POSITION_STR; static const int BUFFER_SIZE = 512; - std::string delimiter_; // Delimiter for the data incoming from the tailed file. + std::optional<char> delimiter_; // Delimiter for the data incoming from the tailed file. core::StateManager* state_manager_ = nullptr; std::map<std::filesystem::path, TailState> tail_states_; Mode tail_mode_ = Mode::UNDEFINED; diff --git a/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp b/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp index 4dc638c5b..68492d5bb 100644 --- a/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp +++ b/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp @@ -77,7 +77,7 @@ class SecureSocketTest : public IntegrationBase { void runAssertions() override { using org::apache::nifi::minifi::utils::verifyLogLinePresenceInPollTime; - assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "SSL socket connect success")); + assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "Accepted on")); isRunning_ = false; server_socket_.reset(); assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "send succeed 20")); diff --git a/extensions/standard-processors/tests/unit/GetTCPTests.cpp b/extensions/standard-processors/tests/unit/GetTCPTests.cpp index a267b67a5..8d4b88a0c 100644 --- a/extensions/standard-processors/tests/unit/GetTCPTests.cpp +++ b/extensions/standard-processors/tests/unit/GetTCPTests.cpp @@ -1,5 +1,4 @@ /** - * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,391 +14,286 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include <utility> -#include <memory> #include <string> -#include <vector> -#include <set> -#include "unit/ProvenanceTestHelper.h" -#include "TestBase.h" -#include "Catch.h" -#include "RandomServerSocket.h" -#include "Scheduling.h" -#include "LogAttribute.h" -#include "GetTCP.h" -#include "core/Core.h" -#include "core/FlowFile.h" -#include "core/Processor.h" -#include "core/ProcessContext.h" -#include "core/ProcessSession.h" -#include "core/ProcessorNode.h" -#include "core/reporting/SiteToSiteProvenanceReportingTask.h" - -TEST_CASE("GetTCPWithoutEOM", "[GetTCP1]") { - TestController testController; - std::vector<uint8_t> buffer; - for (auto c : "Hello World\nHello Warld\nGoodByte Cruel world") { - buffer.push_back(c); - } - std::shared_ptr<core::ContentRepository> content_repo = std::make_shared<core::repository::VolatileContentRepository>(); - - content_repo->initialize(std::make_shared<minifi::Configure>()); - - std::shared_ptr<org::apache::nifi::minifi::io::StreamFactory> stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared<minifi::Configure>()); - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); - - LogTestController::getInstance().setDebug<minifi::processors::LogAttribute>(); - LogTestController::getInstance().setDebug<minifi::processors::GetTCP>(); - LogTestController::getInstance().setTrace<minifi::io::Socket>(); - - std::shared_ptr<core::Repository> repo = std::make_shared<TestRepository>(); - - auto processor = std::make_unique<org::apache::nifi::minifi::processors::GetTCP>("gettcpexample"); - - auto logAttribute = std::make_unique<org::apache::nifi::minifi::processors::LogAttribute>("logattribute"); - - processor->setStreamFactory(stream_factory); - processor->initialize(); - - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); - - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); - - REQUIRE(processoruuid.to_string() != logattribute_uuid.to_string()); - - auto connection = std::make_unique<minifi::Connection>(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("success", "description")); - - auto connection2 = std::make_unique<minifi::Connection>(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("success", "description")); - - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); - - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); - connection2->setSource(logAttribute.get()); - - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); - - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); - - auto node = std::make_shared<core::ProcessorNode>(processor.get()); - auto node2 = std::make_shared<core::ProcessorNode>(logAttribute.get()); - auto context = std::make_shared<core::ProcessContext>(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared<core::ProcessContext>(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - auto session = std::make_shared<core::ProcessSession>(context); - auto session2 = std::make_shared<core::ProcessSession>(context2); - - REQUIRE(processor->getName() == "gettcpexample"); - - std::shared_ptr<core::FlowFile> record; - processor->setScheduledState(core::ScheduledState::RUNNING); - - std::shared_ptr<core::ProcessSessionFactory> factory = std::make_shared<core::ProcessSessionFactory>(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); - - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr<core::ProcessSessionFactory> factory2 = std::make_shared<core::ProcessSessionFactory>(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); +#include "Catch.h" +#include "processors/GetTCP.h" +#include "SingleProcessorTestController.h" +#include "Utils.h" +#include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" +#include "controllers/SSLContextService.h" +#include "range/v3/algorithm/contains.hpp" +#include "utils/gsl.h" - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); +using GetTCP = org::apache::nifi::minifi::processors::GetTCP; - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); +using namespace std::literals::chrono_literals; - session->commit(); +namespace org::apache::nifi::minifi::test { - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); +void check_for_attributes(core::FlowFile& flow_file, uint16_t port) { + CHECK(std::to_string(port) == flow_file.getAttribute("tcp.port")); + const auto local_addresses = {"127.0.0.1", "::ffff:127.0.0.1", "::1"}; + CHECK(ranges::contains(local_addresses, flow_file.getAttribute("tcp.sender"))); +} - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:45 Offset:0")); +minifi::utils::net::SslData createSslDataForServer() { + const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); + minifi::utils::net::SslData ssl_data; + ssl_data.ca_loc = (executable_dir / "resources" / "ca_A.crt").string(); + ssl_data.cert_loc = (executable_dir / "resources" / "localhost_by_A.pem").string(); + ssl_data.key_loc = (executable_dir / "resources" / "localhost_by_A.pem").string(); + return ssl_data; +} - LogTestController::getInstance().reset(); +void addSslContextServiceTo(SingleProcessorTestController& controller) { + auto ssl_context_service = controller.plan->addController("SSLContextService", "SSLContextService"); + LogTestController::getInstance().setTrace<GetTCP>(); + const auto executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::CACertificate.getName(), (executable_dir / "resources" / "ca_A.crt").string())); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::ClientCertificate.getName(), (executable_dir / "resources" / "alice_by_A.pem").string())); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::PrivateKey.getName(), (executable_dir / "resources" / "alice_by_A.pem").string())); + ssl_context_service->enable(); } -TEST_CASE("GetTCPWithOEM", "[GetTCP2]") { - std::vector<uint8_t> buffer; - for (auto c : "Hello World\nHello Warld\nGoodByte Cruel world") { - buffer.push_back(c); +class TcpTestServer { + public: + void run() { + server_thread_ = std::thread([&]() { + asio::co_spawn(io_context_, listenAndSendMessages(), asio::detached); + io_context_.run(); + }); } - std::shared_ptr<core::ContentRepository> content_repo = std::make_shared<core::repository::VolatileContentRepository>(); - - content_repo->initialize(std::make_shared<minifi::Configure>()); - - std::shared_ptr<org::apache::nifi::minifi::io::StreamFactory> stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared<minifi::Configure>()); - - TestController testController; - - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); - - LogTestController::getInstance().setDebug<minifi::processors::LogAttribute>(); - LogTestController::getInstance().setTrace<core::repository::VolatileContentRepository >(); - LogTestController::getInstance().setTrace<minifi::processors::GetTCP>(); - LogTestController::getInstance().setTrace<core::ConfigurableComponent>(); - LogTestController::getInstance().setTrace<minifi::io::Socket>(); - - std::shared_ptr<core::Repository> repo = std::make_shared<TestRepository>(); - - std::shared_ptr<core::Processor> processor = std::make_shared<org::apache::nifi::minifi::processors::GetTCP>("gettcpexample"); - std::shared_ptr<core::Processor> logAttribute = std::make_shared<org::apache::nifi::minifi::processors::LogAttribute>("logattribute"); - - processor->setStreamFactory(stream_factory); - processor->initialize(); - - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); - - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); - - auto connection = std::make_unique<minifi::Connection>(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("partial", "description")); - - auto connection2 = std::make_unique<minifi::Connection>(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("partial", "description")); - - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); - - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); - - connection2->setSource(logAttribute.get()); - - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); - - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); - - auto node = std::make_shared<core::ProcessorNode>(processor.get()); - auto node2 = std::make_shared<core::ProcessorNode>(logAttribute.get()); - auto context = std::make_shared<core::ProcessContext>(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared<core::ProcessContext>(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - // we're using new lines above - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte, "10"); - auto session = std::make_shared<core::ProcessSession>(context); - auto session2 = std::make_shared<core::ProcessSession>(context2); - - - REQUIRE(processor->getName() == "gettcpexample"); - - std::shared_ptr<core::FlowFile> record; - processor->setScheduledState(core::ScheduledState::RUNNING); - - std::shared_ptr<core::ProcessSessionFactory> factory = std::make_shared<core::ProcessSessionFactory>(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); + void queueMessage(std::string message) { + messages_to_send_.enqueue(std::move(message)); + } - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr<core::ProcessSessionFactory> factory2 = std::make_shared<core::ProcessSessionFactory>(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); + void enableSSL() { + const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + asio::ssl::context ssl_context(asio::ssl::context::tls_server); + ssl_context.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + ssl_context.set_password_callback([key_pw = "Password12"](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; }); + ssl_context.use_certificate_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem); + ssl_context.use_private_key_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem); + ssl_context.load_verify_file((executable_dir / "resources" / "ca_A.crt").string()); + ssl_context.set_verify_mode(asio::ssl::verify_peer); - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); + ssl_context_ = std::move(ssl_context); + } - session->commit(); + uint16_t getPort() const { + return port_; + } - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + ~TcpTestServer() { + io_context_.stop(); + if (server_thread_.joinable()) + server_thread_.join(); + } - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + private: + asio::awaitable<void> sendMessages(auto& socket) { + while (true) { + std::string message_to_send; + if (!messages_to_send_.tryDequeue(message_to_send)) { + co_await minifi::utils::net::async_wait(10ms); + continue; + } + co_await asio::async_write(socket, asio::buffer(message_to_send), minifi::utils::net::use_nothrow_awaitable); + } + } - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:11 Offset:0")); - REQUIRE(true == LogTestController::getInstance().contains("Size:12 Offset:0")); - REQUIRE(true == LogTestController::getInstance().contains("Size:22 Offset:0")); + asio::awaitable<void> secureSession(asio::ip::tcp::socket socket) { + gsl_Expects(ssl_context_); + minifi::utils::net::SslSocket ssl_socket(std::move(socket), *ssl_context_); + auto [handshake_error] = co_await ssl_socket.async_handshake(minifi::utils::net::HandshakeType::server, minifi::utils::net::use_nothrow_awaitable); + if (handshake_error) { + co_return; + } + co_await sendMessages(ssl_socket); + } - LogTestController::getInstance().reset(); -} + asio::awaitable<void> insecureSession(asio::ip::tcp::socket socket) { + co_await sendMessages(socket); + } -TEST_CASE("GetTCPWithOnlyOEM", "[GetTCP3]") { - std::vector<uint8_t> buffer; - for (auto c : "\n") { - buffer.push_back(c); + asio::awaitable<void> listenAndSendMessages() { + asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_)); + if (port_ == 0) + port_ = acceptor.local_endpoint().port(); + while (true) { + auto [accept_error, socket] = co_await acceptor.async_accept(minifi::utils::net::use_nothrow_awaitable); + if (accept_error) { + co_return; + } + if (ssl_context_) + co_spawn(io_context_, secureSession(std::move(socket)), asio::detached); + else + co_spawn(io_context_, insecureSession(std::move(socket)), asio::detached); + } } - std::shared_ptr<core::ContentRepository> content_repo = std::make_shared<core::repository::VolatileContentRepository>(); + std::optional<asio::ssl::context> ssl_context_; + minifi::utils::ConcurrentQueue<std::string> messages_to_send_; + std::atomic<uint16_t> port_ = 0; + std::thread server_thread_; + asio::io_context io_context_; +}; - content_repo->initialize(std::make_shared<minifi::Configure>()); +TEST_CASE("GetTCP test with delimiter", "[GetTCP]") { + const auto get_tcp = std::make_shared<GetTCP>("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace<GetTCP>(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); - std::shared_ptr<org::apache::nifi::minifi::io::StreamFactory> stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared<minifi::Configure>()); - TestController testController; + TcpTestServer tcp_test_server; - LogTestController::getInstance().setDebug<minifi::io::Socket>(); + SECTION("No SSL") {} - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); + SECTION("SSL") { + addSslContextServiceTo(controller); + tcp_test_server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - LogTestController::getInstance().setDebug<minifi::processors::LogAttribute>(); + tcp_test_server.queueMessage("Hello\n"); + tcp_test_server.run(); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms)); - LogTestController::getInstance().setDebug<minifi::processors::GetTCP>(); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - std::shared_ptr<core::Repository> repo = std::make_shared<TestRepository>(); + ProcessorTriggerResult result; + REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms)); + CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Hello\n"); - std::shared_ptr<core::Processor> processor = std::make_shared<org::apache::nifi::minifi::processors::GetTCP>("gettcpexample"); + check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort()); +} - std::shared_ptr<core::Processor> logAttribute = std::make_shared<org::apache::nifi::minifi::processors::LogAttribute>("logattribute"); +TEST_CASE("GetTCP test with too large message", "[GetTCP]") { + const auto get_tcp = std::make_shared<GetTCP>("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace<GetTCP>(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); + REQUIRE(get_tcp->setProperty(GetTCP::MaxMessageSize, "10")); + REQUIRE(get_tcp->setProperty(GetTCP::MessageDelimiter, "\r")); - processor->setStreamFactory(stream_factory); - processor->initialize(); + TcpTestServer tcp_test_server; - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); + SECTION("No SSL") {} - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); + SECTION("SSL") { + addSslContextServiceTo(controller); + tcp_test_server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - auto connection = std::make_unique<minifi::Connection>(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("success", "description")); + tcp_test_server.queueMessage("abcdefghijklmnopqrstuvwxyz\rBye\r"); + tcp_test_server.run(); - auto connection2 = std::make_unique<minifi::Connection>(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("success", "description")); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms)); - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); + ProcessorTriggerResult result; + REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms)); + REQUIRE(result.at(GetTCP::Partial).size() == 3); + REQUIRE(result.at(GetTCP::Success).size() == 1); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[0]) == "abcdefghij"); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[1]) == "klmnopqrst"); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[2]) == "uvwxyz\r"); + CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Bye\r"); - connection2->setSource(logAttribute.get()); + check_for_attributes(*result.at(GetTCP::Partial)[0], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Partial)[1], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Partial)[2], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort()); +} - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); +TEST_CASE("GetTCP test multiple endpoints", "[GetTCP]") { + const auto get_tcp = std::make_shared<GetTCP>("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace<GetTCP>(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); + TcpTestServer server_1; + TcpTestServer server_2; - auto node = std::make_shared<core::ProcessorNode>(processor.get()); - auto node2 = std::make_shared<core::ProcessorNode>(logAttribute.get()); - auto context = std::make_shared<core::ProcessContext>(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared<core::ProcessContext>(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - // we're using new lines above - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte, "10"); - auto session = std::make_shared<core::ProcessSession>(context); - auto session2 = std::make_shared<core::ProcessSession>(context2); + SECTION("No SSL") {} + SECTION("SSL") { + addSslContextServiceTo(controller); + server_1.enableSSL(); + server_2.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - REQUIRE(processor->getName() == "gettcpexample"); + server_1.queueMessage("abcdefghijklmnopqrstuvwxyz\nBye\n"); + server_1.run(); - std::shared_ptr<core::FlowFile> record; - processor->setScheduledState(core::ScheduledState::RUNNING); + server_2.queueMessage("012345678901234567890\nAuf Wiedersehen\n"); + server_2.run(); - std::shared_ptr<core::ProcessSessionFactory> factory = std::make_shared<core::ProcessSessionFactory>(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server_1.getPort() != 0 && server_2.getPort() != 0; }, 20ms)); - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr<core::ProcessSessionFactory> factory2 = std::make_shared<core::ProcessSessionFactory>(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{},localhost:{}", server_1.getPort(), server_2.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + ProcessorTriggerResult result; + CHECK(controller.triggerUntil({{GetTCP::Success, 4}}, result, 1s, 50ms)); + CHECK(result.at(GetTCP::Success).size() == 4); - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); + std::vector<std::string> success_flow_file_contents; + for (const auto& flow_file: result.at(GetTCP::Success)) { + success_flow_file_contents.push_back(controller.plan->getContent(flow_file)); + } - session->commit(); + CHECK(ranges::contains(success_flow_file_contents, "abcdefghijklmnopqrstuvwxyz\n")); + CHECK(ranges::contains(success_flow_file_contents, "Bye\n")); + CHECK(ranges::contains(success_flow_file_contents, "012345678901234567890\n")); + CHECK(ranges::contains(success_flow_file_contents, "Auf Wiedersehen\n")); +} - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); +TEST_CASE("GetTCP max queue and max batch size test", "[GetTCP]") { + const auto get_tcp = std::make_shared<GetTCP>("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace<GetTCP>(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "10")); + REQUIRE(get_tcp->setProperty(GetTCP::MaxQueueSize, "50")); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + TcpTestServer server; - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:2 Offset:0")); - LogTestController::getInstance().reset(); -} + SECTION("No SSL") {} + SECTION("SSL") { + addSslContextServiceTo(controller); + server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } -TEST_CASE("GetTCPEmptyNoConnect", "[GetTCP3]") { - TestController testController; - LogTestController::getInstance().setDebug<minifi::processors::LogAttribute>(); - LogTestController::getInstance().setDebug<minifi::processors::GetTCP>(); - LogTestController::getInstance().setTrace<minifi::io::Socket>(); + LogTestController::getInstance().setWarn<GetTCP>(); - std::shared_ptr<TestPlan> plan = testController.createPlan(); - std::shared_ptr<core::Processor> getfile = plan->addProcessor("GetTCP", "gettcpexample"); + for (auto i = 0; i < 100; ++i) { + server.queueMessage("some_message\n"); + } - plan->addProcessor("LogAttribute", "logattribute", core::Relationship("success", "description"), true); + server.run(); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::EndpointList.getName(), org::apache::nifi::minifi::io::Socket::getMyHostName() + ":9182"); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval.getName(), "200 msec"); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit.getName(), "10"); - // we're using new lines above - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte.getName(), "10"); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server.getPort() != 0; }, 20ms)); - TestController::runSession(plan, false); - auto records = plan->getProvenanceRecords(); - std::shared_ptr<core::FlowFile> record = plan->getCurrentFlowFile(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Could not create socket during initialization for " + org::apache::nifi::minifi::io::Socket::getMyHostName() + ":9182")); - LogTestController::getInstance().reset(); + CHECK(utils::countLogOccurrencesUntil("Queue is full. TCP message ignored.", 50, 300ms, 50ms)); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).empty()); } +} // namespace org::apache::nifi::minifi::test diff --git a/libminifi/include/utils/StringUtils.h b/libminifi/include/utils/StringUtils.h index 2e6f7285c..ea1f85976 100644 --- a/libminifi/include/utils/StringUtils.h +++ b/libminifi/include/utils/StringUtils.h @@ -34,6 +34,7 @@ #endif #include "utils/FailurePolicy.h" #include "utils/gsl.h" +#include "utils/expected.h" #include "utils/meta/detected.h" #include "range/v3/view/transform.hpp" #include "range/v3/range/conversion.hpp" @@ -490,6 +491,10 @@ class StringUtils { static bool matchesSequence(std::string_view str, const std::vector<std::string>& patterns); static bool splitToValueAndUnit(std::string_view input, int64_t& value, std::string& unit); + + struct ParseError {}; + + static nonstd::expected<std::optional<char>, ParseError> parseCharacter(const std::string &input); }; } // namespace org::apache::nifi::minifi::utils diff --git a/libminifi/include/utils/net/AsioCoro.h b/libminifi/include/utils/net/AsioCoro.h index 5c2e5268b..55a3a4cbc 100644 --- a/libminifi/include/utils/net/AsioCoro.h +++ b/libminifi/include/utils/net/AsioCoro.h @@ -35,10 +35,6 @@ namespace org::apache::nifi::minifi::utils::net { constexpr auto use_nothrow_awaitable = asio::experimental::as_tuple(asio::use_awaitable); -using HandshakeType = asio::ssl::stream_base::handshake_type; -using TcpSocket = asio::ip::tcp::socket; -using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; - #if defined(__GNUC__) && __GNUC__ < 11 // [coroutines] unexpected 'warning: statement has no effect [-Wunused-value]' // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=96749 @@ -52,19 +48,17 @@ inline asio::awaitable<void> async_wait(asio::steady_timer& timer) { #pragma GCC diagnostic pop #endif // defined(__GNUC__) && __GNUC__ < 11 -namespace detail { -inline asio::awaitable<void> timeout(std::chrono::steady_clock::duration duration) { +inline asio::awaitable<void> async_wait(std::chrono::steady_clock::duration duration) { asio::steady_timer timer(co_await asio::this_coro::executor); // NOLINT timer.expires_after(duration); co_await async_wait(timer); } -} // namespace detail template<class... Types> asio::awaitable<std::tuple<std::error_code, Types...>> asyncOperationWithTimeout(asio::awaitable<std::tuple<std::error_code, Types...>>&& async_operation, std::chrono::steady_clock::duration timeout_duration) { using asio::experimental::awaitable_operators::operator||; - auto operation_result = co_await(std::move(async_operation) || detail::timeout(timeout_duration)); + auto operation_result = co_await(std::move(async_operation) || async_wait(timeout_duration)); // NOLINT if (operation_result.index() == 1) { std::tuple<std::error_code, Types...> result; std::get<0>(result) = asio::error::timed_out; diff --git a/libminifi/include/utils/net/AsioSocketUtils.h b/libminifi/include/utils/net/AsioSocketUtils.h index dc397acd3..2cdcaa266 100644 --- a/libminifi/include/utils/net/AsioSocketUtils.h +++ b/libminifi/include/utils/net/AsioSocketUtils.h @@ -17,11 +17,59 @@ #pragma once +#include <string> +#include <utility> +#include <tuple> + #include "asio/ssl.hpp" +#include "asio/ip/tcp.hpp" + +#include "AsioCoro.h" +#include "utils/Hash.h" +#include "utils/StringUtils.h" // for string <=> on libc++ #include "controllers/SSLContextService.h" + namespace org::apache::nifi::minifi::utils::net { -asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method = asio::ssl::context::tls_client); +using HandshakeType = asio::ssl::stream_base::handshake_type; +using TcpSocket = asio::ip::tcp::socket; +using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; + +class ConnectionId { + public: + ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), service_(std::move(port)) {} + ConnectionId(const ConnectionId& connection_id) = default; + ConnectionId(ConnectionId&& connection_id) = default; + + auto operator<=>(const ConnectionId&) const = default; + [[nodiscard]] std::string_view getHostname() const { return hostname_; } + [[nodiscard]] std::string_view getService() const { return service_; } + + private: + std::string hostname_; + std::string service_; +}; + +template<class SocketType> +asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, asio::steady_timer::duration) = delete; +template<> +asio::awaitable<std::tuple<std::error_code>> handshake(TcpSocket&, asio::steady_timer::duration); +template<> +asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration); + + +asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method = asio::ssl::context::tls_client); } // namespace org::apache::nifi::minifi::utils::net + +namespace std { +template<> +struct hash<org::apache::nifi::minifi::utils::net::ConnectionId> { + size_t operator()(const org::apache::nifi::minifi::utils::net::ConnectionId& connection_id) const { + return org::apache::nifi::minifi::utils::hash_combine( + std::hash<std::string_view>{}(connection_id.getHostname()), + std::hash<std::string_view>{}(connection_id.getService())); + } +}; +} // namespace std diff --git a/libminifi/include/utils/net/AsioSocketUtils.h b/libminifi/include/utils/net/Message.h similarity index 61% copy from libminifi/include/utils/net/AsioSocketUtils.h copy to libminifi/include/utils/net/Message.h index dc397acd3..2cc05f717 100644 --- a/libminifi/include/utils/net/AsioSocketUtils.h +++ b/libminifi/include/utils/net/Message.h @@ -14,14 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include "asio/ssl.hpp" -#include "controllers/SSLContextService.h" +#include <string> +#include <utility> + +#include "IpProtocol.h" +#include "asio/ts/internet.hpp" namespace org::apache::nifi::minifi::utils::net { -asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method = asio::ssl::context::tls_client); +struct Message { + public: + Message() = default; + Message(std::string message_data, IpProtocol protocol, asio::ip::address sender_address, asio::ip::port_type server_port) + : message_data(std::move(message_data)), + protocol(protocol), + server_port(server_port), + sender_address(std::move(sender_address)) { + } + + bool is_partial = false; + std::string message_data; + IpProtocol protocol; + asio::ip::port_type server_port; + asio::ip::address sender_address; +}; } // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/include/utils/net/Server.h b/libminifi/include/utils/net/Server.h index e84815c3b..b36936ef6 100644 --- a/libminifi/include/utils/net/Server.h +++ b/libminifi/include/utils/net/Server.h @@ -25,30 +25,13 @@ #include "utils/MinifiConcurrentQueue.h" #include "core/logging/Logger.h" #include "asio/ts/buffer.hpp" -#include "asio/ts/internet.hpp" #include "asio/awaitable.hpp" #include "asio/co_spawn.hpp" #include "asio/detached.hpp" -#include "IpProtocol.h" +#include "Message.h" namespace org::apache::nifi::minifi::utils::net { -struct Message { - public: - Message() = default; - Message(std::string message_data, IpProtocol protocol, asio::ip::address sender_address, asio::ip::port_type server_port) - : message_data(std::move(message_data)), - protocol(protocol), - server_port(server_port), - sender_address(std::move(sender_address)) { - } - - std::string message_data; - IpProtocol protocol; - asio::ip::port_type server_port; - asio::ip::address sender_address; -}; - class Server { public: virtual void run() { diff --git a/libminifi/src/utils/StringUtils.cpp b/libminifi/src/utils/StringUtils.cpp index 84b7a531b..b6788ebc3 100644 --- a/libminifi/src/utils/StringUtils.cpp +++ b/libminifi/src/utils/StringUtils.cpp @@ -517,4 +517,24 @@ bool StringUtils::splitToValueAndUnit(std::string_view input, int64_t& value, st return true; } +nonstd::expected<std::optional<char>, StringUtils::ParseError> StringUtils::parseCharacter(const std::string &input) { + if (input.empty()) { return std::nullopt; } + if (input.size() == 1) { return input[0]; } + + if (input.size() == 2 && input.starts_with('\\')) { + switch (input[1]) { + case '0': return '\0'; // Null + case 'a': return '\a'; // Bell + case 'b': return '\b'; // Backspace + case 't': return '\t'; // Horizontal Tab + case 'n': return '\n'; // Line Feed + case 'v': return '\v'; // Vertical Tab + case 'f': return '\f'; // Form Feed + case 'r': return '\r'; // Carriage Return + default: return input[1]; + } + } + return nonstd::make_unexpected(ParseError{}); +} + } // namespace org::apache::nifi::minifi::utils diff --git a/libminifi/src/utils/net/AsioSocketUtils.cpp b/libminifi/src/utils/net/AsioSocketUtils.cpp index 10b6af63e..bd01b4534 100644 --- a/libminifi/src/utils/net/AsioSocketUtils.cpp +++ b/libminifi/src/utils/net/AsioSocketUtils.cpp @@ -16,9 +16,20 @@ */ #include "utils/net/AsioSocketUtils.h" +#include "controllers/SSLContextService.h" namespace org::apache::nifi::minifi::utils::net { +template<> +asio::awaitable<std::tuple<std::error_code>> handshake(TcpSocket&, asio::steady_timer::duration) { + co_return std::error_code(); +} + +template<> +asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) { + co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT +} + asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method) { asio::ssl::context ssl_context(ssl_context_method); ssl_context.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); @@ -31,5 +42,4 @@ asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_conte ssl_context.use_private_key_file(private_key_file.string(), asio::ssl::context::pem); return ssl_context; } - } // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/src/utils/net/TcpServer.cpp b/libminifi/src/utils/net/TcpServer.cpp index b1fa06b20..c443bf347 100644 --- a/libminifi/src/utils/net/TcpServer.cpp +++ b/libminifi/src/utils/net/TcpServer.cpp @@ -16,6 +16,9 @@ */ #include "utils/net/TcpServer.h" #include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" + +using namespace std::literals::chrono_literals; namespace org::apache::nifi::minifi::utils::net { @@ -27,7 +30,8 @@ asio::awaitable<void> TcpServer::doReceive() { auto [accept_error, socket] = co_await acceptor.async_accept(use_nothrow_awaitable); if (accept_error) { logger_->log_error("Error during accepting new connection: %s", accept_error.message()); - break; + co_await utils::net::async_wait(1s); + continue; } if (ssl_data_) co_spawn(io_context_, secureSession(std::move(socket)), asio::detached); diff --git a/libminifi/test/resources/TestC2Metrics.yml b/libminifi/test/resources/TestC2Metrics.yml index ea3b7eb74..6a0af5e4c 100644 --- a/libminifi/test/resources/TestC2Metrics.yml +++ b/libminifi/test/resources/TestC2Metrics.yml @@ -31,10 +31,10 @@ Processors: run duration nanos: 0 auto-terminated relationships list: Properties: - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -66,4 +66,3 @@ Connections: Controller Services: [] Remote Processing Groups: - diff --git a/libminifi/test/resources/TestGetTCPSecure.yml b/libminifi/test/resources/TestGetTCPSecure.yml index ecf56fea7..15618931b 100644 --- a/libminifi/test/resources/TestGetTCPSecure.yml +++ b/libminifi/test/resources/TestGetTCPSecure.yml @@ -32,10 +32,8 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: d - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml b/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml index f9dd9a0d6..dfcaa2fdd 100644 --- a/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml +++ b/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml @@ -32,10 +32,10 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:29776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:29776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -88,4 +88,3 @@ Controller Services: - value: nifi-cert.pem Remote Processing Groups: - diff --git a/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml b/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml index c48aa2567..0e68be782 100644 --- a/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml +++ b/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml @@ -32,10 +32,8 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:18776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:18776 + Message Delimiter: \r - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestGetTCPSecureWithPass.yml b/libminifi/test/resources/TestGetTCPSecureWithPass.yml index c11d76fe6..0393eb6a9 100644 --- a/libminifi/test/resources/TestGetTCPSecureWithPass.yml +++ b/libminifi/test/resources/TestGetTCPSecureWithPass.yml @@ -32,10 +32,10 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:28776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:28776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestSameProcessorMetrics.yml b/libminifi/test/resources/TestSameProcessorMetrics.yml index 2c842b8b8..4b5a4d69d 100644 --- a/libminifi/test/resources/TestSameProcessorMetrics.yml +++ b/libminifi/test/resources/TestSameProcessorMetrics.yml @@ -60,10 +60,10 @@ Processors: - partial Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: GetTCP2 id: 2438e3c8-015a-1000-79ca-83af40ec1996 class: org.apache.nifi.processors.standard.GetTCP @@ -78,10 +78,10 @@ Processors: - partial Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -120,4 +120,3 @@ Connections: flowfile expiration: 60 sec Remote Processing Groups: - diff --git a/libminifi/test/resources/encrypted.cn.pass b/libminifi/test/resources/encrypted.cn.pass index 9dd74dac1..cdbe4f5d3 100644 --- a/libminifi/test/resources/encrypted.cn.pass +++ b/libminifi/test/resources/encrypted.cn.pass @@ -1 +1 @@ -VsVTmHBzixyA9UfTCttRYXus1oMpIxO6jmDXrNrOp5w +VsVTmHBzixyA9UfTCttRYXus1oMpIxO6jmDXrNrOp5w \ No newline at end of file diff --git a/libminifi/test/unit/StringUtilsTests.cpp b/libminifi/test/unit/StringUtilsTests.cpp index 2b64b5af5..b71ba5998 100644 --- a/libminifi/test/unit/StringUtilsTests.cpp +++ b/libminifi/test/unit/StringUtilsTests.cpp @@ -594,4 +594,19 @@ TEST_CASE("StringUtils::splitToValueAndUnit tests") { } } +TEST_CASE("StringUtils::parseCharacter tests") { + CHECK(StringUtils::parseCharacter("a") == 'a'); + CHECK(StringUtils::parseCharacter("\\n") == '\n'); + CHECK(StringUtils::parseCharacter("\\t") == '\t'); + CHECK(StringUtils::parseCharacter("\\r") == '\r'); + CHECK(StringUtils::parseCharacter("\\s") == 's'); + CHECK(StringUtils::parseCharacter("\\'") == '\''); + CHECK(StringUtils::parseCharacter("\\") == '\\'); + CHECK(StringUtils::parseCharacter("\\?") == '\?'); + + CHECK_FALSE(StringUtils::parseCharacter("abc").has_value()); + CHECK_FALSE(StringUtils::parseCharacter("\\nd").has_value()); + CHECK(StringUtils::parseCharacter("") == std::nullopt); +} + // NOLINTEND(readability-container-size-empty)