This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 96bf373e9002 [SPARK-46430][CONNECT][TESTS] Add test for 
`ProtoUtils.abbreviate`
96bf373e9002 is described below

commit 96bf373e90026dac8ef5020fe3032107c11df73f
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Dec 17 13:27:09 2023 -0800

    [SPARK-46430][CONNECT][TESTS] Add test for `ProtoUtils.abbreviate`
    
    ### What changes were proposed in this pull request?
    Add test for `ProtoUtils.abbreviate`
    
    ### Why are the changes needed?
    `ProtoUtils.abbreviate` is not tested, for better test coverage
    we are going to improve this functionality, before that we should protect 
its behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44383 from zhengruifeng/proto_utils_test.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/connect/messages/AbbreviateSuite.scala     | 121 +++++++++++++++++++++
 1 file changed, 121 insertions(+)

diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
new file mode 100644
index 000000000000..9a712e9b7bf1
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.messages
+
+import scala.jdk.CollectionConverters._
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.common.{ProtoDataTypes, ProtoUtils}
+
+class AbbreviateSuite extends SparkFunSuite {
+
+  test("truncate string: simple SQL text") {
+    val message = proto.SQL.newBuilder().setQuery("x" * 1024).build()
+
+    Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+      val truncated = ProtoUtils.abbreviate(message, threshold)
+      assert(truncated.isInstanceOf[proto.SQL])
+      val truncatedSQL = truncated.asInstanceOf[proto.SQL]
+
+      if (threshold < 1024) {
+        assert(truncatedSQL.getQuery.indexOf("[truncated") === threshold)
+      } else {
+        assert(truncatedSQL.getQuery.indexOf("[truncated") === -1)
+        assert(truncatedSQL.getQuery.length === 1024)
+      }
+    }
+  }
+
+  test("truncate string: nested message") {
+    val sql = proto.Relation
+      .newBuilder()
+      .setSql(
+        proto.SQL
+          .newBuilder()
+          .setQuery("x" * 1024)
+          .build())
+      .build()
+    val drop = proto.Relation
+      .newBuilder()
+      .setDrop(
+        proto.Drop
+          .newBuilder()
+          .setInput(sql)
+          .addAllColumnNames(Seq("a", "b").asJava)
+          .build())
+      .build()
+    val limit = proto.Relation
+      .newBuilder()
+      .setLimit(
+        proto.Limit
+          .newBuilder()
+          .setInput(drop)
+          .setLimit(100)
+          .build())
+      .build()
+
+    Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+      val truncated = ProtoUtils.abbreviate(limit, threshold)
+      assert(truncated.isInstanceOf[proto.Relation])
+
+      val truncatedLimit = truncated.asInstanceOf[proto.Relation].getLimit
+      assert(truncatedLimit.getLimit === 100)
+
+      val truncatedDrop = truncatedLimit.getInput.getDrop
+      assert(truncatedDrop.getColumnNamesList.asScala.toSeq === Seq("a", "b"))
+
+      val truncatedSQL = truncatedDrop.getInput.getSql
+
+      if (threshold < 1024) {
+        assert(truncatedSQL.getQuery.indexOf("[truncated") === threshold)
+      } else {
+        assert(truncatedSQL.getQuery.indexOf("[truncated") === -1)
+        assert(truncatedSQL.getQuery.length === 1024)
+      }
+    }
+  }
+
+  test("truncate bytes: simple python udf") {
+    Seq(1, 8, 16, 64, 256).foreach { numBytes =>
+      val bytes = Array.ofDim[Byte](numBytes)
+      val message = proto.PythonUDF
+        .newBuilder()
+        .setEvalType(1)
+        .setOutputType(ProtoDataTypes.BinaryType)
+        .setCommand(ByteString.copyFrom(bytes))
+        .setPythonVer("3.12")
+        .build()
+
+      val truncated = ProtoUtils.abbreviate(message)
+      assert(truncated.isInstanceOf[proto.PythonUDF])
+
+      val truncatedUDF = truncated.asInstanceOf[proto.PythonUDF]
+      assert(truncatedUDF.getEvalType === 1)
+      assert(truncatedUDF.getOutputType === ProtoDataTypes.BinaryType)
+      assert(truncatedUDF.getPythonVer === "3.12")
+
+      if (numBytes <= 8) {
+        assert(truncatedUDF.getCommand.size() === numBytes)
+      } else {
+        assert(truncatedUDF.getCommand.size() != numBytes)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to