dhruv9vats commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r805961512



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+  for (int64_t i = 0; i < arrow_array->length(); ++i) {
+    auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+    if (scalar->Equals(arg)) return true;
+  }
+  *result_listener << "Argument scalar: '" << arg->ToString()
+                   << "' matches no input scalar.";
+  return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+  const auto& flatten = unique_list->Flatten().ValueOrDie();
+  const auto& offsets = 
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+  for (int64_t i = 0; i < arg->length(); ++i) {
+    bool match_found = false;
+    const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+    int64_t start = offsets->Value(i);
+    int64_t end = offsets->Value(i + 1);
+    for (int64_t j = start; j < end; ++j) {
+      auto s = flatten->GetScalar(j).ValueOrDie();
+      if (s->Equals(group_hash_one)) {
+        match_found = true;
+        break;
+      }
+    }
+    if (!match_found) {
+      *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+                       << "' matches no input scalar.";
+      return false;
+    }
+  }
+  return true;
+}
+
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false, true}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      //      AssertDatumsEqual(ArrayFromJSON(struct_({
+      //                                          field("hash_one", float64()),
+      //                                          field("hash_one", null()),
+      //                                          field("hash_one", boolean()),
+      //                                          field("key_0", int64()),
+      //                                      }),
+      //                                      R"([
+      //          [1.0,  null, true,  1],
+      //          [0.0,  null, false, 2],
+      //          [null, null, false, 3],
+      //          [4.0,  null, null,  null]
+      //        ])"),
+      //                        aggregated_and_grouped,
+      //                        /*verbose=*/true);
+
+      const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+      //  Check the key column
+      AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"), 
struct_arr->field(3));
+
+      auto type_col_0 = float64();
+      auto group_one_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+      auto group_two_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+      auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, 
R"([null])"));
+      auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0, 
0.75])"));
+
+      //  Check values individually
+      const auto& col0 = struct_arr->field(0);
+      ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+      EXPECT_THAT(g_one, group_one_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+      EXPECT_THAT(g_two, group_two_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+      EXPECT_THAT(g_three, group_three_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+      EXPECT_THAT(g_null, group_null_col_0);

Review comment:
       This is what the `EXPECT_THAT` and `AnyOf` approach might look like, for 
testing _just one_ column.

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+  for (int64_t i = 0; i < arrow_array->length(); ++i) {
+    auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+    if (scalar->Equals(arg)) return true;
+  }
+  *result_listener << "Argument scalar: '" << arg->ToString()
+                   << "' matches no input scalar.";
+  return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+  const auto& flatten = unique_list->Flatten().ValueOrDie();
+  const auto& offsets = 
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+  for (int64_t i = 0; i < arg->length(); ++i) {
+    bool match_found = false;
+    const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+    int64_t start = offsets->Value(i);
+    int64_t end = offsets->Value(i + 1);
+    for (int64_t j = start; j < end; ++j) {
+      auto s = flatten->GetScalar(j).ValueOrDie();
+      if (s->Equals(group_hash_one)) {
+        match_found = true;
+        break;
+      }
+    }
+    if (!match_found) {
+      *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+                       << "' matches no input scalar.";
+      return false;
+    }
+  }
+  return true;
+}
+
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false, true}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      //      AssertDatumsEqual(ArrayFromJSON(struct_({
+      //                                          field("hash_one", float64()),
+      //                                          field("hash_one", null()),
+      //                                          field("hash_one", boolean()),
+      //                                          field("key_0", int64()),
+      //                                      }),
+      //                                      R"([
+      //          [1.0,  null, true,  1],
+      //          [0.0,  null, false, 2],
+      //          [null, null, false, 3],
+      //          [4.0,  null, null,  null]
+      //        ])"),
+      //                        aggregated_and_grouped,
+      //                        /*verbose=*/true);
+
+      const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+      //  Check the key column
+      AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"), 
struct_arr->field(3));
+
+      auto type_col_0 = float64();
+      auto group_one_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+      auto group_two_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+      auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, 
R"([null])"));
+      auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0, 
0.75])"));
+
+      //  Check values individually
+      const auto& col0 = struct_arr->field(0);
+      ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+      EXPECT_THAT(g_one, group_one_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+      EXPECT_THAT(g_two, group_two_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+      EXPECT_THAT(g_three, group_three_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+      EXPECT_THAT(g_null, group_null_col_0);
+
+      CountOptions all(CountOptions::ALL);
+      ASSERT_OK_AND_ASSIGN(
+          auto distinct_out,
+          internal::GroupBy(
+              {
+                  table->GetColumnByName("argument0"),
+                  table->GetColumnByName("argument1"),
+                  table->GetColumnByName("argument2"),
+              },
+              {
+                  table->GetColumnByName("key"),
+              },
+              {{"hash_distinct", &all}, {"hash_distinct", &all}, 
{"hash_distinct", &all}},
+              use_threads));
+      ValidateOutput(distinct_out);
+      SortBy({"key_0"}, &distinct_out);
+
+      const auto& struct_arr_distinct = distinct_out.array_as<StructArray>();
+      for (int64_t col = 0; col < struct_arr_distinct->length() - 1; ++col) {
+        const auto matcher = AnyOfScalarFromUniques(
+            checked_pointer_cast<ListArray>(struct_arr_distinct->field(col)));
+        EXPECT_THAT(struct_arr->field(col), matcher);
+      }

Review comment:
       While this tests _all_ the columns (the key column is not a `ListArray` 
so will have to be tested manually, but that's non-trivial). So if using other 
kernels is not strictly discouraged to write test, this is a rather clean way 
of doing this. @lidavidm 

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+  for (int64_t i = 0; i < arrow_array->length(); ++i) {
+    auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+    if (scalar->Equals(arg)) return true;
+  }
+  *result_listener << "Argument scalar: '" << arg->ToString()
+                   << "' matches no input scalar.";
+  return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+  const auto& flatten = unique_list->Flatten().ValueOrDie();
+  const auto& offsets = 
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+  for (int64_t i = 0; i < arg->length(); ++i) {
+    bool match_found = false;
+    const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+    int64_t start = offsets->Value(i);
+    int64_t end = offsets->Value(i + 1);
+    for (int64_t j = start; j < end; ++j) {
+      auto s = flatten->GetScalar(j).ValueOrDie();
+      if (s->Equals(group_hash_one)) {
+        match_found = true;
+        break;
+      }
+    }
+    if (!match_found) {
+      *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+                       << "' matches no input scalar.";
+      return false;
+    }
+  }
+  return true;
+}
+
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false, true}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      //      AssertDatumsEqual(ArrayFromJSON(struct_({
+      //                                          field("hash_one", float64()),
+      //                                          field("hash_one", null()),
+      //                                          field("hash_one", boolean()),
+      //                                          field("key_0", int64()),
+      //                                      }),
+      //                                      R"([
+      //          [1.0,  null, true,  1],
+      //          [0.0,  null, false, 2],
+      //          [null, null, false, 3],
+      //          [4.0,  null, null,  null]
+      //        ])"),
+      //                        aggregated_and_grouped,
+      //                        /*verbose=*/true);
+
+      const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+      //  Check the key column
+      AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"), 
struct_arr->field(3));
+
+      auto type_col_0 = float64();
+      auto group_one_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+      auto group_two_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+      auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, 
R"([null])"));
+      auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0, 
0.75])"));
+
+      //  Check values individually
+      const auto& col0 = struct_arr->field(0);
+      ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+      EXPECT_THAT(g_one, group_one_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+      EXPECT_THAT(g_two, group_two_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+      EXPECT_THAT(g_three, group_three_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+      EXPECT_THAT(g_null, group_null_col_0);
+
+      CountOptions all(CountOptions::ALL);
+      ASSERT_OK_AND_ASSIGN(
+          auto distinct_out,
+          internal::GroupBy(
+              {
+                  table->GetColumnByName("argument0"),
+                  table->GetColumnByName("argument1"),
+                  table->GetColumnByName("argument2"),
+              },
+              {
+                  table->GetColumnByName("key"),
+              },
+              {{"hash_distinct", &all}, {"hash_distinct", &all}, 
{"hash_distinct", &all}},
+              use_threads));
+      ValidateOutput(distinct_out);
+      SortBy({"key_0"}, &distinct_out);
+
+      const auto& struct_arr_distinct = distinct_out.array_as<StructArray>();
+      for (int64_t col = 0; col < struct_arr_distinct->length() - 1; ++col) {
+        const auto matcher = AnyOfScalarFromUniques(
+            checked_pointer_cast<ListArray>(struct_arr_distinct->field(col)));
+        EXPECT_THAT(struct_arr->field(col), matcher);
+      }

Review comment:
       While this tests _all_ the columns (the key column is not a `ListArray` 
so will have to be tested manually, but that's trivial). So if using other 
kernels is not strictly discouraged to write test, this is a rather clean way 
of doing this. @lidavidm 

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+  for (int64_t i = 0; i < arrow_array->length(); ++i) {
+    auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+    if (scalar->Equals(arg)) return true;
+  }
+  *result_listener << "Argument scalar: '" << arg->ToString()
+                   << "' matches no input scalar.";
+  return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+  const auto& flatten = unique_list->Flatten().ValueOrDie();
+  const auto& offsets = 
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+  for (int64_t i = 0; i < arg->length(); ++i) {
+    bool match_found = false;
+    const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+    int64_t start = offsets->Value(i);
+    int64_t end = offsets->Value(i + 1);
+    for (int64_t j = start; j < end; ++j) {
+      auto s = flatten->GetScalar(j).ValueOrDie();
+      if (s->Equals(group_hash_one)) {
+        match_found = true;
+        break;
+      }
+    }
+    if (!match_found) {
+      *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+                       << "' matches no input scalar.";
+      return false;
+    }
+  }
+  return true;
+}
+
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false, true}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      //      AssertDatumsEqual(ArrayFromJSON(struct_({
+      //                                          field("hash_one", float64()),
+      //                                          field("hash_one", null()),
+      //                                          field("hash_one", boolean()),
+      //                                          field("key_0", int64()),
+      //                                      }),
+      //                                      R"([
+      //          [1.0,  null, true,  1],
+      //          [0.0,  null, false, 2],
+      //          [null, null, false, 3],
+      //          [4.0,  null, null,  null]
+      //        ])"),
+      //                        aggregated_and_grouped,
+      //                        /*verbose=*/true);
+
+      const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+      //  Check the key column
+      AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"), 
struct_arr->field(3));
+
+      auto type_col_0 = float64();
+      auto group_one_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+      auto group_two_col_0 =
+          AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+      auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, 
R"([null])"));
+      auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0, 
0.75])"));
+
+      //  Check values individually
+      const auto& col0 = struct_arr->field(0);
+      ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+      EXPECT_THAT(g_one, group_one_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+      EXPECT_THAT(g_two, group_two_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+      EXPECT_THAT(g_three, group_three_col_0);
+      ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+      EXPECT_THAT(g_null, group_null_col_0);
+
+      CountOptions all(CountOptions::ALL);
+      ASSERT_OK_AND_ASSIGN(
+          auto distinct_out,
+          internal::GroupBy(
+              {
+                  table->GetColumnByName("argument0"),
+                  table->GetColumnByName("argument1"),
+                  table->GetColumnByName("argument2"),
+              },
+              {
+                  table->GetColumnByName("key"),
+              },
+              {{"hash_distinct", &all}, {"hash_distinct", &all}, 
{"hash_distinct", &all}},
+              use_threads));
+      ValidateOutput(distinct_out);
+      SortBy({"key_0"}, &distinct_out);
+
+      const auto& struct_arr_distinct = distinct_out.array_as<StructArray>();
+      for (int64_t col = 0; col < struct_arr_distinct->length() - 1; ++col) {

Review comment:
       ```suggestion
         for (int64_t col = 0; col < struct_arr_distinct->num_fields() - 1; 
++col) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to