fgerlits commented on code in PR #1457: URL: https://github.com/apache/nifi-minifi-cpp/pull/1457#discussion_r1067152515
########## extensions/standard-processors/processors/NetworkListenerProcessor.cpp: ########## @@ -66,16 +66,16 @@ void NetworkListenerProcessor::startTcpServer(const core::ProcessContext& contex auto options = readServerOptions(context); std::string ssl_value; + std::optional<utils::net::SslServerOptions> ssl_options; if (context.getProperty(ssl_context_property.getName(), ssl_value) && !ssl_value.empty()) { auto ssl_data = utils::net::getSslData(context, ssl_context_property, logger_); if (!ssl_data || !ssl_data->isValid()) { throw Exception(PROCESSOR_EXCEPTION, "SSL Context Service is set, but no valid SSL data was found!"); } - auto client_auth = utils::parseEnumProperty<utils::net::SslServer::ClientAuthOption>(context, client_auth_property); - server_ = std::make_unique<utils::net::SslServer>(options.max_queue_size, options.port, logger_, *ssl_data, client_auth); - } else { - server_ = std::make_unique<utils::net::TcpServer>(options.max_queue_size, options.port, logger_); + auto client_auth = utils::parseEnumProperty<utils::net::ClientAuthOption>(context, client_auth_property); + ssl_options.emplace(utils::net::SslServerOptions{std::move(*ssl_data), client_auth}); Review Comment: minor, but this constructs `SslServerOptions` first, then moves it; you can get rid of the move by changing it to ```suggestion ssl_options.emplace(std::move(*ssl_data), client_auth); ``` ########## extensions/standard-processors/tests/unit/ListenSyslogTests.cpp: ########## @@ -249,132 +248,155 @@ void check_parsed_attributes(const core::FlowFile& flow_file, const ValidRFC3164 CHECK(original_message.msg_ == flow_file.getAttribute("syslog.msg")); } -TEST_CASE("ListenSyslog without parsing test", "[ListenSyslog][NetworkListenerProcessor]") { +uint16_t schedule_on_random_port(SingleProcessorTestController& controller, const std::shared_ptr<ListenSyslog>& listen_syslog) { + REQUIRE(listen_syslog->setProperty(ListenSyslog::Port, "0")); + controller.plan->scheduleProcessor(listen_syslog); + uint16_t port = listen_syslog->getPort(); + auto deadline = std::chrono::steady_clock::now() + 200ms; + while (port == 0 && deadline > std::chrono::steady_clock::now()) { + std::this_thread::sleep_for(20ms); + port = listen_syslog->getPort(); + } + REQUIRE(port != 0); + return port; +} Review Comment: can we use `utils::scheduleProcessorOnRandomPort()` instead of this? ########## libminifi/test/Utils.h: ########## @@ -183,33 +188,54 @@ bool sendMessagesViaSSL(const std::vector<std::string_view>& contents, asio::error_code err; socket.lowest_layer().connect(remote_endpoint, err); if (err) { - return false; + return err; } socket.handshake(asio::ssl::stream_base::client, err); if (err) { - return false; + return err; } for (auto& content : contents) { std::string tcp_message(content); tcp_message += '\n'; asio::write(socket, asio::buffer(tcp_message, tcp_message.size()), err); if (err) { - return false; + return err; } } - return true; + return std::error_code(); } #ifdef WIN32 inline std::error_code hide_file(const std::filesystem::path& file_name) { - const bool success = SetFileAttributesA(file_name.string().c_str(), FILE_ATTRIBUTE_HIDDEN); - if (!success) { - // note: All possible documented error codes from GetLastError are in [0;15999] at the time of writing. - // The below casting is safe in [0;std::numeric_limits<int>::max()], int max is guaranteed to be at least 32767 - return { static_cast<int>(GetLastError()), std::system_category() }; - } - return {}; + const bool success = SetFileAttributesA(file_name.string().c_str(), FILE_ATTRIBUTE_HIDDEN); + if (!success) { + // note: All possible documented error codes from GetLastError are in [0;15999] at the time of writing. + // The below casting is safe in [0;std::numeric_limits<int>::max()], int max is guaranteed to be at least 32767 + return { static_cast<int>(GetLastError()), std::system_category() }; } + return {}; +} #endif /* WIN32 */ +template<typename T> +concept NetworkingProcessor = std::derived_from<T, minifi::core::Processor> + && requires(T x) { + {T::Port} -> std::convertible_to<core::Property>; + {x.getPort()} -> std::convertible_to<uint16_t>; + }; // NOLINT(readability/braces) + +template<NetworkingProcessor T> +uint16_t scheduleProcessorOnRandomPort(const std::shared_ptr<TestPlan>& test_plan, const std::shared_ptr<T>& processor) { + REQUIRE(processor->setProperty(T::Port, "0")); + test_plan->scheduleProcessor(processor); + uint16_t port = processor->getPort(); + auto deadline = std::chrono::steady_clock::now() + 200ms; + while (port == 0 && deadline > std::chrono::steady_clock::now()) { + std::this_thread::sleep_for(20ms); + port = processor->getPort(); + } + REQUIRE(port != 0); + return port; Review Comment: this could be rewritten to use `verifyEventHappenedInPollTime`, too ########## extensions/standard-processors/processors/PutTCP.cpp: ########## @@ -160,339 +178,147 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio } namespace { +template<class SocketType> +asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, asio::steady_timer::duration) { + co_return std::error_code(); +} + +template<> +asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) { + co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT +} + template<class SocketType> class ConnectionHandler : public ConnectionHandlerBase { public: ConnectionHandler(detail::ConnectionId connection_id, std::chrono::milliseconds timeout, std::shared_ptr<core::logging::Logger> logger, std::optional<size_t> max_size_of_socket_send_buffer, - std::shared_ptr<controllers::SSLContextService> ssl_context_service) + std::optional<asio::ssl::context>& ssl_context) : connection_id_(std::move(connection_id)), - timeout_(timeout), + timeout_duration_(timeout), logger_(std::move(logger)), max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer), - ssl_context_service_(std::move(ssl_context_service)) { + ssl_context_(ssl_context) { } ~ConnectionHandler() override = default; - nonstd::expected<void, std::error_code> sendData(const std::shared_ptr<io::InputStream>& flow_file_content_stream, const std::vector<std::byte>& delimiter) override; + asio::awaitable<std::error_code> sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter, asio::io_context& io_context_) override; private: - nonstd::expected<std::shared_ptr<SocketType>, std::error_code> getSocket(); - [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const override { - return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() - dur); + return last_used_ && *last_used_ >= (steady_clock::now() - dur); } void reset() override { last_used_.reset(); socket_.reset(); - io_context_.reset(); - last_error_.clear(); - deadline_.expires_at(asio::steady_timer::time_point::max()); } - void checkDeadline(std::error_code error_code, SocketType* socket); - void startConnect(tcp::resolver::results_type::iterator endpoint_iter, const std::shared_ptr<SocketType>& socket); - - void handleConnect(std::error_code error, - tcp::resolver::results_type::iterator endpoint_iter, - const std::shared_ptr<SocketType>& socket); - void handleConnectionSuccess(const tcp::resolver::results_type::iterator& endpoint_iter, - const std::shared_ptr<SocketType>& socket); - void handleHandshake(std::error_code error, - const tcp::resolver::results_type::iterator& endpoint_iter, - const std::shared_ptr<SocketType>& socket); - - void handleWrite(std::error_code error, - std::size_t bytes_written, - const std::shared_ptr<io::InputStream>& flow_file_content_stream, - const std::vector<std::byte>& delimiter, - const std::shared_ptr<SocketType>& socket); + [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); } + [[nodiscard]] asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context); + [[nodiscard]] bool hasUsableSocket() const { return socket_ && socket_->lowest_layer().is_open(); } - void handleDelimiterWrite(std::error_code error, std::size_t bytes_written, const std::shared_ptr<SocketType>& socket); + asio::awaitable<std::error_code> establishNewConnection(const tcp::resolver::results_type& endpoints, asio::io_context& io_context_); + asio::awaitable<std::error_code> send(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter); - nonstd::expected<std::shared_ptr<SocketType>, std::error_code> establishConnection(const tcp::resolver::results_type& resolved_query); - - [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); } + SocketType createNewSocket(asio::io_context& io_context_); detail::ConnectionId connection_id_; - std::optional<std::chrono::steady_clock::time_point> last_used_; - asio::io_context io_context_; - std::error_code last_error_; - asio::steady_timer deadline_{io_context_}; - std::chrono::milliseconds timeout_; - std::shared_ptr<SocketType> socket_; + std::optional<SocketType> socket_; + + std::optional<steady_clock::time_point> last_used_; + std::chrono::milliseconds timeout_duration_; std::shared_ptr<core::logging::Logger> logger_; std::optional<size_t> max_size_of_socket_send_buffer_; - std::shared_ptr<controllers::SSLContextService> ssl_context_service_; - - nonstd::expected<tcp::resolver::results_type, std::error_code> resolveHostname(); - nonstd::expected<void, std::error_code> sendDataToSocket(const std::shared_ptr<SocketType>& socket, - const std::shared_ptr<io::InputStream>& flow_file_content_stream, - const std::vector<std::byte>& delimiter); + std::optional<asio::ssl::context>& ssl_context_; Review Comment: Why is this a non-const reference-to-optional? That is a bit of a mind-bending type. Could it be a bare pointer to `asio::ssl::context`? ########## extensions/standard-processors/tests/unit/PutTCPTests.cpp: ########## @@ -238,27 +220,22 @@ class PutTCPTestFixture { const std::shared_ptr<PutTCP> put_tcp_ = std::make_shared<PutTCP>("PutTCP"); test::SingleProcessorTestController controller_{put_tcp_}; - std::mt19937 random_engine_{std::random_device{}()}; // NOLINT: "Missing space before { [whitespace/braces] [5]" - // most systems use ports 32768 - 65535 as ephemeral ports, so avoid binding to those - class Server { public: Server() = default; - void startTCPServer(uint16_t port) { - gsl_Expects(!listener_ && !server_thread_.joinable()); - listener_ = std::make_unique<SessionAwareTcpServer>(std::nullopt, port, core::logging::LoggerFactory<utils::net::Server>::getLogger()); - server_thread_ = std::thread([this]() { listener_->run(); }); - } - - void startSSLServer(uint16_t port) { + uint16_t startTCPServer(std::optional<utils::net::SslServerOptions> ssl_server_options) { gsl_Expects(!listener_ && !server_thread_.joinable()); - listener_ = std::make_unique<SessionAwareSslServer>(std::nullopt, - port, - core::logging::LoggerFactory<utils::net::Server>::getLogger(), - createSslDataForServer(), - utils::net::SslServer::ClientAuthOption::REQUIRED); + listener_ = std::make_unique<CancellableTcpServer>(std::nullopt, 0, core::logging::LoggerFactory<utils::net::Server>::getLogger(), std::move(ssl_server_options)); server_thread_ = std::thread([this]() { listener_->run(); }); + uint16_t port = listener_->getPort(); + auto deadline = std::chrono::steady_clock::now() + 200ms; + while (port == 0 && deadline > std::chrono::steady_clock::now()) { + std::this_thread::sleep_for(20ms); + port = listener_->getPort(); + } + REQUIRE(port != 0); + return port; Review Comment: I think ```suggestion REQUIRE(utils::verifyEventHappenedInPollTime(200ms, [this] { return listener_->getPort() != 0; })); return listener_->getPort(); ``` would be nicer ########## extensions/standard-processors/tests/unit/ListenSyslogTests.cpp: ########## @@ -480,41 +504,44 @@ TEST_CASE("ListenSyslog max queue and max batch size test", "[ListenSyslog][Netw } TEST_CASE("Test ListenSyslog via TCP with SSL connection", "[ListenSyslog][NetworkListenerProcessor]") { - asio::ip::tcp::endpoint endpoint; - SECTION("sending through IPv4", "[IPv4]") { - endpoint = asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), SYSLOG_PORT); - } - SECTION("sending through IPv6", "[IPv6]") { - if (utils::isIPv6Disabled()) - return; - endpoint = asio::ip::tcp::endpoint(asio::ip::address_v6::loopback(), SYSLOG_PORT); - } const auto listen_syslog = std::make_shared<ListenSyslog>("ListenSyslog"); - SingleProcessorTestController controller{listen_syslog}; + auto ssl_context_service = controller.plan->addController("SSLContextService", "SSLContextService"); const auto executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::CACertificate.getName(), (executable_dir / "resources" / "ca_A.crt").string())); REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::ClientCertificate.getName(), (executable_dir / "resources" / "localhost_by_A.pem").string())); REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::PrivateKey.getName(), (executable_dir / "resources" / "localhost_by_A.pem").string())); + ssl_context_service->enable(); + LogTestController::getInstance().setTrace<ListenSyslog>(); - REQUIRE(listen_syslog->setProperty(ListenSyslog::Port, std::to_string(SYSLOG_PORT))); REQUIRE(listen_syslog->setProperty(ListenSyslog::MaxBatchSize, "2")); REQUIRE(listen_syslog->setProperty(ListenSyslog::ParseMessages, "false")); REQUIRE(listen_syslog->setProperty(ListenSyslog::ProtocolProperty, "TCP")); REQUIRE(listen_syslog->setProperty(ListenSyslog::SSLContextService, "SSLContextService")); - ssl_context_service->enable(); - controller.plan->scheduleProcessor(listen_syslog); - REQUIRE(utils::sendMessagesViaSSL({rfc5424_logger_example_1}, endpoint, executable_dir / "resources" / "ca_A.crt")); - REQUIRE(utils::sendMessagesViaSSL({invalid_syslog}, endpoint, executable_dir / "resources" / "ca_A.crt")); + + auto port = schedule_on_random_port(controller, listen_syslog); + + asio::ip::tcp::endpoint endpoint; + SECTION("sending through IPv4", "[IPv4]") { + endpoint = asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), port); + } + SECTION("sending through IPv6", "[IPv6]") { + if (utils::isIPv6Disabled()) + return; + endpoint = asio::ip::tcp::endpoint(asio::ip::address_v6::loopback(), port); + } + + CHECK_THAT(utils::sendMessagesViaSSL({rfc5424_logger_example_1}, endpoint, (executable_dir / "resources" / "ca_A.crt").string()), MatchesSuccess()); + CHECK_THAT(utils::sendMessagesViaSSL({invalid_syslog}, endpoint, (executable_dir / "resources" / "ca_A.crt").string()), MatchesSuccess()); Review Comment: the parameter type is `path`, so the `.string()`s are not needed ########## libminifi/src/utils/net/TcpServer.cpp: ########## @@ -15,53 +15,73 @@ * limitations under the License. */ #include "utils/net/TcpServer.h" +#include "utils/net/AsioCoro.h" namespace org::apache::nifi::minifi::utils::net { -TcpSession::TcpSession(asio::io_context& io_context, utils::ConcurrentQueue<Message>& concurrent_queue, std::optional<size_t> max_queue_size, std::shared_ptr<core::logging::Logger> logger) - : concurrent_queue_(concurrent_queue), - max_queue_size_(max_queue_size), - socket_(io_context), - logger_(std::move(logger)) { +asio::awaitable<void> TcpServer::doReceive() { + asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_)); + if (port_ == 0) + port_ = acceptor.local_endpoint().port(); + while (true) { + auto [accept_error, socket] = co_await acceptor.async_accept(use_nothrow_awaitable); + if (accept_error) { + logger_->log_error("Error during accepting new connection: %s", accept_error.message()); + break; + } + if (ssl_data_) + co_spawn(io_context_, secureSession(std::move(socket)), asio::detached); + else + co_spawn(io_context_, insecureSession(std::move(socket)), asio::detached); + } } -asio::ip::tcp::socket& TcpSession::getSocket() { - return socket_; -} +asio::awaitable<void> TcpServer::readLoop(auto& socket) { + std::string read_message; + while (true) { + auto [read_error, bytes_read] = co_await asio::async_read_until(socket, asio::dynamic_buffer(read_message), '\n', use_nothrow_awaitable); // NOLINT + if (read_error || bytes_read == 0) + co_return; -void TcpSession::start() { - asio::async_read_until(socket_, - buffer_, - '\n', - [self = shared_from_this()](const auto& error_code, size_t) -> void { - self->handleReadUntilNewLine(error_code); - }); + if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) + concurrent_queue_.enqueue(Message(read_message.substr(0, bytes_read - 1), IpProtocol::TCP, socket.lowest_layer().remote_endpoint().address(), socket.lowest_layer().local_endpoint().port())); Review Comment: Is it possible that `read_message[bytes_read - 1]` is not a `\n` (eg. if we got to the end of the stream)? In that case, throwing away this character would not be the right thing to do. ########## extensions/standard-processors/processors/PutTCP.cpp: ########## @@ -114,6 +114,21 @@ void PutTCP::initialize() { void PutTCP::notifyStop() {} +namespace { +asio::ssl::context getSslContext(const std::shared_ptr<controllers::SSLContextService>& ssl_context_service) { + gsl_Expects(ssl_context_service); Review Comment: we could change the parameter type to `const SSLContextService&`, then this assertion would not be needed (since we already check for nullness at the calling site) ########## extensions/standard-processors/tests/unit/ListenTcpTests.cpp: ########## @@ -194,31 +194,115 @@ TEST_CASE("Test ListenTCP with SSL connection", "[ListenTCP][NetworkListenerProc expected_successful_messages = {"test_message_1", "another_message"}; for (const auto& message : expected_successful_messages) { - REQUIRE(utils::sendMessagesViaSSL({message}, endpoint, executable_dir / "resources" / "ca_A.crt", ssl_data)); + CHECK_THAT(utils::sendMessagesViaSSL({message}, endpoint, executable_dir / "resources" / "ca_A.crt", ssl_data), MatchesSuccess()); } } SECTION("Required certificate not provided") { + ssl_context_service->enable(); + REQUIRE(controller.plan->setProperty(listen_tcp, ListenTCP::ClientAuth.getName(), "REQUIRED")); + port = utils::scheduleProcessorOnRandomPort(controller.plan, listen_tcp); SECTION("sending through IPv4", "[IPv4]") { - endpoint = asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), PORT); + endpoint = asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), port); } SECTION("sending through IPv6", "[IPv6]") { if (utils::isIPv6Disabled()) return; - endpoint = asio::ip::tcp::endpoint(asio::ip::address_v6::loopback(), PORT); + endpoint = asio::ip::tcp::endpoint(asio::ip::address_v6::loopback(), port); } - REQUIRE(controller.plan->setProperty(listen_tcp, ListenTCP::ClientAuth.getName(), "REQUIRED")); - ssl_context_service->enable(); - controller.plan->scheduleProcessor(listen_tcp); - REQUIRE_FALSE(utils::sendMessagesViaSSL({"test_message_1"}, endpoint, executable_dir / "resources" / "ca_A.crt")); + auto send_error = utils::sendMessagesViaSSL({"test_message_1"}, endpoint, executable_dir / "resources" / "ca_A.crt"); + CHECK(send_error); Review Comment: could this be changed to `CHECK_THAT(send_error, MatchesError())` (possibly with a specific error code)? ########## libminifi/test/Catch.h: ########## @@ -40,4 +40,46 @@ struct StringMaker<std::nullopt_t> { return "std::nullopt"; } }; + +template <> +struct StringMaker<std::error_code> { + static std::string convert(const std::error_code& error_code) { + return fmt::format("std::error_code(value:{}, message:{})", error_code.value(), error_code.message()); + } +}; } // namespace Catch + +namespace org::apache::nifi::minifi::test { +struct MatchesSuccess : Catch::MatcherBase<std::error_code> { + MatchesSuccess() = default; + + bool match(const std::error_code& err) const override { + return err.value() == 0; + } + + std::string describe() const override { + return fmt::format("== {}", Catch::StringMaker<std::error_code>::convert(std::error_code{})); + } +}; + +struct MatchesError : Catch::MatcherBase<std::error_code> { + explicit MatchesError(std::optional<std::error_code> expected_error = std::nullopt) + : Catch::MatcherBase<std::error_code>(), + expected_error_(expected_error) { + } + + bool match(const std::error_code& err) const override { + if (expected_error_) + return err.value() == expected_error_->value(); Review Comment: I would do ```suggestion return err == *expected_error_; ``` here, so the category gets compared, too ########## libminifi/test/Catch.h: ########## @@ -40,4 +40,46 @@ struct StringMaker<std::nullopt_t> { return "std::nullopt"; } }; + +template <> +struct StringMaker<std::error_code> { + static std::string convert(const std::error_code& error_code) { + return fmt::format("std::error_code(value:{}, message:{})", error_code.value(), error_code.message()); Review Comment: it could be useful to include the category, as well, as in the `ostream <<` operator: https://en.cppreference.com/w/cpp/error/error_code/operator_ltlt ########## extensions/standard-processors/tests/unit/PutTCPTests.cpp: ########## @@ -141,16 +120,16 @@ class PutTCPTestFixture { } size_t getNumberOfActiveSessions(std::optional<uint16_t> port = std::nullopt) { - if (auto session_aware_listener = dynamic_cast<ISessionAwareServer*>(getListener(port))) { - return session_aware_listener->getNumberOfSessions() - 1; // There is always one inactive session waiting for a new connection + if (auto session_aware_listener = dynamic_cast<CancellableTcpServer*>(getListener(port))) { + return session_aware_listener->getNumberOfSessions(); } return -1; } void closeActiveConnections() { for (auto& [port, server] : listeners_) { - if (auto session_aware_listener = dynamic_cast<ISessionAwareServer*>(server.listener_.get())) { - session_aware_listener->closeSessions(); + if (auto session_aware_listener = dynamic_cast<CancellableTcpServer*>(getListener(port))) { + session_aware_listener->cancelEverything(); Review Comment: nitpicking, but "session_aware_" doesn't make much sense now; I would rename these `server` or `listener` or something like that -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@nifi.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org