This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 0df8d3c4 fix(c/driver/postgresql): reset transaction after rollback 
(#1159)
0df8d3c4 is described below

commit 0df8d3c48270e2935785a7e98cff50e1b77cca6a
Author: David Li <[email protected]>
AuthorDate: Wed Oct 4 14:14:06 2023 -0400

    fix(c/driver/postgresql): reset transaction after rollback (#1159)
    
    Fixes #1158.
---
 c/driver/postgresql/connection.cc                  |  2 +-
 c/validation/adbc_validation.cc                    | 49 ++++++++++++++++++++++
 c/validation/adbc_validation.h                     |  3 ++
 .../recipe/postgresql_create_append_table.py       |  2 +
 python/adbc_driver_postgresql/tests/test_dbapi.py  |  4 ++
 5 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/c/driver/postgresql/connection.cc 
b/c/driver/postgresql/connection.cc
index a9f74058..d389a66c 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -629,7 +629,7 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* 
error) {
     return ADBC_STATUS_INVALID_STATE;
   }
 
-  PGresult* result = PQexec(conn_, "COMMIT");
+  PGresult* result = PQexec(conn_, "COMMIT; BEGIN TRANSACTION");
   if (PQresultStatus(result) != PGRES_COMMAND_OK) {
     AdbcStatusCode code = SetError(error, result, "%s%s",
                                    "[libpq] Failed to commit: ", 
PQerrorMessage(conn_));
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index cae3598d..0dedb8f6 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -3077,6 +3077,55 @@ void StatementTest::TestSqlQueryStrings() {
   ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
 }
 
+void StatementTest::TestSqlQueryInsertRollback() {
+  if (!quirks()->supports_transactions()) {
+    GTEST_SKIP();
+  }
+
+  ASSERT_THAT(quirks()->DropTable(&connection, "rollbacktest", &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcConnectionSetOption(&connection, 
ADBC_CONNECTION_OPTION_AUTOCOMMIT,
+                                      ADBC_OPTION_VALUE_DISABLED, &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  ASSERT_THAT(
+      AdbcStatementSetSqlQuery(&statement, "CREATE TABLE rollbacktest (a 
INT)", &error),
+      IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcConnectionCommit(&connection, &error), IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
+                                       "INSERT INTO rollbacktest (a) VALUES 
(1)", &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcConnectionRollback(&connection, &error), IsOkStatus(&error));
+
+  adbc_validation::StreamReader reader;
+  ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM 
rollbacktest", &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                        &reader.rows_affected, &error),
+              IsOkStatus(&error));
+
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+
+  int64_t total_rows = 0;
+  while (true) {
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    if (!reader.array->release) break;
+
+    total_rows += reader.array->length;
+  }
+
+  ASSERT_EQ(0, total_rows);
+}
+
 void StatementTest::TestSqlQueryCancel() {
   if (!quirks()->supports_cancel()) {
     GTEST_SKIP();
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index d1c23a03..df5e17ff 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -332,6 +332,8 @@ class StatementTest {
   void TestSqlQueryFloats();
   void TestSqlQueryStrings();
 
+  void TestSqlQueryInsertRollback();
+
   void TestSqlQueryCancel();
   void TestSqlQueryErrors();
 
@@ -420,6 +422,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); }                        
         \
   TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); }                    
         \
   TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); }                  
         \
+  TEST_F(FIXTURE, SqlQueryInsertRollback) { TestSqlQueryInsertRollback(); }    
         \
   TEST_F(FIXTURE, SqlQueryCancel) { TestSqlQueryCancel(); }                    
         \
   TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); }                    
         \
   TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); }                      
         \
diff --git a/docs/source/python/recipe/postgresql_create_append_table.py 
b/docs/source/python/recipe/postgresql_create_append_table.py
index a2f6258c..54331ba0 100644
--- a/docs/source/python/recipe/postgresql_create_append_table.py
+++ b/docs/source/python/recipe/postgresql_create_append_table.py
@@ -68,6 +68,8 @@ with conn.cursor() as cur:
     else:
         raise RuntimeError("Should have failed!")
 
+conn.rollback()
+
 #: Instead, we can append to the table.
 with conn.cursor() as cur:
     cur.adbc_ingest("example", data, mode="append")
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py 
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index 9b4b7451..1e317472 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -118,12 +118,16 @@ def test_query_cancel(postgres: dbapi.Connection) -> None:
         with pytest.raises(postgres.OperationalError, match="canceling 
statement"):
             cur.fetchone()
 
+    postgres.rollback()
+
     with postgres.cursor() as cur:
         cur.execute("SELECT * FROM test_batch_size")
         cur.adbc_cancel()
         with pytest.raises(postgres.OperationalError, match="canceling 
statement"):
             cur.fetch_arrow_table()
 
+    postgres.rollback()
+
     with postgres.cursor() as cur:
         cur.execute("SELECT * FROM test_batch_size")
         cur.adbc_cancel()

Reply via email to