pitrou commented on code in PR #13492:
URL: https://github.com/apache/arrow/pull/13492#discussion_r969881049


##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -627,28 +674,80 @@ class SQLiteFlightSqlServer::Impl {
     return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema());
   }
 
-  Status ExecuteSql(const std::string& sql) {
+  Status ExecuteSql(const std::string& sql) { return ExecuteSql(db_, sql); }
+
+  Status ExecuteSql(sqlite3* db, const std::string& sql) {
     char* err_msg = nullptr;
-    int rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &err_msg);
+    int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &err_msg);
     if (rc != SQLITE_OK) {
       std::string error_msg;
       if (err_msg != nullptr) {
         error_msg = err_msg;
+        sqlite3_free(err_msg);
       }
-      sqlite3_free(err_msg);
-      return Status::ExecutionError(error_msg);
+      return Status::IOError(error_msg);
     }
+    if (err_msg) sqlite3_free(err_msg);
     return Status::OK();
   }
+
+  arrow::Result<ActionBeginTransactionResult> BeginTransaction(
+      const ServerCallContext& context, const ActionBeginTransactionRequest& 
request) {
+    std::string handle = GenerateRandomString();
+    sqlite3* new_db = nullptr;
+    if (sqlite3_open_v2(db_uri_.c_str(), &new_db,
+                        SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | 
SQLITE_OPEN_URI,
+                        /*zVfs=*/nullptr) != SQLITE_OK) {
+      std::string error_message = "Can't open new connection: ";
+      if (new_db) {
+        error_message += sqlite3_errmsg(new_db);
+        sqlite3_close(new_db);
+      }
+      return Status::Invalid(error_message);
+    }
+
+    ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION"));
+
+    open_transactions_[handle] = new_db;

Review Comment:
   Should be mutex-protected?



##########
cpp/src/arrow/flight/integration_tests/test_integration.cc:
##########
@@ -785,10 +1037,292 @@ class FlightSqlScenario : public Scenario {
         AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows,
                  "Wrong number of updated rows for prepared statement 
ExecuteUpdate"));
     ARROW_RETURN_NOT_OK(update_prepared_statement->Close());
+    return Status::OK();
+  }
+};
+
+/// \brief Integration test scenario for validating the Substrait and
+///    transaction extensions to Flight SQL.
+class FlightSqlExtensionScenario : public FlightSqlScenario {
+ public:
+  Status RunClient(std::unique_ptr<FlightClient> client) override {
+    sql::FlightSqlClient sql_client(std::move(client));
+    Status status;
+    if (!(status = ValidateMetadataRetrieval(&sql_client)).ok()) {
+      return status.WithMessage("MetadataRetrieval failed: ", 
status.message());
+    }
+    if (!(status = ValidateStatementExecution(&sql_client)).ok()) {
+      return status.WithMessage("StatementExecution failed: ", 
status.message());
+    }
+    if (!(status = ValidatePreparedStatementExecution(&sql_client)).ok()) {
+      return status.WithMessage("PreparedStatementExecution failed: ", 
status.message());
+    }
+    if (!(status = ValidateTransactions(&sql_client)).ok()) {
+      return status.WithMessage("Transactions failed: ", status.message());
+    }
+    return Status::OK();
+  }
+
+  Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) {
+    std::unique_ptr<FlightInfo> info;
+    std::vector<int32_t> sql_info = {
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT,
+        sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT,
+    };
+    ARROW_ASSIGN_OR_RAISE(info, sql_client->GetSqlInfo({}, sql_info));
+    ARROW_ASSIGN_OR_RAISE(auto reader,
+                          sql_client->DoGet({}, info->endpoints()[0].ticket));
+
+    ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
+    if (!sql::SqlSchema::GetSqlInfoSchema()->Equals(*actual_schema,
+                                                    /*check_metadata=*/true)) {
+      return Status::Invalid("Schemas did not match. Expected:\n",
+                             *sql::SqlSchema::GetSqlInfoSchema(), 
"\nActual:\n",
+                             *actual_schema);
+    }
+
+    sql::SqlInfoResultMap info_values;
+    while (true) {
+      ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next());
+      if (!chunk.data) break;
+
+      const UInt32Array& info_name =
+          static_cast<const UInt32Array&>(*chunk.data->column(0));

Review Comment:
   Use `checked_cast` here and below?



##########
cpp/src/arrow/flight/sql/example/acero_server.cc:
##########
@@ -0,0 +1,296 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/example/acero_server.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+namespace acero_example {
+
+namespace {
+class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
+ public:
+  Status Init(const std::shared_ptr<Schema>& schema,
+              compute::BackpressureControl*) override {
+    schema_ = schema;
+    return Status::OK();
+  }
+  Status Consume(compute::ExecBatch exec_batch) override { return 
Status::OK(); }
+  Future<> Finish() override { return Status::OK(); }
+
+  const std::shared_ptr<Schema>& schema() const { return schema_; }
+
+ private:
+  std::shared_ptr<Schema> schema_;
+};
+
+class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
+ public:
+  QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {}
+
+  Status Init(const std::shared_ptr<Schema>& schema,
+              compute::BackpressureControl*) override {
+    schema_ = schema;
+    return Status::OK();
+  }
+
+  Status Consume(compute::ExecBatch exec_batch) override {
+    {
+      std::lock_guard<std::mutex> guard(mutex_);
+      batches_.push_back(std::move(exec_batch));
+      batches_added_.notify_all();
+    }
+
+    return Status::OK();
+  }
+
+  Future<> Finish() override {
+    {
+      std::lock_guard<std::mutex> guard(mutex_);
+      finished_ = true;
+      batches_added_.notify_all();
+    }
+
+    return Status::OK();
+  }
+
+  const std::shared_ptr<Schema>& schema() const { return schema_; }
+
+  arrow::Result<std::shared_ptr<RecordBatch>> Next() {
+    compute::ExecBatch batch;
+    {
+      std::unique_lock<std::mutex> guard(mutex_);
+      batches_added_.wait(guard, [this] { return !batches_.empty() || 
finished_; });
+
+      if (finished_ && batches_.empty()) {
+        return nullptr;
+      }
+      batch = std::move(batches_.front());
+      batches_.pop_front();
+    }
+
+    return batch.ToRecordBatch(schema_);
+  }
+
+ private:
+  std::mutex mutex_;
+  std::condition_variable batches_added_;
+  std::deque<compute::ExecBatch> batches_;
+  std::shared_ptr<Schema> schema_;
+  bool finished_;
+};
+
+class ConsumerBasedRecordBatchReader : public RecordBatchReader {
+ public:
+  explicit ConsumerBasedRecordBatchReader(
+      std::shared_ptr<compute::ExecPlan> plan,
+      std::shared_ptr<QueuingSinkNodeConsumer> consumer)
+      : plan_(std::move(plan)), consumer_(std::move(consumer)) {}
+
+  std::shared_ptr<Schema> schema() const override { return 
consumer_->schema(); }
+
+  Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+    return consumer_->Next().Value(batch);
+  }
+
+  // TODO(ARROW-17242): FlightDataStream needs to call Close()
+  Status Close() override { return plan_->finished().status(); }
+
+ private:
+  std::shared_ptr<compute::ExecPlan> plan_;
+  std::shared_ptr<QueuingSinkNodeConsumer> consumer_;
+};
+
+class AceroFlightSqlServer : public FlightSqlServerBase {
+ public:
+  AceroFlightSqlServer() {
+    RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
+                    SqlInfoResult(true));
+    
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
+                    SqlInfoResult(std::string("0.6.0")));
+    
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
+                    SqlInfoResult(std::string("0.6.0")));
+    RegisterSqlInfo(
+        SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
+        SqlInfoResult(
+            
SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE));
+    RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL,
+                    SqlInfoResult(false));
+  }
+
+  arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSubstraitPlan(
+      const ServerCallContext& context, const StatementSubstraitPlan& command,
+      const FlightDescriptor& descriptor) override {
+    if (!command.transaction_id.empty()) {
+      return Status::NotImplemented("Transactions are unsupported");
+    }
+
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> output_schema,
+                          GetPlanSchema(command.plan.plan));
+
+    ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output 
schema "
+                    << *output_schema;
+
+    ARROW_ASSIGN_OR_RAISE(auto ticket, 
CreateStatementQueryTicket(command.plan.plan));
+    std::vector<FlightEndpoint> endpoints{
+        FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}};
+    ARROW_ASSIGN_OR_RAISE(
+        auto info, FlightInfo::Make(*output_schema, descriptor, 
std::move(endpoints),
+                                    /*total_records=*/-1, /*total_bytes=*/-1));
+    return std::unique_ptr<FlightInfo>(new FlightInfo(std::move(info)));

Review Comment:
   Nit, but the logic to create the `FlightInfo` from an encoded substrait plan 
could perhaps be factored out in a dedicated helper method? (since 
`GetFlightInfoPreparedStatement` has the same logic inside)



##########
cpp/src/arrow/flight/sql/example/acero_server.cc:
##########
@@ -0,0 +1,296 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/example/acero_server.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+namespace acero_example {
+
+namespace {
+class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {

Review Comment:
   Would be nice to add docstrings/comments explaining each non-trivial helper 
class here.



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -237,23 +226,82 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) {
 }
 
 class SQLiteFlightSqlServer::Impl {
+ private:
   sqlite3* db_;
-  std::map<std::string, std::shared_ptr<SqliteStatement>> prepared_statements_;
+  std::string db_uri_;

Review Comment:
   Nit
   ```suggestion
     const std::string db_uri_;
   ```



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -237,23 +226,82 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) {
 }
 
 class SQLiteFlightSqlServer::Impl {
+ private:
   sqlite3* db_;
-  std::map<std::string, std::shared_ptr<SqliteStatement>> prepared_statements_;
+  std::string db_uri_;
+  std::mutex mutex_;
+  std::unordered_map<std::string, std::shared_ptr<SqliteStatement>> 
prepared_statements_;

Review Comment:
   I see that accesses to `prepared_statements_` are never mutex-protected, is 
it right?



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -627,28 +674,80 @@ class SQLiteFlightSqlServer::Impl {
     return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema());
   }
 
-  Status ExecuteSql(const std::string& sql) {
+  Status ExecuteSql(const std::string& sql) { return ExecuteSql(db_, sql); }
+
+  Status ExecuteSql(sqlite3* db, const std::string& sql) {
     char* err_msg = nullptr;
-    int rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &err_msg);
+    int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &err_msg);
     if (rc != SQLITE_OK) {
       std::string error_msg;
       if (err_msg != nullptr) {
         error_msg = err_msg;
+        sqlite3_free(err_msg);
       }
-      sqlite3_free(err_msg);
-      return Status::ExecutionError(error_msg);
+      return Status::IOError(error_msg);
     }
+    if (err_msg) sqlite3_free(err_msg);
     return Status::OK();
   }
+
+  arrow::Result<ActionBeginTransactionResult> BeginTransaction(
+      const ServerCallContext& context, const ActionBeginTransactionRequest& 
request) {
+    std::string handle = GenerateRandomString();
+    sqlite3* new_db = nullptr;
+    if (sqlite3_open_v2(db_uri_.c_str(), &new_db,
+                        SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | 
SQLITE_OPEN_URI,
+                        /*zVfs=*/nullptr) != SQLITE_OK) {
+      std::string error_message = "Can't open new connection: ";
+      if (new_db) {
+        error_message += sqlite3_errmsg(new_db);
+        sqlite3_close(new_db);
+      }
+      return Status::Invalid(error_message);
+    }
+
+    ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION"));
+
+    open_transactions_[handle] = new_db;
+    return ActionBeginTransactionResult{std::move(handle)};
+  }
+
+  Status EndTransaction(const ServerCallContext& context,
+                        const ActionEndTransactionRequest& request) {
+    std::lock_guard<std::mutex> guard(mutex_);
+    auto it = open_transactions_.find(request.transaction_id);
+    if (it == open_transactions_.end()) {
+      return Status::KeyError("Unknown transaction ID: ", 
request.transaction_id);
+    }
+
+    Status status;
+    if (request.action == ActionEndTransactionRequest::kCommit) {
+      status = ExecuteSql(it->second, "COMMIT");
+    } else {
+      status = ExecuteSql(it->second, "ROLLBACK");
+    }
+    sqlite3_close(it->second);

Review Comment:
   Not sure how efficient you want this to be, but you might release the lock 
around these lines (and call `open_transactions_.erase` before?).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to