This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7fe2759e9f8 [SPARK-39046][SQL] Return an empty context string if TreeNode.origin is wrongly set 7fe2759e9f8 is described below commit 7fe2759e9f81ec267e92e1c6f8a48f42042db791 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Thu Apr 28 09:59:17 2022 +0800 [SPARK-39046][SQL] Return an empty context string if TreeNode.origin is wrongly set ### What changes were proposed in this pull request? For the query context `TreeNode.origin.context`, this PR proposal to return an empty context string if * the query text/ the start index/ the stop index is missing * the start index is less than 0 * the stop index is larger than the length of query text * the start index is larger than the stop index ### Why are the changes needed? There are downstream projects that depend on Spark. There is no guarantee for the correctness of TreeNode.origin. Developers may create a plan/expression with a Origin containing wrong startIndex/stopIndex/sqlText. Thus, to avoid errors in calling `String.substring` or showing misleading debug information, I suggest returning an empty context string if TreeNode.origin is wrongly set. The query context is just for better error messages and we should handle it cautiously. ### Does this PR introduce _any_ user-facing change? No, the context framework is not released yet. ### How was this patch tested? UT Closes #36379 from gengliangwang/safeContext. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 126 +++++++++++---------- .../expressions/ArithmeticExpressionSuite.scala | 6 + .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 37 ++++++ 3 files changed, 110 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index eed59b9e1bf..0714898e19d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -72,71 +72,79 @@ case class Origin( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val context: String = sqlText.map { text => - val positionContext = if (line.isDefined && startPosition.isDefined) { - s"(line ${line.get}, position ${startPosition.get})" - } else { + lazy val context: String = { + // If the query context is missing or incorrect, simply return an empty string. + if (sqlText.isEmpty || startIndex.isEmpty || stopIndex.isEmpty || + startIndex.get < 0 || stopIndex.get >= sqlText.get.length || startIndex.get > stopIndex.get) { "" - } - val objectContext = if (objectType.isDefined && objectName.isDefined) { - s" of ${objectType.get} ${objectName.get}" } else { - "" - } - val builder = new StringBuilder - builder ++= s"\n== SQL$objectContext$positionContext ==\n" - - val start = startIndex.getOrElse(0) - val stop = stopIndex.getOrElse(sqlText.get.length - 1) - // Ideally we should show all the lines which contains the SQL text context of the current node: - // [additional text] [current tree node] [additional text] - // However, we need to truncate the additional text in case it is too long. The following - // variable is to define the max length of additional text. - val maxExtraContextLength = 32 - val truncatedText = "..." - var lineStartIndex = start - // Collect the SQL text within the starting line of current Node. - // The text is truncated if it is too long. - while(lineStartIndex >= 0 && - start - lineStartIndex <= maxExtraContextLength && - text.charAt(lineStartIndex) != '\n') { - lineStartIndex -= 1 - } - val startTruncated = start - lineStartIndex > maxExtraContextLength - var currentIndex = lineStartIndex - if (startTruncated) { - currentIndex -= truncatedText.length - } + val positionContext = if (line.isDefined && startPosition.isDefined) { + s"(line ${line.get}, position ${startPosition.get})" + } else { + "" + } + val objectContext = if (objectType.isDefined && objectName.isDefined) { + s" of ${objectType.get} ${objectName.get}" + } else { + "" + } + val builder = new StringBuilder + builder ++= s"\n== SQL$objectContext$positionContext ==\n" + + val text = sqlText.get + val start = math.max(startIndex.get, 0) + val stop = math.min(stopIndex.getOrElse(text.length - 1), text.length - 1) + // Ideally we should show all the lines which contains the SQL text context of the current + // node: + // [additional text] [current tree node] [additional text] + // However, we need to truncate the additional text in case it is too long. The following + // variable is to define the max length of additional text. + val maxExtraContextLength = 32 + val truncatedText = "..." + var lineStartIndex = start + // Collect the SQL text within the starting line of current Node. + // The text is truncated if it is too long. + while (lineStartIndex >= 0 && + start - lineStartIndex <= maxExtraContextLength && + text.charAt(lineStartIndex) != '\n') { + lineStartIndex -= 1 + } + val startTruncated = start - lineStartIndex > maxExtraContextLength + var currentIndex = lineStartIndex + if (startTruncated) { + currentIndex -= truncatedText.length + } - var lineStopIndex = stop - // Collect the SQL text within the ending line of current Node. - // The text is truncated if it is too long. - while(lineStopIndex < text.length && - lineStopIndex - stop <= maxExtraContextLength && - text.charAt(lineStopIndex) != '\n') { - lineStopIndex += 1 - } - val stopTruncated = lineStopIndex - stop > maxExtraContextLength - - val subText = (if (startTruncated) truncatedText else "") + - text.substring(lineStartIndex + 1, lineStopIndex) + - (if (stopTruncated) truncatedText else "") - val lines = subText.split("\n") - lines.foreach { lineText => - builder ++= lineText + "\n" - currentIndex += 1 - (0 until lineText.length).foreach { _ => - if (currentIndex < start) { - builder ++= " " - } else if (currentIndex >= start && currentIndex <= stop) { - builder ++= "^" - } + var lineStopIndex = stop + // Collect the SQL text within the ending line of current Node. + // The text is truncated if it is too long. + while (lineStopIndex < text.length && + lineStopIndex - stop <= maxExtraContextLength && + text.charAt(lineStopIndex) != '\n') { + lineStopIndex += 1 + } + val stopTruncated = lineStopIndex - stop > maxExtraContextLength + + val truncatedSubText = (if (startTruncated) truncatedText else "") + + text.substring(lineStartIndex + 1, lineStopIndex) + + (if (stopTruncated) truncatedText else "") + val lines = truncatedSubText.split("\n") + lines.foreach { lineText => + builder ++= lineText + "\n" currentIndex += 1 + (0 until lineText.length).foreach { _ => + if (currentIndex < start) { + builder ++= " " + } else if (currentIndex >= start && currentIndex <= stop) { + builder ++= "^" + } + currentIndex += 1 + } + builder ++= "\n" } - builder ++= "\n" + builder.result() } - builder.result() - }.getOrElse("") + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 818694bf84e..8c0defea230 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -93,6 +93,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { val expr = Add(maxValue, maxValue, failOnError = true) @@ -177,6 +178,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { val expr = Subtract(minValue, maxValue, failOnError = true) @@ -215,6 +217,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { val expr = Multiply(maxValue, maxValue, failOnError = true) @@ -259,6 +262,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), failOnError = true) @@ -312,6 +316,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { val expr = @@ -362,6 +367,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper line = Some(1), startPosition = Some(7), startIndex = Some(7), + stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4bdb3f715bc..ffbc5d89bdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -888,4 +888,41 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(origin.context == expected) } + + test("SPARK-39046: Return an empty context string if TreeNode.origin is wrongly set") { + val text = Some("select a + b") + // missing start index + val origin1 = Origin( + startIndex = Some(7), + stopIndex = None, + sqlText = text) + // missing stop index + val origin2 = Origin( + startIndex = None, + stopIndex = Some(11), + sqlText = text) + // missing text + val origin3 = Origin( + startIndex = Some(7), + stopIndex = Some(11), + sqlText = None) + // negative start index + val origin4 = Origin( + startIndex = Some(-1), + stopIndex = Some(11), + sqlText = text) + // stop index >= text.length + val origin5 = Origin( + startIndex = Some(-1), + stopIndex = Some(text.get.length), + sqlText = text) + // start index > stop index + val origin6 = Origin( + startIndex = Some(2), + stopIndex = Some(1), + sqlText = text) + Seq(origin1, origin2, origin3, origin4, origin5, origin6).foreach { origin => + assert(origin.context.isEmpty) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org