This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 61a1ce00 refactor(c/driver/postgresql): Use Prepared Statement in
Result Helper (#714)
61a1ce00 is described below
commit 61a1ce00cad58392dd794924ca3e4747ae8d667f
Author: William Ayd <[email protected]>
AuthorDate: Tue May 30 10:15:28 2023 -0700
refactor(c/driver/postgresql): Use Prepared Statement in Result Helper
(#714)
---
c/driver/postgresql/connection.cc | 242 ++++++++++++++++++++------------------
1 file changed, 128 insertions(+), 114 deletions(-)
diff --git a/c/driver/postgresql/connection.cc
b/c/driver/postgresql/connection.cc
index 4730721b..e7fa5911 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -22,6 +22,7 @@
#include <cstring>
#include <memory>
#include <string>
+#include <utility>
#include <vector>
#include <adbc.h>
@@ -68,16 +69,52 @@ class PqResultRow {
// Helper to manager the lifecycle of a PQResult. The query argument
// will be evaluated as part of the constructor, with the desctructor handling
cleanup
-// Caller is responsible for calling the `Status()` method to ensure results
are
-// as expected prior to iterating
+// Caller must call Prepare then Execute, checking both for an OK
AdbcStatusCode
+// prior to iterating
class PqResultHelper {
public:
- PqResultHelper(PGconn* conn, const char* query) : conn_(conn) {
- query_ = std::string(query);
- result_ = PQexec(conn_, query_.c_str());
+ explicit PqResultHelper(PGconn* conn, std::string query, struct AdbcError*
error)
+ : conn_(conn), query_(std::move(query)), error_(error) {}
+
+ explicit PqResultHelper(PGconn* conn, std::string query,
+ std::vector<std::string> param_values, struct
AdbcError* error)
+ : conn_(conn),
+ query_(std::move(query)),
+ param_values_(param_values),
+ error_(error) {}
+
+ AdbcStatusCode Prepare() {
+ // TODO: make stmtName a unique identifier?
+ PGresult* result =
+ PQprepare(conn_, /*stmtName=*/"", query_.c_str(),
param_values_.size(), NULL);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error_, "[libpq] Failed to prepare query: %s\nQuery was:%s",
+ PQerrorMessage(conn_), query_.c_str());
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+
+ PQclear(result);
+ return ADBC_STATUS_OK;
}
- ExecStatusType Status() { return PQresultStatus(result_); }
+ AdbcStatusCode Execute() {
+ std::vector<const char*> param_c_strs;
+
+ for (auto index = 0; index < param_values_.size(); index++) {
+ param_c_strs.push_back(param_values_[index].c_str());
+ }
+
+ result_ = PQexecPrepared(conn_, "", param_values_.size(),
param_c_strs.data(), NULL,
+ NULL, 0);
+
+ if (PQresultStatus(result_) != PGRES_TUPLES_OK) {
+ SetError(error_, "[libpq] Failed to execute query: %s",
PQerrorMessage(conn_));
+ return ADBC_STATUS_IO;
+ }
+
+ return ADBC_STATUS_OK;
+ }
~PqResultHelper() {
if (result_ != nullptr) {
@@ -124,6 +161,8 @@ class PqResultHelper {
pg_result* result_ = nullptr;
PGconn* conn_;
std::string query_;
+ std::vector<std::string> param_values_;
+ struct AdbcError* error_;
};
class PqGetObjectsHelper {
@@ -146,15 +185,16 @@ class PqGetObjectsHelper {
}
AdbcStatusCode GetObjects() {
- PqResultHelper curr_db_helper = PqResultHelper{conn_, "SELECT
current_database()"};
- if (curr_db_helper.Status() == PGRES_TUPLES_OK) {
- assert(curr_db_helper.NumRows() == 1);
- auto curr_iter = curr_db_helper.begin();
- PqResultRow db_row = *curr_iter;
- current_db_ = std::string(db_row[0].data);
- } else {
- return ADBC_STATUS_INTERNAL;
- }
+ PqResultHelper curr_db_helper =
+ PqResultHelper{conn_, std::string("SELECT current_database()"),
error_};
+
+ RAISE_ADBC(curr_db_helper.Prepare());
+ RAISE_ADBC(curr_db_helper.Execute());
+
+ assert(curr_db_helper.NumRows() == 1);
+ auto curr_iter = curr_db_helper.begin();
+ PqResultRow db_row = *curr_iter;
+ current_db_ = std::string(db_row[0].data);
RAISE_ADBC(InitArrowArray());
@@ -197,41 +237,33 @@ class PqGetObjectsHelper {
return ADBC_STATUS_INTERNAL;
}
+ std::vector<std::string> params;
if (db_schema_ != NULL) {
- char* schema_name = PQescapeIdentifier(conn_, db_schema_,
strlen(db_schema_));
- if (schema_name == NULL) {
- SetError(error_, "%s%s", "Failed to escape schema: ",
PQerrorMessage(conn_));
+ if (StringBuilderAppend(&query, "%s", " AND nspname = $1")) {
StringBuilderReset(&query);
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
-
- int res =
- StringBuilderAppend(&query, "%s%s%s", " AND nspname ='",
schema_name, "'");
- PQfreemem(schema_name);
- if (res) {
return ADBC_STATUS_INTERNAL;
}
+ params.push_back(db_schema_);
}
- auto result_helper = PqResultHelper{conn_, query.buffer};
+ auto result_helper =
+ PqResultHelper{conn_, std::string(query.buffer), params, error_};
StringBuilderReset(&query);
- if (result_helper.Status() == PGRES_TUPLES_OK) {
- for (PqResultRow row : result_helper) {
- const char* schema_name = row[0].data;
- CHECK_NA(
- INTERNAL,
- ArrowArrayAppendString(db_schema_name_col_,
ArrowCharView(schema_name)),
- error_);
- if (depth_ >= ADBC_OBJECT_DEPTH_TABLES) {
- return ADBC_STATUS_NOT_IMPLEMENTED;
- } else {
- CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1),
error_);
- }
- CHECK_NA(INTERNAL,
ArrowArrayFinishElement(catalog_db_schemas_items_), error_);
+ RAISE_ADBC(result_helper.Prepare());
+ RAISE_ADBC(result_helper.Execute());
+
+ for (PqResultRow row : result_helper) {
+ const char* schema_name = row[0].data;
+ CHECK_NA(INTERNAL,
+ ArrowArrayAppendString(db_schema_name_col_,
ArrowCharView(schema_name)),
+ error_);
+ if (depth_ >= ADBC_OBJECT_DEPTH_TABLES) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ } else {
+ CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1),
error_);
}
- } else {
- return ADBC_STATUS_NOT_IMPLEMENTED;
+ CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items_),
error_);
}
}
@@ -247,40 +279,32 @@ class PqGetObjectsHelper {
return ADBC_STATUS_INTERNAL;
}
+ std::vector<std::string> params;
if (catalog_ != NULL) {
- char* catalog_name = PQescapeIdentifier(conn_, catalog_,
strlen(catalog_));
- if (catalog_name == NULL) {
- SetError(error_, "%s%s", "Failed to escape catalog: ",
PQerrorMessage(conn_));
+ if (StringBuilderAppend(&query, "%s", " WHERE datname = $1")) {
StringBuilderReset(&query);
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
-
- int res =
- StringBuilderAppend(&query, "%s%s%s", " WHERE datname = '",
catalog_name, "'");
- PQfreemem(catalog_name);
- if (res) {
return ADBC_STATUS_INTERNAL;
}
+ params.push_back(catalog_);
}
- PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
+ PqResultHelper result_helper =
+ PqResultHelper{conn_, std::string(query.buffer), params, error_};
StringBuilderReset(&query);
- if (result_helper.Status() == PGRES_TUPLES_OK) {
- for (PqResultRow row : result_helper) {
- const char* db_name = row[0].data;
- CHECK_NA(INTERNAL,
- ArrowArrayAppendString(catalog_name_col_,
ArrowCharView(db_name)),
- error_);
- if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) {
- CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1),
error_);
- } else {
- RAISE_ADBC(AppendSchemas(std::string(db_name)));
- }
- CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_);
+ RAISE_ADBC(result_helper.Prepare());
+ RAISE_ADBC(result_helper.Execute());
+
+ for (PqResultRow row : result_helper) {
+ const char* db_name = row[0].data;
+ CHECK_NA(INTERNAL,
+ ArrowArrayAppendString(catalog_name_col_,
ArrowCharView(db_name)), error_);
+ if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) {
+ CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1),
error_);
+ } else {
+ RAISE_ADBC(AppendSchemas(std::string(db_name)));
}
- } else {
- return ADBC_STATUS_INTERNAL;
+ CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_);
}
return ADBC_STATUS_OK;
@@ -430,6 +454,7 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const
char* catalog,
struct AdbcError* error) {
AdbcStatusCode final_status = ADBC_STATUS_OK;
struct StringBuilder query = {0};
+ std::vector<std::string> params;
if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return
ADBC_STATUS_INTERNAL;
if (StringBuilderAppend(
@@ -438,67 +463,56 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const
char* catalog,
"FROM pg_catalog.pg_class AS cls "
"INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid =
attr.attrelid "
"INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
- "WHERE attr.attnum >= 0 AND cls.oid = '") != 0)
+ "WHERE attr.attnum >= 0 AND cls.oid = ") != 0)
return ADBC_STATUS_INTERNAL;
if (db_schema != nullptr) {
- char* schema = PQescapeIdentifier(conn_, db_schema, strlen(db_schema));
- if (schema == NULL) {
- SetError(error, "%s%s", "Faled to escape schema: ",
PQerrorMessage(conn_));
- return ADBC_STATUS_INVALID_ARGUMENT;
+ if (StringBuilderAppend(&query, "%s", "$1.")) {
+ StringBuilderReset(&query);
+ return ADBC_STATUS_INTERNAL;
}
-
- int ret = StringBuilderAppend(&query, "%s%s", schema, ".");
- PQfreemem(schema);
-
- if (ret != 0) return ADBC_STATUS_INTERNAL;
+ params.push_back(db_schema);
}
- char* table = PQescapeIdentifier(conn_, table_name, strlen(table_name));
- if (table == NULL) {
- SetError(error, "%s%s", "Failed to escape table: ", PQerrorMessage(conn_));
- return ADBC_STATUS_INVALID_ARGUMENT;
+ if (StringBuilderAppend(&query, "%s%" PRIu64 "%s", "$",
+ static_cast<uint64_t>(params.size() + 1),
"::regclass::oid")) {
+ StringBuilderReset(&query);
+ return ADBC_STATUS_INTERNAL;
}
+ params.push_back(table_name);
- int ret = StringBuilderAppend(&query, "%s%s", table, "'::regclass::oid");
- PQfreemem(table);
-
- if (ret != 0) return ADBC_STATUS_INTERNAL;
-
- PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
+ PqResultHelper result_helper =
+ PqResultHelper{conn_, std::string(query.buffer), params, error};
StringBuilderReset(&query);
- if (result_helper.Status() != PGRES_TUPLES_OK) {
- SetError(error, "%s%s", "Failed to get table schema: ",
PQerrorMessage(conn_));
- final_status = ADBC_STATUS_IO;
- } else {
- auto uschema = nanoarrow::UniqueSchema();
- ArrowSchemaInit(uschema.get());
- CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(),
result_helper.NumRows()),
- error);
+ RAISE_ADBC(result_helper.Prepare());
+ RAISE_ADBC(result_helper.Execute());
- ArrowError na_error;
- int row_counter = 0;
- for (auto row : result_helper) {
- const char* colname = row[0].data;
- const Oid pg_oid = static_cast<uint32_t>(
- std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10));
-
- PostgresType pg_type;
- if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
- SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, "
(\"",
- colname, "\") has unknown type code ", pg_oid);
- final_status = ADBC_STATUS_NOT_IMPLEMENTED;
- goto loopExit;
- }
- CHECK_NA(INTERNAL,
-
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
- error);
- row_counter++;
+ auto uschema = nanoarrow::UniqueSchema();
+ ArrowSchemaInit(uschema.get());
+ CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(),
result_helper.NumRows()),
+ error);
+
+ ArrowError na_error;
+ int row_counter = 0;
+ for (auto row : result_helper) {
+ const char* colname = row[0].data;
+ const Oid pg_oid =
+ static_cast<uint32_t>(std::strtol(row[1].data, /*str_end=*/nullptr,
/*base=*/10));
+
+ PostgresType pg_type;
+ if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
+ SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, "
(\"", colname,
+ "\") has unknown type code ", pg_oid);
+ final_status = ADBC_STATUS_NOT_IMPLEMENTED;
+ break;
}
- uschema.move(schema);
+ CHECK_NA(INTERNAL,
+
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
+ error);
+ row_counter++;
}
-loopExit:
+ uschema.move(schema);
return final_status;
}