lidavidm commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r808483788



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,294 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+TEST(GroupBy, OneMiscTypes) {
+  auto in_schema = schema({
+      field("floats", float64()),
+      field("nulls", null()),
+      field("booleans", boolean()),
+      field("decimal128", decimal128(3, 2)),
+      field("decimal256", decimal256(3, 2)),
+      field("fixed_binary", fixed_size_binary(3)),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {true, false}) {
+    for (bool use_threads : {true, false}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [null, null, true,   null,    null,    null,  1],
+    [1.0,  null, true,   "1.01",  "1.01",  "aaa", 1]
+])",
+                                             R"([
+    [0.0,   null, false, "0.00",  "0.00",  "bac", 2],
+    [null,  null, false, null,    null,    null,  3],
+    [4.0,   null, null,  "4.01",  "4.01",  "234", null],
+    [3.25,  null, true,  "3.25",  "3.25",  "ddd", 1],
+    [0.125, null, false, "0.12",  "0.12",  "bcd", 2]
+])",
+                                             R"([
+    [-0.25, null, false, "-0.25", "-0.25", "bab", 2],
+    [0.75,  null, true,  "0.75",  "0.75",  "123", null],
+    [null,  null, true,  null,    null,    null,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("floats"),
+                                   table->GetColumnByName("nulls"),
+                                   table->GetColumnByName("booleans"),
+                                   table->GetColumnByName("decimal128"),
+                                   table->GetColumnByName("decimal256"),
+                                   table->GetColumnByName("fixed_binary"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+      //  Check the key column
+      AssertDatumsEqual(ArrayFromJSON(int64(), R"([1, 2, 3, null])"),
+                        struct_arr->field(struct_arr->num_fields() - 1));
+
+      //  Check values individually
+      auto col_0_type = float64();
+      const auto& col_0 = struct_arr->field(0);
+      EXPECT_THAT(col_0->GetScalar(0),
+                  ResultWith(AnyOfJSON(col_0_type, R"([1.0, 3.25])")));
+      EXPECT_THAT(col_0->GetScalar(1),
+                  ResultWith(AnyOfJSON(col_0_type, R"([0.0, 0.125, -0.25])")));
+      EXPECT_THAT(col_0->GetScalar(2), ResultWith(AnyOfJSON(col_0_type, 
R"([null])")));
+      EXPECT_THAT(col_0->GetScalar(3),
+                  ResultWith(AnyOfJSON(col_0_type, R"([4.0, 0.75])")));
+
+      auto col_1_type = null();
+      const auto& col_1 = struct_arr->field(1);
+      EXPECT_THAT(col_1->GetScalar(0), ResultWith(AnyOfJSON(col_1_type, 
R"([null])")));
+      EXPECT_THAT(col_1->GetScalar(1), ResultWith(AnyOfJSON(col_1_type, 
R"([null])")));
+      EXPECT_THAT(col_1->GetScalar(2), ResultWith(AnyOfJSON(col_1_type, 
R"([null])")));
+      EXPECT_THAT(col_1->GetScalar(3), ResultWith(AnyOfJSON(col_1_type, 
R"([null])")));
+
+      auto col_2_type = boolean();
+      const auto& col_2 = struct_arr->field(2);
+      EXPECT_THAT(col_2->GetScalar(0), ResultWith(AnyOfJSON(col_2_type, 
R"([true])")));
+      EXPECT_THAT(col_2->GetScalar(1), ResultWith(AnyOfJSON(col_2_type, 
R"([false])")));
+      EXPECT_THAT(col_2->GetScalar(2),
+                  ResultWith(AnyOfJSON(col_2_type, R"([true, false])")));
+      EXPECT_THAT(col_2->GetScalar(3),
+                  ResultWith(AnyOfJSON(col_2_type, R"([true, null])")));
+
+      auto col_3_type = decimal128(3, 2);
+      const auto& col_3 = struct_arr->field(3);
+      EXPECT_THAT(col_3->GetScalar(0),
+                  ResultWith(AnyOfJSON(col_3_type, R"(["1.01", "3.25"])")));
+      EXPECT_THAT(col_3->GetScalar(1),
+                  ResultWith(AnyOfJSON(col_3_type, R"(["0.00", "0.12", 
"-0.25"])")));
+      EXPECT_THAT(col_3->GetScalar(2), ResultWith(AnyOfJSON(col_3_type, 
R"([null])")));
+      EXPECT_THAT(col_3->GetScalar(3),
+                  ResultWith(AnyOfJSON(col_3_type, R"(["4.01", "0.75"])")));
+
+      auto col_4_type = decimal256(3, 2);
+      const auto& col_4 = struct_arr->field(4);
+      EXPECT_THAT(col_4->GetScalar(0),
+                  ResultWith(AnyOfJSON(col_4_type, R"(["1.01", "3.25"])")));
+      EXPECT_THAT(col_4->GetScalar(1),
+                  ResultWith(AnyOfJSON(col_4_type, R"(["0.00", "0.12", 
"-0.25"])")));
+      EXPECT_THAT(col_4->GetScalar(2), ResultWith(AnyOfJSON(col_4_type, 
R"([null])")));
+      EXPECT_THAT(col_4->GetScalar(3),
+                  ResultWith(AnyOfJSON(col_4_type, R"(["4.01", "0.75"])")));
+
+      auto col_5_type = fixed_size_binary(3);
+      const auto& col_5 = struct_arr->field(5);
+      EXPECT_THAT(col_5->GetScalar(0),
+                  ResultWith(AnyOfJSON(col_5_type, R"(["aaa", "ddd"])")));
+      EXPECT_THAT(col_5->GetScalar(1),
+                  ResultWith(AnyOfJSON(col_5_type, R"(["bab", "bcd", 
"bac"])")));
+      EXPECT_THAT(col_5->GetScalar(2), ResultWith(AnyOfJSON(col_5_type, 
R"([null])")));
+      EXPECT_THAT(col_5->GetScalar(3),
+                  ResultWith(AnyOfJSON(col_5_type, R"(["123", "234"])")));
+    }
+  }
+}
+
+TEST(GroupBy, OneNumericTypes) {
+  std::vector<std::shared_ptr<DataType>> types;
+  types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
+  types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
+  types.push_back(month_interval());
+
+  const std::vector<std::string> numeric_table_json = {R"([
+      [null, 1],
+      [1,    1]
+    ])",
+                                                       R"([
+      [0,    2],
+      [null, 3],
+      [3,    4],
+      [5,    4],
+      [4,    null],
+      [3,    1],
+      [0,    2]
+    ])",
+                                                       R"([
+      [0,    2],
+      [1,    null],
+      [null, 3]
+    ])"};
+
+  const std::vector<std::string> temporal_table_json = {R"([
+      [null,      1],
+      [86400000,  1]
+    ])",
+                                                        R"([
+      [0,         2],
+      [null,      3],
+      [259200000, 4],
+      [432000000, 4],
+      [345600000, null],
+      [259200000, 1],
+      [0,         2]
+    ])",
+                                                        R"([
+      [0,         2],
+      [86400000,  null],
+      [null,      3]
+    ])"};
+
+  for (const auto& type : types) {
+    for (bool use_exec_plan : {true, false}) {
+      for (bool use_threads : {true, false}) {
+        SCOPED_TRACE(type->ToString());
+        auto in_schema = schema({field("argument0", type), field("key", 
int64())});
+        auto table =
+            TableFromJSON(in_schema, (type->name() == "date64") ? 
temporal_table_json
+                                                                : 
numeric_table_json);
+        ASSERT_OK_AND_ASSIGN(
+            Datum aggregated_and_grouped,
+            GroupByTest({table->GetColumnByName("argument0")},
+                        {table->GetColumnByName("key")}, {{"hash_one", 
nullptr}},
+                        use_threads, use_exec_plan));
+        ValidateOutput(aggregated_and_grouped);
+        SortBy({"key_0"}, &aggregated_and_grouped);
+
+        const auto& struct_arr = 
aggregated_and_grouped.array_as<StructArray>();
+        //  Check the key column
+        AssertDatumsEqual(ArrayFromJSON(int64(), R"([1, 2, 3, 4, null])"),
+                          struct_arr->field(struct_arr->num_fields() - 1));
+
+        //  Check values individually
+        const auto& col = struct_arr->field(0);
+        if (type->name() == "date64") {
+          EXPECT_THAT(col->GetScalar(0),
+                      ResultWith(AnyOfJSON(type, R"([86400000, 259200000])")));
+          EXPECT_THAT(col->GetScalar(1), ResultWith(AnyOfJSON(type, 
R"([0])")));
+          EXPECT_THAT(col->GetScalar(2), ResultWith(AnyOfJSON(type, 
R"([null])")));
+          EXPECT_THAT(col->GetScalar(3),
+                      ResultWith(AnyOfJSON(type, R"([259200000, 
432000000])")));
+          EXPECT_THAT(col->GetScalar(4),
+                      ResultWith(AnyOfJSON(type, R"([345600000, 86400000])")));
+        } else {
+          EXPECT_THAT(col->GetScalar(0), ResultWith(AnyOfJSON(type, R"([1, 
3])")));
+          EXPECT_THAT(col->GetScalar(1), ResultWith(AnyOfJSON(type, 
R"([0])")));
+          EXPECT_THAT(col->GetScalar(2), ResultWith(AnyOfJSON(type, 
R"([null])")));
+          EXPECT_THAT(col->GetScalar(3), ResultWith(AnyOfJSON(type, R"([3, 
5])")));
+          EXPECT_THAT(col->GetScalar(4), ResultWith(AnyOfJSON(type, R"([4, 
1])")));
+        }
+      }
+    }
+  }
+}
+
+TEST(GroupBy, OneBinaryTypes) {
+  for (bool use_exec_plan : {true, false}) {
+    for (bool use_threads : {true, false}) {
+      for (const auto& type : BaseBinaryTypes()) {
+        SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+        auto table = TableFromJSON(schema({
+                                       field("argument0", type),
+                                       field("key", int64()),
+                                   }),
+                                   {R"([
+    [null,   1],
+    ["aaaa", 1]
+])",
+                                    R"([
+    ["babcd",2],
+    [null,   3],
+    ["2",    null],
+    ["d",    1],
+    ["bc",   2]
+])",
+                                    R"([
+    ["bcd", 2],
+    ["123", null],
+    [null,  3]
+])"});
+
+        ASSERT_OK_AND_ASSIGN(
+            Datum aggregated_and_grouped,
+            GroupByTest({table->GetColumnByName("argument0")},
+                        {table->GetColumnByName("key")}, {{"hash_one", 
nullptr}},
+                        use_threads, use_exec_plan));
+        ValidateOutput(aggregated_and_grouped);
+        SortBy({"key_0"}, &aggregated_and_grouped);
+
+        const auto& struct_arr = 
aggregated_and_grouped.array_as<StructArray>();
+        //  Check the key column
+        AssertDatumsEqual(ArrayFromJSON(int64(), R"([1, 2, 3, null])"),
+                          struct_arr->field(struct_arr->num_fields() - 1));
+
+        const auto& col = struct_arr->field(0);
+        EXPECT_THAT(col->GetScalar(0), ResultWith(AnyOfJSON(type, R"(["aaaa", 
"d"])")));
+        EXPECT_THAT(col->GetScalar(1),
+                    ResultWith(AnyOfJSON(type, R"(["bcd", "bc", "babcd"])")));
+        EXPECT_THAT(col->GetScalar(2), ResultWith(AnyOfJSON(type, 
R"([null])")));
+        EXPECT_THAT(col->GetScalar(3), ResultWith(AnyOfJSON(type, R"(["2", 
"123"])")));
+      }
+    }
+  }
+}
+
+TEST(GroupBy, OneScalar) {

Review comment:
       Hmm, this needs to be something like `ValueDescr::Scalar(type)`. See for 
instance CountScalar: 
https://github.com/apache/arrow/blob/7236f48d7c534802ebd84daa709aeaba070d6780/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc#L846-L884

##########
File path: cpp/src/arrow/testing/matchers.h
##########
@@ -61,6 +61,65 @@ class PointeesEqualMatcher {
 // Useful in conjunction with other googletest matchers.
 inline PointeesEqualMatcher PointeesEqual() { return {}; }
 
+class AnyOfJSONMatcher {

Review comment:
       Thanks! I think it would technically compose a little better if it took 
the array instead of type + string but that's not a big deal (we can see if 
it's useful elsewhere first).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to