jonkeane commented on a change in pull request #9521: URL: https://github.com/apache/arrow/pull/9521#discussion_r581372997
########## File path: r/tests/testthat/helper-expectation.R ########## @@ -59,3 +59,66 @@ verify_output <- function(...) { } testthat::verify_output(...) } + +expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + skip_record_batch = NULL, # Msg, if should skip RB test + skip_table = NULL, # Msg, if should skip Table test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + + skip_msg <- NULL + + if (is.null(skip_record_batch)) { + via_batch <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ) + expect_equivalent(via_batch, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_record_batch) + } + + if (is.null(skip_table)) { + via_table <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ) + expect_equivalent(via_table, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} + +expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + ...) { + expr <- rlang::enquo(expr) + msg <- tryCatch( + rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), + error = function (e) conditionMessage(e) + ) + expect_is(msg, "character", label = "dplyr on data.frame did not error") Review comment: This isn't absolutely necessary, but will be more work in the future if/when we move to testthat 3e, `expect_is()` is deprecated. Could we use `expect_type()` here? ########## File path: r/R/dplyr.R ########## @@ -73,6 +80,22 @@ print.arrow_dplyr_query <- function(x, ...) { invisible(x) } +get_field_names <- function(selected_cols) { + if (inherits(selected_cols, "arrow_dplyr_query")) { + selected_cols <- selected_cols$selected_columns + } + map_chr(selected_cols, ~.$field_name %||% .$args$field_name %||% "") Review comment: `.$field_name %||% .$args$field_name %||% ""` Ends up to be `.$field_name`, if `.$field_name` is `NULL` then `.$args$field_name` if `.$args$field_name` is *also* `NULL` then "", yeah? Maybe a comment would be good to make that super clear / easier to read? ########## File path: r/R/dplyr.R ########## @@ -309,8 +344,27 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { # See dataset.R for Dataset and Scanner(Builder) classes tab <- Scanner$create(x)$ToTable() } else { - # This is a Table/RecordBatch. See record-batch.R for the [ method - tab <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE] + # This is a Table or RecordBatch + + # Filter and select the data referenced in selected columns + if (isTRUE(x$filtered_rows)) { + filter <- TRUE + } else { + filter <- eval_array_expression(x$filtered_rows, x$.data) + } + # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? + tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # Now evaluate those expressions on the filtered table + cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + if (length(cols) == 0) { + tab <- tab[, integer(0)] Review comment: Would this resolve https://issues.apache.org/jira/browse/ARROW-11328 ? ########## File path: r/R/dplyr.R ########## @@ -423,26 +482,115 @@ ungroup.arrow_dplyr_query <- function(x, ...) { } ungroup.Dataset <- ungroup.ArrowTabular <- force -mutate.arrow_dplyr_query <- function(.data, ...) { +mutate.arrow_dplyr_query <- function(.data, + ..., + .keep = c("all", "used", "unused", "none"), + .before = NULL, + .after = NULL) { + call <- match.call() + exprs <- quos(...) + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } + .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("mutate()") } - # TODO: see if we can defer evaluating the expressions and not collect here. - # It's different from filters (as currently implemented) because the basic - # vector transformation functions aren't yet implemented in Arrow C++. - dplyr::mutate(dplyr::collect(.data), ...) + + .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) + # Restrict the cases we support for now + if (!quo_is_null(.before) || !quo_is_null(.after)) { + # TODO(ARROW-11701) + return(abandon_ship(call, .data, '.before and .after arguments are not supported in Arrow')) + } else if (length(group_vars(.data)) > 0) { + # mutate() on a grouped dataset does calculations within groups + # This doesn't matter on scalar ops (arithmetic etc.) but it does + # for things with aggregations (e.g. subtracting the mean) + return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) Review comment: This should be obvious, but I wonder if we should add something like "remove the grouping with `ungroup()` or remove `group_by()` from your pipeline first" to help folks know what to do if they get this error ########## File path: r/tests/testthat/test-dplyr-mutate.R ########## @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("mutate() is lazy", { + expect_is( + tbl %>% record_batch() %>% mutate(int = int + 6L), + "arrow_dplyr_query" + ) Review comment: `expect_is()` > `expect_s3_class()` here? As odd as the naming is this is the approved method it appears https://github.com/r-lib/testthat/issues/1271 ########## File path: r/tests/testthat/test-RecordBatch.R ########## @@ -416,6 +416,14 @@ test_that("record_batch() handles null type (ARROW-7064)", { expect_equivalent(batch$schema, schema(a = int32(), n = null())) }) +test_that("record_batch() scalar recycling", { + skip("Not implemented (ARROW-11705") Review comment: ```suggestion skip("Not implemented (ARROW-11705)") ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org