This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d30796021174 [SPARK-46655][SQL] Skip query context catching in 
`DataFrame` methods
d30796021174 is described below

commit d30796021174d8dc5595054d00c83ccdf0eb5b38
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Jan 11 17:08:29 2024 +0300

    [SPARK-46655][SQL] Skip query context catching in `DataFrame` methods
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to do not catch DataFrame query context in DataFrame 
methods but leave that close to `Column` functions.
    
    ### Why are the changes needed?
    To improve user experience with Spark DataFrame/Dataset APIs, and provide 
more precise context of errors.
    
    ### Does this PR introduce _any_ user-facing change?
    No, since the feature hasn't been released yet.
    
    ### How was this patch tested?
    By running the modified test suites.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44501 from MaxGekk/exclude-funcs-withOrigin.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 359 +++++++++------------
 .../main/scala/org/apache/spark/sql/package.scala  |   7 +-
 .../spark/sql/DataFrameSetOperationsSuite.scala    |   6 +-
 .../apache/spark/sql/streaming/StreamSuite.scala   |   4 +-
 4 files changed, 160 insertions(+), 216 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ff1bd8c73e6f..d792cdbcf865 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -512,11 +512,9 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 3.4.0
    */
-  def to(schema: StructType): DataFrame = withOrigin {
-    withPlan {
-      val replaced = 
CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
-      Project.matchSchema(logicalPlan, replaced, 
sparkSession.sessionState.conf)
-    }
+  def to(schema: StructType): DataFrame = withPlan {
+    val replaced = 
CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+    Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf)
   }
 
   /**
@@ -776,13 +774,12 @@ class Dataset[T] private[sql](
    */
   // We only accept an existing column name, not a derived column here as a 
watermark that is
   // defined on a derived column cannot referenced elsewhere in the plan.
-  def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = 
withOrigin {
-    withTypedPlan {
-      val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
-      require(!IntervalUtils.isNegative(parsedDelay),
-        s"delay threshold ($delayThreshold) should not be negative.")
-      EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, 
logicalPlan)
-    }
+  def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = 
withTypedPlan {
+    val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
+    require(!IntervalUtils.isNegative(parsedDelay),
+      s"delay threshold ($delayThreshold) should not be negative.")
+    EliminateEventTimeWatermark(
+      EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, 
logicalPlan))
   }
 
   /**
@@ -954,10 +951,8 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.0.0
    */
-  def join(right: Dataset[_]): DataFrame = withOrigin {
-    withPlan {
-      Join(logicalPlan, right.logicalPlan, joinType = Inner, None, 
JoinHint.NONE)
-    }
+  def join(right: Dataset[_]): DataFrame = withPlan {
+    Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE)
   }
 
   /**
@@ -1090,23 +1085,22 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.0.0
    */
-  def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): 
DataFrame =
-    withOrigin {
-      // Analyze the self join. The assumption is that the analyzer will 
disambiguate left vs right
-      // by creating a new instance for one of the branch.
-      val joined = sparkSession.sessionState.executePlan(
-        Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), 
None, JoinHint.NONE))
-        .analyzed.asInstanceOf[Join]
-
-      withPlan {
-        Join(
-          joined.left,
-          joined.right,
-          UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq),
-          None,
-          JoinHint.NONE)
-      }
+  def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): 
DataFrame = {
+    // Analyze the self join. The assumption is that the analyzer will 
disambiguate left vs right
+    // by creating a new instance for one of the branch.
+    val joined = sparkSession.sessionState.executePlan(
+      Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), 
None, JoinHint.NONE))
+      .analyzed.asInstanceOf[Join]
+
+    withPlan {
+      Join(
+        joined.left,
+        joined.right,
+        UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq),
+        None,
+        JoinHint.NONE)
     }
+  }
 
   /**
    * Inner join with another `DataFrame`, using the given join expression.
@@ -1187,7 +1181,7 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.0.0
    */
-  def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame 
= withOrigin {
+  def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame 
= {
     withPlan {
       resolveSelfJoinCondition(right, Some(joinExprs), joinType)
     }
@@ -1203,10 +1197,8 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.1.0
    */
-  def crossJoin(right: Dataset[_]): DataFrame = withOrigin {
-    withPlan {
-      Join(logicalPlan, right.logicalPlan, joinType = Cross, None, 
JoinHint.NONE)
-    }
+  def crossJoin(right: Dataset[_]): DataFrame = withPlan {
+    Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE)
   }
 
   /**
@@ -1230,28 +1222,27 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] =
-    withOrigin {
-      // Creates a Join node and resolve it first, to get join condition 
resolved, self-join
-      // resolved, etc.
-      val joined = sparkSession.sessionState.executePlan(
-        Join(
-          this.logicalPlan,
-          other.logicalPlan,
-          JoinType(joinType),
-          Some(condition.expr),
-          JoinHint.NONE)).analyzed.asInstanceOf[Join]
-
-      implicit val tuple2Encoder: Encoder[(T, U)] =
-        ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
-
-      withTypedPlan(JoinWith.typedJoinWith(
-        joined,
-        sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
-        sparkSession.sessionState.analyzer.resolver,
-        this.exprEnc.isSerializedAsStructForTopLevel,
-        other.exprEnc.isSerializedAsStructForTopLevel))
-    }
+  def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] = {
+    // Creates a Join node and resolve it first, to get join condition 
resolved, self-join resolved,
+    // etc.
+    val joined = sparkSession.sessionState.executePlan(
+      Join(
+        this.logicalPlan,
+        other.logicalPlan,
+        JoinType(joinType),
+        Some(condition.expr),
+        JoinHint.NONE)).analyzed.asInstanceOf[Join]
+
+    implicit val tuple2Encoder: Encoder[(T, U)] =
+      ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
+
+    withTypedPlan(JoinWith.typedJoinWith(
+      joined,
+      sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
+      sparkSession.sessionState.analyzer.resolver,
+      this.exprEnc.isSerializedAsStructForTopLevel,
+      other.exprEnc.isSerializedAsStructForTopLevel))
+  }
 
   /**
    * Using inner equi-join to join this Dataset returning a `Tuple2` for each 
pair
@@ -1429,16 +1420,14 @@ class Dataset[T] private[sql](
    * @since 2.2.0
    */
   @scala.annotation.varargs
-  def hint(name: String, parameters: Any*): Dataset[T] = withOrigin {
-    withTypedPlan {
-      val exprs = parameters.map {
-        case c: Column => c.expr
-        case s: Symbol => Column(s.name).expr
-        case e: Expression => e
-        case literal => Literal(literal)
-      }
-      UnresolvedHint(name, exprs, logicalPlan)
-    }
+  def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan {
+    val exprs = parameters.map {
+      case c: Column => c.expr
+      case s: Symbol => Column(s.name).expr
+      case e: Expression => e
+      case literal => Literal(literal)
+    }.toSeq
+    UnresolvedHint(name, exprs, logicalPlan)
   }
 
   /**
@@ -1514,10 +1503,8 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def as(alias: String): Dataset[T] = withOrigin {
-    withTypedPlan {
-      SubqueryAlias(alias, logicalPlan)
-    }
+  def as(alias: String): Dataset[T] = withTypedPlan {
+    SubqueryAlias(alias, logicalPlan)
   }
 
   /**
@@ -1554,28 +1541,25 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @scala.annotation.varargs
-  def select(cols: Column*): DataFrame = withOrigin {
-    withPlan {
-      val untypedCols = cols.map {
-        case typedCol: TypedColumn[_, _] =>
-          // Checks if a `TypedColumn` has been inserted with
-          // specific input type and schema by `withInputType`.
-          val needInputType = typedCol.expr.exists {
-            case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty 
=> true
-            case _ => false
-          }
+  def select(cols: Column*): DataFrame = withPlan {
+    val untypedCols = cols.map {
+      case typedCol: TypedColumn[_, _] =>
+        // Checks if a `TypedColumn` has been inserted with
+        // specific input type and schema by `withInputType`.
+        val needInputType = typedCol.expr.exists {
+          case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => 
true
+          case _ => false
+        }
 
-          if (!needInputType) {
-            typedCol
-          } else {
-            throw
-              
QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString)
-          }
+        if (!needInputType) {
+          typedCol
+        } else {
+          throw 
QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString)
+        }
 
-        case other => other
-      }
-      Project(untypedCols.map(_.named), logicalPlan)
+      case other => other
     }
+    Project(untypedCols.map(_.named), logicalPlan)
   }
 
   /**
@@ -1592,9 +1576,7 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @scala.annotation.varargs
-  def select(col: String, cols: String*): DataFrame = withOrigin {
-    select((col +: cols).map(Column(_)) : _*)
-  }
+  def select(col: String, cols: String*): DataFrame = select((col +: 
cols).map(Column(_)) : _*)
 
   /**
    * Selects a set of SQL expressions. This is a variant of `select` that 
accepts
@@ -1610,12 +1592,10 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @scala.annotation.varargs
-  def selectExpr(exprs: String*): DataFrame = withOrigin {
-    sparkSession.withActive {
-      select(exprs.map { expr =>
-        Column(sparkSession.sessionState.sqlParser.parseExpression(expr))
-      }: _*)
-    }
+  def selectExpr(exprs: String*): DataFrame = sparkSession.withActive {
+    select(exprs.map { expr =>
+      Column(sparkSession.sessionState.sqlParser.parseExpression(expr))
+    }: _*)
   }
 
   /**
@@ -1629,7 +1609,7 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = withOrigin {
+  def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
     implicit val encoder: ExpressionEncoder[U1] = c1.encoder
     val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named 
:: Nil, logicalPlan)
 
@@ -1713,10 +1693,8 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def filter(condition: Column): Dataset[T] = withOrigin {
-    withTypedPlan {
-      Filter(condition.expr, logicalPlan)
-    }
+  def filter(condition: Column): Dataset[T] = withTypedPlan {
+    Filter(condition.expr, logicalPlan)
   }
 
   /**
@@ -2102,17 +2080,15 @@ class Dataset[T] private[sql](
       ids: Array[Column],
       values: Array[Column],
       variableColumnName: String,
-      valueColumnName: String): DataFrame = withOrigin {
-    withPlan {
-      Unpivot(
-        Some(ids.map(_.named).toImmutableArraySeq),
-        Some(values.map(v => Seq(v.named)).toImmutableArraySeq),
-        None,
-        variableColumnName,
-        Seq(valueColumnName),
-        logicalPlan
-      )
-    }
+      valueColumnName: String): DataFrame = withPlan {
+    Unpivot(
+      Some(ids.map(_.named).toImmutableArraySeq),
+      Some(values.map(v => Seq(v.named)).toImmutableArraySeq),
+      None,
+      variableColumnName,
+      Seq(valueColumnName),
+      logicalPlan
+    )
   }
 
   /**
@@ -2135,17 +2111,15 @@ class Dataset[T] private[sql](
   def unpivot(
       ids: Array[Column],
       variableColumnName: String,
-      valueColumnName: String): DataFrame = withOrigin {
-    withPlan {
-      Unpivot(
-        Some(ids.map(_.named).toImmutableArraySeq),
-        None,
-        None,
-        variableColumnName,
-        Seq(valueColumnName),
-        logicalPlan
-      )
-    }
+      valueColumnName: String): DataFrame = withPlan {
+    Unpivot(
+      Some(ids.map(_.named).toImmutableArraySeq),
+      None,
+      None,
+      variableColumnName,
+      Seq(valueColumnName),
+      logicalPlan
+    )
   }
 
   /**
@@ -2262,10 +2236,8 @@ class Dataset[T] private[sql](
   * @since 3.0.0
   */
   @varargs
-  def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withOrigin {
-    withTypedPlan {
-      CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
-    }
+  def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withTypedPlan {
+    CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
   }
 
   /**
@@ -2302,10 +2274,8 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 2.0.0
    */
-  def limit(n: Int): Dataset[T] = withOrigin {
-    withTypedPlan {
-      Limit(Literal(n), logicalPlan)
-    }
+  def limit(n: Int): Dataset[T] = withTypedPlan {
+    Limit(Literal(n), logicalPlan)
   }
 
   /**
@@ -2314,10 +2284,8 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 3.4.0
    */
-  def offset(n: Int): Dataset[T] = withOrigin {
-    withTypedPlan {
-      Offset(Literal(n), logicalPlan)
-    }
+  def offset(n: Int): Dataset[T] = withTypedPlan {
+    Offset(Literal(n), logicalPlan)
   }
 
   // This breaks caching, but it's usually ok because it addresses a very 
specific use case:
@@ -2727,20 +2695,20 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @deprecated("use flatMap() or select() with functions.explode() instead", 
"2.0.0")
-  def explode[A <: Product : TypeTag](input: Column*)(f: Row => 
IterableOnce[A]): DataFrame =
-    withOrigin {
-      val elementSchema = 
ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
-      val convert = 
CatalystTypeConverters.createToCatalystConverter(elementSchema)
-
-      val rowFunction =
-        f.andThen(_.map(convert(_).asInstanceOf[InternalRow]))
-      val generator = UserDefinedGenerator(elementSchema, rowFunction, 
input.map(_.expr))
-
-      withPlan {
-        Generate(generator, unrequiredChildIndex = Nil, outer = false,
-          qualifier = None, generatorOutput = Nil, logicalPlan)
-      }
+  def explode[A <: Product : TypeTag](input: Column*)(f: Row => 
IterableOnce[A]): DataFrame = {
+    val elementSchema = 
ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+
+    val convert = 
CatalystTypeConverters.createToCatalystConverter(elementSchema)
+
+    val rowFunction =
+      f.andThen(_.map(convert(_).asInstanceOf[InternalRow]))
+    val generator = UserDefinedGenerator(elementSchema, rowFunction, 
input.map(_.expr))
+
+    withPlan {
+      Generate(generator, unrequiredChildIndex = Nil, outer = false,
+        qualifier = None, generatorOutput = Nil, logicalPlan)
     }
+  }
 
   /**
    * (Scala-specific) Returns a new Dataset where a single column has been 
expanded to zero
@@ -2765,7 +2733,7 @@ class Dataset[T] private[sql](
    */
   @deprecated("use flatMap() or select() with functions.explode() instead", 
"2.0.0")
   def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A 
=> IterableOnce[B])
-    : DataFrame = withOrigin {
+    : DataFrame = {
     val dataType = ScalaReflection.schemaFor[B].dataType
     val attributes = AttributeReference(outputColumn, dataType)() :: Nil
     // TODO handle the metadata?
@@ -2922,14 +2890,14 @@ class Dataset[T] private[sql](
    * @since 3.4.0
    */
   @throws[AnalysisException]
-  def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin 
{
+  def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = {
     val (colNames, newColNames) = colsMap.toSeq.unzip
     withColumnsRenamed(colNames, newColNames)
   }
 
   private[spark] def withColumnsRenamed(
     colNames: Seq[String],
-    newColNames: Seq[String]): DataFrame = withOrigin {
+    newColNames: Seq[String]): DataFrame = {
     require(colNames.size == newColNames.size,
       s"The size of existing column names: ${colNames.size} isn't equal to " +
         s"the size of new column names: ${newColNames.size}")
@@ -3104,10 +3072,8 @@ class Dataset[T] private[sql](
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def drop(col: Column, cols: Column*): DataFrame = withOrigin {
-    withPlan {
-      DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan)
-    }
+  def drop(col: Column, cols: Column*): DataFrame = withPlan {
+    DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan)
   }
 
   /**
@@ -3138,11 +3104,9 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 2.0.0
    */
-  def dropDuplicates(colNames: Seq[String]): Dataset[T] = withOrigin {
-    withTypedPlan {
-      val groupCols = groupColsFromDropDuplicates(colNames)
-      Deduplicate(groupCols, logicalPlan)
-    }
+  def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
+    val groupCols = groupColsFromDropDuplicates(colNames)
+    Deduplicate(groupCols, logicalPlan)
   }
 
   /**
@@ -3219,12 +3183,10 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 3.5.0
    */
-  def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = 
withOrigin {
-    withTypedPlan {
-      val groupCols = groupColsFromDropDuplicates(colNames)
-      // UnsupportedOperationChecker will fail the query if this is called 
with batch Dataset.
-      DeduplicateWithinWatermark(groupCols, logicalPlan)
-    }
+  def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = 
withTypedPlan {
+    val groupCols = groupColsFromDropDuplicates(colNames)
+    // UnsupportedOperationChecker will fail the query if this is called with 
batch Dataset.
+    DeduplicateWithinWatermark(groupCols, logicalPlan)
   }
 
   /**
@@ -3448,7 +3410,7 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def filter(func: T => Boolean): Dataset[T] = withOrigin {
+  def filter(func: T => Boolean): Dataset[T] = {
     withTypedPlan(TypedFilter(func, logicalPlan))
   }
 
@@ -3459,7 +3421,7 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def filter(func: FilterFunction[T]): Dataset[T] = withOrigin {
+  def filter(func: FilterFunction[T]): Dataset[T] = {
     withTypedPlan(TypedFilter(func, logicalPlan))
   }
 
@@ -3470,7 +3432,7 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def map[U : Encoder](func: T => U): Dataset[U] = withOrigin {
+  def map[U : Encoder](func: T => U): Dataset[U] = {
     withTypedPlan {
       MapElements[T, U](func, logicalPlan)
     }
@@ -3483,7 +3445,7 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = 
withOrigin {
+  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
     implicit val uEnc: Encoder[U] = encoder
     withTypedPlan(MapElements[T, U](func, logicalPlan))
   }
@@ -3646,9 +3608,8 @@ class Dataset[T] private[sql](
    * @group action
    * @since 3.0.0
    */
-  def tail(n: Int): Array[T] = withOrigin {
-    withAction("tail", withTypedPlan(Tail(Literal(n), 
logicalPlan)).queryExecution)(collectFromPlan)
-  }
+  def tail(n: Int): Array[T] = withAction(
+    "tail", withTypedPlan(Tail(Literal(n), 
logicalPlan)).queryExecution)(collectFromPlan)
 
   /**
    * Returns the first `n` rows in the Dataset as a list.
@@ -3712,10 +3673,8 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def count(): Long = withOrigin {
-    withAction("count", groupBy().count().queryExecution) { plan =>
-      plan.executeCollect().head.getLong(0)
-    }
+  def count(): Long = withAction("count", groupBy().count().queryExecution) { 
plan =>
+    plan.executeCollect().head.getLong(0)
   }
 
   /**
@@ -3724,15 +3683,13 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def repartition(numPartitions: Int): Dataset[T] = withOrigin {
-    withTypedPlan {
-      Repartition(numPartitions, shuffle = true, logicalPlan)
-    }
+  def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
+    Repartition(numPartitions, shuffle = true, logicalPlan)
   }
 
   private def repartitionByExpression(
       numPartitions: Option[Int],
-      partitionExprs: Seq[Column]): Dataset[T] = withOrigin {
+      partitionExprs: Seq[Column]): Dataset[T] = {
     // The underlying `LogicalPlan` operator special-cases all-`SortOrder` 
arguments.
     // However, we don't want to complicate the semantics of this API method.
     // Instead, let's give users a friendly error message, pointing them to 
the new method.
@@ -3777,7 +3734,7 @@ class Dataset[T] private[sql](
 
   private def repartitionByRange(
       numPartitions: Option[Int],
-      partitionExprs: Seq[Column]): Dataset[T] = withOrigin {
+      partitionExprs: Seq[Column]): Dataset[T] = {
     require(partitionExprs.nonEmpty, "At least one partition-by expression 
must be specified.")
     val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
       case expr: SortOrder => expr
@@ -3849,10 +3806,8 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def coalesce(numPartitions: Int): Dataset[T] = withOrigin {
-    withTypedPlan {
-      Repartition(numPartitions, shuffle = false, logicalPlan)
-    }
+  def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
+    Repartition(numPartitions, shuffle = false, logicalPlan)
   }
 
   /**
@@ -3996,10 +3951,8 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @throws[AnalysisException]
-  def createTempView(viewName: String): Unit = withOrigin {
-    withPlan {
-      createTempViewCommand(viewName, replace = false, global = false)
-    }
+  def createTempView(viewName: String): Unit = withPlan {
+    createTempViewCommand(viewName, replace = false, global = false)
   }
 
 
@@ -4011,10 +3964,8 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 2.0.0
    */
-  def createOrReplaceTempView(viewName: String): Unit = withOrigin {
-    withPlan {
-      createTempViewCommand(viewName, replace = true, global = false)
-    }
+  def createOrReplaceTempView(viewName: String): Unit = withPlan {
+    createTempViewCommand(viewName, replace = true, global = false)
   }
 
   /**
@@ -4032,10 +3983,8 @@ class Dataset[T] private[sql](
    * @since 2.1.0
    */
   @throws[AnalysisException]
-  def createGlobalTempView(viewName: String): Unit = withOrigin {
-    withPlan {
-      createTempViewCommand(viewName, replace = false, global = true)
-    }
+  def createGlobalTempView(viewName: String): Unit = withPlan {
+    createTempViewCommand(viewName, replace = false, global = true)
   }
 
   /**
@@ -4475,7 +4424,7 @@ class Dataset[T] private[sql](
     plan.executeCollect().map(fromRow)
   }
 
-  private def sortInternal(global: Boolean, sortExprs: Seq[Column]): 
Dataset[T] = withOrigin {
+  private def sortInternal(global: Boolean, sortExprs: Seq[Column]): 
Dataset[T] = {
     val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
       col.expr match {
         case expr: SortOrder =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 877d9906a1cf..9831ce62801a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -111,9 +111,10 @@ package object sql {
     }
   }
 
-  private val sparkCodePattern = 
Pattern.compile("org\\.apache\\.spark\\.sql\\." +
-      
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions)" +
-      "(?:|\\..*|\\$.*)")
+  private val sparkCodePattern = 
Pattern.compile("(org\\.apache\\.spark\\.sql\\." +
+      
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)"
 +
+      "(?:|\\..*|\\$.*))" +
+      "|(scala\\.collection\\..*)")
 
   private def sparkCode(ste: StackTraceElement): Boolean = {
     sparkCodePattern.matcher(ste.getClassName).matches()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
index e2dd029d4b10..bbb1561bb695 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
@@ -374,11 +374,7 @@ class DataFrameSetOperationsSuite extends QueryTest
       errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
       parameters = Map(
         "colName" -> "`m`",
-        "dataType" -> "\"MAP<STRING, BIGINT>\""),
-      context = ExpectedContext(
-        fragment = "distinct",
-        callSitePattern = getCurrentClassCallSitePattern)
-    )
+        "dataType" -> "\"MAP<STRING, BIGINT>\""))
     withTempView("v") {
       df.createOrReplaceTempView("v")
       checkError(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index cd0bbfd47b2b..b0e54737d104 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -713,9 +713,7 @@ class StreamSuite extends StreamTest {
         "columnName" -> "`rn_col`",
         "windowSpec" ->
           ("(PARTITION BY COL1 ORDER BY COL2 ASC NULLS FIRST ROWS BETWEEN 
UNBOUNDED PRECEDING " +
-          "AND CURRENT ROW)")),
-      queryContext = Array(
-        ExpectedContext(fragment = "withColumn", callSitePattern = 
getCurrentClassCallSitePattern)))
+          "AND CURRENT ROW)")))
   }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to