This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 606a99f4f2d [SPARK-39046][SQL] Return an empty context string if 
TreeNode.origin is wrongly set
606a99f4f2d is described below

commit 606a99f4f2d91ea30c81285d6c95ee566e80577f
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Thu Apr 28 09:59:17 2022 +0800

    [SPARK-39046][SQL] Return an empty context string if TreeNode.origin is 
wrongly set
    
    ### What changes were proposed in this pull request?
    
    For the query context `TreeNode.origin.context`, this PR proposal to return 
an empty context string if
    * the query text/ the start index/ the stop index is missing
    * the start index is less than 0
    * the stop index is larger than the length of query text
    * the start index is larger than the stop index
    
    ### Why are the changes needed?
    
    There are downstream projects that depend on Spark. There is no guarantee 
for the correctness of TreeNode.origin. Developers may create a plan/expression 
with a Origin containing wrong startIndex/stopIndex/sqlText.
    Thus, to avoid errors in calling `String.substring` or showing misleading 
debug information, I suggest returning an empty context string if 
TreeNode.origin is wrongly set. The query context is just for better error 
messages and we should handle it cautiously.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the context framework is not released yet.
    
    ### How was this patch tested?
    
    UT
    
    Closes #36379 from gengliangwang/safeContext.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
    (cherry picked from commit 7fe2759e9f81ec267e92e1c6f8a48f42042db791)
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala | 126 +++++++++++----------
 .../expressions/ArithmeticExpressionSuite.scala    |   6 +
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   |  37 ++++++
 3 files changed, 110 insertions(+), 59 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 00690abf18f..079abd3f2e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -72,71 +72,79 @@ case class Origin(
    * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
    *                          ^^^^^^^^^^^^^^^
    */
-  lazy val context: String = sqlText.map { text =>
-    val positionContext = if (line.isDefined && startPosition.isDefined) {
-      s"(line ${line.get}, position ${startPosition.get})"
-    } else {
+  lazy val context: String = {
+    // If the query context is missing or incorrect, simply return an empty 
string.
+    if (sqlText.isEmpty || startIndex.isEmpty || stopIndex.isEmpty ||
+      startIndex.get < 0 || stopIndex.get >= sqlText.get.length || 
startIndex.get > stopIndex.get) {
       ""
-    }
-    val objectContext = if (objectType.isDefined && objectName.isDefined) {
-      s" of ${objectType.get} ${objectName.get}"
     } else {
-      ""
-    }
-    val builder = new StringBuilder
-    builder ++= s"\n== SQL$objectContext$positionContext ==\n"
-
-    val start = startIndex.getOrElse(0)
-    val stop = stopIndex.getOrElse(sqlText.get.length - 1)
-    // Ideally we should show all the lines which contains the SQL text 
context of the current node:
-    // [additional text] [current tree node] [additional text]
-    // However, we need to truncate the additional text in case it is too 
long. The following
-    // variable is to define the max length of additional text.
-    val maxExtraContextLength = 32
-    val truncatedText = "..."
-    var lineStartIndex = start
-    // Collect the SQL text within the starting line of current Node.
-    // The text is truncated if it is too long.
-    while(lineStartIndex >= 0 &&
-      start - lineStartIndex <= maxExtraContextLength &&
-      text.charAt(lineStartIndex) != '\n') {
-      lineStartIndex -= 1
-    }
-    val startTruncated = start - lineStartIndex > maxExtraContextLength
-    var currentIndex = lineStartIndex
-    if (startTruncated) {
-      currentIndex -= truncatedText.length
-    }
+      val positionContext = if (line.isDefined && startPosition.isDefined) {
+        s"(line ${line.get}, position ${startPosition.get})"
+      } else {
+        ""
+      }
+      val objectContext = if (objectType.isDefined && objectName.isDefined) {
+        s" of ${objectType.get} ${objectName.get}"
+      } else {
+        ""
+      }
+      val builder = new StringBuilder
+      builder ++= s"\n== SQL$objectContext$positionContext ==\n"
+
+      val text = sqlText.get
+      val start = math.max(startIndex.get, 0)
+      val stop = math.min(stopIndex.getOrElse(text.length - 1), text.length - 
1)
+      // Ideally we should show all the lines which contains the SQL text 
context of the current
+      // node:
+      // [additional text] [current tree node] [additional text]
+      // However, we need to truncate the additional text in case it is too 
long. The following
+      // variable is to define the max length of additional text.
+      val maxExtraContextLength = 32
+      val truncatedText = "..."
+      var lineStartIndex = start
+      // Collect the SQL text within the starting line of current Node.
+      // The text is truncated if it is too long.
+      while (lineStartIndex >= 0 &&
+        start - lineStartIndex <= maxExtraContextLength &&
+        text.charAt(lineStartIndex) != '\n') {
+        lineStartIndex -= 1
+      }
+      val startTruncated = start - lineStartIndex > maxExtraContextLength
+      var currentIndex = lineStartIndex
+      if (startTruncated) {
+        currentIndex -= truncatedText.length
+      }
 
-    var lineStopIndex = stop
-    // Collect the SQL text within the ending line of current Node.
-    // The text is truncated if it is too long.
-    while(lineStopIndex < text.length &&
-      lineStopIndex - stop <= maxExtraContextLength &&
-      text.charAt(lineStopIndex) != '\n') {
-      lineStopIndex += 1
-    }
-    val stopTruncated = lineStopIndex - stop > maxExtraContextLength
-
-    val subText = (if (startTruncated) truncatedText else "") +
-      text.substring(lineStartIndex + 1, lineStopIndex) +
-      (if (stopTruncated) truncatedText else "")
-    val lines = subText.split("\n")
-    lines.foreach { lineText =>
-      builder ++= lineText + "\n"
-      currentIndex += 1
-      (0 until lineText.length).foreach { _ =>
-        if (currentIndex < start) {
-          builder ++= " "
-        } else if (currentIndex >= start && currentIndex <= stop) {
-          builder ++= "^"
-        }
+      var lineStopIndex = stop
+      // Collect the SQL text within the ending line of current Node.
+      // The text is truncated if it is too long.
+      while (lineStopIndex < text.length &&
+        lineStopIndex - stop <= maxExtraContextLength &&
+        text.charAt(lineStopIndex) != '\n') {
+        lineStopIndex += 1
+      }
+      val stopTruncated = lineStopIndex - stop > maxExtraContextLength
+
+      val truncatedSubText = (if (startTruncated) truncatedText else "") +
+        text.substring(lineStartIndex + 1, lineStopIndex) +
+        (if (stopTruncated) truncatedText else "")
+      val lines = truncatedSubText.split("\n")
+      lines.foreach { lineText =>
+        builder ++= lineText + "\n"
         currentIndex += 1
+        (0 until lineText.length).foreach { _ =>
+          if (currentIndex < start) {
+            builder ++= " "
+          } else if (currentIndex >= start && currentIndex <= stop) {
+            builder ++= "^"
+          }
+          currentIndex += 1
+        }
+        builder ++= "\n"
       }
-      builder ++= "\n"
+      builder.result()
     }
-    builder.result()
-  }.getOrElse("")
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 7624828dd08..b167b5f7a16 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -93,6 +93,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         line = Some(1),
         startPosition = Some(7),
         startIndex = Some(7),
+        stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
         val expr = Add(maxValue, maxValue, failOnError = true)
@@ -177,6 +178,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         line = Some(1),
         startPosition = Some(7),
         startIndex = Some(7),
+        stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
         val expr = Subtract(minValue, maxValue, failOnError = true)
@@ -215,6 +217,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         line = Some(1),
         startPosition = Some(7),
         startIndex = Some(7),
+        stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
         val expr = Multiply(maxValue, maxValue, failOnError = true)
@@ -259,6 +262,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
       line = Some(1),
       startPosition = Some(7),
       startIndex = Some(7),
+      stopIndex = Some(7 + query.length -1),
       sqlText = Some(s"select $query"))
     withOrigin(o) {
       val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), 
failOnError = true)
@@ -312,6 +316,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
         line = Some(1),
         startPosition = Some(7),
         startIndex = Some(7),
+        stopIndex = Some(7 + query.length -1),
         sqlText = Some(s"select $query"))
       withOrigin(o) {
         val expr =
@@ -362,6 +367,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
           line = Some(1),
           startPosition = Some(7),
           startIndex = Some(7),
+          stopIndex = Some(7 + query.length -1),
           sqlText = Some(s"select $query"))
         withOrigin(o) {
           checkExceptionInExpression[ArithmeticException](expr, EmptyRow, 
query)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 4bdb3f715bc..ffbc5d89bdb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -888,4 +888,41 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
 
     assert(origin.context == expected)
   }
+
+  test("SPARK-39046: Return an empty context string if TreeNode.origin is 
wrongly set") {
+    val text = Some("select a + b")
+    // missing start index
+    val origin1 = Origin(
+      startIndex = Some(7),
+      stopIndex = None,
+      sqlText = text)
+    // missing stop index
+    val origin2 = Origin(
+      startIndex = None,
+      stopIndex = Some(11),
+      sqlText = text)
+    // missing text
+    val origin3 = Origin(
+      startIndex = Some(7),
+      stopIndex = Some(11),
+      sqlText = None)
+    // negative start index
+    val origin4 = Origin(
+      startIndex = Some(-1),
+      stopIndex = Some(11),
+      sqlText = text)
+    // stop index >= text.length
+    val origin5 = Origin(
+      startIndex = Some(-1),
+      stopIndex = Some(text.get.length),
+      sqlText = text)
+    // start index > stop index
+    val origin6 = Origin(
+      startIndex = Some(2),
+      stopIndex = Some(1),
+      sqlText = text)
+    Seq(origin1, origin2, origin3, origin4, origin5, origin6).foreach { origin 
=>
+      assert(origin.context.isEmpty)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to