This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 71b76dc3e66a [SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate` support repeated fields 71b76dc3e66a is described below commit 71b76dc3e66a9fdd99f961876c503776e8085325 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Feb 8 15:23:34 2024 +0800 [SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate` support repeated fields ### What changes were proposed in this pull request? Make `ProtoUtils.abbreviate` support repeated fields ### Why are the changes needed? existing implementation does not work for repeated fields (strings/messages) we don't have `repeated bytes` in Spark Connect for now, so let it alone ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added UTs ### Was this patch authored or co-authored using generative AI tooling? no Closes #45056 from zhengruifeng/proto_abbr_repeat. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/sql/connect/common/ProtoUtils.scala | 34 ++++++++--- .../sql/connect/messages/AbbreviateSuite.scala | 71 ++++++++++++++++++++++ 2 files changed, 98 insertions(+), 7 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 2f31b63acf87..66146698b701 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -42,7 +42,17 @@ private[connect] object ProtoUtils { val size = string.length val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE) if (size > threshold) { - builder.setField(field, createString(string.take(threshold), size)) + builder.setField(field, truncateString(string, threshold)) + } + + case (field: FieldDescriptor, strings: java.lang.Iterable[_]) + if field.getJavaType == FieldDescriptor.JavaType.STRING && field.isRepeated + && strings != null => + val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE) + strings.iterator().asScala.zipWithIndex.foreach { + case (string: String, i) if string != null && string.length > threshold => + builder.setRepeatedField(field, i, truncateString(string, threshold)) + case _ => } case (field: FieldDescriptor, byteString: ByteString) @@ -69,23 +79,33 @@ private[connect] object ProtoUtils { .concat(createTruncatedByteString(size))) } - // TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, msg> + // TODO(SPARK-46988): should support map<xxx, msg> case (field: FieldDescriptor, msg: Message) - if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => + if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && !field.isRepeated + && msg != null => builder.setField(field, abbreviate(msg, thresholds)) + case (field: FieldDescriptor, msgs: java.lang.Iterable[_]) + if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && field.isRepeated + && msgs != null => + msgs.iterator().asScala.zipWithIndex.foreach { + case (msg: Message, i) if msg != null => + builder.setRepeatedField(field, i, abbreviate(msg, thresholds)) + case _ => + } + case _ => } builder.build() } - private def createTruncatedByteString(size: Int): ByteString = { - ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]") + private def truncateString(string: String, threshold: Int): String = { + s"${string.take(threshold)}[truncated(size=${format.format(string.length)})]" } - private def createString(prefix: String, size: Int): String = { - s"$prefix[truncated(size=${format.format(size)})]" + private def createTruncatedByteString(size: Int): ByteString = { + ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]") } // Because Spark Connect operation tags are also set as SparkContext Job tags, they cannot contain diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala index 6dca2c1e8907..0b7104f6c67e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala @@ -92,6 +92,77 @@ class AbbreviateSuite extends SparkFunSuite { } } + test("truncate repeated strings") { + val sql = proto.Relation + .newBuilder() + .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T")) + .build() + val names = Seq.range(0, 10).map(i => i.toString * 1024) + val drop = proto.Drop.newBuilder().setInput(sql).addAllColumnNames(names.asJava).build() + + Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold => + val truncated = ProtoUtils.abbreviate(drop, threshold) + assert(drop.isInstanceOf[proto.Drop]) + + val truncatedNames = truncated.asInstanceOf[proto.Drop].getColumnNamesList.asScala.toSeq + assert(truncatedNames.length === 10) + + if (threshold < 1024) { + truncatedNames.foreach { truncatedName => + assert(truncatedName.indexOf("[truncated") === threshold) + } + } else { + truncatedNames.foreach { truncatedName => + assert(truncatedName.indexOf("[truncated") === -1) + assert(truncatedName.length === 1024) + } + } + + } + } + + test("truncate repeated messages") { + val sql = proto.Relation + .newBuilder() + .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T")) + .build() + + val cols = Seq.range(0, 10).map { i => + proto.Expression + .newBuilder() + .setUnresolvedAttribute( + proto.Expression.UnresolvedAttribute + .newBuilder() + .setUnparsedIdentifier(i.toString * 1024) + .build()) + .build() + } + val drop = proto.Drop.newBuilder().setInput(sql).addAllColumns(cols.asJava).build() + + Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold => + val truncated = ProtoUtils.abbreviate(drop, threshold) + assert(drop.isInstanceOf[proto.Drop]) + + val truncatedCols = truncated.asInstanceOf[proto.Drop].getColumnsList.asScala.toSeq + assert(truncatedCols.length === 10) + + if (threshold < 1024) { + truncatedCols.foreach { truncatedCol => + assert(truncatedCol.isInstanceOf[proto.Expression]) + val truncatedName = truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier + assert(truncatedName.indexOf("[truncated") === threshold) + } + } else { + truncatedCols.foreach { truncatedCol => + assert(truncatedCol.isInstanceOf[proto.Expression]) + val truncatedName = truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier + assert(truncatedName.indexOf("[truncated") === -1) + assert(truncatedName.length === 1024) + } + } + } + } + test("truncate bytes: simple python udf") { Seq(1, 8, 16, 64, 256).foreach { numBytes => val bytes = Array.ofDim[Byte](numBytes) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org