This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 70904dffef ARROW-16703: [R] Refactor map_batches() so it can stream
results (#13650)
70904dffef is described below
commit 70904dffef25a8c883a1a829c66a1d30a7d9c249
Author: Dewey Dunnington <[email protected]>
AuthorDate: Sat Jul 23 09:09:44 2022 -0300
ARROW-16703: [R] Refactor map_batches() so it can stream results (#13650)
Authored-by: Dewey Dunnington <[email protected]>
Signed-off-by: Dewey Dunnington <[email protected]>
---
r/NAMESPACE | 1 +
r/R/arrowExports.R | 4 ++
r/R/dataset-scan.R | 68 ++++++++++++++++++-----
r/R/record-batch-reader.R | 9 +++
r/man/as_record_batch_reader.Rd | 6 ++
r/man/map_batches.Rd | 9 ++-
r/src/arrowExports.cpp | 10 ++++
r/src/recordbatchreader.cpp | 45 +++++++++++++++
r/src/safe-call-into-r-impl.cpp | 2 +-
r/src/safe-call-into-r.h | 4 +-
r/tests/testthat/test-dataset-dplyr.R | 4 ++
r/tests/testthat/test-dataset.R | 86 +++++++++++++++++++++++++++++
r/tests/testthat/test-record-batch-reader.R | 33 +++++++++++
13 files changed, 264 insertions(+), 17 deletions(-)
diff --git a/r/NAMESPACE b/r/NAMESPACE
index a8c8a974d6..733261f33c 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -67,6 +67,7 @@ S3method(as_record_batch,arrow_dplyr_query)
S3method(as_record_batch,data.frame)
S3method(as_record_batch,pyarrow.lib.RecordBatch)
S3method(as_record_batch,pyarrow.lib.Table)
+S3method(as_record_batch_reader,"function")
S3method(as_record_batch_reader,Dataset)
S3method(as_record_batch_reader,RecordBatch)
S3method(as_record_batch_reader,RecordBatchReader)
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 7cd2c5dbfc..ab3358d666 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -1740,6 +1740,10 @@ RecordBatchReader__from_batches <- function(batches,
schema_sxp) {
.Call(`_arrow_RecordBatchReader__from_batches`, batches, schema_sxp)
}
+RecordBatchReader__from_function <- function(fun_sexp, schema) {
+ .Call(`_arrow_RecordBatchReader__from_function`, fun_sexp, schema)
+}
+
RecordBatchReader__from_Table <- function(table) {
.Call(`_arrow_RecordBatchReader__from_Table`, table)
}
diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R
index cca92b676f..53fe7078c2 100644
--- a/r/R/dataset-scan.R
+++ b/r/R/dataset-scan.R
@@ -183,11 +183,16 @@ tail_from_batches <- function(batches, n) {
#' @param FUN A function or `purrr`-style lambda expression to apply to each
#' batch. It must return a RecordBatch or something coercible to one via
#' `as_record_batch()'.
+#' @param .schema An optional [schema()]. If NULL, the schema will be inferred
+#' from the first batch.
+#' @param .lazy Use `TRUE` to evaluate `FUN` lazily as batches are read from
+#' the result; use `FALSE` to evaluate `FUN` on all batches before returning
+#' the reader.
#' @param ... Additional arguments passed to `FUN`
#' @param .data.frame Deprecated argument, ignored
#' @return An `arrow_dplyr_query`.
#' @export
-map_batches <- function(X, FUN, ..., .data.frame = NULL) {
+map_batches <- function(X, FUN, ..., .schema = NULL, .lazy = FALSE,
.data.frame = NULL) {
if (!is.null(.data.frame)) {
warning(
"The .data.frame argument is deprecated. ",
@@ -197,25 +202,60 @@ map_batches <- function(X, FUN, ..., .data.frame = NULL) {
}
FUN <- as_mapper(FUN)
reader <- as_record_batch_reader(X)
+ dots <- rlang::list2(...)
- # TODO: for future consideration
- # * Move eval to C++ and make it a generator so it can stream, not block
- # * Accept an output schema argument: with that, we could make this lazy
(via collapse)
- batch <- reader$read_next_batch()
- res <- vector("list", 1024)
- i <- 0L
- while (!is.null(batch)) {
- i <- i + 1L
- res[[i]] <- as_record_batch(FUN(batch, ...))
+ # If no schema is supplied, we have to evaluate the first batch here
+ if (is.null(.schema)) {
batch <- reader$read_next_batch()
+ if (is.null(batch)) {
+ abort("Can't infer schema from a RecordBatchReader with zero batches")
+ }
+
+ first_result <- as_record_batch(do.call(FUN, c(list(batch), dots)))
+ .schema <- first_result$schema
+ fun <- function() {
+ if (!is.null(first_result)) {
+ result <- first_result
+ first_result <<- NULL
+ result
+ } else {
+ batch <- reader$read_next_batch()
+ if (is.null(batch)) {
+ NULL
+ } else {
+ as_record_batch(
+ do.call(FUN, c(list(batch), dots)),
+ schema = .schema
+ )
+ }
+ }
+ }
+ } else {
+ fun <- function() {
+ batch <- reader$read_next_batch()
+ if (is.null(batch)) {
+ return(NULL)
+ }
+
+ as_record_batch(
+ do.call(FUN, c(list(batch), dots)),
+ schema = .schema
+ )
+ }
}
- # Trim list back
- if (i < length(res)) {
- res <- res[seq_len(i)]
+ reader_out <- as_record_batch_reader(fun, schema = .schema)
+
+ # TODO(ARROW-17178) because there are some restrictions on evaluating
+ # reader_out in some ExecPlans, the default .lazy is FALSE for now.
+ if (!.lazy) {
+ reader_out <- RecordBatchReader$create(
+ batches = reader_out$batches(),
+ schema = .schema
+ )
}
- RecordBatchReader$create(batches = res)
+ reader_out
}
#' @usage NULL
diff --git a/r/R/record-batch-reader.R b/r/R/record-batch-reader.R
index 8f6a600dfb..3a985d8abc 100644
--- a/r/R/record-batch-reader.R
+++ b/r/R/record-batch-reader.R
@@ -191,6 +191,8 @@ RecordBatchFileReader$create <- function(file) {
#' Convert an object to an Arrow RecordBatchReader
#'
#' @param x An object to convert to a [RecordBatchReader]
+#' @param schema The [schema()] that must match the schema returned by each
+#' call to `x` when `x` is a function.
#' @param ... Passed to S3 methods
#'
#' @return A [RecordBatchReader]
@@ -234,6 +236,13 @@ as_record_batch_reader.Dataset <- function(x, ...) {
Scanner$create(x)$ToRecordBatchReader()
}
+#' @rdname as_record_batch_reader
+#' @export
+as_record_batch_reader.function <- function(x, ..., schema) {
+ assert_that(inherits(schema, "Schema"))
+ RecordBatchReader__from_function(x, schema)
+}
+
#' @rdname as_record_batch_reader
#' @export
as_record_batch_reader.arrow_dplyr_query <- function(x, ...) {
diff --git a/r/man/as_record_batch_reader.Rd b/r/man/as_record_batch_reader.Rd
index e635c0b98b..2ed5435476 100644
--- a/r/man/as_record_batch_reader.Rd
+++ b/r/man/as_record_batch_reader.Rd
@@ -7,6 +7,7 @@
\alias{as_record_batch_reader.RecordBatch}
\alias{as_record_batch_reader.data.frame}
\alias{as_record_batch_reader.Dataset}
+\alias{as_record_batch_reader.function}
\alias{as_record_batch_reader.arrow_dplyr_query}
\alias{as_record_batch_reader.Scanner}
\title{Convert an object to an Arrow RecordBatchReader}
@@ -23,6 +24,8 @@ as_record_batch_reader(x, ...)
\method{as_record_batch_reader}{Dataset}(x, ...)
+\method{as_record_batch_reader}{`function`}(x, ..., schema)
+
\method{as_record_batch_reader}{arrow_dplyr_query}(x, ...)
\method{as_record_batch_reader}{Scanner}(x, ...)
@@ -31,6 +34,9 @@ as_record_batch_reader(x, ...)
\item{x}{An object to convert to a \link{RecordBatchReader}}
\item{...}{Passed to S3 methods}
+
+\item{schema}{The \code{\link[=schema]{schema()}} that must match the schema
returned by each
+call to \code{x} when \code{x} is a function.}
}
\value{
A \link{RecordBatchReader}
diff --git a/r/man/map_batches.Rd b/r/man/map_batches.Rd
index eaeab6013a..0e4d48e024 100644
--- a/r/man/map_batches.Rd
+++ b/r/man/map_batches.Rd
@@ -4,7 +4,7 @@
\alias{map_batches}
\title{Apply a function to a stream of RecordBatches}
\usage{
-map_batches(X, FUN, ..., .data.frame = NULL)
+map_batches(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL)
}
\arguments{
\item{X}{A \code{Dataset} or \code{arrow_dplyr_query} object, as returned by
the
@@ -16,6 +16,13 @@ batch. It must return a RecordBatch or something coercible
to one via
\item{...}{Additional arguments passed to \code{FUN}}
+\item{.schema}{An optional \code{\link[=schema]{schema()}}. If NULL, the
schema will be inferred
+from the first batch.}
+
+\item{.lazy}{Use \code{TRUE} to evaluate \code{FUN} lazily as batches are read
from
+the result; use \code{FALSE} to evaluate \code{FUN} on all batches before
returning
+the reader.}
+
\item{.data.frame}{Deprecated argument, ignored}
}
\value{
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index dc96af41d4..adb6636e9e 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -4491,6 +4491,15 @@ BEGIN_CPP11
END_CPP11
}
// recordbatchreader.cpp
+std::shared_ptr<arrow::RecordBatchReader>
RecordBatchReader__from_function(cpp11::sexp fun_sexp, const
std::shared_ptr<arrow::Schema>& schema);
+extern "C" SEXP _arrow_RecordBatchReader__from_function(SEXP fun_sexp_sexp,
SEXP schema_sexp){
+BEGIN_CPP11
+ arrow::r::Input<cpp11::sexp>::type fun_sexp(fun_sexp_sexp);
+ arrow::r::Input<const std::shared_ptr<arrow::Schema>&>::type
schema(schema_sexp);
+ return cpp11::as_sexp(RecordBatchReader__from_function(fun_sexp,
schema));
+END_CPP11
+}
+// recordbatchreader.cpp
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(const
std::shared_ptr<arrow::Table>& table);
extern "C" SEXP _arrow_RecordBatchReader__from_Table(SEXP table_sexp){
BEGIN_CPP11
@@ -5611,6 +5620,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_RecordBatchReader__ReadNext", (DL_FUNC)
&_arrow_RecordBatchReader__ReadNext, 1},
{ "_arrow_RecordBatchReader__batches", (DL_FUNC)
&_arrow_RecordBatchReader__batches, 1},
{ "_arrow_RecordBatchReader__from_batches", (DL_FUNC)
&_arrow_RecordBatchReader__from_batches, 2},
+ { "_arrow_RecordBatchReader__from_function", (DL_FUNC)
&_arrow_RecordBatchReader__from_function, 2},
{ "_arrow_RecordBatchReader__from_Table", (DL_FUNC)
&_arrow_RecordBatchReader__from_Table, 1},
{ "_arrow_Table__from_RecordBatchReader", (DL_FUNC)
&_arrow_Table__from_RecordBatchReader, 1},
{ "_arrow_RecordBatchReader__Head", (DL_FUNC)
&_arrow_RecordBatchReader__Head, 2},
diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp
index fb173825f3..c571d282da 100644
--- a/r/src/recordbatchreader.cpp
+++ b/r/src/recordbatchreader.cpp
@@ -16,6 +16,7 @@
// under the License.
#include "./arrow_types.h"
+#include "./safe-call-into-r.h"
#include <arrow/ipc/reader.h>
#include <arrow/table.h>
@@ -54,6 +55,50 @@ std::shared_ptr<arrow::RecordBatchReader>
RecordBatchReader__from_batches(
}
}
+class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
+ public:
+ RFunctionRecordBatchReader(cpp11::sexp fun,
+ const std::shared_ptr<arrow::Schema>& schema)
+ : fun_(fun), schema_(schema) {}
+
+ std::shared_ptr<arrow::Schema> schema() const { return schema_; }
+
+ arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
+ auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
+ cpp11::sexp result_sexp = fun_();
+ if (result_sexp == R_NilValue) {
+ return std::shared_ptr<arrow::RecordBatch>(nullptr);
+ } else if (!Rf_inherits(result_sexp, "RecordBatch")) {
+ cpp11::stop("Expected fun() to return an arrow::RecordBatch");
+ }
+
+ return cpp11::as_cpp<std::shared_ptr<arrow::RecordBatch>>(result_sexp);
+ });
+
+ RETURN_NOT_OK(batch);
+
+ if (batch.ValueUnsafe().get() != nullptr &&
+ !batch.ValueUnsafe()->schema()->Equals(schema_)) {
+ return arrow::Status::Invalid("Expected fun() to return batch with
schema '",
+ schema_->ToString(), "' but got batch with
schema '",
+ batch.ValueUnsafe()->schema()->ToString(),
"'");
+ }
+
+ *batch_out = batch.ValueUnsafe();
+ return arrow::Status::OK();
+ }
+
+ private:
+ cpp11::function fun_;
+ std::shared_ptr<arrow::Schema> schema_;
+};
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_function(
+ cpp11::sexp fun_sexp, const std::shared_ptr<arrow::Schema>& schema) {
+ return std::make_shared<RFunctionRecordBatchReader>(fun_sexp, schema);
+}
+
// [[arrow::export]]
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(
const std::shared_ptr<arrow::Table>& table) {
diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp
index 7318c81bb5..4eec3a85df 100644
--- a/r/src/safe-call-into-r-impl.cpp
+++ b/r/src/safe-call-into-r-impl.cpp
@@ -38,7 +38,7 @@ bool CanRunWithCapturedR() {
on_old_windows = on_old_windows_fun();
}
- return !on_old_windows;
+ return !on_old_windows && GetMainRThread().Executor() == nullptr;
#else
return false;
#endif
diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h
index 937163a05d..08e8a8c11b 100644
--- a/r/src/safe-call-into-r.h
+++ b/r/src/safe-call-into-r.h
@@ -31,7 +31,9 @@
// and crash R in older versions (ARROW-16201). Crashes also occur
// on 32-bit R builds on R 3.6 and lower. Implementation provided
// in safe-call-into-r-impl.cpp so that we can skip some tests
-// when this feature is not provided.
+// when this feature is not provided. This also checks that there
+// is not already an event loop registered (via MainRThread::Executor()),
+// because only one of these can exist at any given time.
bool CanRunWithCapturedR();
// The MainRThread class keeps track of the thread on which it is safe
diff --git a/r/tests/testthat/test-dataset-dplyr.R
b/r/tests/testthat/test-dataset-dplyr.R
index b6982939ee..b09b549d59 100644
--- a/r/tests/testthat/test-dataset-dplyr.R
+++ b/r/tests/testthat/test-dataset-dplyr.R
@@ -70,6 +70,8 @@ test_that("filter() with %in%", {
})
test_that("filter() on timestamp columns", {
+ skip_if_not_available("re2")
+
ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
expect_equal(
ds %>%
@@ -116,6 +118,8 @@ test_that("filter() on date32 columns", {
1L
)
+ skip_if_not_available("re2")
+
# Also with timestamp scalar
expect_equal(
open_dataset(tmp) %>%
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index 5c826d6dff..3bcdd8bcde 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -689,6 +689,92 @@ test_that("map_batches", {
)
})
+test_that("map_batches with explicit schema", {
+ fun_with_dots <- function(batch, first_col, first_col_val) {
+ record_batch(
+ !! first_col := first_col_val,
+ b = batch$a$cast(float64())
+ )
+ }
+
+ empty_reader <- RecordBatchReader$create(
+ batches = list(),
+ schema = schema(a = int32())
+ )
+ expect_equal(
+ map_batches(
+ empty_reader,
+ fun_with_dots,
+ "first_col_name",
+ "first_col_value",
+ .schema = schema(first_col_name = string(), b = float64())
+ )$read_table(),
+ arrow_table(first_col_name = character(), b = double())
+ )
+
+ reader <- RecordBatchReader$create(
+ batches = list(
+ record_batch(a = 1, b = "two"),
+ record_batch(a = 2, b = "three")
+ )
+ )
+ expect_equal(
+ map_batches(
+ reader,
+ fun_with_dots,
+ "first_col_name",
+ "first_col_value",
+ .schema = schema(first_col_name = string(), b = float64())
+ )$read_table(),
+ arrow_table(
+ first_col_name = c("first_col_value", "first_col_value"),
+ b = as.numeric(1:2)
+ )
+ )
+})
+
+test_that("map_batches without explicit schema", {
+ fun_with_dots <- function(batch, first_col, first_col_val) {
+ record_batch(
+ !! first_col := first_col_val,
+ b = batch$a$cast(float64())
+ )
+ }
+
+ empty_reader <- RecordBatchReader$create(
+ batches = list(),
+ schema = schema(a = int32())
+ )
+ expect_error(
+ map_batches(
+ empty_reader,
+ fun_with_dots,
+ "first_col_name",
+ "first_col_value"
+ )$read_table(),
+ "Can't infer schema"
+ )
+
+ reader <- RecordBatchReader$create(
+ batches = list(
+ record_batch(a = 1, b = "two"),
+ record_batch(a = 2, b = "three")
+ )
+ )
+ expect_equal(
+ map_batches(
+ reader,
+ fun_with_dots,
+ "first_col_name",
+ "first_col_value"
+ )$read_table(),
+ arrow_table(
+ first_col_name = c("first_col_value", "first_col_value"),
+ b = as.numeric(1:2)
+ )
+ )
+})
+
test_that("head/tail", {
# head/tail with no query are still deterministic order
ds <- open_dataset(dataset_dir)
diff --git a/r/tests/testthat/test-record-batch-reader.R
b/r/tests/testthat/test-record-batch-reader.R
index 597187da45..3cd856de66 100644
--- a/r/tests/testthat/test-record-batch-reader.R
+++ b/r/tests/testthat/test-record-batch-reader.R
@@ -236,3 +236,36 @@ test_that("as_record_batch_reader() works for data.frame",
{
reader <- as_record_batch_reader(df)
expect_equal(reader$read_next_batch(), record_batch(a = 1, b = "two"))
})
+
+test_that("as_record_batch_reader() works for function", {
+ batches <- list(
+ record_batch(a = 1, b = "two"),
+ record_batch(a = 2, b = "three")
+ )
+
+ i <- 0
+ fun <- function() {
+ i <<- i + 1
+ if (i > length(batches)) NULL else batches[[i]]
+ }
+
+ reader <- as_record_batch_reader(fun, schema = batches[[1]]$schema)
+ expect_equal(reader$read_next_batch(), batches[[1]])
+ expect_equal(reader$read_next_batch(), batches[[2]])
+ expect_null(reader$read_next_batch())
+
+ # check invalid returns
+ fun_bad_type <- function() "not a record batch"
+ reader <- as_record_batch_reader(fun_bad_type, schema = schema())
+ expect_error(
+ reader$read_next_batch(),
+ "Expected fun\\(\\) to return an arrow::RecordBatch"
+ )
+
+ fun_bad_schema <- function() record_batch(a = 1)
+ reader <- as_record_batch_reader(fun_bad_schema, schema = schema(a =
string()))
+ expect_error(
+ reader$read_next_batch(),
+ "Expected fun\\(\\) to return batch with schema 'a: string'"
+ )
+})