This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new ea314a3f8d GH-41358: [R] Support join "na_matches" argument (#41372)
ea314a3f8d is described below

commit ea314a3f8d9d4446836aa999b66659c07421f7a4
Author: Neal Richardson <neal.p.richard...@gmail.com>
AuthorDate: Fri Apr 26 18:32:32 2024 -0400

    GH-41358: [R] Support join "na_matches" argument (#41372)
    
    ### Rationale for this change
    
    Noticed in #41350, I made #41358 to implement this in C++, but it turns
    out the option was there, just buried a bit.
    
    ### What changes are included in this PR?
    
    `na_matches` is mapped through to the `key_cmp` field in
    `HashJoinNodeOptions`. Acero supports having a different value for this
    for each of the join keys, but dplyr does not, so I kept it constant for
    all key columns to match the dplyr behavior.
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    Yes
    * GitHub Issue: #41358
---
 r/NEWS.md                          |  1 +
 r/R/arrow-package.R                | 12 ++++++------
 r/R/arrowExports.R                 |  4 ++--
 r/R/dplyr-funcs-doc.R              | 12 ++++++------
 r/R/dplyr-join.R                   |  8 +++++---
 r/R/query-engine.R                 |  8 +++++---
 r/man/acero.Rd                     | 12 ++++++------
 r/src/arrowExports.cpp             | 11 ++++++-----
 r/src/compute-exec.cpp             | 18 +++++++++++++-----
 r/tests/testthat/test-dplyr-join.R | 32 ++++++++++++++++++++++++++++++++
 10 files changed, 82 insertions(+), 36 deletions(-)

diff --git a/r/NEWS.md b/r/NEWS.md
index 4ed9f28a28..05f934dac6 100644
--- a/r/NEWS.md
+++ b/r/NEWS.md
@@ -21,6 +21,7 @@
 
 * R functions that users write that use functions that Arrow supports in 
dataset queries now can be used in queries too. Previously, only functions that 
used arithmetic operators worked. For example, `time_hours <- function(mins) 
mins / 60` worked, but `time_hours_rounded <- function(mins) round(mins / 60)` 
did not; now both work. These are automatic translations rather than true 
user-defined functions (UDFs); for UDFs, see `register_scalar_function()`. 
(#41223)
 * `summarize()` supports more complex expressions, and correctly handles cases 
where column names are reused in expressions. 
+* The `na_matches` argument to the `dplyr::*_join()` functions is now 
supported. This argument controls whether `NA` values are considered equal when 
joining. (#41358)
 
 # arrow 16.0.0
 
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index f6977e6262..7087a40c49 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -66,12 +66,12 @@ supported_dplyr_methods <- list(
   compute = NULL,
   collapse = NULL,
   distinct = "`.keep_all = TRUE` not supported",
-  left_join = "the `copy` and `na_matches` arguments are ignored",
-  right_join = "the `copy` and `na_matches` arguments are ignored",
-  inner_join = "the `copy` and `na_matches` arguments are ignored",
-  full_join = "the `copy` and `na_matches` arguments are ignored",
-  semi_join = "the `copy` and `na_matches` arguments are ignored",
-  anti_join = "the `copy` and `na_matches` arguments are ignored",
+  left_join = "the `copy` argument is ignored",
+  right_join = "the `copy` argument is ignored",
+  inner_join = "the `copy` argument is ignored",
+  full_join = "the `copy` argument is ignored",
+  semi_join = "the `copy` argument is ignored",
+  anti_join = "the `copy` argument is ignored",
   count = NULL,
   tally = NULL,
   rename_with = NULL,
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 752d3a266b..62e2182ffc 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -484,8 +484,8 @@ ExecNode_Aggregate <- function(input, options, key_names) {
   .Call(`_arrow_ExecNode_Aggregate`, input, options, key_names)
 }
 
-ExecNode_Join <- function(input, join_type, right_data, left_keys, right_keys, 
left_output, right_output, output_suffix_for_left, output_suffix_for_right) {
-  .Call(`_arrow_ExecNode_Join`, input, join_type, right_data, left_keys, 
right_keys, left_output, right_output, output_suffix_for_left, 
output_suffix_for_right)
+ExecNode_Join <- function(input, join_type, right_data, left_keys, right_keys, 
left_output, right_output, output_suffix_for_left, output_suffix_for_right, 
na_matches) {
+  .Call(`_arrow_ExecNode_Join`, input, join_type, right_data, left_keys, 
right_keys, left_output, right_output, output_suffix_for_left, 
output_suffix_for_right, na_matches)
 }
 
 ExecNode_Union <- function(input, right_data) {
diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R
index 2042f80014..fda77bca83 100644
--- a/r/R/dplyr-funcs-doc.R
+++ b/r/R/dplyr-funcs-doc.R
@@ -36,7 +36,7 @@
 #' which returns an `arrow` [Table], or `collect()`, which pulls the resulting
 #' Table into an R `tibble`.
 #'
-#' * [`anti_join()`][dplyr::anti_join()]: the `copy` and `na_matches` 
arguments are ignored
+#' * [`anti_join()`][dplyr::anti_join()]: the `copy` argument is ignored
 #' * [`arrange()`][dplyr::arrange()]
 #' * [`collapse()`][dplyr::collapse()]
 #' * [`collect()`][dplyr::collect()]
@@ -45,22 +45,22 @@
 #' * [`distinct()`][dplyr::distinct()]: `.keep_all = TRUE` not supported
 #' * [`explain()`][dplyr::explain()]
 #' * [`filter()`][dplyr::filter()]
-#' * [`full_join()`][dplyr::full_join()]: the `copy` and `na_matches` 
arguments are ignored
+#' * [`full_join()`][dplyr::full_join()]: the `copy` argument is ignored
 #' * [`glimpse()`][dplyr::glimpse()]
 #' * [`group_by()`][dplyr::group_by()]
 #' * [`group_by_drop_default()`][dplyr::group_by_drop_default()]
 #' * [`group_vars()`][dplyr::group_vars()]
 #' * [`groups()`][dplyr::groups()]
-#' * [`inner_join()`][dplyr::inner_join()]: the `copy` and `na_matches` 
arguments are ignored
-#' * [`left_join()`][dplyr::left_join()]: the `copy` and `na_matches` 
arguments are ignored
+#' * [`inner_join()`][dplyr::inner_join()]: the `copy` argument is ignored
+#' * [`left_join()`][dplyr::left_join()]: the `copy` argument is ignored
 #' * [`mutate()`][dplyr::mutate()]: window functions (e.g. things that require 
aggregation within groups) not currently supported
 #' * [`pull()`][dplyr::pull()]: the `name` argument is not supported; returns 
an R vector by default but this behavior is deprecated and will return an Arrow 
[ChunkedArray] in a future release. Provide `as_vector = TRUE/FALSE` to control 
this behavior, or set `options(arrow.pull_as_vector)` globally.
 #' * [`relocate()`][dplyr::relocate()]
 #' * [`rename()`][dplyr::rename()]
 #' * [`rename_with()`][dplyr::rename_with()]
-#' * [`right_join()`][dplyr::right_join()]: the `copy` and `na_matches` 
arguments are ignored
+#' * [`right_join()`][dplyr::right_join()]: the `copy` argument is ignored
 #' * [`select()`][dplyr::select()]
-#' * [`semi_join()`][dplyr::semi_join()]: the `copy` and `na_matches` 
arguments are ignored
+#' * [`semi_join()`][dplyr::semi_join()]: the `copy` argument is ignored
 #' * [`show_query()`][dplyr::show_query()]
 #' * [`slice_head()`][dplyr::slice_head()]: slicing within groups not 
supported; Arrow datasets do not have row order, so head is non-deterministic; 
`prop` only supported on queries where `nrow()` is knowable without evaluating
 #' * [`slice_max()`][dplyr::slice_max()]: slicing within groups not supported; 
`with_ties = TRUE` (dplyr default) is not supported; `prop` only supported on 
queries where `nrow()` is knowable without evaluating
diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R
index 39237f574b..e76e041a54 100644
--- a/r/R/dplyr-join.R
+++ b/r/R/dplyr-join.R
@@ -25,14 +25,15 @@ do_join <- function(x,
                     suffix = c(".x", ".y"),
                     ...,
                     keep = FALSE,
-                    na_matches,
+                    na_matches = c("na", "never"),
                     join_type) {
   # TODO: handle `copy` arg: ignore?
-  # TODO: handle `na_matches` arg
   x <- as_adq(x)
   y <- as_adq(y)
   by <- handle_join_by(by, x, y)
 
+  na_matches <- match.arg(na_matches)
+
   # For outer joins, we need to output the join keys on both sides so we
   # can coalesce them afterwards.
   left_output <- if (!keep && join_type == "RIGHT_OUTER") {
@@ -54,7 +55,8 @@ do_join <- function(x,
     left_output = left_output,
     right_output = right_output,
     suffix = suffix,
-    keep = keep
+    keep = keep,
+    na_matches = na_matches == "na"
   )
   collapse.arrow_dplyr_query(x)
 }
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index 0f8a84f9b8..fb48d790fd 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -148,7 +148,8 @@ ExecPlan <- R6Class("ExecPlan",
             left_output = .data$join$left_output,
             right_output = .data$join$right_output,
             left_suffix = .data$join$suffix[[1]],
-            right_suffix = .data$join$suffix[[2]]
+            right_suffix = .data$join$suffix[[2]],
+            na_matches = .data$join$na_matches
           )
         }
 
@@ -307,7 +308,7 @@ ExecNode <- R6Class("ExecNode",
       out$extras$source_schema$metadata[["r"]]$attributes <- NULL
       out
     },
-    Join = function(type, right_node, by, left_output, right_output, 
left_suffix, right_suffix) {
+    Join = function(type, right_node, by, left_output, right_output, 
left_suffix, right_suffix, na_matches = TRUE) {
       self$preserve_extras(
         ExecNode_Join(
           self,
@@ -318,7 +319,8 @@ ExecNode <- R6Class("ExecNode",
           left_output = left_output,
           right_output = right_output,
           output_suffix_for_left = left_suffix,
-          output_suffix_for_right = right_suffix
+          output_suffix_for_right = right_suffix,
+          na_matches = na_matches
         )
       )
     },
diff --git a/r/man/acero.Rd b/r/man/acero.Rd
index 365795d9fc..ca51ef5633 100644
--- a/r/man/acero.Rd
+++ b/r/man/acero.Rd
@@ -23,7 +23,7 @@ the query on the data. To run the query, call either 
\code{compute()},
 which returns an \code{arrow} \link{Table}, or \code{collect()}, which pulls 
the resulting
 Table into an R \code{tibble}.
 \itemize{
-\item \code{\link[dplyr:filter-joins]{anti_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:filter-joins]{anti_join()}}: the \code{copy} argument 
is ignored
 \item \code{\link[dplyr:arrange]{arrange()}}
 \item \code{\link[dplyr:compute]{collapse()}}
 \item \code{\link[dplyr:compute]{collect()}}
@@ -32,22 +32,22 @@ Table into an R \code{tibble}.
 \item \code{\link[dplyr:distinct]{distinct()}}: \code{.keep_all = TRUE} not 
supported
 \item \code{\link[dplyr:explain]{explain()}}
 \item \code{\link[dplyr:filter]{filter()}}
-\item \code{\link[dplyr:mutate-joins]{full_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{full_join()}}: the \code{copy} argument 
is ignored
 \item \code{\link[dplyr:glimpse]{glimpse()}}
 \item \code{\link[dplyr:group_by]{group_by()}}
 \item \code{\link[dplyr:group_by_drop_default]{group_by_drop_default()}}
 \item \code{\link[dplyr:group_data]{group_vars()}}
 \item \code{\link[dplyr:group_data]{groups()}}
-\item \code{\link[dplyr:mutate-joins]{inner_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
-\item \code{\link[dplyr:mutate-joins]{left_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{inner_join()}}: the \code{copy} argument 
is ignored
+\item \code{\link[dplyr:mutate-joins]{left_join()}}: the \code{copy} argument 
is ignored
 \item \code{\link[dplyr:mutate]{mutate()}}: window functions (e.g. things that 
require aggregation within groups) not currently supported
 \item \code{\link[dplyr:pull]{pull()}}: the \code{name} argument is not 
supported; returns an R vector by default but this behavior is deprecated and 
will return an Arrow \link{ChunkedArray} in a future release. Provide 
\code{as_vector = TRUE/FALSE} to control this behavior, or set 
\code{options(arrow.pull_as_vector)} globally.
 \item \code{\link[dplyr:relocate]{relocate()}}
 \item \code{\link[dplyr:rename]{rename()}}
 \item \code{\link[dplyr:rename]{rename_with()}}
-\item \code{\link[dplyr:mutate-joins]{right_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{right_join()}}: the \code{copy} argument 
is ignored
 \item \code{\link[dplyr:select]{select()}}
-\item \code{\link[dplyr:filter-joins]{semi_join()}}: the \code{copy} and 
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:filter-joins]{semi_join()}}: the \code{copy} argument 
is ignored
 \item \code{\link[dplyr:explain]{show_query()}}
 \item \code{\link[dplyr:slice]{slice_head()}}: slicing within groups not 
supported; Arrow datasets do not have row order, so head is non-deterministic; 
\code{prop} only supported on queries where \code{nrow()} is knowable without 
evaluating
 \item \code{\link[dplyr:slice]{slice_max()}}: slicing within groups not 
supported; \code{with_ties = TRUE} (dplyr default) is not supported; 
\code{prop} only supported on queries where \code{nrow()} is knowable without 
evaluating
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index a4c4b614d6..d5aec50219 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1163,8 +1163,8 @@ extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP 
input_sexp, SEXP options_sexp, SE
 
 // compute-exec.cpp
 #if defined(ARROW_R_WITH_ACERO)
-std::shared_ptr<acero::ExecNode> ExecNode_Join(const 
std::shared_ptr<acero::ExecNode>& input, acero::JoinType join_type, const 
std::shared_ptr<acero::ExecNode>& right_data, std::vector<std::string> 
left_keys, std::vector<std::string> right_keys, std::vector<std::string> 
left_output, std::vector<std::string> right_output, std::string 
output_suffix_for_left, std::string output_suffix_for_right);
-extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp, 
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP 
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp, 
SEXP output_suffix_for_right_sexp){
+std::shared_ptr<acero::ExecNode> ExecNode_Join(const 
std::shared_ptr<acero::ExecNode>& input, acero::JoinType join_type, const 
std::shared_ptr<acero::ExecNode>& right_data, std::vector<std::string> 
left_keys, std::vector<std::string> right_keys, std::vector<std::string> 
left_output, std::vector<std::string> right_output, std::string 
output_suffix_for_left, std::string output_suffix_for_right, bool na_matches);
+extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp, 
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP 
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp, 
SEXP output_suffix_for_right_sexp, SEXP na_matches_sexp){
 BEGIN_CPP11
        arrow::r::Input<const std::shared_ptr<acero::ExecNode>&>::type 
input(input_sexp);
        arrow::r::Input<acero::JoinType>::type join_type(join_type_sexp);
@@ -1175,11 +1175,12 @@ BEGIN_CPP11
        arrow::r::Input<std::vector<std::string>>::type 
right_output(right_output_sexp);
        arrow::r::Input<std::string>::type 
output_suffix_for_left(output_suffix_for_left_sexp);
        arrow::r::Input<std::string>::type 
output_suffix_for_right(output_suffix_for_right_sexp);
-       return cpp11::as_sexp(ExecNode_Join(input, join_type, right_data, 
left_keys, right_keys, left_output, right_output, output_suffix_for_left, 
output_suffix_for_right));
+       arrow::r::Input<bool>::type na_matches(na_matches_sexp);
+       return cpp11::as_sexp(ExecNode_Join(input, join_type, right_data, 
left_keys, right_keys, left_output, right_output, output_suffix_for_left, 
output_suffix_for_right, na_matches));
 END_CPP11
 }
 #else
-extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp, 
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP 
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp, 
SEXP output_suffix_for_right_sexp){
+extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp, 
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP 
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp, 
SEXP output_suffix_for_right_sexp, SEXP na_matches_sexp){
        Rf_error("Cannot call ExecNode_Join(). See 
https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow 
C++ libraries. ");
 }
 #endif
@@ -5790,7 +5791,7 @@ static const R_CallMethodDef CallEntries[] = {
                { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 
2}, 
                { "_arrow_ExecNode_Project", (DL_FUNC) 
&_arrow_ExecNode_Project, 3}, 
                { "_arrow_ExecNode_Aggregate", (DL_FUNC) 
&_arrow_ExecNode_Aggregate, 3}, 
-               { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 9}, 
+               { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 10}, 
                { "_arrow_ExecNode_Union", (DL_FUNC) &_arrow_ExecNode_Union, 
2}, 
                { "_arrow_ExecNode_Fetch", (DL_FUNC) &_arrow_ExecNode_Fetch, 
3}, 
                { "_arrow_ExecNode_OrderBy", (DL_FUNC) 
&_arrow_ExecNode_OrderBy, 2}, 
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index e0b3c62c47..d0c50315c2 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -411,10 +411,17 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(
     const std::shared_ptr<acero::ExecNode>& right_data,
     std::vector<std::string> left_keys, std::vector<std::string> right_keys,
     std::vector<std::string> left_output, std::vector<std::string> 
right_output,
-    std::string output_suffix_for_left, std::string output_suffix_for_right) {
+    std::string output_suffix_for_left, std::string output_suffix_for_right,
+    bool na_matches) {
   std::vector<arrow::FieldRef> left_refs, right_refs, left_out_refs, 
right_out_refs;
+  std::vector<acero::JoinKeyCmp> key_cmps;
   for (auto&& name : left_keys) {
     left_refs.emplace_back(std::move(name));
+    // Populate key_cmps in this loop, one for each key
+    // Note that Acero supports having different values for each key, but dplyr
+    // only supports one value for all keys, so we're only going to support 
that
+    // for now.
+    key_cmps.emplace_back(na_matches ? acero::JoinKeyCmp::IS : 
acero::JoinKeyCmp::EQ);
   }
   for (auto&& name : right_keys) {
     right_refs.emplace_back(std::move(name));
@@ -434,10 +441,11 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(
 
   return MakeExecNodeOrStop(
       "hashjoin", input->plan(), {input.get(), right_data.get()},
-      acero::HashJoinNodeOptions{
-          join_type, std::move(left_refs), std::move(right_refs),
-          std::move(left_out_refs), std::move(right_out_refs), 
compute::literal(true),
-          std::move(output_suffix_for_left), 
std::move(output_suffix_for_right)});
+      acero::HashJoinNodeOptions{join_type, std::move(left_refs), 
std::move(right_refs),
+                                 std::move(left_out_refs), 
std::move(right_out_refs),
+                                 std::move(key_cmps), compute::literal(true),
+                                 std::move(output_suffix_for_left),
+                                 std::move(output_suffix_for_right)});
 }
 
 // [[acero::export]]
diff --git a/r/tests/testthat/test-dplyr-join.R 
b/r/tests/testthat/test-dplyr-join.R
index e3e1e98cfc..9a1c8b7b80 100644
--- a/r/tests/testthat/test-dplyr-join.R
+++ b/r/tests/testthat/test-dplyr-join.R
@@ -441,3 +441,35 @@ test_that("full joins handle keep", {
     small_dataset_df
   )
 })
+
+left <- tibble::tibble(
+  x = c(1, NA, 3),
+)
+right <- tibble::tibble(
+  x = c(1, NA, 3),
+  y = c("a", "b", "c")
+)
+na_matches_na <- right
+na_matches_never <- tibble::tibble(
+  x = c(1, NA, 3),
+  y = c("a", NA, "c")
+)
+test_that("na_matches argument to join: na (default)", {
+  expect_equal(
+    arrow_table(left) %>%
+      left_join(right, by = "x", na_matches = "na") %>%
+      arrange(x) %>%
+      collect(),
+    na_matches_na %>% arrange(x)
+  )
+})
+
+test_that("na_matches argument to join: never", {
+  expect_equal(
+    arrow_table(left) %>%
+      left_join(right, by = "x", na_matches = "never") %>%
+      arrange(x) %>%
+      collect(),
+    na_matches_never %>% arrange(x)
+  )
+})

Reply via email to