This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b81d779 feat(c/driver/postgresql): TimestampTz write (#868)
4b81d779 is described below

commit 4b81d77906024234d4100b2d24d60770e4b62dac
Author: William Ayd <[email protected]>
AuthorDate: Mon Jul 10 05:07:51 2023 -0700

    feat(c/driver/postgresql): TimestampTz write (#868)
    
    closes #867
    
    The lifecycle of setting the session time zone in this isn't great. I
    think should be handled during connect somehow? But not sure if postgres
    even supports that reading things like:
    
    https://stackoverflow.com/a/11779621/621736
    
    I think there is also something awry with the release callback for
    timezone schemas; needs further investigation
    
    ---------
    
    Co-authored-by: David Li <[email protected]>
---
 c/driver/postgresql/connection.h             |  1 +
 c/driver/postgresql/statement.cc             | 94 +++++++++++++++++++++-------
 c/driver/sqlite/sqlite_test.cc               |  3 +
 c/driver_manager/adbc_driver_manager_test.cc |  3 +
 c/validation/adbc_validation.cc              | 33 +++++++---
 c/validation/adbc_validation.h               |  4 +-
 6 files changed, 109 insertions(+), 29 deletions(-)

diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h
index 99770c21..74315ee0 100644
--- a/c/driver/postgresql/connection.h
+++ b/c/driver/postgresql/connection.h
@@ -54,6 +54,7 @@ class PostgresConnection {
   const std::shared_ptr<PostgresTypeResolver>& type_resolver() const {
     return type_resolver_;
   }
+  bool autocommit() const { return autocommit_; }
 
  private:
   std::shared_ptr<PostgresDatabase> database_;
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 73141362..4cae15b6 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -145,6 +145,9 @@ struct BindStream {
   // XXX: this assumes fixed-length fields only - will need more
   // consideration to deal with variable-length fields
 
+  bool has_tz_field = false;
+  std::string tz_setting;
+
   struct ArrowError na_error;
 
   explicit BindStream(struct ArrowArrayStream&& bind) {
@@ -217,12 +220,6 @@ struct BindStream {
           param_lengths[i] = 0;
           break;
         case ArrowType::NANOARROW_TYPE_TIMESTAMP:
-          if (strcmp("", bind_schema_fields[i].timezone)) {
-            SetError(error, "[libpq] Field #%" PRIi64 "%s%s%s",
-                     static_cast<int64_t>(i + 1), " (\"", 
bind_schema->children[i]->name,
-                     "\") has unsupported type code timestamp with timezone");
-            return ADBC_STATUS_NOT_IMPLEMENTED;
-          }
           type_id = PostgresTypeId::kTimestamp;
           param_lengths[i] = 8;
           break;
@@ -253,8 +250,49 @@ struct BindStream {
     return ADBC_STATUS_OK;
   }
 
-  AdbcStatusCode Prepare(PGconn* conn, const std::string& query,
-                         struct AdbcError* error) {
+  AdbcStatusCode Prepare(PGconn* conn, const std::string& query, struct 
AdbcError* error,
+                         const bool autocommit) {
+    // tz-aware timestamps require special handling to set the timezone to UTC
+    // prior to sending over the binary protocol; must be reset after execute
+    for (int64_t col = 0; col < bind_schema->n_children; col++) {
+      if ((bind_schema_fields[col].type == 
ArrowType::NANOARROW_TYPE_TIMESTAMP) &&
+          (strcmp("", bind_schema_fields[col].timezone))) {
+        has_tz_field = true;
+
+        if (autocommit) {
+          PGresult* begin_result = PQexec(conn, "BEGIN");
+          if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) {
+            SetError(error, "[libpq] Failed to begin transaction for timezone 
data: %s",
+                     PQerrorMessage(conn));
+            PQclear(begin_result);
+            return ADBC_STATUS_IO;
+          }
+          PQclear(begin_result);
+        }
+
+        PGresult* get_tz_result = PQexec(conn, "SELECT 
current_setting('TIMEZONE')");
+        if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) {
+          SetError(error, "[libpq] Could not query current timezone: %s",
+                   PQerrorMessage(conn));
+          PQclear(get_tz_result);
+          return ADBC_STATUS_IO;
+        }
+
+        tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0));
+        PQclear(get_tz_result);
+
+        PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'");
+        if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) {
+          SetError(error, "[libpq] Failed to set time zone to UTC: %s",
+                   PQerrorMessage(conn));
+          PQclear(set_utc_result);
+          return ADBC_STATUS_IO;
+        }
+        PQclear(set_utc_result);
+        break;
+      }
+    }
+
     PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(),
                                  /*nParams=*/bind_schema->n_children, 
param_types.data());
     if (PQresultStatus(result) != PGRES_COMMAND_OK) {
@@ -349,12 +387,6 @@ struct BindStream {
             }
             case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
               int64_t val = 
array_view->children[col]->buffer_views[1].data.as_int64[row];
-              if (strcmp("", bind_schema_fields[col].timezone)) {
-                SetError(error, "[libpq] Column #%" PRIi64 "%s%s%s", col + 1, 
" (\"",
-                         PQfname(result, col),
-                         "\") has unsupported type code timestamp with 
timezone");
-                return ADBC_STATUS_NOT_IMPLEMENTED;
-              }
 
               // 2000-01-01 00:00:00.000000 in microseconds
               constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
@@ -418,6 +450,26 @@ struct BindStream {
         PQclear(result);
       }
       if (rows_affected) *rows_affected += array->length;
+
+      if (has_tz_field) {
+        std::string reset_query = "SET TIME ZONE '" + tz_setting + "'";
+        PGresult* reset_tz_result = PQexec(conn, reset_query.c_str());
+        if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) {
+          SetError(error, "[libpq] Failed to reset time zone: %s", 
PQerrorMessage(conn));
+          PQclear(reset_tz_result);
+          return ADBC_STATUS_IO;
+        }
+        PQclear(reset_tz_result);
+
+        PGresult* commit_result = PQexec(conn, "COMMIT");
+        if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) {
+          SetError(error, "[libpq] Failed to commit transaction: %s",
+                   PQerrorMessage(conn));
+          PQclear(commit_result);
+          return ADBC_STATUS_IO;
+        }
+        PQclear(commit_result);
+      }
     }
     return ADBC_STATUS_OK;
   }
@@ -730,12 +782,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
         break;
       case ArrowType::NANOARROW_TYPE_TIMESTAMP:
         if (strcmp("", source_schema_fields[i].timezone)) {
-          SetError(error, "[libpq] Field #%" PRIi64 "%s%s%s", 
static_cast<int64_t>(i + 1),
-                   " (\"", source_schema.children[i]->name,
-                   "\") has unsupported type for ingestion timestamp with 
timezone");
-          return ADBC_STATUS_NOT_IMPLEMENTED;
+          create += " TIMESTAMPTZ";
+        } else {
+          create += " TIMESTAMP";
         }
-        create += " TIMESTAMP";
         break;
       default:
         SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
@@ -782,7 +832,8 @@ AdbcStatusCode PostgresStatement::ExecutePreparedStatement(
 
   RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error));
   RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
-  RAISE_ADBC(bind_stream.Prepare(connection_->conn(), query_, error));
+  RAISE_ADBC(
+      bind_stream.Prepare(connection_->conn(), query_, error, 
connection_->autocommit()));
   RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error));
   return ADBC_STATUS_OK;
 }
@@ -933,7 +984,8 @@ AdbcStatusCode 
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
   }
   insert += ")";
 
-  RAISE_ADBC(bind_stream.Prepare(connection_->conn(), insert, error));
+  RAISE_ADBC(
+      bind_stream.Prepare(connection_->conn(), insert, error, 
connection_->autocommit()));
   RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error));
   return ADBC_STATUS_OK;
 }
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index d245c1da..47df44e4 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -174,6 +174,9 @@ class SqliteStatementTest : public ::testing::Test,
   void TestSqlIngestTimestamp() {
     GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
   }
+  void TestSqlIngestTimestampTz() {
+    GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)";
+  }
 
  protected:
   SqliteQuirks quirks_;
diff --git a/c/driver_manager/adbc_driver_manager_test.cc 
b/c/driver_manager/adbc_driver_manager_test.cc
index d33114ee..4475bd19 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -229,6 +229,9 @@ class SqliteStatementTest : public ::testing::Test,
   void TestSqlIngestTimestamp() {
     GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
   }
+  void TestSqlIngestTimestampTz() {
+    GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)";
+  }
 
  protected:
   SqliteQuirks quirks_;
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 6a35d6f8..803f14c0 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1096,7 +1096,7 @@ void StatementTest::TestSqlIngestBinary() {
 }
 
 template <enum ArrowTimeUnit TU>
-void StatementTest::TestSqlIngestTemporalType() {
+void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
   if (!quirks()->supports_bulk_ingest()) {
     GTEST_SKIP();
   }
@@ -1113,8 +1113,7 @@ void StatementTest::TestSqlIngestTemporalType() {
   // changes to allow for various time units to be tested
   ArrowSchemaInit(&schema.value);
   ArrowSchemaSetTypeStruct(&schema.value, 1);
-  ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIMESTAMP, TU,
-                             /*timezone=*/nullptr);
+  ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIMESTAMP, 
TU, timezone);
   ArrowSchemaSetName(schema->children[0], "col");
   ASSERT_THAT(MakeBatch<int64_t>(&schema.value, &array.value, &na_error, 
values),
               IsOkErrno());
@@ -1145,6 +1144,10 @@ void StatementTest::TestSqlIngestTemporalType() {
                 ::testing::AnyOf(::testing::Eq(values.size()), 
::testing::Eq(-1)));
 
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+
+    // postgres does not receive/store/send the timezone, just the UTC integer
+    // value; we may still want to update CompareSchema to explicitly check 
for UTC
+    // with TIMESTAMP WITH TIMEZONE and naive for TIMESTAMP
     ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value,
                                           {{"col", NANOARROW_TYPE_TIMESTAMP, 
NULLABLE}}));
 
@@ -1168,10 +1171,26 @@ void StatementTest::TestSqlIngestTemporalType() {
 }
 
 void StatementTest::TestSqlIngestTimestamp() {
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>());
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>());
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>());
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>());
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>(nullptr));
+}
+
+void StatementTest::TestSqlIngestTimestampTz() {
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("UTC"));
+
+  ASSERT_NO_FATAL_FAILURE(
+      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
+  ASSERT_NO_FATAL_FAILURE(
+      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
+  ASSERT_NO_FATAL_FAILURE(
+      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
+  ASSERT_NO_FATAL_FAILURE(
+      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
 }
 
 void StatementTest::TestSqlIngestAppend() {
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 1650d064..2a5883c0 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -231,6 +231,7 @@ class StatementTest {
 
   // Temporal
   void TestSqlIngestTimestamp();
+  void TestSqlIngestTimestampTz();
 
   // ---- End Type-specific tests ----------------
 
@@ -274,7 +275,7 @@ class StatementTest {
   void TestSqlIngestNumericType(ArrowType type);
 
   template <enum ArrowTimeUnit TU>
-  void TestSqlIngestTemporalType();
+  void TestSqlIngestTemporalType(const char* timezone);
 };
 
 #define ADBCV_TEST_STATEMENT(FIXTURE)                                          
         \
@@ -295,6 +296,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); }                  
         \
   TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); }                  
         \
   TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); }            
         \
+  TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }        
         \
   TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }                  
         \
   TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); }                  
         \
   TEST_F(FIXTURE, SqlIngestMultipleConnections) { 
TestSqlIngestMultipleConnections(); } \

Reply via email to