Repository: spark Updated Branches: refs/heads/master f47700c9c -> 8ce0d8ffb
[SPARK-20392][SQL] Set barrier to prevent re-entering a tree ## What changes were proposed in this pull request? It is reported that there is performance downgrade when applying ML pipeline for dataset with many columns but few rows. A big part of the performance downgrade comes from some operations (e.g., `select`) on DataFrame/Dataset which re-create new DataFrame/Dataset with a new `LogicalPlan`. The cost can be ignored in the usage of SQL, normally. However, it's not rare to chain dozens of pipeline stages in ML. When the query plan grows incrementally during running those stages, the total cost spent on re-creation of DataFrame grows too. In particular, the `Analyzer` will go through the big query plan even most part of it is analyzed. By eliminating part of the cost, the time to run the example code locally is reduced from about 1min to about 30 secs. In particular, the time applying the pipeline locally is mostly spent on calling transform of the 137 `Bucketizer`s. Before the change, each call of `Bucketizer`'s transform can cost about 0.4 sec. So the total time spent on all `Bucketizer`s' transform is about 50 secs. After the change, each call only costs about 0.1 sec. <del>We also make `boundEnc` as lazy variable to reduce unnecessary running time.</del> ### Performance improvement The codes and datasets provided by Barry Becker to re-produce this issue and benchmark can be found on the JIRA. Before this patch: about 1 min After this patch: about 20 secs ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #17770 from viirya/SPARK-20392. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ce0d8ff Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ce0d8ff Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ce0d8ff Branch: refs/heads/master Commit: 8ce0d8ffb68bd9e89c23d3a026308dcc039a1b1d Parents: f47700c Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Fri May 26 13:45:55 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri May 26 13:45:55 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 75 ++++++++++------ .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../analysis/ResolveTableValuedFunctions.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++--- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../spark/sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 35 -------- .../plans/logical/basicLogicalOperators.scala | 9 ++ .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++ .../sql/catalyst/plans/LogicalPlanSuite.scala | 26 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 92 ++++++++++---------- .../sql/execution/datasources/DataSource.scala | 2 +- .../spark/sql/execution/datasources/rules.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../apache/spark/sql/hive/HiveStrategies.scala | 6 +- 16 files changed, 151 insertions(+), 144 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d130962..85cf8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -166,14 +166,15 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + EliminateBarriers) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -201,7 +202,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -243,7 +244,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -615,7 +616,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -670,7 +671,9 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { + // Remove analysis barrier if any. + val right = EliminateBarriers(oriRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -713,7 +716,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + oriRight case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -726,7 +729,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - newRight + AnalysisBarrier(newRight) } } @@ -787,7 +790,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -961,7 +964,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1017,7 +1020,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1041,11 +1044,13 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1066,7 +1071,8 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1093,7 +1099,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return plan + return AnalysisBarrier(plan) } plan match { case p: Project => @@ -1165,7 +1171,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1504,7 +1510,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1519,7 +1525,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1545,7 +1551,9 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1605,6 +1613,8 @@ class Analyzer( case ae: AnalysisException => filter } + case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1717,7 +1727,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1775,7 +1785,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2092,7 +2102,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2137,7 +2147,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2202,7 +2212,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2267,7 +2277,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2353,7 +2363,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2387,7 +2397,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2441,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2470,6 +2480,13 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** Remove the barrier nodes of analysis */ +object EliminateBarriers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case AnalysisBarrier(child) => child + } +} + /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. @@ -2519,7 +2536,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2..ac72bc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // fix decimal precision for expressions case q => q.transformExpressions( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index dad1340..4067535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010..c364517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -206,7 +206,7 @@ object TypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -261,7 +261,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if p.analyzed => p case s @ SetOperation(left, right) if s.childrenResolved && @@ -335,7 +335,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -391,7 +391,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -449,7 +449,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -490,7 +490,7 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -580,7 +580,7 @@ object TypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -602,7 +602,7 @@ object TypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -632,7 +632,7 @@ object TypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -656,7 +656,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -673,7 +673,7 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index a27aa84..af1f916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.resolveExpressions(transformTimeZoneExprs) + plan.transformAllExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index ea46dd7..3bbe41c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2a3e07a..46d1aac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -236,7 +236,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2ebb2ff..23520eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -46,41 +46,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already - * been marked as analyzed. - * - * @param rule the function use to transform this nodes children - */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) - } - } - } else { - this - } - } - - /** - * Recursively transforms the expressions of a tree, skipping nodes that have already - * been analyzed. - */ - def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this resolveOperators { - case p => p.transformExpressions(r) - } - } - /** A cache for the estimated statistics, such that it will only be computed once. */ private var statsCache: Option[Statistics] = None http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6878b6b..b8c2f76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -896,3 +897,11 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } + +/** A logical plan for setting a barrier of analysis */ +case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override def output: Seq[Attribute] = child.output + override def analyzed: Boolean = true + override def isStreaming: Boolean = child.isStreaming + override lazy val canonicalized: LogicalPlan = child.canonicalized +} http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0896cae..3b42897 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -441,6 +441,20 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + test("analysis barrier") { + // [[AnalysisBarrier]] will be removed after analysis + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl.a")), + AnalysisBarrier(SubqueryAlias("tbl", testRelation))), + Project(testRelation.output, SubqueryAlias("tbl", testRelation))) + + // Verify we won't go through a plan wrapped in a barrier. + // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. + val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), + SubqueryAlias("tbl", testRelation))) + assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + } + test("SPARK-20311 range(N) as alias") { def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index cc86f1f..215db84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly - * skips sub-trees that have already been marked as analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure + * it can correctly skip sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -36,37 +36,35 @@ class LogicalPlanSuite extends SparkFunSuite { private val testRelation = LocalRelation() - test("resolveOperator runs on operators") { + test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) - plan resolveOperators function + plan transformUp function assert(invocationCount === 1) } - test("resolveOperator runs on operators recursively") { + test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) - plan resolveOperators function + plan transformUp function assert(invocationCount === 2) } - test("resolveOperator skips all ready resolved plans") { + test("transformUp skips all ready resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan = Project(Nil, Project(Nil, testRelation)) - plan.foreach(_.setAnalyzed()) - plan resolveOperators function + val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) + plan transformUp function assert(invocationCount === 0) } - test("resolveOperator skips partially resolved plans") { + test("transformUp skips partially resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan1 = Project(Nil, testRelation) + val plan1 = AnalysisBarrier(Project(Nil, testRelation)) val plan2 = Project(Nil, plan1) - plan1.foreach(_.setAnalyzed()) - plan2 resolveOperators function + plan2 transformUp function assert(invocationCount === 1) } http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cbab029..f9bd8f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -187,6 +187,9 @@ class Dataset[T] private[sql]( } } + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. + @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -413,7 +416,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -616,7 +619,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) } /** @@ -790,7 +793,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) } /** @@ -868,7 +871,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -929,7 +932,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -938,8 +941,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -971,7 +974,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) } /** @@ -1003,8 +1006,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.logicalPlan, - other.logicalPlan, + this.planWithBarrier, + other.planWithBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1174,7 +1177,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - UnresolvedHint(name, parameters, logicalPlan) + UnresolvedHint(name, parameters, planWithBarrier) } /** @@ -1200,7 +1203,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, planWithBarrier) } /** @@ -1238,7 +1241,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), planWithBarrier) } /** @@ -1293,8 +1296,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, + planWithBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1312,8 +1315,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1389,7 +1392,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + Filter(condition.expr, planWithBarrier) } /** @@ -1566,7 +1569,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = planWithBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1712,7 +1715,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + Limit(Literal(n), planWithBarrier) } /** @@ -1741,7 +1744,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) } /** @@ -1755,7 +1758,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(planWithBarrier, other.planWithBarrier) } /** @@ -1769,7 +1772,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(planWithBarrier, other.planWithBarrier) } /** @@ -1790,7 +1793,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() } } @@ -1832,15 +1835,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = logicalPlan.output + val sortOrder = planWithBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, logicalPlan) + Sort(sortOrder, global = false, planWithBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - logicalPlan + planWithBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1924,7 +1927,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -1965,7 +1968,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2080,7 +2083,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output + val attrs = this.planWithBarrier.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2128,7 +2131,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan, isStreaming) + Deduplicate(groupCols, planWithBarrier, isStreaming) } /** @@ -2277,7 +2280,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2291,7 +2294,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2305,7 +2308,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, planWithBarrier) } /** @@ -2320,7 +2323,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, planWithBarrier)) } /** @@ -2336,7 +2339,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, planWithBarrier), implicitly[Encoder[U]]) } @@ -2367,7 +2370,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) } /** @@ -2522,7 +2525,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + Repartition(numPartitions, shuffle = true, planWithBarrier) } /** @@ -2536,7 +2539,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) } /** @@ -2552,7 +2555,8 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), planWithBarrier, + sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2573,7 +2577,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + Repartition(numPartitions, shuffle = false, planWithBarrier) } /** @@ -2662,7 +2666,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](planWithBarrier) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2761,7 +2765,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = planWithBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -2932,7 +2936,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, logicalPlan) + Sort(sortOrder, global = global, planWithBarrier) } } http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 14c4060..9fce29b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -424,7 +424,7 @@ case class DataSource( }.head } // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This + // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a785..5f65898 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -38,7 +38,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4d155d5..d02c8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext { test("collapse adjacent repartitions") { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length - assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => http://git-wip-us.apache.org/repos/asf/spark/blob/8ce0d8ff/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9c60d22..662fc80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,7 +88,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -115,7 +115,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -146,7 +146,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org