This is an automated email from the ASF dual-hosted git repository. szaszm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 69558ffe92277d2784e2ea10dad9b64bda1e3c28 Author: Gabor Gyimesi <gamezb...@gmail.com> AuthorDate: Thu Aug 17 11:26:03 2023 +0200 MINIFICPP-2137 Rewrite MiNiFi Controller to use asio Closes #1595 Signed-off-by: Marton Szasz <sza...@apache.org> --- controller/Controller.cpp | 283 ++++++++++++++------- controller/Controller.h | 71 ++---- controller/MiNiFiController.cpp | 80 +++--- controller/tests/ControllerTests.cpp | 91 ++----- .../standard-processors/processors/GetTCP.cpp | 2 +- .../standard-processors/processors/PutTCP.cpp | 2 +- .../standard-processors/tests/unit/GetTCPTests.cpp | 4 +- libminifi/include/c2/ControllerSocketProtocol.h | 39 ++- libminifi/include/io/AsioStream.h | 81 ++++++ libminifi/include/utils/net/AsioSocketUtils.h | 4 +- libminifi/src/c2/ControllerSocketProtocol.cpp | 158 ++++++++---- libminifi/src/io/InputStream.cpp | 32 +-- libminifi/src/utils/net/AsioSocketUtils.cpp | 6 +- libminifi/src/utils/net/TcpServer.cpp | 16 +- libminifi/test/unit/NetUtilsTest.cpp | 16 +- 15 files changed, 518 insertions(+), 367 deletions(-) diff --git a/controller/Controller.cpp b/controller/Controller.cpp index a56802595..63db43c9d 100644 --- a/controller/Controller.cpp +++ b/controller/Controller.cpp @@ -20,195 +20,286 @@ #include "io/BufferStream.h" #include "c2/C2Payload.h" +#include "io/AsioStream.h" +#include "asio/ssl/context.hpp" +#include "asio/ssl/stream.hpp" +#include "asio/connect.hpp" +#include "core/logging/Logger.h" +#include "utils/net/AsioSocketUtils.h" namespace org::apache::nifi::minifi::controller { -bool sendSingleCommand(std::unique_ptr<io::Socket> socket, uint8_t op, const std::string& value) { - if (socket->initialize() < 0) { +namespace { + +class ClientConnection { + public: + explicit ClientConnection(const ControllerSocketData& socket_data) { + if (socket_data.ssl_context_service) { + connectTcpSocketOverSsl(socket_data); + } else { + connectTcpSocket(socket_data); + } + } + + [[nodiscard]] io::BaseStream* getStream() const { + return stream_.get(); + } + + private: + void connectTcpSocketOverSsl(const ControllerSocketData& socket_data) { + auto ssl_context = utils::net::getSslContext(*socket_data.ssl_context_service); + asio::ssl::stream<asio::ip::tcp::socket> socket(io_context_, ssl_context); + + asio::ip::tcp::resolver resolver(io_context_); + asio::error_code err; + asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(socket_data.host, std::to_string(socket_data.port), err); + if (err) { + logger_->log_error("Resolving host '%s' on port '%s' failed with the following message: '%s'", socket_data.host, std::to_string(socket_data.port), err.message()); + return; + } + + asio::connect(socket.lowest_layer(), endpoints, err); + if (err) { + logger_->log_error("Connecting to host '%s' on port '%s' failed with the following message: '%s'", socket_data.host, std::to_string(socket_data.port), err.message()); + return; + } + socket.handshake(asio::ssl::stream_base::client, err); + if (err) { + logger_->log_error("SSL handshake failed while connecting to host '%s' on port '%s' with the following message: '%s'", socket_data.host, std::to_string(socket_data.port), err.message()); + return; + } + stream_ = std::make_unique<io::AsioStream<asio::ssl::stream<asio::ip::tcp::socket>>>(std::move(socket)); + } + + void connectTcpSocket(const ControllerSocketData& socket_data) { + asio::ip::tcp::socket socket(io_context_); + + asio::ip::tcp::resolver resolver(io_context_); + asio::error_code err; + asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(socket_data.host, std::to_string(socket_data.port)); + if (err) { + logger_->log_error("Resolving host '%s' on port '%s' failed with the following message: '%s'", socket_data.host, std::to_string(socket_data.port), err.message()); + return; + } + + asio::connect(socket, endpoints, err); + if (err) { + logger_->log_error("Connecting to host '%s' on port '%s' failed with the following message: '%s'", socket_data.host, std::to_string(socket_data.port), err.message()); + return; + } + stream_ = std::make_unique<io::AsioStream<asio::ip::tcp::socket>>(std::move(socket)); + } + + asio::io_context io_context_; + std::unique_ptr<io::BaseStream> stream_; + std::shared_ptr<core::logging::Logger> logger_{core::logging::LoggerFactory<ClientConnection>::getLogger()}; +}; + +} // namespace + + +bool sendSingleCommand(const ControllerSocketData& socket_data, uint8_t op, const std::string& value) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { return false; } - io::BufferStream stream; - stream.write(&op, 1); - stream.write(value); - return socket->write(stream.getBuffer()) == stream.size(); + io::BufferStream buffer; + buffer.write(&op, 1); + buffer.write(value); + return connection_stream->write(buffer.getBuffer()) == buffer.size(); } -bool stopComponent(std::unique_ptr<io::Socket> socket, const std::string& component) { - return sendSingleCommand(std::move(socket), static_cast<uint8_t>(c2::Operation::stop), component); +bool stopComponent(const ControllerSocketData& socket_data, const std::string& component) { + return sendSingleCommand(socket_data, static_cast<uint8_t>(c2::Operation::stop), component); } -bool startComponent(std::unique_ptr<io::Socket> socket, const std::string& component) { - return sendSingleCommand(std::move(socket), static_cast<uint8_t>(c2::Operation::start), component); +bool startComponent(const ControllerSocketData& socket_data, const std::string& component) { + return sendSingleCommand(socket_data, static_cast<uint8_t>(c2::Operation::start), component); } -bool clearConnection(std::unique_ptr<io::Socket> socket, const std::string& connection) { - return sendSingleCommand(std::move(socket), static_cast<uint8_t>(c2::Operation::clear), connection); +bool clearConnection(const ControllerSocketData& socket_data, const std::string& connection) { + return sendSingleCommand(socket_data, static_cast<uint8_t>(c2::Operation::clear), connection); } -int updateFlow(std::unique_ptr<io::Socket> socket, std::ostream &out, const std::string& file) { - if (socket->initialize() < 0) { - return -1; +bool updateFlow(const ControllerSocketData& socket_data, std::ostream &out, const std::string& file) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } auto op = static_cast<uint8_t>(c2::Operation::update); - io::BufferStream stream; - stream.write(&op, 1); - stream.write("flow"); - stream.write(file); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + io::BufferStream buffer; + buffer.write(&op, 1); + buffer.write("flow"); + buffer.write(file); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } // read the response uint8_t resp = 0; - socket->read(resp); + connection_stream->read(resp); if (resp == static_cast<uint8_t>(c2::Operation::describe)) { uint16_t connections = 0; - socket->read(connections); + connection_stream->read(connections); out << connections << " are full" << std::endl; for (int i = 0; i < connections; i++) { std::string fullcomponent; - socket->read(fullcomponent); + connection_stream->read(fullcomponent); out << fullcomponent << " is full" << std::endl; } } - return 0; + return true; } -int getFullConnections(std::unique_ptr<io::Socket> socket, std::ostream &out) { - if (socket->initialize() < 0) { - return -1; +bool getFullConnections(const ControllerSocketData& socket_data, std::ostream &out) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } auto op = static_cast<uint8_t>(c2::Operation::describe); - io::BufferStream stream; - stream.write(&op, 1); - stream.write("getfull"); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + io::BufferStream buffer; + buffer.write(&op, 1); + buffer.write("getfull"); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } // read the response uint8_t resp = 0; - socket->read(resp); + connection_stream->read(resp); if (resp == static_cast<uint8_t>(c2::Operation::describe)) { uint16_t connections = 0; - socket->read(connections); + connection_stream->read(connections); out << connections << " are full" << std::endl; for (int i = 0; i < connections; i++) { std::string fullcomponent; - socket->read(fullcomponent); + connection_stream->read(fullcomponent); out << fullcomponent << " is full" << std::endl; } } - return 0; + return true; } -int getConnectionSize(std::unique_ptr<io::Socket> socket, std::ostream &out, const std::string& connection) { - if (socket->initialize() < 0) { - return -1; +bool getConnectionSize(const ControllerSocketData& socket_data, std::ostream &out, const std::string& connection) { + ClientConnection client_connection(socket_data); + auto connection_stream = client_connection.getStream(); + if (!connection_stream) { + return false; } auto op = static_cast<uint8_t>(c2::Operation::describe); - io::BufferStream stream; - stream.write(&op, 1); - stream.write("queue"); - stream.write(connection); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + io::BufferStream buffer; + buffer.write(&op, 1); + buffer.write("queue"); + buffer.write(connection); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } // read the response uint8_t resp = 0; - socket->read(resp); + connection_stream->read(resp); if (resp == static_cast<uint8_t>(c2::Operation::describe)) { std::string size; - socket->read(size); + connection_stream->read(size); out << "Size/Max of " << connection << " " << size << std::endl; } - return 0; + return true; } -int listComponents(std::unique_ptr<io::Socket> socket, std::ostream &out, bool show_header) { - if (socket->initialize() < 0) { - return -1; +bool listComponents(const ControllerSocketData& socket_data, std::ostream &out, bool show_header) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } - io::BufferStream stream; + io::BufferStream buffer; auto op = static_cast<uint8_t>(c2::Operation::describe); - stream.write(&op, 1); - stream.write("components"); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + buffer.write(&op, 1); + buffer.write("components"); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } uint16_t responses = 0; - socket->read(op); - socket->read(responses); + connection_stream->read(op); + connection_stream->read(responses); if (show_header) out << "Components:" << std::endl; for (int i = 0; i < responses; i++) { std::string name; - socket->read(name, false); + connection_stream->read(name, false); std::string status; - socket->read(status, false); + connection_stream->read(status, false); out << name << ", running: " << status << std::endl; } - return 0; + return true; } -int listConnections(std::unique_ptr<io::Socket> socket, std::ostream &out, bool show_header) { - if (socket->initialize() < 0) { - return -1; +bool listConnections(const ControllerSocketData& socket_data, std::ostream &out, bool show_header) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } - io::BufferStream stream; + io::BufferStream buffer; auto op = static_cast<uint8_t>(c2::Operation::describe); - stream.write(&op, 1); - stream.write("connections"); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + buffer.write(&op, 1); + buffer.write("connections"); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } uint16_t responses = 0; - socket->read(op); - socket->read(responses); + connection_stream->read(op); + connection_stream->read(responses); if (show_header) out << "Connection Names:" << std::endl; for (int i = 0; i < responses; i++) { std::string name; - socket->read(name, false); + connection_stream->read(name, false); out << name << std::endl; } - return 0; + return true; } -int printManifest(std::unique_ptr<io::Socket> socket, std::ostream &out) { - if (socket->initialize() < 0) { - return -1; +bool printManifest(const ControllerSocketData& socket_data, std::ostream &out) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } - io::BufferStream stream; + io::BufferStream buffer; auto op = static_cast<uint8_t>(c2::Operation::describe); - stream.write(&op, 1); - stream.write("manifest"); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + buffer.write(&op, 1); + buffer.write("manifest"); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } - socket->read(op); + connection_stream->read(op); std::string manifest; - socket->read(manifest, true); + connection_stream->read(manifest, true); out << manifest << std::endl; - return 0; + return true; } -int getJstacks(std::unique_ptr<io::Socket> socket, std::ostream &out) { - if (socket->initialize() < 0) { - return -1; +bool getJstacks(const ControllerSocketData& socket_data, std::ostream &out) { + ClientConnection connection(socket_data); + auto connection_stream = connection.getStream(); + if (!connection_stream) { + return false; } - io::BufferStream stream; + io::BufferStream buffer; auto op = static_cast<uint8_t>(c2::Operation::describe); - stream.write(&op, 1); - stream.write("jstack"); - if (io::isError(socket->write(stream.getBuffer()))) { - return -1; + buffer.write(&op, 1); + buffer.write("jstack"); + if (io::isError(connection_stream->write(buffer.getBuffer()))) { + return false; } - socket->read(op); + connection_stream->read(op); std::string manifest; - socket->read(manifest, true); + connection_stream->read(manifest, true); out << manifest << std::endl; - return 0; + return true; } } // namespace org::apache::nifi::minifi::controller diff --git a/controller/Controller.h b/controller/Controller.h index 0a6c1c484..3dac89737 100644 --- a/controller/Controller.h +++ b/controller/Controller.h @@ -20,61 +20,26 @@ #include <memory> #include <string> -#include "io/ClientSocket.h" +#include "controllers/SSLContextService.h" namespace org::apache::nifi::minifi::controller { -/** - * Sends a single argument comment - * @param socket socket unique ptr. - * @param op operation to perform - * @param value value to send - */ -bool sendSingleCommand(std::unique_ptr<io::Socket> socket, uint8_t op, const std::string& value); - -/** - * Stops a stopped component - * @param socket socket unique ptr. - * @param op operation to perform - */ -bool stopComponent(std::unique_ptr<io::Socket> socket, const std::string& component); - -/** - * Starts a previously stopped component. - * @param socket socket unique ptr. - * @param op operation to perform - */ -bool startComponent(std::unique_ptr<io::Socket> socket, const std::string& component); - -/** - * Clears a connection queue. - * @param socket socket unique ptr. - * @param op operation to perform - */ -bool clearConnection(std::unique_ptr<io::Socket> socket, const std::string& connection); - -/** - * Updates the flow to the provided file - */ -int updateFlow(std::unique_ptr<io::Socket> socket, std::ostream &out, const std::string& file); - -/** - * Lists connections which are full - * @param socket socket ptr - */ -int getFullConnections(std::unique_ptr<io::Socket> socket, std::ostream &out); - -/** - * Prints the connection size for the provided connection. - * @param socket socket ptr - * @param connection connection whose size will be returned. - */ -int getConnectionSize(std::unique_ptr<io::Socket> socket, std::ostream &out, const std::string& connection); - -int listComponents(std::unique_ptr<io::Socket> socket, std::ostream &out, bool show_header = true); -int listConnections(std::unique_ptr<io::Socket> socket, std::ostream &out, bool show_header = true); -int printManifest(std::unique_ptr<io::Socket> socket, std::ostream &out); - -int getJstacks(std::unique_ptr<io::Socket> socket, std::ostream &out); +struct ControllerSocketData { + std::string host = "localhost"; + int port = -1; + std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service; +}; + +bool sendSingleCommand(const ControllerSocketData& socket_data, uint8_t op, const std::string& value); +bool stopComponent(const ControllerSocketData& socket_data, const std::string& component); +bool startComponent(const ControllerSocketData& socket_data, const std::string& component); +bool clearConnection(const ControllerSocketData& socket_data, const std::string& connection); +bool updateFlow(const ControllerSocketData& socket_data, std::ostream &out, const std::string& file); +bool getFullConnections(const ControllerSocketData& socket_data, std::ostream &out); +bool getConnectionSize(const ControllerSocketData& socket_data, std::ostream &out, const std::string& connection); +bool listComponents(const ControllerSocketData& socket_data, std::ostream &out, bool show_header = true); +bool listConnections(const ControllerSocketData& socket_data, std::ostream &out, bool show_header = true); +bool printManifest(const ControllerSocketData& socket_data, std::ostream &out); +bool getJstacks(const ControllerSocketData& socket_data, std::ostream &out); } // namespace org::apache::nifi::minifi::controller diff --git a/controller/MiNiFiController.cpp b/controller/MiNiFiController.cpp index e66afbdf9..b17d60c1c 100644 --- a/controller/MiNiFiController.cpp +++ b/controller/MiNiFiController.cpp @@ -104,20 +104,15 @@ int main(int argc, char **argv) { log_properties->loadConfigureFile(DEFAULT_LOG_PROPERTIES_FILE); minifi::core::logging::LoggerConfiguration::getConfiguration().initialize(log_properties); - std::shared_ptr<minifi::controllers::SSLContextService> secure_context; + minifi::controller::ControllerSocketData socket_data; try { - secure_context = getSSLContextService(configuration); + socket_data.ssl_context_service = getSSLContextService(configuration); } catch(const minifi::Exception& ex) { logger->log_error(ex.what()); exit(1); } auto stream_factory_ = minifi::io::StreamFactory::getInstance(configuration); - std::string host = "localhost"; - std::string port_str; - std::string ca_cert; - int port = -1; - cxxopts::Options options("MiNiFiController", "MiNiFi local agent controller"); options.positional_help("[optional args]").show_positional_help(); @@ -147,20 +142,19 @@ int main(int argc, char **argv) { } if (result.count("host")) { - host = result["host"].as<std::string>(); + socket_data.host = result["host"].as<std::string>(); } else { - configuration->get(minifi::Configure::controller_socket_host, host); + configuration->get(minifi::Configure::controller_socket_host, socket_data.host); } + std::string port_str; if (result.count("port")) { - port = result["port"].as<int>(); - } else { - if (port == -1 && configuration->get(minifi::Configure::controller_socket_port, port_str)) { - port = std::stoi(port_str); - } + socket_data.port = result["port"].as<int>(); + } else if (socket_data.port == -1 && configuration->get(minifi::Configure::controller_socket_port, port_str)) { + socket_data.port = std::stoi(port_str); } - if ((minifi::IsNullOrEmpty(host) && port == -1)) { + if ((minifi::IsNullOrEmpty(socket_data.host) && socket_data.port == -1)) { std::cout << "MiNiFi Controller is disabled" << std::endl; exit(0); } @@ -171,38 +165,30 @@ int main(int argc, char **argv) { if (result.count("stop") > 0) { auto& components = result["stop"].as<std::vector<std::string>>(); for (const auto& component : components) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::stopComponent(std::move(socket), component)) + if (minifi::controller::stopComponent(socket_data, component)) std::cout << component << " requested to stop" << std::endl; else - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } if (result.count("start") > 0) { auto& components = result["start"].as<std::vector<std::string>>(); for (const auto& component : components) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::startComponent(std::move(socket), component)) + if (minifi::controller::startComponent(socket_data, component)) std::cout << component << " requested to start" << std::endl; else - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } if (result.count("c") > 0) { auto& components = result["c"].as<std::vector<std::string>>(); for (const auto& connection : components) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) - : stream_factory_->createSocket(host, port); - if (minifi::controller::clearConnection(std::move(socket), connection)) { - std::cout << "Sent clear command to " << connection << ". Size before clear operation sent: " << std::endl; - socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) - : stream_factory_->createSocket(host, port); - if (minifi::controller::getConnectionSize(std::move(socket), std::cout, connection) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (minifi::controller::clearConnection(socket_data, connection)) { + std::cout << "Sent clear command to " << connection << "." << std::endl; } else { - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } } @@ -210,45 +196,39 @@ int main(int argc, char **argv) { if (result.count("getsize") > 0) { auto& components = result["getsize"].as<std::vector<std::string>>(); for (const auto& component : components) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::getConnectionSize(std::move(socket), std::cout, component) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::getConnectionSize(socket_data, std::cout, component)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } if (result.count("l") > 0) { auto& option = result["l"].as<std::string>(); - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); if (option == "components") { - if (minifi::controller::listComponents(std::move(socket), std::cout, show_headers) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::listComponents(socket_data, std::cout, show_headers)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } else if (option == "connections") { - if (minifi::controller::listConnections(std::move(socket), std::cout, show_headers) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::listConnections(socket_data, std::cout, show_headers)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } if (result.count("getfull") > 0) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::getFullConnections(std::move(socket), std::cout) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::getFullConnections(socket_data, std::cout)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } if (result.count("updateflow") > 0) { auto& flow_file = result["updateflow"].as<std::string>(); - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::updateFlow(std::move(socket), std::cout, flow_file) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::updateFlow(socket_data, std::cout, flow_file)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } if (result.count("manifest") > 0) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::printManifest(std::move(socket), std::cout) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::printManifest(socket_data, std::cout)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } if (result.count("jstack") > 0) { - auto socket = secure_context != nullptr ? stream_factory_->createSecureSocket(host, port, secure_context) : stream_factory_->createSocket(host, port); - if (minifi::controller::getJstacks(std::move(socket), std::cout) < 0) - std::cout << "Could not connect to remote host " << host << ":" << port << std::endl; + if (!minifi::controller::getJstacks(socket_data, std::cout)) + std::cout << "Could not connect to remote host " << socket_data.host << ":" << socket_data.port << std::endl; } } catch (const std::exception &exc) { // catch anything thrown within try block that derives from std::exception diff --git a/controller/tests/ControllerTests.cpp b/controller/tests/ControllerTests.cpp index 82f90e1bc..d0e55d1cb 100644 --- a/controller/tests/ControllerTests.cpp +++ b/controller/tests/ControllerTests.cpp @@ -24,7 +24,6 @@ #include "TestBase.h" #include "Catch.h" -#include "io/ClientSocket.h" #include "core/Processor.h" #include "Controller.h" #include "c2/ControllerSocketProtocol.h" @@ -245,8 +244,7 @@ class ControllerTestFixture { ControllerTestFixture() : configuration_(std::make_shared<minifi::Configure>()), controller_(std::make_shared<TestStateController>()), - update_sink_(std::make_unique<TestUpdateSink>(controller_)), - stream_factory_(minifi::io::StreamFactory::getInstance(configuration_)) { + update_sink_(std::make_unique<TestUpdateSink>(controller_)) { configuration_->set(minifi::Configure::controller_socket_host, "localhost"); configuration_->set(minifi::Configure::controller_socket_port, "9997"); configuration_->set(minifi::Configure::nifi_security_client_certificate, (minifi::utils::file::FileUtils::get_executable_dir() / "resources" / "minifi-cpp-flow.crt").string()); @@ -257,6 +255,8 @@ class ControllerTestFixture { ssl_context_service_ = std::make_shared<controllers::SSLContextService>("SSLContextService", configuration_); ssl_context_service_->onEnable(); controller_service_provider_ = std::make_unique<TestControllerServiceProvider>(ssl_context_service_); + controller_socket_data_.host = "localhost"; + controller_socket_data_.port = 9997; } void initalizeControllerSocket(const std::shared_ptr<c2::ControllerSocketReporter>& reporter = nullptr) { @@ -270,27 +270,24 @@ class ControllerTestFixture { controller_socket_protocol_->initialize(); } - std::unique_ptr<minifi::io::Socket> createSocket() { + void setConnectionType(ConnectionType connection_type) { + connection_type_ = connection_type; if (connection_type_ == ConnectionType::UNSECURE) { - return stream_factory_->createSocket("localhost", 9997); + controller_socket_data_.ssl_context_service = nullptr; } else { - return stream_factory_->createSecureSocket("localhost", 9997, ssl_context_service_); + controller_socket_data_.ssl_context_service = ssl_context_service_; } } - void setConnectionType(ConnectionType connection_type) { - connection_type_ = connection_type; - } - protected: ConnectionType connection_type_ = ConnectionType::UNSECURE; std::shared_ptr<minifi::Configure> configuration_; std::shared_ptr<TestStateController> controller_; std::unique_ptr<TestUpdateSink> update_sink_; - std::shared_ptr<minifi::io::StreamFactory> stream_factory_; std::unique_ptr<minifi::c2::ControllerSocketProtocol> controller_socket_protocol_; std::shared_ptr<controllers::SSLContextService> ssl_context_service_; std::unique_ptr<TestControllerServiceProvider> controller_service_provider_; + minifi::controller::ControllerSocketData controller_socket_data_; }; TEST_CASE_METHOD(ControllerTestFixture, "Test listComponents", "[controllerTests]") { @@ -308,28 +305,19 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test listComponents", "[controllerTests initalizeControllerSocket(); - { - auto socket = createSocket(); - minifi::controller::startComponent(std::move(socket), "TestStateController"); - } + minifi::controller::startComponent(controller_socket_data_, "TestStateController"); using org::apache::nifi::minifi::utils::verifyEventHappenedInPollTime; REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return controller_->isRunning(); }, 20ms)); - { - auto socket = createSocket(); - minifi::controller::stopComponent(std::move(socket), "TestStateController"); - } + minifi::controller::stopComponent(controller_socket_data_, "TestStateController"); REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return !controller_->isRunning(); }, 20ms)); - { - auto socket = createSocket(); - std::stringstream ss; - minifi::controller::listComponents(std::move(socket), ss); + std::stringstream ss; + minifi::controller::listComponents(controller_socket_data_, ss); - REQUIRE(ss.str() == "Components:\nTestStateController, running: false\n"); - } + REQUIRE(ss.str() == "Components:\nTestStateController, running: false\n"); } TEST_CASE_METHOD(ControllerTestFixture, "TestClear", "[controllerTests]") { @@ -347,17 +335,13 @@ TEST_CASE_METHOD(ControllerTestFixture, "TestClear", "[controllerTests]") { initalizeControllerSocket(); - { - auto socket = createSocket(); - minifi::controller::startComponent(std::move(socket), "TestStateController"); - } + minifi::controller::startComponent(controller_socket_data_, "TestStateController"); using org::apache::nifi::minifi::utils::verifyEventHappenedInPollTime; REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return controller_->isRunning(); }, 20ms)); for (auto i = 0; i < 3; ++i) { - auto socket = createSocket(); - minifi::controller::clearConnection(std::move(socket), "connection"); + minifi::controller::clearConnection(controller_socket_data_, "connection"); } REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return 3 == update_sink_->clear_calls; }, 20ms)); @@ -377,21 +361,13 @@ TEST_CASE_METHOD(ControllerTestFixture, "TestUpdate", "[controllerTests]") { } initalizeControllerSocket(); - - { - auto socket = createSocket(); - minifi::controller::startComponent(std::move(socket), "TestStateController"); - } + minifi::controller::startComponent(controller_socket_data_, "TestStateController"); using org::apache::nifi::minifi::utils::verifyEventHappenedInPollTime; REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return controller_->isRunning(); }, 20ms)); std::stringstream ss; - - { - auto socket = createSocket(); - minifi::controller::updateFlow(std::move(socket), ss, "connection"); - } + minifi::controller::updateFlow(controller_socket_data_, ss, "connection"); REQUIRE(verifyEventHappenedInPollTime(5s, [&] { return 1 == update_sink_->update_calls; }, 20ms)); REQUIRE(0 == update_sink_->clear_calls); @@ -413,30 +389,26 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test connection getters on empty flow", initalizeControllerSocket(); { - auto socket = createSocket(); std::stringstream connection_stream; - minifi::controller::getConnectionSize(std::move(socket), connection_stream, "con1"); + minifi::controller::getConnectionSize(controller_socket_data_, connection_stream, "con1"); CHECK(connection_stream.str() == "Size/Max of con1 not found\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::getFullConnections(std::move(socket), connection_stream); + minifi::controller::getFullConnections(controller_socket_data_, connection_stream); CHECK(connection_stream.str() == "0 are full\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::listConnections(std::move(socket), connection_stream); + minifi::controller::listConnections(controller_socket_data_, connection_stream); CHECK(connection_stream.str() == "Connection Names:\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::listConnections(std::move(socket), connection_stream, false); + minifi::controller::listConnections(controller_socket_data_, connection_stream, false); CHECK(connection_stream.str().empty()); } } @@ -459,29 +431,25 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test connection getters", "[controllerT { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::getConnectionSize(std::move(socket), connection_stream, "conn"); + minifi::controller::getConnectionSize(controller_socket_data_, connection_stream, "conn"); CHECK(connection_stream.str() == "Size/Max of conn not found\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::getConnectionSize(std::move(socket), connection_stream, "con1"); + minifi::controller::getConnectionSize(controller_socket_data_, connection_stream, "con1"); CHECK(connection_stream.str() == "Size/Max of con1 1 / 2\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::getFullConnections(std::move(socket), connection_stream); + minifi::controller::getFullConnections(controller_socket_data_, connection_stream); CHECK(connection_stream.str() == "1 are full\ncon2 is full\n"); } { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::listConnections(std::move(socket), connection_stream); + minifi::controller::listConnections(controller_socket_data_, connection_stream); auto lines = minifi::utils::StringUtils::splitRemovingEmpty(connection_stream.str(), "\n"); CHECK(lines.size() == 3); CHECK(ranges::find(lines, "Connection Names:") != ranges::end(lines)); @@ -491,8 +459,7 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test connection getters", "[controllerT { std::stringstream connection_stream; - auto socket = createSocket(); - minifi::controller::listConnections(std::move(socket), connection_stream, false); + minifi::controller::listConnections(controller_socket_data_, connection_stream, false); auto lines = minifi::utils::StringUtils::splitRemovingEmpty(connection_stream.str(), "\n"); CHECK(lines.size() == 2); CHECK(ranges::find(lines, "con1") != ranges::end(lines)); @@ -519,8 +486,7 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test manifest getter", "[controllerTest initalizeControllerSocket(reporter); std::stringstream manifest_stream; - auto socket = createSocket(); - minifi::controller::printManifest(std::move(socket), manifest_stream); + minifi::controller::printManifest(controller_socket_data_, manifest_stream); REQUIRE(manifest_stream.str().find("\"agentType\": \"cpp\",") != std::string::npos); } @@ -543,8 +509,7 @@ TEST_CASE_METHOD(ControllerTestFixture, "Test jstack getter", "[controllerTests] initalizeControllerSocket(reporter); std::stringstream jstack_stream; - auto socket = createSocket(); - minifi::controller::getJstacks(std::move(socket), jstack_stream); + minifi::controller::getJstacks(controller_socket_data_, jstack_stream); std::string expected_trace = "trace1 -- bt line 1 for trace1\n" "trace1 -- bt line 2 for trace1\n" "trace2 -- bt line 1 for trace2\n" diff --git a/extensions/standard-processors/processors/GetTCP.cpp b/extensions/standard-processors/processors/GetTCP.cpp index 002e86848..c3389edb2 100644 --- a/extensions/standard-processors/processors/GetTCP.cpp +++ b/extensions/standard-processors/processors/GetTCP.cpp @@ -76,7 +76,7 @@ std::optional<asio::ssl::context> GetTCP::parseSSLContext(core::ProcessContext& if (auto context_name = context.getProperty(SSLContextService)) { if (auto controller_service = context.getControllerService(*context_name)) { if (auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(context.getControllerService(*context_name))) { - ssl_context = utils::net::getClientSslContext(*ssl_context_service); + ssl_context = utils::net::getSslContext(*ssl_context_service); } else { throw Exception(PROCESS_SCHEDULE_EXCEPTION, *context_name + " is not an SSL Context Service"); } diff --git a/extensions/standard-processors/processors/PutTCP.cpp b/extensions/standard-processors/processors/PutTCP.cpp index 08c4226ac..ece033a52 100644 --- a/extensions/standard-processors/processors/PutTCP.cpp +++ b/extensions/standard-processors/processors/PutTCP.cpp @@ -85,7 +85,7 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio if (context->getProperty(SSLContextService, context_name) && !IsNullOrEmpty(context_name)) { if (auto controller_service = context->getControllerService(context_name)) { if (auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(context->getControllerService(context_name))) { - ssl_context_ = utils::net::getClientSslContext(*ssl_context_service); + ssl_context_ = utils::net::getSslContext(*ssl_context_service); } else { throw Exception(PROCESS_SCHEDULE_EXCEPTION, context_name + " is not an SSL Context Service"); } diff --git a/extensions/standard-processors/tests/unit/GetTCPTests.cpp b/extensions/standard-processors/tests/unit/GetTCPTests.cpp index 48c5b1640..e60d02dbe 100644 --- a/extensions/standard-processors/tests/unit/GetTCPTests.cpp +++ b/extensions/standard-processors/tests/unit/GetTCPTests.cpp @@ -72,8 +72,8 @@ class TcpTestServer { void enableSSL() { const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); - asio::ssl::context ssl_context(asio::ssl::context::tls_server); - ssl_context.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + asio::ssl::context ssl_context(asio::ssl::context::tlsv12_server); + ssl_context.set_options(minifi::utils::net::MINIFI_SSL_OPTIONS); ssl_context.set_password_callback([key_pw = "Password12"](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; }); ssl_context.use_certificate_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem); ssl_context.use_private_key_file((executable_dir / "resources" / "localhost.key").string(), asio::ssl::context::pem); diff --git a/libminifi/include/c2/ControllerSocketProtocol.h b/libminifi/include/c2/ControllerSocketProtocol.h index 67bb47ef3..b01cee75d 100644 --- a/libminifi/include/c2/ControllerSocketProtocol.h +++ b/libminifi/include/c2/ControllerSocketProtocol.h @@ -31,6 +31,9 @@ #include "core/controller/ControllerServiceProvider.h" #include "ControllerSocketReporter.h" #include "utils/MinifiConcurrentQueue.h" +#include "asio/ip/tcp.hpp" +#include "asio/ssl/context.hpp" +#include "utils/net/AsioCoro.h" namespace org::apache::nifi::minifi::c2 { @@ -42,27 +45,35 @@ class ControllerSocketProtocol { public: ControllerSocketProtocol(core::controller::ControllerServiceProvider& controller, state::StateMonitor& update_sink, std::shared_ptr<Configure> configuration, const std::shared_ptr<ControllerSocketReporter>& controller_socket_reporter); + ~ControllerSocketProtocol(); void initialize(); private: - void handleStart(io::BaseStream *stream); - void handleStop(io::BaseStream *stream); - void handleClear(io::BaseStream *stream); - void handleUpdate(io::BaseStream *stream); - void writeQueueSizesResponse(io::BaseStream *stream); - void writeComponentsResponse(io::BaseStream *stream); - void writeConnectionsResponse(io::BaseStream *stream); - void writeGetFullResponse(io::BaseStream *stream); - void writeManifestResponse(io::BaseStream *stream); - void writeJstackResponse(io::BaseStream *stream); - void handleDescribe(io::BaseStream *stream); - void handleCommand(io::BaseStream *stream); + void handleStart(io::BaseStream &stream); + void handleStop(io::BaseStream &stream); + void handleClear(io::BaseStream &stream); + void handleUpdate(io::BaseStream &stream); + void writeQueueSizesResponse(io::BaseStream &stream); + void writeComponentsResponse(io::BaseStream &stream); + void writeConnectionsResponse(io::BaseStream &stream); + void writeGetFullResponse(io::BaseStream &stream); + void writeManifestResponse(io::BaseStream &stream); + void writeJstackResponse(io::BaseStream &stream); + void handleDescribe(io::BaseStream &stream); + asio::awaitable<void> handleCommand(std::unique_ptr<io::BaseStream> stream); + asio::awaitable<void> handshakeAndHandleCommand(asio::ip::tcp::socket&& socket, std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service); std::string getJstack(); + asio::awaitable<void> startAccept(); + asio::awaitable<void> startAcceptSsl(std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service); + void stopListener(); core::controller::ControllerServiceProvider& controller_; state::StateMonitor& update_sink_; - std::unique_ptr<io::BaseServerSocket> server_socket_; - std::shared_ptr<minifi::io::StreamFactory> stream_factory_; + + asio::io_context io_context_; + std::unique_ptr<asio::ip::tcp::acceptor> acceptor_; + std::thread server_thread_; + std::weak_ptr<ControllerSocketReporter> controller_socket_reporter_; std::shared_ptr<Configure> configuration_; std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<ControllerSocketProtocol>::getLogger(); diff --git a/libminifi/include/io/AsioStream.h b/libminifi/include/io/AsioStream.h new file mode 100644 index 000000000..55bae80cb --- /dev/null +++ b/libminifi/include/io/AsioStream.h @@ -0,0 +1,81 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include <string> +#include <memory> +#include <utility> + +#include "BaseStream.h" +#include "core/logging/LoggerFactory.h" +#include "asio/ts/internet.hpp" +#include "asio/read.hpp" +#include "asio/write.hpp" +#include "io/validation.h" + +namespace org::apache::nifi::minifi::io { + +template<typename AsioSocketStreamType> +class AsioStream : public io::BaseStream { + public: + explicit AsioStream(AsioSocketStreamType&& stream) : stream_(std::move(stream)) {} + + size_t read(std::span<std::byte> target_buffer) override; + size_t write(const uint8_t *source_buffer, size_t size) override; + + private: + AsioSocketStreamType stream_; + + std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<AsioStream<AsioSocketStreamType>>::getLogger(); +}; + +template<typename AsioSocketStreamType> +size_t AsioStream<AsioSocketStreamType>::read(std::span<std::byte> target_buffer) { + if (target_buffer.empty()) { + return 0; + } + + asio::error_code err; + auto read_bytes = stream_.read_some(asio::buffer(target_buffer.data(), target_buffer.size()), err); + if (err) { + return STREAM_ERROR; + } + + return read_bytes; +} + +template<typename AsioSocketStreamType> +size_t AsioStream<AsioSocketStreamType>::write(const uint8_t *source_buffer, size_t size) { + if (size == 0) { + return 0; + } + + if (IsNullOrEmpty(source_buffer)) { + return STREAM_ERROR; + } + + asio::error_code err; + auto bytes_written = asio::write(stream_, asio::buffer(source_buffer, size), asio::transfer_exactly(size), err); + if (err || bytes_written != size) { + return STREAM_ERROR; + } + + return bytes_written; +} + +} // namespace org::apache::nifi::minifi::io diff --git a/libminifi/include/utils/net/AsioSocketUtils.h b/libminifi/include/utils/net/AsioSocketUtils.h index a9173e082..fcd3c48d0 100644 --- a/libminifi/include/utils/net/AsioSocketUtils.h +++ b/libminifi/include/utils/net/AsioSocketUtils.h @@ -36,6 +36,8 @@ using HandshakeType = asio::ssl::stream_base::handshake_type; using TcpSocket = asio::ip::tcp::socket; using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; +constexpr auto MINIFI_SSL_OPTIONS = asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use; + class ConnectionId { public: ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), service_(std::move(port)) {} @@ -60,7 +62,7 @@ template<> asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration); -asio::ssl::context getClientSslContext(const controllers::SSLContextService& ssl_context_service); +asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method = asio::ssl::context::tlsv12_client); } // namespace org::apache::nifi::minifi::utils::net namespace std { diff --git a/libminifi/src/c2/ControllerSocketProtocol.cpp b/libminifi/src/c2/ControllerSocketProtocol.cpp index 0dcc354c9..94ee18acf 100644 --- a/libminifi/src/c2/ControllerSocketProtocol.cpp +++ b/libminifi/src/c2/ControllerSocketProtocol.cpp @@ -27,6 +27,10 @@ #include "utils/StringUtils.h" #include "c2/C2Payload.h" #include "properties/Configuration.h" +#include "io/AsioStream.h" +#include "asio/ssl/stream.hpp" +#include "asio/detached.hpp" +#include "utils/net/AsioSocketUtils.h" namespace org::apache::nifi::minifi::c2 { @@ -66,7 +70,60 @@ ControllerSocketProtocol::ControllerSocketProtocol(core::controller::ControllerS configuration_(std::move(configuration)), socket_restart_processor_(update_sink_) { gsl_Expects(configuration_); - stream_factory_ = minifi::io::StreamFactory::getInstance(configuration_); +} + +ControllerSocketProtocol::~ControllerSocketProtocol() { + stopListener(); +} + +void ControllerSocketProtocol::stopListener() { + io_context_.stop(); + if (acceptor_) { + acceptor_->close(); + } + if (server_thread_.joinable()) { + server_thread_.join(); + } + io_context_.restart(); +} + +asio::awaitable<void> ControllerSocketProtocol::startAccept() { + while (true) { + auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable); + if (accept_error) { + logger_->log_error("Controller socket accept failed with the following message: '%s'", accept_error.message()); + continue; + } + auto stream = std::make_unique<io::AsioStream<asio::ip::tcp::socket>>(std::move(socket)); + co_spawn(io_context_, handleCommand(std::move(stream)), asio::detached); + } +} + +asio::awaitable<void> ControllerSocketProtocol::handshakeAndHandleCommand(asio::ip::tcp::socket&& socket, std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service) { + asio::ssl::context ssl_context = utils::net::getSslContext(*ssl_context_service, asio::ssl::context::tlsv12_server); + ssl_context.set_options(utils::net::MINIFI_SSL_OPTIONS); + asio::ssl::stream<asio::ip::tcp::socket> ssl_socket(std::move(socket), ssl_context); + + auto [handshake_error] = co_await ssl_socket.async_handshake(utils::net::HandshakeType::server, utils::net::use_nothrow_awaitable); + if (handshake_error) { + logger_->log_error("Controller socket handshake failed with the following message: '%s'", handshake_error.message()); + co_return; + } + + auto stream = std::make_unique<io::AsioStream<asio::ssl::stream<asio::ip::tcp::socket>>>(std::move(ssl_socket)); + co_return co_await handleCommand(std::move(stream)); +} + +asio::awaitable<void> ControllerSocketProtocol::startAcceptSsl(std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service) { + while (true) { // NOLINT(clang-analyzer-core.NullDereference) suppressing asio library linter warning + auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable); + if (accept_error) { + logger_->log_error("Controller socket accept failed with the following message: '%s'", accept_error.message()); + continue; + } + + co_spawn(io_context_, handshakeAndHandleCommand(std::move(socket), ssl_context_service), asio::detached); + } } void ControllerSocketProtocol::initialize() { @@ -95,40 +152,29 @@ void ControllerSocketProtocol::initialize() { configuration_->get(Configuration::controller_socket_host, host); std::string port; + stopListener(); if (configuration_->get(Configuration::controller_socket_port, port)) { - if (nullptr != secure_context) { -#ifdef OPENSSL_SUPPORT - // if there is no openssl support we won't be using SSL - auto tls_context = std::make_shared<io::TLSContext>(configuration_, secure_context); - server_socket_ = std::unique_ptr<io::BaseServerSocket>(new io::TLSServerSocket(tls_context, host, std::stoi(port), 2)); -#else - server_socket_ = std::unique_ptr<io::BaseServerSocket>(new io::ServerSocket(nullptr, host, std::stoi(port), 2)); -#endif - } else { - server_socket_ = std::unique_ptr<io::BaseServerSocket>(new io::ServerSocket(nullptr, host, std::stoi(port), 2)); - } - // if we have a localhost hostname and we did not manually specify any.interface we will - // bind only to the loopback adapter + // if we have a localhost hostname and we did not manually specify any.interface we will bind only to the loopback adapter if ((host == "localhost" || host == "127.0.0.1" || host == "::") && !any_interface) { - server_socket_->initialize(true); + acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), std::stoi(port))); } else { - server_socket_->initialize(false); + acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), std::stoi(port))); } - auto check = [this]() -> bool { - return update_sink_.isRunning(); - }; - - auto handler = [this](io::BaseStream *stream) { - handleCommand(stream); - }; - server_socket_->registerCallback(check, handler); + if (secure_context) { + co_spawn(io_context_, startAcceptSsl(std::move(secure_context)), asio::detached); + } else { + co_spawn(io_context_, startAccept(), asio::detached); + } + server_thread_ = std::thread([this] { + io_context_.run(); + }); } } -void ControllerSocketProtocol::handleStart(io::BaseStream *stream) { +void ControllerSocketProtocol::handleStart(io::BaseStream &stream) { std::string component_str; - const auto size = stream->read(component_str); + const auto size = stream.read(component_str); if (!io::isError(size)) { if (component_str == "FlowController") { // Starting flow controller resets socket @@ -143,9 +189,9 @@ void ControllerSocketProtocol::handleStart(io::BaseStream *stream) { } } -void ControllerSocketProtocol::handleStop(io::BaseStream *stream) { +void ControllerSocketProtocol::handleStop(io::BaseStream &stream) { std::string component_str; - const auto size = stream->read(component_str); + const auto size = stream.read(component_str); if (!io::isError(size)) { update_sink_.executeOnComponent(component_str, [](state::StateController& component) { component.stop(); @@ -155,18 +201,18 @@ void ControllerSocketProtocol::handleStop(io::BaseStream *stream) { } } -void ControllerSocketProtocol::handleClear(io::BaseStream *stream) { +void ControllerSocketProtocol::handleClear(io::BaseStream &stream) { std::string connection; - const auto size = stream->read(connection); + const auto size = stream.read(connection); if (!io::isError(size)) { update_sink_.clearConnection(connection); } } -void ControllerSocketProtocol::handleUpdate(io::BaseStream *stream) { +void ControllerSocketProtocol::handleUpdate(io::BaseStream &stream) { std::string what; { - const auto size = stream->read(what); + const auto size = stream.read(what); if (io::isError(size)) { logger_->log_debug("Connection broke"); return; @@ -175,7 +221,7 @@ void ControllerSocketProtocol::handleUpdate(io::BaseStream *stream) { if (what == "flow") { std::string ff_loc; { - const auto size = stream->read(ff_loc); + const auto size = stream.read(ff_loc); if (io::isError(size)) { logger_->log_debug("Connection broke"); return; @@ -188,9 +234,9 @@ void ControllerSocketProtocol::handleUpdate(io::BaseStream *stream) { } } -void ControllerSocketProtocol::writeQueueSizesResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeQueueSizesResponse(io::BaseStream &stream) { std::string connection; - const auto size_ = stream->read(connection); + const auto size_ = stream.read(connection); if (io::isError(size_)) { logger_->log_debug("Connection broke"); return; @@ -209,10 +255,10 @@ void ControllerSocketProtocol::writeQueueSizesResponse(io::BaseStream *stream) { auto op = static_cast<uint8_t>(Operation::describe); resp.write(&op, 1); resp.write(response.str()); - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } -void ControllerSocketProtocol::writeComponentsResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeComponentsResponse(io::BaseStream &stream) { std::vector<std::pair<std::string, bool>> components; update_sink_.executeOnAllComponents([&components](state::StateController& component) { components.emplace_back(component.getComponentName(), component.isRunning()); @@ -226,10 +272,10 @@ void ControllerSocketProtocol::writeComponentsResponse(io::BaseStream *stream) { resp.write(is_running ? "true" : "false"); } - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } -void ControllerSocketProtocol::writeConnectionsResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeConnectionsResponse(io::BaseStream &stream) { io::BufferStream resp; auto op = static_cast<uint8_t>(Operation::describe); resp.write(&op, 1); @@ -243,10 +289,10 @@ void ControllerSocketProtocol::writeConnectionsResponse(io::BaseStream *stream) for (const auto &connection : connections) { resp.write(connection, false); } - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } -void ControllerSocketProtocol::writeGetFullResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeGetFullResponse(io::BaseStream &stream) { io::BufferStream resp; auto op = static_cast<uint8_t>(Operation::describe); resp.write(&op, 1); @@ -260,10 +306,10 @@ void ControllerSocketProtocol::writeGetFullResponse(io::BaseStream *stream) { for (const auto &connection : full_connections) { resp.write(connection, false); } - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } -void ControllerSocketProtocol::writeManifestResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeManifestResponse(io::BaseStream &stream) { io::BufferStream resp; auto op = static_cast<uint8_t>(Operation::describe); resp.write(&op, 1); @@ -272,7 +318,7 @@ void ControllerSocketProtocol::writeManifestResponse(io::BaseStream *stream) { manifest = controller_socket_reporter->getAgentManifest(); } resp.write(manifest, true); - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } std::string ControllerSocketProtocol::getJstack() { @@ -289,7 +335,7 @@ std::string ControllerSocketProtocol::getJstack() { return result.str(); } -void ControllerSocketProtocol::writeJstackResponse(io::BaseStream *stream) { +void ControllerSocketProtocol::writeJstackResponse(io::BaseStream &stream) { io::BufferStream resp; auto op = static_cast<uint8_t>(Operation::describe); resp.write(&op, 1); @@ -298,12 +344,12 @@ void ControllerSocketProtocol::writeJstackResponse(io::BaseStream *stream) { jstack_response = getJstack(); } resp.write(jstack_response, true); - stream->write(resp.getBuffer()); + stream.write(resp.getBuffer()); } -void ControllerSocketProtocol::handleDescribe(io::BaseStream *stream) { +void ControllerSocketProtocol::handleDescribe(io::BaseStream &stream) { std::string what; - const auto size = stream->read(what); + const auto size = stream.read(what); if (io::isError(size)) { logger_->log_debug("Connection broke"); return; @@ -325,34 +371,34 @@ void ControllerSocketProtocol::handleDescribe(io::BaseStream *stream) { } } -void ControllerSocketProtocol::handleCommand(io::BaseStream *stream) { +asio::awaitable<void> ControllerSocketProtocol::handleCommand(std::unique_ptr<io::BaseStream> stream) { uint8_t head; if (stream->read(head) != 1) { logger_->log_debug("Connection broke"); - return; + co_return; } if (socket_restart_processor_.isSocketRestarting()) { logger_->log_debug("Socket restarting, dropping command"); - return; + co_return; } auto op = static_cast<Operation>(head); switch (op) { case Operation::start: - handleStart(stream); + handleStart(*stream); break; case Operation::stop: - handleStop(stream); + handleStop(*stream); break; case Operation::clear: - handleClear(stream); + handleClear(*stream); break; case Operation::update: - handleUpdate(stream); + handleUpdate(*stream); break; case Operation::describe: - handleDescribe(stream); + handleDescribe(*stream); break; default: logger_->log_error("Unhandled C2 operation: %s", std::to_string(head)); diff --git a/libminifi/src/io/InputStream.cpp b/libminifi/src/io/InputStream.cpp index bf21a76d9..03aacc389 100644 --- a/libminifi/src/io/InputStream.cpp +++ b/libminifi/src/io/InputStream.cpp @@ -21,11 +21,7 @@ #include "io/InputStream.h" #include "utils/gsl.h" -namespace org { -namespace apache { -namespace nifi { -namespace minifi { -namespace io { +namespace org::apache::nifi::minifi::io { size_t InputStream::read(bool &value) { uint8_t buf = 0; @@ -71,18 +67,24 @@ size_t InputStream::read(std::string &str, bool widen) { return length_return; } - std::vector<std::byte> buffer(string_length); - const auto read_return = read(buffer); - if (read_return != string_length) { - return read_return; + str.clear(); + str.reserve(string_length); + + auto bytes_to_read = string_length; + auto zero_return_retry_count = 0; + while (bytes_to_read > 0) { + std::vector<std::byte> buffer(bytes_to_read); + const auto read_return = read(buffer); + if (io::isError(read_return)) + return read_return; + if (read_return == 0 && ++zero_return_retry_count > 3) { + return STREAM_ERROR; + } + bytes_to_read -= read_return; + str.append(std::string(reinterpret_cast<const char*>(buffer.data()), read_return)); } - str = std::string(reinterpret_cast<const char*>(buffer.data()), string_length); return length_return + string_length; } -} /* namespace io */ -} /* namespace minifi */ -} /* namespace nifi */ -} /* namespace apache */ -} /* namespace org */ +} // namespace org::apache::nifi::minifi::io diff --git a/libminifi/src/utils/net/AsioSocketUtils.cpp b/libminifi/src/utils/net/AsioSocketUtils.cpp index 15141fa0a..e5200183d 100644 --- a/libminifi/src/utils/net/AsioSocketUtils.cpp +++ b/libminifi/src/utils/net/AsioSocketUtils.cpp @@ -30,9 +30,9 @@ asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio:: co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT } -asio::ssl::context getClientSslContext(const controllers::SSLContextService& ssl_context_service) { - asio::ssl::context ssl_context(asio::ssl::context::tls_client); - ssl_context.set_options(asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); +asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method) { + asio::ssl::context ssl_context(ssl_context_method); + ssl_context.set_options(MINIFI_SSL_OPTIONS); if (const auto& ca_cert = ssl_context_service.getCACertificate(); !ca_cert.empty()) ssl_context.load_verify_file(ssl_context_service.getCACertificate().string()); ssl_context.set_verify_mode(asio::ssl::verify_peer); diff --git a/libminifi/src/utils/net/TcpServer.cpp b/libminifi/src/utils/net/TcpServer.cpp index a6b957bd5..cf2fe24e2 100644 --- a/libminifi/src/utils/net/TcpServer.cpp +++ b/libminifi/src/utils/net/TcpServer.cpp @@ -47,10 +47,18 @@ asio::awaitable<void> TcpServer::readLoop(auto& socket) { if (read_error || bytes_read == 0) co_return; - if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) - concurrent_queue_.enqueue(Message(read_message.substr(0, bytes_read - 1), IpProtocol::TCP, socket.lowest_layer().remote_endpoint().address(), socket.lowest_layer().local_endpoint().port())); - else + if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) { + std::error_code error; + auto remote_address = socket.lowest_layer().remote_endpoint(error).address(); + if (error) + logger_->log_debug("Error during fetching remote endpoint: %s", error.message()); + auto local_port = socket.lowest_layer().local_endpoint(error).port(); + if (error) + logger_->log_debug("Error during fetching local endpoint: %s", error.message()); + concurrent_queue_.enqueue(Message(read_message.substr(0, bytes_read - 1), IpProtocol::TCP, remote_address, local_port)); + } else { logger_->log_warn("Queue is full. TCP message ignored."); + } read_message.erase(0, bytes_read); } } @@ -62,7 +70,7 @@ asio::awaitable<void> TcpServer::insecureSession(asio::ip::tcp::socket socket) { namespace { asio::ssl::context setupSslContext(SslServerOptions& ssl_data) { asio::ssl::context ssl_context(asio::ssl::context::tlsv12_server); - ssl_context.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + ssl_context.set_options(minifi::utils::net::MINIFI_SSL_OPTIONS); ssl_context.set_password_callback([key_pw = ssl_data.cert_data.key_pw](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; }); ssl_context.use_certificate_file(ssl_data.cert_data.cert_loc.string(), asio::ssl::context::pem); ssl_context.use_private_key_file(ssl_data.cert_data.key_loc.string(), asio::ssl::context::pem); diff --git a/libminifi/test/unit/NetUtilsTest.cpp b/libminifi/test/unit/NetUtilsTest.cpp index 8b41d6e12..314be891c 100644 --- a/libminifi/test/unit/NetUtilsTest.cpp +++ b/libminifi/test/unit/NetUtilsTest.cpp @@ -66,7 +66,7 @@ TEST_CASE("net::reverseDnsLookup", "[net][dns][reverseDnsLookup]") { } } -TEST_CASE("utils::net::getClientSslContext") { +TEST_CASE("utils::net::getSslContext") { TestController controller; auto plan = controller.createPlan(); @@ -101,13 +101,13 @@ TEST_CASE("utils::net::getClientSslContext") { REQUIRE(ssl_context_service->setProperty(minifi::controllers::SSLContextService::CACertificate, (cert_dir / "alice_by_A_with_key.pem").string())); } REQUIRE_NOTHROW(plan->finalize()); - auto ssl_context = utils::net::getClientSslContext(*ssl_context_service); + auto ssl_context = utils::net::getSslContext(*ssl_context_service); asio::error_code verification_error; ssl_context.set_verify_mode(asio::ssl::verify_peer, verification_error); CHECK(!verification_error); } -TEST_CASE("utils::net::getClientSslContext passphrase problems") { +TEST_CASE("utils::net::getSslContext passphrase problems") { TestController controller; auto plan = controller.createPlan(); @@ -122,23 +122,23 @@ TEST_CASE("utils::net::getClientSslContext passphrase problems") { SECTION("Missing passphrase") { REQUIRE_NOTHROW(plan->finalize()); - REQUIRE_THROWS_WITH(utils::net::getClientSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); + REQUIRE_THROWS_WITH(utils::net::getSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); } SECTION("Invalid passphrase") { REQUIRE(ssl_context_service->setProperty(minifi::controllers::SSLContextService::Passphrase, "not_the_correct_passphrase")); REQUIRE_NOTHROW(plan->finalize()); - REQUIRE_THROWS_WITH(utils::net::getClientSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); + REQUIRE_THROWS_WITH(utils::net::getSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); } SECTION("Invalid passphrase file") { REQUIRE(ssl_context_service->setProperty(minifi::controllers::SSLContextService::Passphrase, (cert_dir / "alice_by_B.pem").string())); REQUIRE_NOTHROW(plan->finalize()); - REQUIRE_THROWS_WITH(utils::net::getClientSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); + REQUIRE_THROWS_WITH(utils::net::getSslContext(*ssl_context_service), "use_private_key_file: bad decrypt (Provider routines)"); } } -TEST_CASE("utils::net::getClientSslContext missing CA") { +TEST_CASE("utils::net::getSslContext missing CA") { TestController controller; auto plan = controller.createPlan(); @@ -151,7 +151,7 @@ TEST_CASE("utils::net::getClientSslContext missing CA") { REQUIRE(ssl_context_service->setProperty(minifi::controllers::SSLContextService::PrivateKey, (cert_dir / "alice.key").string())); REQUIRE_NOTHROW(plan->finalize()); - auto ssl_context = utils::net::getClientSslContext(*ssl_context_service); + auto ssl_context = utils::net::getSslContext(*ssl_context_service); asio::error_code verification_error; ssl_context.set_verify_mode(asio::ssl::verify_peer, verification_error); CHECK(!verification_error);