This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9cce5991 fix(r/adbcdrivermanager): Improve handling of integer and
character list inputs (#1205)
9cce5991 is described below
commit 9cce5991a233141f6012b35c8623f59cbfbcacc2
Author: Dewey Dunnington <[email protected]>
AuthorDate: Tue Oct 17 09:23:07 2023 -0300
fix(r/adbcdrivermanager): Improve handling of integer and character list
inputs (#1205)
Closes #1127. This also introduces better error messages and fixes some
inputs that might have segfaulted (`NA_character_` input as a table
type):
``` r
library(adbcdrivermanager)
con <- adbc_driver_void() |>
adbc_database_init() |>
adbc_connection_init()
adbc_connection_get_objects(con, table_name = 5L)
#> Error in adbc_connection_get_objects(con, table_name = 5L): Expected
character(1) for conversion to const char*
adbc_connection_get_objects(con, table_name = NA_character_)
#> Error in adbc_connection_get_objects(con, table_name = NA_character_):
Can't convert NA_character_ to const char*
adbc_connection_get_objects(con, NA_integer_)
#> Error in adbc_connection_get_objects(con, NA_integer_): Can't convert
NA_integer_ to int
adbc_connection_get_objects(con, NA_real_)
#> Error in adbc_connection_get_objects(con, NA_real_): Can't convert
NA_real_ to int
```
<sup>Created on 2023-10-16 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
---
r/adbcdrivermanager/R/adbc.R | 4 +-
.../man/adbc_connection_get_info.Rd | 2 +-
r/adbcdrivermanager/src/radbc.cc | 31 +++-----
r/adbcdrivermanager/src/radbc.h | 89 +++++++++++++++++++++-
r/adbcdrivermanager/tests/testthat/test-radbc.R | 61 ++++++++++++++-
5 files changed, 157 insertions(+), 30 deletions(-)
diff --git a/r/adbcdrivermanager/R/adbc.R b/r/adbcdrivermanager/R/adbc.R
index 53756f41..85f9676b 100644
--- a/r/adbcdrivermanager/R/adbc.R
+++ b/r/adbcdrivermanager/R/adbc.R
@@ -210,13 +210,13 @@ adbc_connection_release <- function(connection) {
#' # (not implemented by the void driver)
#' try(adbc_connection_get_info(con, 0))
#'
-adbc_connection_get_info <- function(connection, info_codes) {
+adbc_connection_get_info <- function(connection, info_codes = NULL) {
error <- adbc_allocate_error()
out_stream <- nanoarrow::nanoarrow_allocate_array_stream()
status <- .Call(
RAdbcConnectionGetInfo,
connection,
- as.integer(info_codes),
+ info_codes,
out_stream,
error
)
diff --git a/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
b/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
index 12f5afa5..92acb78f 100644
--- a/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
+++ b/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
@@ -12,7 +12,7 @@
\alias{adbc_connection_rollback}
\title{Connection methods}
\usage{
-adbc_connection_get_info(connection, info_codes)
+adbc_connection_get_info(connection, info_codes = NULL)
adbc_connection_get_objects(
connection,
diff --git a/r/adbcdrivermanager/src/radbc.cc b/r/adbcdrivermanager/src/radbc.cc
index f5afdd9d..c10f05c5 100644
--- a/r/adbcdrivermanager/src/radbc.cc
+++ b/r/adbcdrivermanager/src/radbc.cc
@@ -20,6 +20,7 @@
#include <Rinternals.h>
#include <string.h>
+#include <utility>
#include <adbc.h>
#include "adbc_driver_manager.h"
@@ -280,10 +281,13 @@ extern "C" SEXP RAdbcConnectionGetInfo(SEXP
connection_xptr, SEXP info_codes_sex
auto connection = adbc_from_xptr<AdbcConnection>(connection_xptr);
auto error = adbc_from_xptr<AdbcError>(error_xptr);
auto out_stream = adbc_from_xptr<ArrowArrayStream>(out_stream_xptr);
- auto info_codes = reinterpret_cast<uint32_t*>(INTEGER(info_codes_sexp));
+ std::pair<SEXP, int*> info_codes = adbc_as_int_list(info_codes_sexp);
+ PROTECT(info_codes.first);
size_t info_codes_length = Rf_xlength(info_codes_sexp);
int status =
- AdbcConnectionGetInfo(connection, info_codes, info_codes_length,
out_stream, error);
+ AdbcConnectionGetInfo(connection,
reinterpret_cast<uint32_t*>(info_codes.second),
+ info_codes_length, out_stream, error);
+ UNPROTECT(1);
return adbc_wrap_status(status);
}
@@ -297,25 +301,8 @@ extern "C" SEXP RAdbcConnectionGetObjects(SEXP
connection_xptr, SEXP depth_sexp,
const char* catalog = adbc_as_const_char(catalog_sexp, true);
const char* db_schema = adbc_as_const_char(db_schema_sexp, true);
const char* table_name = adbc_as_const_char(table_name_sexp, true);
-
- // Build the null-terminated const char** used to filter by table type
- int table_type_length = Rf_length(table_type_sexp);
- SEXP table_type_shelter =
- PROTECT(Rf_allocVector(RAWSXP, (table_type_length + 1) * sizeof(const
char*)));
- auto table_type = reinterpret_cast<const char**>(RAW(table_type_shelter));
- for (int i = 0; i < table_type_length; i++) {
- table_type[i] = Rf_translateCharUTF8(STRING_ELT(table_type_sexp, i));
- }
- table_type[table_type_length] = nullptr;
-
- // Ensure that R_NilValue maps to null and not a null-termianted const char**
- // of length 0.
- const char** table_type_maybe_null;
- if (table_type_sexp == R_NilValue) {
- table_type_maybe_null = nullptr;
- } else {
- table_type_maybe_null = table_type;
- }
+ std::pair<SEXP, const char**> table_type =
adbc_as_const_char_list(table_type_sexp);
+ PROTECT(table_type.first);
const char* column_name = adbc_as_const_char(column_name_sexp, true);
auto out_stream = adbc_from_xptr<ArrowArrayStream>(out_stream_xptr);
@@ -323,7 +310,7 @@ extern "C" SEXP RAdbcConnectionGetObjects(SEXP
connection_xptr, SEXP depth_sexp,
int status =
AdbcConnectionGetObjects(connection, depth, catalog, db_schema,
table_name,
- table_type_maybe_null, column_name, out_stream,
error);
+ table_type.second, column_name, out_stream,
error);
UNPROTECT(1);
return adbc_wrap_status(status);
}
diff --git a/r/adbcdrivermanager/src/radbc.h b/r/adbcdrivermanager/src/radbc.h
index 9c20686d..fa9fb5ff 100644
--- a/r/adbcdrivermanager/src/radbc.h
+++ b/r/adbcdrivermanager/src/radbc.h
@@ -20,6 +20,8 @@
#include <R.h>
#include <Rinternals.h>
+#include <utility>
+
template <typename T>
static inline const char* adbc_xptr_class();
@@ -151,16 +153,95 @@ static inline const char* adbc_as_const_char(SEXP sexp,
bool nullable = false) {
static inline int adbc_as_int(SEXP sexp) {
if (Rf_length(sexp) == 1) {
switch (TYPEOF(sexp)) {
- case REALSXP:
- return REAL(sexp)[0];
- case INTSXP:
- return INTEGER(sexp)[0];
+ case REALSXP: {
+ double value = REAL(sexp)[0];
+ if (ISNA(value) || ISNAN(value)) {
+ Rf_error("Can't convert NA_real_ to int");
+ }
+
+ return value;
+ }
+
+ case INTSXP: {
+ int value = INTEGER(sexp)[0];
+ if (value == NA_INTEGER) {
+ Rf_error("Can't convert NA_integer_ to int");
+ }
+
+ return value;
+ }
}
}
Rf_error("Expected integer(1) or double(1) for conversion to int");
}
+static inline std::pair<SEXP, const char**> adbc_as_const_char_list(SEXP sexp)
{
+ switch (TYPEOF(sexp)) {
+ case NILSXP:
+ return {R_NilValue, nullptr};
+ case STRSXP:
+ break;
+ default:
+ Rf_error("Expected character() for conversion to const char**");
+ }
+
+ int sexp_length = Rf_length(sexp);
+ SEXP result_shelter =
+ PROTECT(Rf_allocVector(RAWSXP, (sexp_length + 1) * sizeof(const char*)));
+ auto result = reinterpret_cast<const char**>(RAW(result_shelter));
+ for (int i = 0; i < sexp_length; i++) {
+ SEXP item = STRING_ELT(sexp, i);
+ if (item == NA_STRING) {
+ Rf_error("Can't convert NA_character_ element to const char*");
+ }
+
+ result[i] = Rf_translateCharUTF8(STRING_ELT(sexp, i));
+ }
+ result[sexp_length] = nullptr;
+ UNPROTECT(1);
+ return {result_shelter, result};
+}
+
+static inline std::pair<SEXP, int*> adbc_as_int_list(SEXP sexp) {
+ int result_length = Rf_length(sexp);
+
+ switch (TYPEOF(sexp)) {
+ case NILSXP:
+ return {R_NilValue, nullptr};
+
+ case INTSXP: {
+ int* result = INTEGER(sexp);
+ for (int i = 0; i < result_length; i++) {
+ if (result[i] == NA_INTEGER) {
+ Rf_error("Can't convert NA_integer_ element to int");
+ }
+ }
+
+ return {sexp, result};
+ }
+
+ case REALSXP: {
+ SEXP result_shelter = PROTECT(Rf_allocVector(INTSXP, result_length));
+ int* result = INTEGER(result_shelter);
+ for (int i = 0; i < result_length; i++) {
+ double item = REAL(sexp)[i];
+ if (ISNA(item) || ISNAN(item)) {
+ Rf_error("Can't convert NA_real_ or NaN element to int");
+ }
+
+ result[i] = item;
+ }
+
+ UNPROTECT(1);
+ return {result_shelter, result};
+ }
+
+ default:
+ Rf_error("Expected character for conversion to const char**");
+ }
+}
+
static inline SEXP adbc_wrap_status(AdbcStatusCode code) {
return Rf_ScalarInteger(code);
}
diff --git a/r/adbcdrivermanager/tests/testthat/test-radbc.R
b/r/adbcdrivermanager/tests/testthat/test-radbc.R
index 1ee83660..3d6e6522 100644
--- a/r/adbcdrivermanager/tests/testthat/test-radbc.R
+++ b/r/adbcdrivermanager/tests/testthat/test-radbc.R
@@ -39,6 +39,23 @@ test_that("connection methods work for the void driver", {
"NOT_IMPLEMENTED"
)
+ expect_error(
+ adbc_connection_get_info(con, double()),
+ "NOT_IMPLEMENTED"
+ )
+
+ expect_error(
+ adbc_connection_get_info(con, NULL),
+ "NOT_IMPLEMENTED"
+ )
+
+ # With defaults of NULL/OL
+ expect_error(
+ adbc_connection_get_objects(con),
+ "NOT_IMPLEMENTED"
+ )
+
+ # With explicit args
expect_error(
adbc_connection_get_objects(
con, 0,
@@ -155,7 +172,7 @@ test_that("invalid parameter types generate errors", {
expect_error(
adbc_connection_get_objects(
- con, NULL,
+ con, character(),
"catalog", "db_schema",
"table_name", "table_type", "column_name"
),
@@ -163,6 +180,33 @@ test_that("invalid parameter types generate errors", {
fixed = TRUE
)
+ expect_error(
+ adbc_connection_get_objects(
+ con, NA_integer_,
+ "catalog", "db_schema",
+ "table_name", "table_type", "column_name"
+ ),
+ "Can't convert NA_integer_"
+ )
+
+ expect_error(
+ adbc_connection_get_objects(
+ con, NA_real_,
+ "catalog", "db_schema",
+ "table_name", "table_type", "column_name"
+ ),
+ "Can't convert NA_real_"
+ )
+
+ expect_error(
+ adbc_connection_get_objects(
+ con, 0L,
+ "catalog", "db_schema",
+ "table_name", c("table_type1", NA_character_), "column_name"
+ ),
+ "Can't convert NA_character_ element"
+ )
+
expect_error(
adbc_statement_set_sql_query(stmt, NULL),
"Expected character(1)",
@@ -174,6 +218,21 @@ test_that("invalid parameter types generate errors", {
"Can't convert NA_character_"
)
+ expect_error(
+ adbc_connection_get_info(con, NA_integer_),
+ "Can't convert NA_integer_ element"
+ )
+
+ expect_error(
+ adbc_connection_get_info(con, NA_real_),
+ "Can't convert NA_real_ or NaN element"
+ )
+
+ expect_error(
+ adbc_connection_get_info(con, NaN),
+ "Can't convert NA_real_ or NaN element"
+ )
+
# (makes a NULL xptr)
stmt2 <- unserialize(serialize(stmt, NULL))
expect_error(