This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 70904dffef ARROW-16703: [R] Refactor map_batches() so it can stream 
results (#13650)
70904dffef is described below

commit 70904dffef25a8c883a1a829c66a1d30a7d9c249
Author: Dewey Dunnington <[email protected]>
AuthorDate: Sat Jul 23 09:09:44 2022 -0300

    ARROW-16703: [R] Refactor map_batches() so it can stream results (#13650)
    
    Authored-by: Dewey Dunnington <[email protected]>
    Signed-off-by: Dewey Dunnington <[email protected]>
---
 r/NAMESPACE                                 |  1 +
 r/R/arrowExports.R                          |  4 ++
 r/R/dataset-scan.R                          | 68 ++++++++++++++++++-----
 r/R/record-batch-reader.R                   |  9 +++
 r/man/as_record_batch_reader.Rd             |  6 ++
 r/man/map_batches.Rd                        |  9 ++-
 r/src/arrowExports.cpp                      | 10 ++++
 r/src/recordbatchreader.cpp                 | 45 +++++++++++++++
 r/src/safe-call-into-r-impl.cpp             |  2 +-
 r/src/safe-call-into-r.h                    |  4 +-
 r/tests/testthat/test-dataset-dplyr.R       |  4 ++
 r/tests/testthat/test-dataset.R             | 86 +++++++++++++++++++++++++++++
 r/tests/testthat/test-record-batch-reader.R | 33 +++++++++++
 13 files changed, 264 insertions(+), 17 deletions(-)

diff --git a/r/NAMESPACE b/r/NAMESPACE
index a8c8a974d6..733261f33c 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -67,6 +67,7 @@ S3method(as_record_batch,arrow_dplyr_query)
 S3method(as_record_batch,data.frame)
 S3method(as_record_batch,pyarrow.lib.RecordBatch)
 S3method(as_record_batch,pyarrow.lib.Table)
+S3method(as_record_batch_reader,"function")
 S3method(as_record_batch_reader,Dataset)
 S3method(as_record_batch_reader,RecordBatch)
 S3method(as_record_batch_reader,RecordBatchReader)
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 7cd2c5dbfc..ab3358d666 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -1740,6 +1740,10 @@ RecordBatchReader__from_batches <- function(batches, 
schema_sxp) {
   .Call(`_arrow_RecordBatchReader__from_batches`, batches, schema_sxp)
 }
 
+RecordBatchReader__from_function <- function(fun_sexp, schema) {
+  .Call(`_arrow_RecordBatchReader__from_function`, fun_sexp, schema)
+}
+
 RecordBatchReader__from_Table <- function(table) {
   .Call(`_arrow_RecordBatchReader__from_Table`, table)
 }
diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R
index cca92b676f..53fe7078c2 100644
--- a/r/R/dataset-scan.R
+++ b/r/R/dataset-scan.R
@@ -183,11 +183,16 @@ tail_from_batches <- function(batches, n) {
 #' @param FUN A function or `purrr`-style lambda expression to apply to each
 #' batch. It must return a RecordBatch or something coercible to one via
 #' `as_record_batch()'.
+#' @param .schema An optional [schema()]. If NULL, the schema will be inferred
+#'   from the first batch.
+#' @param .lazy Use `TRUE` to evaluate `FUN` lazily as batches are read from
+#'   the result; use `FALSE` to evaluate `FUN` on all batches before returning
+#'   the reader.
 #' @param ... Additional arguments passed to `FUN`
 #' @param .data.frame Deprecated argument, ignored
 #' @return An `arrow_dplyr_query`.
 #' @export
-map_batches <- function(X, FUN, ..., .data.frame = NULL) {
+map_batches <- function(X, FUN, ..., .schema = NULL, .lazy = FALSE, 
.data.frame = NULL) {
   if (!is.null(.data.frame)) {
     warning(
       "The .data.frame argument is deprecated. ",
@@ -197,25 +202,60 @@ map_batches <- function(X, FUN, ..., .data.frame = NULL) {
   }
   FUN <- as_mapper(FUN)
   reader <- as_record_batch_reader(X)
+  dots <- rlang::list2(...)
 
-  # TODO: for future consideration
-  # * Move eval to C++ and make it a generator so it can stream, not block
-  # * Accept an output schema argument: with that, we could make this lazy 
(via collapse)
-  batch <- reader$read_next_batch()
-  res <- vector("list", 1024)
-  i <- 0L
-  while (!is.null(batch)) {
-    i <- i + 1L
-    res[[i]] <- as_record_batch(FUN(batch, ...))
+  # If no schema is supplied, we have to evaluate the first batch here
+  if (is.null(.schema)) {
     batch <- reader$read_next_batch()
+    if (is.null(batch)) {
+      abort("Can't infer schema from a RecordBatchReader with zero batches")
+    }
+
+    first_result <- as_record_batch(do.call(FUN, c(list(batch), dots)))
+    .schema <- first_result$schema
+    fun <- function() {
+      if (!is.null(first_result)) {
+        result <- first_result
+        first_result <<- NULL
+        result
+      } else {
+        batch <- reader$read_next_batch()
+        if (is.null(batch)) {
+          NULL
+        } else {
+          as_record_batch(
+            do.call(FUN, c(list(batch), dots)),
+            schema = .schema
+          )
+        }
+      }
+    }
+  } else {
+    fun <- function() {
+      batch <- reader$read_next_batch()
+      if (is.null(batch)) {
+        return(NULL)
+      }
+
+      as_record_batch(
+        do.call(FUN, c(list(batch), dots)),
+        schema = .schema
+      )
+    }
   }
 
-  # Trim list back
-  if (i < length(res)) {
-    res <- res[seq_len(i)]
+  reader_out <- as_record_batch_reader(fun, schema = .schema)
+
+  # TODO(ARROW-17178) because there are some restrictions on evaluating
+  # reader_out in some ExecPlans, the default .lazy is FALSE for now.
+  if (!.lazy) {
+    reader_out <- RecordBatchReader$create(
+      batches = reader_out$batches(),
+      schema = .schema
+    )
   }
 
-  RecordBatchReader$create(batches = res)
+  reader_out
 }
 
 #' @usage NULL
diff --git a/r/R/record-batch-reader.R b/r/R/record-batch-reader.R
index 8f6a600dfb..3a985d8abc 100644
--- a/r/R/record-batch-reader.R
+++ b/r/R/record-batch-reader.R
@@ -191,6 +191,8 @@ RecordBatchFileReader$create <- function(file) {
 #' Convert an object to an Arrow RecordBatchReader
 #'
 #' @param x An object to convert to a [RecordBatchReader]
+#' @param schema The [schema()] that must match the schema returned by each
+#'   call to `x` when `x` is a function.
 #' @param ... Passed to S3 methods
 #'
 #' @return A [RecordBatchReader]
@@ -234,6 +236,13 @@ as_record_batch_reader.Dataset <- function(x, ...) {
   Scanner$create(x)$ToRecordBatchReader()
 }
 
+#' @rdname as_record_batch_reader
+#' @export
+as_record_batch_reader.function <- function(x, ..., schema) {
+  assert_that(inherits(schema, "Schema"))
+  RecordBatchReader__from_function(x, schema)
+}
+
 #' @rdname as_record_batch_reader
 #' @export
 as_record_batch_reader.arrow_dplyr_query <- function(x, ...) {
diff --git a/r/man/as_record_batch_reader.Rd b/r/man/as_record_batch_reader.Rd
index e635c0b98b..2ed5435476 100644
--- a/r/man/as_record_batch_reader.Rd
+++ b/r/man/as_record_batch_reader.Rd
@@ -7,6 +7,7 @@
 \alias{as_record_batch_reader.RecordBatch}
 \alias{as_record_batch_reader.data.frame}
 \alias{as_record_batch_reader.Dataset}
+\alias{as_record_batch_reader.function}
 \alias{as_record_batch_reader.arrow_dplyr_query}
 \alias{as_record_batch_reader.Scanner}
 \title{Convert an object to an Arrow RecordBatchReader}
@@ -23,6 +24,8 @@ as_record_batch_reader(x, ...)
 
 \method{as_record_batch_reader}{Dataset}(x, ...)
 
+\method{as_record_batch_reader}{`function`}(x, ..., schema)
+
 \method{as_record_batch_reader}{arrow_dplyr_query}(x, ...)
 
 \method{as_record_batch_reader}{Scanner}(x, ...)
@@ -31,6 +34,9 @@ as_record_batch_reader(x, ...)
 \item{x}{An object to convert to a \link{RecordBatchReader}}
 
 \item{...}{Passed to S3 methods}
+
+\item{schema}{The \code{\link[=schema]{schema()}} that must match the schema 
returned by each
+call to \code{x} when \code{x} is a function.}
 }
 \value{
 A \link{RecordBatchReader}
diff --git a/r/man/map_batches.Rd b/r/man/map_batches.Rd
index eaeab6013a..0e4d48e024 100644
--- a/r/man/map_batches.Rd
+++ b/r/man/map_batches.Rd
@@ -4,7 +4,7 @@
 \alias{map_batches}
 \title{Apply a function to a stream of RecordBatches}
 \usage{
-map_batches(X, FUN, ..., .data.frame = NULL)
+map_batches(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL)
 }
 \arguments{
 \item{X}{A \code{Dataset} or \code{arrow_dplyr_query} object, as returned by 
the
@@ -16,6 +16,13 @@ batch. It must return a RecordBatch or something coercible 
to one via
 
 \item{...}{Additional arguments passed to \code{FUN}}
 
+\item{.schema}{An optional \code{\link[=schema]{schema()}}. If NULL, the 
schema will be inferred
+from the first batch.}
+
+\item{.lazy}{Use \code{TRUE} to evaluate \code{FUN} lazily as batches are read 
from
+the result; use \code{FALSE} to evaluate \code{FUN} on all batches before 
returning
+the reader.}
+
 \item{.data.frame}{Deprecated argument, ignored}
 }
 \value{
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index dc96af41d4..adb6636e9e 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -4491,6 +4491,15 @@ BEGIN_CPP11
 END_CPP11
 }
 // recordbatchreader.cpp
+std::shared_ptr<arrow::RecordBatchReader> 
RecordBatchReader__from_function(cpp11::sexp fun_sexp, const 
std::shared_ptr<arrow::Schema>& schema);
+extern "C" SEXP _arrow_RecordBatchReader__from_function(SEXP fun_sexp_sexp, 
SEXP schema_sexp){
+BEGIN_CPP11
+       arrow::r::Input<cpp11::sexp>::type fun_sexp(fun_sexp_sexp);
+       arrow::r::Input<const std::shared_ptr<arrow::Schema>&>::type 
schema(schema_sexp);
+       return cpp11::as_sexp(RecordBatchReader__from_function(fun_sexp, 
schema));
+END_CPP11
+}
+// recordbatchreader.cpp
 std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(const 
std::shared_ptr<arrow::Table>& table);
 extern "C" SEXP _arrow_RecordBatchReader__from_Table(SEXP table_sexp){
 BEGIN_CPP11
@@ -5611,6 +5620,7 @@ static const R_CallMethodDef CallEntries[] = {
                { "_arrow_RecordBatchReader__ReadNext", (DL_FUNC) 
&_arrow_RecordBatchReader__ReadNext, 1}, 
                { "_arrow_RecordBatchReader__batches", (DL_FUNC) 
&_arrow_RecordBatchReader__batches, 1}, 
                { "_arrow_RecordBatchReader__from_batches", (DL_FUNC) 
&_arrow_RecordBatchReader__from_batches, 2}, 
+               { "_arrow_RecordBatchReader__from_function", (DL_FUNC) 
&_arrow_RecordBatchReader__from_function, 2}, 
                { "_arrow_RecordBatchReader__from_Table", (DL_FUNC) 
&_arrow_RecordBatchReader__from_Table, 1}, 
                { "_arrow_Table__from_RecordBatchReader", (DL_FUNC) 
&_arrow_Table__from_RecordBatchReader, 1}, 
                { "_arrow_RecordBatchReader__Head", (DL_FUNC) 
&_arrow_RecordBatchReader__Head, 2}, 
diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp
index fb173825f3..c571d282da 100644
--- a/r/src/recordbatchreader.cpp
+++ b/r/src/recordbatchreader.cpp
@@ -16,6 +16,7 @@
 // under the License.
 
 #include "./arrow_types.h"
+#include "./safe-call-into-r.h"
 
 #include <arrow/ipc/reader.h>
 #include <arrow/table.h>
@@ -54,6 +55,50 @@ std::shared_ptr<arrow::RecordBatchReader> 
RecordBatchReader__from_batches(
   }
 }
 
+class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
+ public:
+  RFunctionRecordBatchReader(cpp11::sexp fun,
+                             const std::shared_ptr<arrow::Schema>& schema)
+      : fun_(fun), schema_(schema) {}
+
+  std::shared_ptr<arrow::Schema> schema() const { return schema_; }
+
+  arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
+    auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
+      cpp11::sexp result_sexp = fun_();
+      if (result_sexp == R_NilValue) {
+        return std::shared_ptr<arrow::RecordBatch>(nullptr);
+      } else if (!Rf_inherits(result_sexp, "RecordBatch")) {
+        cpp11::stop("Expected fun() to return an arrow::RecordBatch");
+      }
+
+      return cpp11::as_cpp<std::shared_ptr<arrow::RecordBatch>>(result_sexp);
+    });
+
+    RETURN_NOT_OK(batch);
+
+    if (batch.ValueUnsafe().get() != nullptr &&
+        !batch.ValueUnsafe()->schema()->Equals(schema_)) {
+      return arrow::Status::Invalid("Expected fun() to return batch with 
schema '",
+                                    schema_->ToString(), "' but got batch with 
schema '",
+                                    batch.ValueUnsafe()->schema()->ToString(), 
"'");
+    }
+
+    *batch_out = batch.ValueUnsafe();
+    return arrow::Status::OK();
+  }
+
+ private:
+  cpp11::function fun_;
+  std::shared_ptr<arrow::Schema> schema_;
+};
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_function(
+    cpp11::sexp fun_sexp, const std::shared_ptr<arrow::Schema>& schema) {
+  return std::make_shared<RFunctionRecordBatchReader>(fun_sexp, schema);
+}
+
 // [[arrow::export]]
 std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(
     const std::shared_ptr<arrow::Table>& table) {
diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp
index 7318c81bb5..4eec3a85df 100644
--- a/r/src/safe-call-into-r-impl.cpp
+++ b/r/src/safe-call-into-r-impl.cpp
@@ -38,7 +38,7 @@ bool CanRunWithCapturedR() {
     on_old_windows = on_old_windows_fun();
   }
 
-  return !on_old_windows;
+  return !on_old_windows && GetMainRThread().Executor() == nullptr;
 #else
   return false;
 #endif
diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h
index 937163a05d..08e8a8c11b 100644
--- a/r/src/safe-call-into-r.h
+++ b/r/src/safe-call-into-r.h
@@ -31,7 +31,9 @@
 // and crash R in older versions (ARROW-16201). Crashes also occur
 // on 32-bit R builds on R 3.6 and lower. Implementation provided
 // in safe-call-into-r-impl.cpp so that we can skip some tests
-// when this feature is not provided.
+// when this feature is not provided. This also checks that there
+// is not already an event loop registered (via MainRThread::Executor()),
+// because only one of these can exist at any given time.
 bool CanRunWithCapturedR();
 
 // The MainRThread class keeps track of the thread on which it is safe
diff --git a/r/tests/testthat/test-dataset-dplyr.R 
b/r/tests/testthat/test-dataset-dplyr.R
index b6982939ee..b09b549d59 100644
--- a/r/tests/testthat/test-dataset-dplyr.R
+++ b/r/tests/testthat/test-dataset-dplyr.R
@@ -70,6 +70,8 @@ test_that("filter() with %in%", {
 })
 
 test_that("filter() on timestamp columns", {
+  skip_if_not_available("re2")
+
   ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
   expect_equal(
     ds %>%
@@ -116,6 +118,8 @@ test_that("filter() on date32 columns", {
     1L
   )
 
+  skip_if_not_available("re2")
+
   # Also with timestamp scalar
   expect_equal(
     open_dataset(tmp) %>%
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index 5c826d6dff..3bcdd8bcde 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -689,6 +689,92 @@ test_that("map_batches", {
   )
 })
 
+test_that("map_batches with explicit schema", {
+  fun_with_dots <- function(batch, first_col, first_col_val) {
+    record_batch(
+      !! first_col := first_col_val,
+      b = batch$a$cast(float64())
+    )
+  }
+
+  empty_reader <- RecordBatchReader$create(
+    batches = list(),
+    schema = schema(a = int32())
+  )
+  expect_equal(
+    map_batches(
+      empty_reader,
+      fun_with_dots,
+      "first_col_name",
+      "first_col_value",
+      .schema = schema(first_col_name = string(), b = float64())
+    )$read_table(),
+    arrow_table(first_col_name = character(), b = double())
+  )
+
+  reader <- RecordBatchReader$create(
+    batches = list(
+      record_batch(a = 1, b = "two"),
+      record_batch(a = 2, b = "three")
+    )
+  )
+  expect_equal(
+    map_batches(
+      reader,
+      fun_with_dots,
+      "first_col_name",
+      "first_col_value",
+      .schema = schema(first_col_name = string(), b = float64())
+    )$read_table(),
+    arrow_table(
+      first_col_name = c("first_col_value", "first_col_value"),
+      b = as.numeric(1:2)
+    )
+  )
+})
+
+test_that("map_batches without explicit schema", {
+  fun_with_dots <- function(batch, first_col, first_col_val) {
+    record_batch(
+      !! first_col := first_col_val,
+      b = batch$a$cast(float64())
+    )
+  }
+
+  empty_reader <- RecordBatchReader$create(
+    batches = list(),
+    schema = schema(a = int32())
+  )
+  expect_error(
+    map_batches(
+      empty_reader,
+      fun_with_dots,
+      "first_col_name",
+      "first_col_value"
+    )$read_table(),
+    "Can't infer schema"
+  )
+
+  reader <- RecordBatchReader$create(
+    batches = list(
+      record_batch(a = 1, b = "two"),
+      record_batch(a = 2, b = "three")
+    )
+  )
+  expect_equal(
+    map_batches(
+      reader,
+      fun_with_dots,
+      "first_col_name",
+      "first_col_value"
+    )$read_table(),
+    arrow_table(
+      first_col_name = c("first_col_value", "first_col_value"),
+      b = as.numeric(1:2)
+    )
+  )
+})
+
 test_that("head/tail", {
   # head/tail with no query are still deterministic order
   ds <- open_dataset(dataset_dir)
diff --git a/r/tests/testthat/test-record-batch-reader.R 
b/r/tests/testthat/test-record-batch-reader.R
index 597187da45..3cd856de66 100644
--- a/r/tests/testthat/test-record-batch-reader.R
+++ b/r/tests/testthat/test-record-batch-reader.R
@@ -236,3 +236,36 @@ test_that("as_record_batch_reader() works for data.frame", 
{
   reader <- as_record_batch_reader(df)
   expect_equal(reader$read_next_batch(), record_batch(a = 1, b = "two"))
 })
+
+test_that("as_record_batch_reader() works for function", {
+  batches <- list(
+    record_batch(a = 1, b = "two"),
+    record_batch(a = 2, b = "three")
+  )
+
+  i <- 0
+  fun <- function() {
+    i <<- i + 1
+    if (i > length(batches)) NULL else batches[[i]]
+  }
+
+  reader <- as_record_batch_reader(fun, schema = batches[[1]]$schema)
+  expect_equal(reader$read_next_batch(), batches[[1]])
+  expect_equal(reader$read_next_batch(), batches[[2]])
+  expect_null(reader$read_next_batch())
+
+  # check invalid returns
+  fun_bad_type <- function() "not a record batch"
+  reader <- as_record_batch_reader(fun_bad_type, schema = schema())
+  expect_error(
+    reader$read_next_batch(),
+    "Expected fun\\(\\) to return an arrow::RecordBatch"
+  )
+
+  fun_bad_schema <- function() record_batch(a = 1)
+  reader <- as_record_batch_reader(fun_bad_schema, schema = schema(a = 
string()))
+  expect_error(
+    reader$read_next_batch(),
+    "Expected fun\\(\\) to return batch with schema 'a: string'"
+  )
+})

Reply via email to