This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 998bce144 [GLUTEN-4830][VL] Support MapType substrait signature (#4833)
998bce144 is described below

commit 998bce14495ed96ca6f811e323db168cc92906a1
Author: WangGuangxin <wangguangxin...@bytedance.com>
AuthorDate: Thu Mar 21 22:38:24 2024 +0800

    [GLUTEN-4830][VL] Support MapType substrait signature (#4833)
---
 .../io/glutenproject/execution/TestOperator.scala  | 53 +++++++++++++++++++
 cpp/velox/substrait/VeloxSubstraitSignature.cc     | 59 ++++++++++++++--------
 cpp/velox/tests/VeloxSubstraitSignatureTest.cc     | 16 ++++++
 .../glutenproject/expression/ConverterUtils.scala  |  9 +++-
 4 files changed, 115 insertions(+), 22 deletions(-)

diff --git 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index 239bec57a..c81a60430 100644
--- 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -1258,4 +1258,57 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite {
       }
     }
   }
+
+  test("Support Map type signature") {
+    // test map<str,str>
+    withTempView("t1") {
+      Seq[(Int, Map[String, String])]((1, Map("byte1" -> "aaa")), (2, 
Map("byte2" -> "bbbb")))
+        .toDF("c1", "map_c2")
+        .createTempView("t1")
+      runQueryAndCompare("""
+                           |SELECT c1, collect_list(map_c2) FROM t1 group by 
c1;
+                           |""".stripMargin) {
+        checkOperatorMatch[HashAggregateExecTransformer]
+      }
+    }
+    // test map<str,map<str,str>>
+    withTempView("t2") {
+      Seq[(Int, Map[String, Map[String, String]])](
+        (1, Map("byte1" -> Map("test1" -> "aaaa"))),
+        (2, Map("byte2" -> Map("test1" -> "bbbb"))))
+        .toDF("c1", "map_c2")
+        .createTempView("t2")
+      runQueryAndCompare("""
+                           |SELECT c1, collect_list(map_c2) FROM t2 group by 
c1;
+                           |""".stripMargin) {
+        checkOperatorMatch[HashAggregateExecTransformer]
+      }
+    }
+    // test map<map<str,str>,map<str,str>>
+    withTempView("t3") {
+      Seq[(Int, Map[Map[String, String], Map[String, String]])](
+        (1, Map(Map("byte1" -> "aaaa") -> Map("test1" -> "aaaa"))),
+        (2, Map(Map("byte2" -> "bbbb") -> Map("test1" -> "bbbb"))))
+        .toDF("c1", "map_c2")
+        .createTempView("t3")
+      runQueryAndCompare("""
+                           |SELECT collect_list(map_c2) FROM t3 group by c1;
+                           |""".stripMargin) {
+        checkOperatorMatch[HashAggregateExecTransformer]
+      }
+    }
+    // test map<str,list<str>>
+    withTempView("t4") {
+      Seq[(Int, Map[String, Array[String]])](
+        (1, Map("test1" -> Array("test1", "test2"))),
+        (2, Map("test2" -> Array("test1", "test2"))))
+        .toDF("c1", "map_c2")
+        .createTempView("t4")
+      runQueryAndCompare("""
+                           |SELECT collect_list(map_c2) FROM t4 group by c1;
+                           |""".stripMargin) {
+        checkOperatorMatch[HashAggregateExecTransformer]
+      }
+    }
+  }
 }
diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc 
b/cpp/velox/substrait/VeloxSubstraitSignature.cc
index 2d2432281..34e0df6de 100644
--- a/cpp/velox/substrait/VeloxSubstraitSignature.cc
+++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc
@@ -121,33 +121,25 @@ TypePtr 
VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa
     return str.size() >= prefix.size() && str.substr(0, prefix.size()) == 
prefix;
   };
 
-  if (startWith(signature, "dec")) {
-    // Decimal type name is in the format of dec<precision,scale>.
-    auto precisionStart = signature.find_first_of('<');
-    auto tokenIndex = signature.find_first_of(',');
-    auto scaleEnd = signature.find_first_of('>');
-    auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - 
precisionStart - 1)));
-    auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex 
- 1)));
-    return DECIMAL(precision, scale);
-  }
-
-  if (startWith(signature, "struct")) {
-    // Struct type name is in the format of struct<T1,T2,...,Tn>.
-    auto structStart = signature.find_first_of('<');
-    auto structEnd = signature.find_last_of('>');
+  auto parseNestedTypeSignature = [&](const std::string& signature) -> 
std::vector<TypePtr> {
+    auto start = signature.find_first_of('<');
+    auto end = signature.find_last_of('>');
     VELOX_CHECK(
-        structEnd - structStart > 1, "Native validation failed due to: more 
information is needed to create RowType");
-    std::string childrenTypes = signature.substr(structStart + 1, structEnd - 
structStart - 1);
+        end - start > 1,
+        "Native validation failed due to: more information is needed to create 
nested type for {}",
+        signature);
+
+    std::string childrenTypes = signature.substr(start + 1, end - start - 1);
 
     // Split the types with delimiter.
     std::string delimiter = ",";
     std::size_t pos;
     std::vector<TypePtr> types;
-    std::vector<std::string> names;
     while ((pos = childrenTypes.find(delimiter)) != std::string::npos) {
       auto typeStr = childrenTypes.substr(0, pos);
       std::size_t endPos = pos;
-      if (startWith(typeStr, "dec") || startWith(typeStr, "struct")) {
+      if (startWith(typeStr, "dec") || startWith(typeStr, "struct") || 
startWith(typeStr, "map") ||
+          startWith(typeStr, "list")) {
         endPos = childrenTypes.find(">") + 1;
         if (endPos > pos) {
           typeStr += childrenTypes.substr(pos, endPos - pos);
@@ -159,16 +151,43 @@ TypePtr 
VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa
         }
       }
       types.emplace_back(fromSubstraitSignature(typeStr));
-      names.emplace_back("");
       childrenTypes.erase(0, endPos + delimiter.length());
     }
     if (childrenTypes.size() > 0 && !startWith(childrenTypes, ">")) {
       types.emplace_back(fromSubstraitSignature(childrenTypes));
-      names.emplace_back("");
+    }
+    return types;
+  };
+
+  if (startWith(signature, "dec")) {
+    // Decimal type name is in the format of dec<precision,scale>.
+    auto precisionStart = signature.find_first_of('<');
+    auto tokenIndex = signature.find_first_of(',');
+    auto scaleEnd = signature.find_first_of('>');
+    auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - 
precisionStart - 1)));
+    auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex 
- 1)));
+    return DECIMAL(precision, scale);
+  }
+
+  if (startWith(signature, "struct")) {
+    // Struct type name is in the format of struct<T1,T2,...,Tn>.
+    auto types = parseNestedTypeSignature(signature);
+    std::vector<std::string> names(types.size());
+    for (int i = 0; i < types.size(); i++) {
+      names[i] = "";
     }
     return std::make_shared<RowType>(std::move(names), std::move(types));
   }
 
+  if (startWith(signature, "map")) {
+    // Map type name is in the format of map<T1,T2>.
+    auto types = parseNestedTypeSignature(signature);
+    if (types.size() != 2) {
+      VELOX_UNSUPPORTED("Substrait type signature conversion to Velox type not 
supported for {}.", signature);
+    }
+    return MAP(std::move(types)[0], std::move(types)[1]);
+  }
+
   if (startWith(signature, "list")) {
     auto listStart = signature.find_first_of('<');
     auto listEnd = signature.find_last_of('>');
diff --git a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc 
b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc
index bbc1165ad..d6db661f7 100644
--- a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc
+++ b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc
@@ -139,6 +139,22 @@ TEST_F(VeloxSubstraitSignatureTest, 
fromSubstraitSignature) {
   type = fromSubstraitSignature("struct<struct<struct<i8,dec<19,2>>>>");
   ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), 
TypeKind::HUGEINT);
   ASSERT_ANY_THROW(fromSubstraitSignature("other")->kind());
+
+  // Map type test.
+  type = fromSubstraitSignature("map<bool,list<map<str,i32>>>");
+  ASSERT_EQ(type->kind(), TypeKind::MAP);
+  ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN);
+  ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
+  ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
+  type = fromSubstraitSignature("struct<map<bool,i8>,list<map<str,i32>>>");
+  ASSERT_EQ(type->kind(), TypeKind::ROW);
+  ASSERT_EQ(type->childAt(0)->kind(), TypeKind::MAP);
+  ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BOOLEAN);
+  ASSERT_EQ(type->childAt(0)->childAt(1)->kind(), TypeKind::TINYINT);
+  ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
+  ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
+  ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(0)->kind(), 
TypeKind::VARCHAR);
+  ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(1)->kind(), 
TypeKind::INTEGER);
 }
 
 } // namespace gluten
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala 
b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala
index 95f6b2861..3400f3337 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala
@@ -379,8 +379,13 @@ object ConverterUtils extends Logging {
           })
         sigName = sigName.concat(">")
         sigName
-      case MapType(_, _, _) =>
-        "map"
+      case MapType(keyType, valueType, _) =>
+        var sigName = "map<"
+        sigName = sigName.concat(getTypeSigName(keyType))
+        sigName = sigName.concat(",")
+        sigName = sigName.concat(getTypeSigName(valueType))
+        sigName = sigName.concat(">")
+        sigName
       case CharType(_) =>
         "fchar"
       case NullType =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to