This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 838687178f ARROW-15260: [R] open_dataset - add file_name as column 
(#12826)
838687178f is described below

commit 838687178fda7f82e31668f502e2f94071ce8077
Author: Nic Crane <[email protected]>
AuthorDate: Wed Aug 10 01:19:40 2022 +0100

    ARROW-15260: [R] open_dataset - add file_name as column (#12826)
    
    Authored-by: Nic Crane <[email protected]>
    Signed-off-by: Neal Richardson <[email protected]>
---
 r/DESCRIPTION                   |  1 +
 r/R/dataset.R                   |  1 +
 r/R/dplyr-collect.R             | 11 +++++
 r/R/dplyr-funcs-augmented.R     | 22 ++++++++++
 r/R/dplyr-funcs.R               |  1 +
 r/R/dplyr.R                     |  3 ++
 r/R/util.R                      | 31 +++++++++++++-
 r/src/compute-exec.cpp          |  8 ++--
 r/tests/testthat/test-dataset.R | 94 ++++++++++++++++++++++++++++++++++++++++-
 9 files changed, 164 insertions(+), 8 deletions(-)

diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index 308a7ec3fa..95c1405869 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -98,6 +98,7 @@ Collate:
     'dplyr-distinct.R'
     'dplyr-eval.R'
     'dplyr-filter.R'
+    'dplyr-funcs-augmented.R'
     'dplyr-funcs-conditional.R'
     'dplyr-funcs-datetime.R'
     'dplyr-funcs-math.R'
diff --git a/r/R/dataset.R b/r/R/dataset.R
index 12765fbfc0..d86962cc1d 100644
--- a/r/R/dataset.R
+++ b/r/R/dataset.R
@@ -224,6 +224,7 @@ open_dataset <- function(sources,
     # and not handle_parquet_io_error()
     error = function(e, call = caller_env(n = 4)) {
       handle_parquet_io_error(e, format, call)
+      abort(conditionMessage(e), call = call)
     }
   )
 }
diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 3e83475a8c..8049e46eb5 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -25,6 +25,8 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = 
TRUE, ...) {
     # and not handle_csv_read_error()
     error = function(e, call = caller_env(n = 4)) {
       handle_csv_read_error(e, x$.data$schema, call)
+      handle_augmented_field_misuse(e, call)
+      abort(conditionMessage(e), call = call)
     }
   )
 
@@ -104,10 +106,18 @@ add_suffix <- function(fields, common_cols, suffix) {
 }
 
 implicit_schema <- function(.data) {
+  # Get the source data schema so that we can evaluate expressions to determine
+  # the output schema. Note that we don't use source_data() because we only
+  # want to go one level up (where we may have called implicit_schema() before)
   .data <- ensure_group_vars(.data)
   old_schm <- .data$.data$schema
+  # Add in any augmented fields that may exist in the query but not in the
+  # real data, in case we have FieldRefs to them
+  old_schm[["__filename"]] <- string()
 
   if (is.null(.data$aggregations)) {
+    # .data$selected_columns is a named list of Expressions (FieldRefs or
+    # something more complex). Bind them in order to determine their output 
type
     new_fields <- map(.data$selected_columns, ~ .$type(old_schm))
     if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) {
       # Add cols from right side, except for semi/anti joins
@@ -128,6 +138,7 @@ implicit_schema <- function(.data) {
       new_fields <- c(left_fields, right_fields)
     }
   } else {
+    # The output schema is based on the aggregations and any group_by vars
     new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
     # * Put group_by_vars first (this can't be done by summarize,
     #   they have to be last per the aggregate node signature,
diff --git a/r/R/dplyr-funcs-augmented.R b/r/R/dplyr-funcs-augmented.R
new file mode 100644
index 0000000000..6e751d49f6
--- /dev/null
+++ b/r/R/dplyr-funcs-augmented.R
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+register_bindings_augmented <- function() {
+  register_binding("add_filename", function() {
+    Expression$field_ref("__filename")
+  })
+}
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index c1dcdd1774..4dadff54b4 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -151,6 +151,7 @@ create_binding_cache <- function() {
   register_bindings_math()
   register_bindings_string()
   register_bindings_type()
+  register_bindings_augmented()
 
   # We only create the cache for nse_funcs and not agg_funcs
   .cache$functions <- c(as.list(nse_funcs), arrow_funcs)
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index dd6340c4f5..dffe269199 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -110,6 +110,9 @@ make_field_refs <- function(field_names) {
 #' @export
 print.arrow_dplyr_query <- function(x, ...) {
   schm <- x$.data$schema
+  # If we are using this augmented field, it won't be in the schema
+  schm[["__filename"]] <- string()
+
   types <- map_chr(x$selected_columns, function(expr) {
     name <- expr$field_name
     if (nzchar(name)) {
diff --git a/r/R/util.R b/r/R/util.R
index 55ff29db73..eef69d0244 100644
--- a/r/R/util.R
+++ b/r/R/util.R
@@ -134,6 +134,10 @@ read_compressed_error <- function(e) {
   stop(e)
 }
 
+# This function was refactored in ARROW-15260 to only raise an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
 handle_parquet_io_error <- function(e, format, call) {
   msg <- conditionMessage(e)
   if (grepl("Parquet magic bytes not found in footer", msg) && length(format) 
> 1 && is_character(format)) {
@@ -143,8 +147,8 @@ handle_parquet_io_error <- function(e, format, call) {
       msg,
       i = "Did you mean to specify a 'format' other than the default 
(parquet)?"
     )
+    abort(msg, call = call)
   }
-  abort(msg, call = call)
 }
 
 as_writable_table <- function(x) {
@@ -205,6 +209,10 @@ repeat_value_as_array <- function(object, n) {
   return(Scalar$create(object)$as_array(n))
 }
 
+# This function was refactored in ARROW-15260 to only raise an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
 handle_csv_read_error <- function(e, schema, call) {
   msg <- conditionMessage(e)
 
@@ -217,8 +225,27 @@ handle_csv_read_error <- function(e, schema, call) {
         "header being read in as data."
       )
     )
+    abort(msg, call = call)
+  }
+}
+
+# This function only raises an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
+handle_augmented_field_misuse <- function(e, call) {
+  msg <- conditionMessage(e)
+  if (grepl("No match for FieldRef.Name(__filename)", msg, fixed = TRUE)) {
+    msg <- c(
+      msg,
+      i = paste(
+        "`add_filename()` or use of the `__filename` augmented field can only",
+        "be used with with Dataset objects, and can only be added before 
doing",
+        "an aggregation or a join."
+      )
+    )
+    abort(msg, call = call)
   }
-  abort(msg, call = call)
 }
 
 is_compressed <- function(compression) {
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index 91d646f0a3..f9183a3a10 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -222,8 +222,7 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(
 
   options->dataset_schema = dataset->schema();
 
-  // ScanNode needs the filter to do predicate pushdown and skip partitions
-  options->filter = ValueOrStop(filter->Bind(*dataset->schema()));
+  options->filter = *filter;
 
   // ScanNode needs to know which fields to materialize (and which are 
unnecessary)
   std::vector<compute::Expression> exprs;
@@ -232,9 +231,8 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(
   }
 
   options->projection =
-      ValueOrStop(call("make_struct", std::move(exprs),
-                       
compute::MakeStructOptions{std::move(materialized_field_names)})
-                      .Bind(*dataset->schema()));
+      call("make_struct", std::move(exprs),
+           compute::MakeStructOptions{std::move(materialized_field_names)});
 
   return MakeExecNodeOrStop("scan", plan.get(), {},
                             ds::ScanNodeOptions{dataset, options});
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index d43bb492d0..d9512ef94f 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -1131,7 +1131,6 @@ test_that("dataset to C-interface to arrow_dplyr_query 
with proj/filter", {
   delete_arrow_array_stream(stream_ptr)
 })
 
-
 test_that("Filter parquet dataset with is.na ARROW-15312", {
   ds_path <- make_temp_dir()
 
@@ -1349,3 +1348,96 @@ test_that("FileSystemFactoryOptions input validation", {
     fixed = TRUE
   )
 })
+
+test_that("can add in augmented fields", {
+  ds <- open_dataset(hive_dir)
+
+  observed <- ds %>%
+    mutate(file_name = add_filename()) %>%
+    collect()
+
+  expect_named(
+    observed,
+    c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file_name")
+  )
+
+  expect_equal(
+    sort(unique(observed$file_name)),
+    list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+  )
+
+  error_regex <- paste(
+    "`add_filename()` or use of the `__filename` augmented field can only",
+    "be used with with Dataset objects, and can only be added before doing",
+    "an aggregation or a join."
+  )
+
+  # errors appropriately with ArrowTabular objects
+  expect_error(
+    arrow_table(mtcars) %>%
+      mutate(file = add_filename()) %>%
+      collect(),
+    regexp = error_regex,
+    fixed = TRUE
+  )
+
+  # errors appropriately with aggregation
+  expect_error(
+    ds %>%
+      summarise(max_int = max(int)) %>%
+      mutate(file_name = add_filename()) %>%
+      collect(),
+    regexp = error_regex,
+    fixed = TRUE
+  )
+
+  # joins to tables
+  another_table <- select(example_data, int, dbl2)
+  expect_error(
+    ds %>%
+      left_join(another_table, by = "int") %>%
+      mutate(file = add_filename()) %>%
+      collect(),
+    regexp = error_regex,
+    fixed = TRUE
+  )
+
+  # and on joins to datasets
+  another_dataset <- write_dataset(another_table, "another_dataset")
+  expect_error(
+    ds %>%
+      left_join(open_dataset("another_dataset"), by = "int") %>%
+      mutate(file = add_filename()) %>%
+      collect(),
+    regexp = error_regex,
+    fixed = TRUE
+  )
+
+  # this hits the implicit_schema path by joining afterwards
+  join_after <- ds %>%
+    mutate(file = add_filename()) %>%
+    left_join(open_dataset("another_dataset"), by = "int") %>%
+    collect()
+
+  expect_named(
+    join_after,
+    c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file", 
"dbl2")
+  )
+
+  expect_equal(
+    sort(unique(join_after$file)),
+    list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+  )
+
+  # another test on the explicit_schema path
+  summarise_after <- ds %>%
+    mutate(file = add_filename()) %>%
+    group_by(file) %>%
+    summarise(max_int = max(int)) %>%
+    collect()
+
+  expect_equal(
+    sort(summarise_after$file),
+    list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+  )
+})

Reply via email to