lidavidm commented on code in PR #34817:
URL: https://github.com/apache/arrow/pull/34817#discussion_r1381801937
##########
format/Flight.proto:
##########
@@ -503,3 +504,100 @@ message FlightData {
message PutResult {
bytes app_metadata = 1;
}
+
+/*
+ * Request message for the "Close Session" action.
+ *
+ * The exiting session is referenced via a cookie header.
+ */
+message CloseSessionRequest {
+ option (experimental) = true;
+}
+
+/*
+ * The result of closing a session.
+ *
+ * The result should be wrapped in a google.protobuf.Any message.
+ */
+message CloseSessionResult {
+ option (experimental) = true;
+
+ enum Status {
+ // The session close status is unknown. Servers should avoid using
+ // this value (send a NOT_FOUND error if the requested query is
+ // not known). Clients can retry the request.
+ CLOSE_RESULT_UNSPECIFIED = 0;
+ // The session close request is complete. Subsequent requests with
+ // a NOT_FOUND error.
+ CLOSE_RESULT_CLOSED = 1;
+ // The session close request is in progress. The client may retry
+ // the close request.
+ CLOSE_RESULT_CLOSING = 2;
+ // The session is not closeable. The client should not retry the
+ // close request.
+ CLOSE_RESULT_NOT_CLOSEABLE = 3;
+ }
+
+ Status status = 1;
+}
+
+message SessionOptionValue {
+ option (experimental) = true;
+
+ message StringListValue {
+ repeated string values = 1;
+ }
+
+ oneof option_value {
+ string string_value = 1;
+ bool bool_value = 2;
+ sfixed32 int32_value = 3;
+ sfixed64 int64_value = 4;
+ float float_value = 5;
+ double double_value = 6;
+ StringListValue string_list_value = 7;
+ }
+}
+
+message SetSessionOptionsRequest {
+ option (experimental) = true;
+
+ map<string, SessionOptionValue> session_options = 1;
+}
+
+message SetSessionOptionsResult {
+ option (experimental) = true;
+
+ enum Status {
+ // The status of setting the option is unknown. Servers should avoid using
+ // this value (send a NOT_FOUND error if the requested query is
+ // not known). Clients can retry the request.
+ SET_SESSION_OPTION_RESULT_UNSPECIFIED = 0;
+ // The session option setting completed successfully.
+ SET_SESSION_OPTION_RESULT_OK = 1;
+ // The given session option name was an alias for another option name.
+ SET_SESSION_OPTION_RESULT_OK_MAPPED = 2;
Review Comment:
What's the use of this?
##########
format/Flight.proto:
##########
@@ -503,3 +504,100 @@ message FlightData {
message PutResult {
bytes app_metadata = 1;
}
+
+/*
+ * Request message for the "Close Session" action.
+ *
+ * The exiting session is referenced via a cookie header.
+ */
+message CloseSessionRequest {
+ option (experimental) = true;
+}
+
+/*
+ * The result of closing a session.
+ *
+ * The result should be wrapped in a google.protobuf.Any message.
+ */
+message CloseSessionResult {
+ option (experimental) = true;
+
+ enum Status {
+ // The session close status is unknown. Servers should avoid using
+ // this value (send a NOT_FOUND error if the requested query is
+ // not known). Clients can retry the request.
+ CLOSE_RESULT_UNSPECIFIED = 0;
+ // The session close request is complete. Subsequent requests with
+ // a NOT_FOUND error.
+ CLOSE_RESULT_CLOSED = 1;
+ // The session close request is in progress. The client may retry
+ // the close request.
+ CLOSE_RESULT_CLOSING = 2;
+ // The session is not closeable. The client should not retry the
+ // close request.
+ CLOSE_RESULT_NOT_CLOSEABLE = 3;
+ }
+
+ Status status = 1;
+}
+
+message SessionOptionValue {
+ option (experimental) = true;
+
+ message StringListValue {
+ repeated string values = 1;
+ }
+
+ oneof option_value {
+ string string_value = 1;
+ bool bool_value = 2;
+ sfixed32 int32_value = 3;
+ sfixed64 int64_value = 4;
+ float float_value = 5;
+ double double_value = 6;
+ StringListValue string_list_value = 7;
+ }
+}
+
+message SetSessionOptionsRequest {
+ option (experimental) = true;
+
+ map<string, SessionOptionValue> session_options = 1;
+}
+
+message SetSessionOptionsResult {
+ option (experimental) = true;
+
+ enum Status {
+ // The status of setting the option is unknown. Servers should avoid using
+ // this value (send a NOT_FOUND error if the requested query is
+ // not known). Clients can retry the request.
+ SET_SESSION_OPTION_RESULT_UNSPECIFIED = 0;
+ // The session option setting completed successfully.
+ SET_SESSION_OPTION_RESULT_OK = 1;
+ // The given session option name was an alias for another option name.
+ SET_SESSION_OPTION_RESULT_OK_MAPPED = 2;
+ // The given session option name is invalid.
+ SET_SESSION_OPTION_RESULT_INVALID_NAME = 3;
+ // The session cannot be set to the given value.
+ SET_SESSION_OPTION_RESULT_INVALID_VALUE = 4;
+ // The session cannot be set.
+ SET_SESSION_OPTION_RESULT_ERROR = 5;
+ }
+
+ map<string, Status> statuses = 1;
Review Comment:
Would it make sense to provide a full message as the value type so we can
also insert an error message?
##########
cpp/src/gandiva/gdv_hash_function_stubs.cc:
##########
Review Comment:
Did you mean to touch these files?
##########
cpp/src/arrow/flight/types.h:
##########
@@ -742,6 +746,164 @@ struct ARROW_FLIGHT_EXPORT CancelFlightInfoRequest {
static arrow::Result<CancelFlightInfoRequest> Deserialize(std::string_view
serialized);
};
+/// \brief Variant supporting all possible value types for
{Set,Get}SessionOptions
+using SessionOptionValue = std::variant<std::string, bool, int32_t, int64_t,
float,
+ double, std::vector<std::string>>;
+
+/// \brief The result of setting a session option.
+enum class SetSessionOptionStatus : int8_t {
+ kUnspecified,
Review Comment:
It would be good to have the enum cases documented, too
##########
cpp/src/arrow/flight/types.cc:
##########
@@ -463,6 +463,318 @@ arrow::Result<CancelFlightInfoRequest>
CancelFlightInfoRequest::Deserialize(
return out;
}
+std::ostream& operator<<(std::ostream& os, const SetSessionOptionStatus& r) {
+ os << SetSessionOptionStatusNames[static_cast<int>(r)];
+ return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& r) {
+ os << CloseSessionStatusNames[static_cast<int>(r)];
+ return os;
+}
+
+// Helpers for stringifying maps containing various types
+std::ostream& operator<<(std::ostream& os, std::vector<std::string> v) {
+ os << '[';
+ std::string sep = "";
+ for (const auto& x : v) {
+ os << sep << '"' << x << '"';
+ sep = ", ";
+ }
+ os << ']';
+
+ return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const SessionOptionValue& v) {
+ std::visit([&](const auto& x) { os << x; }, v);
+ return os;
+}
+
+template <typename T>
+std::ostream& operator<<(std::ostream& os, std::map<std::string, T> m) {
+ os << '{';
+ std::string sep = "";
+ for (const auto& [k, v] : m) {
+ os << sep << '[' << k << "]: '" << v;
+ sep = ", ";
+ }
+ os << '}';
+
+ return os;
+}
+
+namespace {
+static bool CompareSessionOptionMaps(const std::map<std::string,
SessionOptionValue>& a,
+ const std::map<std::string,
SessionOptionValue>& b) {
+ if (a.size() != b.size()) {
+ return false;
+ }
+ for (const auto& [k, v] : a) {
+ if (!b.count(k)) {
+ return false;
+ }
+ try {
+ const auto& b_v = b.at(k);
+ if (v.index() != b_v.index()) {
+ return false;
+ }
+ if (v != b_v) {
+ return false;
+ }
+ } catch (const std::out_of_range& e) {
+ return false;
+ }
Review Comment:
You could use `find` to confirm whether the element is in the map, and then
access it if so, without having to use a try-catch with `at` (which should be
redundant anyways?)
##########
cpp/src/arrow/flight/client.cc:
##########
@@ -713,6 +713,49 @@ arrow::Result<FlightClient::DoExchangeResult>
FlightClient::DoExchange(
return result;
}
+::arrow::Result<SetSessionOptionsResult> FlightClient::SetSessionOptions(
+ const FlightCallOptions& options, const SetSessionOptionsRequest& request)
{
+ RETURN_NOT_OK(CheckOpen());
+ RETURN_NOT_OK(CheckOpen());
Review Comment:
duplicated (here and below)
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
Review Comment:
In general Arrow avoids using exceptions. You can always use `find` to check
whether an item exists and then access it.
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
Review Comment:
```suggestion
return Status::Invalid("Empty ", kSessionCookieName, " cookie
value");
```
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid("Invalid or expired " +
+ static_cast<std::string>(kSessionCookieName) +
" cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
+ *session_id = new_id;
+ auto session = std::make_shared<FlightSqlSession>();
+
+ const std::unique_lock<std::shared_mutex> l(session_store_lock_);
+ session_store_[new_id] = session;
+
+ return session;
+ }
+};
+
+ServerSessionMiddleware::ServerSessionMiddleware(ServerSessionMiddlewareFactory*
factory,
+ const CallHeaders& headers)
+ : factory_(factory), headers_(headers), existing_session(false) {}
+
+ServerSessionMiddleware::ServerSessionMiddleware(
+ ServerSessionMiddlewareFactory* factory, const CallHeaders& headers,
+ std::shared_ptr<FlightSqlSession> session, std::string session_id)
+ : factory_(factory),
+ headers_(headers),
+ session_(std::move(session)),
+ session_id_(std::move(session_id)),
+ existing_session(true) {}
+
+void ServerSessionMiddleware::SendingHeaders(AddCallHeaders* addCallHeaders) {
+ if (!existing_session && session_) {
+ addCallHeaders->AddHeader(
+ "set-cookie", static_cast<std::string>(kSessionCookieName) + "=" +
session_id_);
+ }
+}
+
+void ServerSessionMiddleware::CallCompleted(const Status&) {}
+
+bool ServerSessionMiddleware::HasSession() const { return
static_cast<bool>(session_); }
+
+std::shared_ptr<FlightSqlSession> ServerSessionMiddleware::GetSession() {
+ if (!session_) session_ = factory_->GetNewSession(&session_id_);
+ return session_;
+}
+
+const CallHeaders& ServerSessionMiddleware::GetCallHeaders() const { return
headers_; }
+
+std::shared_ptr<ServerMiddlewareFactory> MakeServerSessionMiddlewareFactory() {
+ return std::shared_ptr<ServerSessionMiddlewareFactory>(
+ new ServerSessionMiddlewareFactory());
Review Comment:
Use make_shared
##########
cpp/src/arrow/flight/sql/client.h:
##########
@@ -18,6 +18,7 @@
#pragma once
#include <cstdint>
+#include <map>
Review Comment:
nit: the header isn't directly used here
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid("Invalid or expired " +
+ static_cast<std::string>(kSessionCookieName) +
" cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
+ *session_id = new_id;
+ auto session = std::make_shared<FlightSqlSession>();
+
+ const std::unique_lock<std::shared_mutex> l(session_store_lock_);
+ session_store_[new_id] = session;
+
+ return session;
+ }
+};
+
+ServerSessionMiddleware::ServerSessionMiddleware(ServerSessionMiddlewareFactory*
factory,
+ const CallHeaders& headers)
+ : factory_(factory), headers_(headers), existing_session(false) {}
+
+ServerSessionMiddleware::ServerSessionMiddleware(
+ ServerSessionMiddlewareFactory* factory, const CallHeaders& headers,
+ std::shared_ptr<FlightSqlSession> session, std::string session_id)
+ : factory_(factory),
+ headers_(headers),
+ session_(std::move(session)),
+ session_id_(std::move(session_id)),
+ existing_session(true) {}
+
+void ServerSessionMiddleware::SendingHeaders(AddCallHeaders* addCallHeaders) {
+ if (!existing_session && session_) {
+ addCallHeaders->AddHeader(
+ "set-cookie", static_cast<std::string>(kSessionCookieName) + "=" +
session_id_);
+ }
+}
+
+void ServerSessionMiddleware::CallCompleted(const Status&) {}
+
+bool ServerSessionMiddleware::HasSession() const { return
static_cast<bool>(session_); }
+
+std::shared_ptr<FlightSqlSession> ServerSessionMiddleware::GetSession() {
+ if (!session_) session_ = factory_->GetNewSession(&session_id_);
+ return session_;
+}
+
+const CallHeaders& ServerSessionMiddleware::GetCallHeaders() const { return
headers_; }
+
+std::shared_ptr<ServerMiddlewareFactory> MakeServerSessionMiddlewareFactory() {
+ return std::shared_ptr<ServerSessionMiddlewareFactory>(
+ new ServerSessionMiddlewareFactory());
+}
+
+::arrow::Result<SessionOptionValue> FlightSqlSession::GetSessionOption(
Review Comment:
Might be semantically clearer to use `std::optional` if the only possible
error is that the key might not be found
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
Review Comment:
We should handle the case that `pair_sep` is not found
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid("Invalid or expired " +
+ static_cast<std::string>(kSessionCookieName) +
" cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
+ *session_id = new_id;
+ auto session = std::make_shared<FlightSqlSession>();
+
+ const std::unique_lock<std::shared_mutex> l(session_store_lock_);
+ session_store_[new_id] = session;
+
+ return session;
+ }
+};
+
+ServerSessionMiddleware::ServerSessionMiddleware(ServerSessionMiddlewareFactory*
factory,
+ const CallHeaders& headers)
+ : factory_(factory), headers_(headers), existing_session(false) {}
+
+ServerSessionMiddleware::ServerSessionMiddleware(
+ ServerSessionMiddlewareFactory* factory, const CallHeaders& headers,
+ std::shared_ptr<FlightSqlSession> session, std::string session_id)
+ : factory_(factory),
+ headers_(headers),
+ session_(std::move(session)),
+ session_id_(std::move(session_id)),
+ existing_session(true) {}
+
+void ServerSessionMiddleware::SendingHeaders(AddCallHeaders* addCallHeaders) {
Review Comment:
nit: Arrow uses snake_case for variable/parameter names
##########
cpp/src/arrow/flight/sql/server_session_middleware.h:
##########
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Middleware for handling Flight SQL Sessions including session cookie
handling.
+// Currently experimental.
+
+#pragma once
+
+#include <shared_mutex>
+#include <string_view>
+
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/sql/types.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+class ServerSessionMiddlewareFactory;
+
+static constexpr char const kSessionCookieName[] = "flight_sql_session_id";
+
+class FlightSqlSession {
+ protected:
+ std::map<std::string, SessionOptionValue> map_;
+ std::shared_mutex map_lock_;
+
+ public:
+ /// \brief Get session option by key
+ ::arrow::Result<SessionOptionValue> GetSessionOption(const std::string&);
+ /// \brief Set session option by key to given value
+ void SetSessionOption(const std::string&, const SessionOptionValue&);
Review Comment:
Take the value by value so it can be moved into the map?
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
Review Comment:
make_shared
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid("Invalid or expired " +
+ static_cast<std::string>(kSessionCookieName) +
" cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
+ *session_id = new_id;
+ auto session = std::make_shared<FlightSqlSession>();
+
+ const std::unique_lock<std::shared_mutex> l(session_store_lock_);
+ session_store_[new_id] = session;
+
+ return session;
+ }
+};
+
+ServerSessionMiddleware::ServerSessionMiddleware(ServerSessionMiddlewareFactory*
factory,
+ const CallHeaders& headers)
+ : factory_(factory), headers_(headers), existing_session(false) {}
+
+ServerSessionMiddleware::ServerSessionMiddleware(
+ ServerSessionMiddlewareFactory* factory, const CallHeaders& headers,
+ std::shared_ptr<FlightSqlSession> session, std::string session_id)
+ : factory_(factory),
+ headers_(headers),
+ session_(std::move(session)),
+ session_id_(std::move(session_id)),
+ existing_session(true) {}
+
+void ServerSessionMiddleware::SendingHeaders(AddCallHeaders* addCallHeaders) {
+ if (!existing_session && session_) {
+ addCallHeaders->AddHeader(
+ "set-cookie", static_cast<std::string>(kSessionCookieName) + "=" +
session_id_);
+ }
+}
+
+void ServerSessionMiddleware::CallCompleted(const Status&) {}
+
+bool ServerSessionMiddleware::HasSession() const { return
static_cast<bool>(session_); }
+
+std::shared_ptr<FlightSqlSession> ServerSessionMiddleware::GetSession() {
+ if (!session_) session_ = factory_->GetNewSession(&session_id_);
+ return session_;
+}
+
+const CallHeaders& ServerSessionMiddleware::GetCallHeaders() const { return
headers_; }
+
+std::shared_ptr<ServerMiddlewareFactory> MakeServerSessionMiddlewareFactory() {
+ return std::shared_ptr<ServerSessionMiddlewareFactory>(
+ new ServerSessionMiddlewareFactory());
+}
+
+::arrow::Result<SessionOptionValue> FlightSqlSession::GetSessionOption(
+ const std::string& k) {
+ const std::shared_lock<std::shared_mutex> l(map_lock_);
+ try {
+ return map_.at(k);
+ } catch (const std::out_of_range& e) {
+ return ::arrow::Status::KeyError("Session option key '" + k + "' not
found.");
+ }
Review Comment:
Same here, use a non-exception version
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
Review Comment:
I don't believe Boost is always available. If you want to use it, we'll have
to adjust the CMakeLists to make sure it's found, and adjust
ThirdpartyToolchain.cmake to reflect the new dependency.
If there's a small/vendorable UUID implementation available, that might be
preferable. CC @pitrou
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len,
std::string::npos));
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo&, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (session_id.empty())
+ return Status::Invalid("Empty " +
+
static_cast<std::string>(kSessionCookieName) +
+ " cookie value.");
+ }
+ }
+ if (!session_id.empty()) break;
+ }
+
+ if (session_id.empty()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers, session,
session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid("Invalid or expired " +
+ static_cast<std::string>(kSessionCookieName) +
" cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
Review Comment:
Is the generator thread-safe?
##########
cpp/src/arrow/flight/sql/server_session_middleware.cc:
##########
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <shared_mutex>
+#include "arrow/flight/sql/server_session_middleware.h"
+#include <boost/lexical_cast.hpp>
+#include <boost/uuid/uuid.hpp>
+#include <boost/uuid/uuid_generators.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+/// \brief A factory for ServerSessionMiddleware, itself storing session data.
+class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory {
+ protected:
+ std::map<std::string, std::shared_ptr<FlightSqlSession>> session_store_;
+ std::shared_mutex session_store_lock_;
+ boost::uuids::random_generator uuid_generator_;
+
+ std::vector<std::pair<std::string, std::string>> ParseCookieString(
+ const std::string_view& s) {
+ const std::string list_sep = "; ";
+ const std::string pair_sep = "=";
+ const size_t pair_sep_len = pair_sep.length();
+
+ std::vector<std::pair<std::string, std::string>> result;
+
+ size_t cur = 0;
+ while (cur < s.length()) {
+ const size_t end = s.find(list_sep, cur);
+ size_t len;
+ if (end == std::string::npos) {
+ // No (further) list delimiters
+ len = std::string::npos;
+ cur = s.length();
+ } else {
+ len = end - cur;
+ cur = end;
+ }
+ const std::string_view tok = s.substr(cur, len);
+
+ const size_t val_pos = tok.find(pair_sep);
+ result.emplace_back(
+ tok.substr(0, val_pos),
+ tok.substr(val_pos + pair_sep_len, std::string::npos)
+ );
+ }
+
+ return result;
+ }
+
+ public:
+ Status StartCall(const CallInfo &, const CallHeaders &incoming_headers,
+ std::shared_ptr<ServerMiddleware> *middleware) {
+ std::string session_id;
+
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
+ headers_it_pr = incoming_headers.equal_range("cookie");
+ for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) {
+ const std::string_view& cookie_header = itr->second;
+ const std::vector<std::pair<std::string, std::string>> cookies =
+ ParseCookieString(cookie_header);
+ for (const std::pair<std::string, std::string>& cookie : cookies) {
+ if (cookie.first == kSessionCookieName) {
+ session_id = cookie.second;
+ if (!session_id.length())
+ return Status::Invalid(
+ "Empty " + static_cast<std::string>(kSessionCookieName)
+ + " cookie value.");
+ }
+ }
+ if (session_id.length()) break;
+ }
+
+ if (!session_id.length()) {
+ // No cookie was found
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers));
+ } else {
+ try {
+ const std::shared_lock<std::shared_mutex> l(session_store_lock_);
+ auto session = session_store_.at(session_id);
+ *middleware = std::shared_ptr<ServerSessionMiddleware>(
+ new ServerSessionMiddleware(this, incoming_headers,
+ session, session_id));
+ } catch (std::out_of_range& e) {
+ return Status::Invalid(
+ "Invalid or expired "
+ + static_cast<std::string>(kSessionCookieName) + " cookie.");
+ }
+ }
+
+ return Status::OK();
+ }
+
+ /// \brief Get a new, empty session option map and its id key.
+ std::shared_ptr<FlightSqlSession> GetNewSession(std::string* session_id) {
+ std::string new_id = boost::lexical_cast<std::string>(uuid_generator_());
+ *session_id = new_id;
Review Comment:
Better yet, return `std::pair` instead of using an out parameter
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]