This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 71b76dc3e66a [SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate` 
support repeated fields
71b76dc3e66a is described below

commit 71b76dc3e66a9fdd99f961876c503776e8085325
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Feb 8 15:23:34 2024 +0800

    [SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate` support repeated fields
    
    ### What changes were proposed in this pull request?
    Make `ProtoUtils.abbreviate` support repeated fields
    
    ### Why are the changes needed?
    existing implementation does not work for repeated fields (strings/messages)
    
    we don't have `repeated bytes` in Spark Connect for now, so let it alone
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #45056 from zhengruifeng/proto_abbr_repeat.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../spark/sql/connect/common/ProtoUtils.scala      | 34 ++++++++---
 .../sql/connect/messages/AbbreviateSuite.scala     | 71 ++++++++++++++++++++++
 2 files changed, 98 insertions(+), 7 deletions(-)

diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index 2f31b63acf87..66146698b701 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -42,7 +42,17 @@ private[connect] object ProtoUtils {
         val size = string.length
         val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
         if (size > threshold) {
-          builder.setField(field, createString(string.take(threshold), size))
+          builder.setField(field, truncateString(string, threshold))
+        }
+
+      case (field: FieldDescriptor, strings: java.lang.Iterable[_])
+          if field.getJavaType == FieldDescriptor.JavaType.STRING && 
field.isRepeated
+            && strings != null =>
+        val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
+        strings.iterator().asScala.zipWithIndex.foreach {
+          case (string: String, i) if string != null && string.length > 
threshold =>
+            builder.setRepeatedField(field, i, truncateString(string, 
threshold))
+          case _ =>
         }
 
       case (field: FieldDescriptor, byteString: ByteString)
@@ -69,23 +79,33 @@ private[connect] object ProtoUtils {
               .concat(createTruncatedByteString(size)))
         }
 
-      // TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, 
msg>
+      // TODO(SPARK-46988): should support map<xxx, msg>
       case (field: FieldDescriptor, msg: Message)
-          if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != 
null =>
+          if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && 
!field.isRepeated
+            && msg != null =>
         builder.setField(field, abbreviate(msg, thresholds))
 
+      case (field: FieldDescriptor, msgs: java.lang.Iterable[_])
+          if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && 
field.isRepeated
+            && msgs != null =>
+        msgs.iterator().asScala.zipWithIndex.foreach {
+          case (msg: Message, i) if msg != null =>
+            builder.setRepeatedField(field, i, abbreviate(msg, thresholds))
+          case _ =>
+        }
+
       case _ =>
     }
 
     builder.build()
   }
 
-  private def createTruncatedByteString(size: Int): ByteString = {
-    ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
+  private def truncateString(string: String, threshold: Int): String = {
+    
s"${string.take(threshold)}[truncated(size=${format.format(string.length)})]"
   }
 
-  private def createString(prefix: String, size: Int): String = {
-    s"$prefix[truncated(size=${format.format(size)})]"
+  private def createTruncatedByteString(size: Int): ByteString = {
+    ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
   }
 
   // Because Spark Connect operation tags are also set as SparkContext Job 
tags, they cannot contain
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
index 6dca2c1e8907..0b7104f6c67e 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
@@ -92,6 +92,77 @@ class AbbreviateSuite extends SparkFunSuite {
     }
   }
 
+  test("truncate repeated strings") {
+    val sql = proto.Relation
+      .newBuilder()
+      .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
+      .build()
+    val names = Seq.range(0, 10).map(i => i.toString * 1024)
+    val drop = 
proto.Drop.newBuilder().setInput(sql).addAllColumnNames(names.asJava).build()
+
+    Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+      val truncated = ProtoUtils.abbreviate(drop, threshold)
+      assert(drop.isInstanceOf[proto.Drop])
+
+      val truncatedNames = 
truncated.asInstanceOf[proto.Drop].getColumnNamesList.asScala.toSeq
+      assert(truncatedNames.length === 10)
+
+      if (threshold < 1024) {
+        truncatedNames.foreach { truncatedName =>
+          assert(truncatedName.indexOf("[truncated") === threshold)
+        }
+      } else {
+        truncatedNames.foreach { truncatedName =>
+          assert(truncatedName.indexOf("[truncated") === -1)
+          assert(truncatedName.length === 1024)
+        }
+      }
+
+    }
+  }
+
+  test("truncate repeated messages") {
+    val sql = proto.Relation
+      .newBuilder()
+      .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
+      .build()
+
+    val cols = Seq.range(0, 10).map { i =>
+      proto.Expression
+        .newBuilder()
+        .setUnresolvedAttribute(
+          proto.Expression.UnresolvedAttribute
+            .newBuilder()
+            .setUnparsedIdentifier(i.toString * 1024)
+            .build())
+        .build()
+    }
+    val drop = 
proto.Drop.newBuilder().setInput(sql).addAllColumns(cols.asJava).build()
+
+    Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+      val truncated = ProtoUtils.abbreviate(drop, threshold)
+      assert(drop.isInstanceOf[proto.Drop])
+
+      val truncatedCols = 
truncated.asInstanceOf[proto.Drop].getColumnsList.asScala.toSeq
+      assert(truncatedCols.length === 10)
+
+      if (threshold < 1024) {
+        truncatedCols.foreach { truncatedCol =>
+          assert(truncatedCol.isInstanceOf[proto.Expression])
+          val truncatedName = 
truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
+          assert(truncatedName.indexOf("[truncated") === threshold)
+        }
+      } else {
+        truncatedCols.foreach { truncatedCol =>
+          assert(truncatedCol.isInstanceOf[proto.Expression])
+          val truncatedName = 
truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
+          assert(truncatedName.indexOf("[truncated") === -1)
+          assert(truncatedName.length === 1024)
+        }
+      }
+    }
+  }
+
   test("truncate bytes: simple python udf") {
     Seq(1, 8, 16, 64, 256).foreach { numBytes =>
       val bytes = Array.ofDim[Byte](numBytes)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to