Repository: spark Updated Branches: refs/heads/branch-1.3 33ccad20e -> 2d7786ed1
[SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution Author: Michael Armbrust <mich...@databricks.com> Closes #4684 from marmbrus/explainAnalysis and squashes the following commits: afbaa19 [Michael Armbrust] fix python d93278c [Michael Armbrust] fix hive e5fa0a4 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis 52119f2 [Michael Armbrust] more tests 82a5431 [Michael Armbrust] fix tests 25753d2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis aee1e6a [Michael Armbrust] fix hive b23a844 [Michael Armbrust] newline de8dc51 [Michael Armbrust] more comments acf620a [Michael Armbrust] [SPARK-5873][SQL] Show partially analyzed plans in query execution (cherry picked from commit 1ed57086d402c38d95cda6c3d9d7aea806609bf9) Signed-off-by: Michael Armbrust <mich...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d7786ed Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d7786ed Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d7786ed Branch: refs/heads/branch-1.3 Commit: 2d7786ed1e008b33e8b171a8f2ea30e19426ba1f Parents: 33ccad2 Author: Michael Armbrust <mich...@databricks.com> Authored: Mon Feb 23 17:34:54 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Feb 23 17:35:04 2015 -0800 ---------------------------------------------------------------------- python/pyspark/sql/context.py | 30 +++--- .../apache/spark/sql/catalyst/SqlParser.scala | 2 + .../spark/sql/catalyst/analysis/Analyzer.scala | 83 --------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 105 +++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 35 ++++--- .../scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 5 +- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++- .../org/apache/spark/sql/sources/rules.scala | 10 +- .../spark/sql/sources/DataSourceTest.scala | 1 - .../apache/spark/sql/sources/InsertSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveContext.scala | 1 - 12 files changed, 164 insertions(+), 126 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 313f15e..125933c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -267,20 +267,20 @@ class SQLContext(object): ... StructField("byte2", ByteType(), False), ... StructField("short1", ShortType(), False), ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", + ... StructField("int1", IntegerType(), False), + ... StructField("float1", FloatType(), False), + ... StructField("date1", DateType(), False), + ... StructField("time1", TimestampType(), False), + ... StructField("map1", ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", + ... StructField("struct1", ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) + ... StructField("list1", ArrayType(ByteType(), False), False), + ... StructField("null1", DoubleType(), True)]) >>> df = sqlCtx.applySchema(rdd, schema) >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) @@ -288,20 +288,20 @@ class SQLContext(object): >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] + ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + ... "float1 + 1.5 as float1 FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)] >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" + >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) >>> df = sqlCtx.applySchema(rdd, typedSchema) >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])] """ if isinstance(rdd, DataFrame): http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 124f083..b16aff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") + protected val INT = Keyword("INT") protected val INSERT = Keyword("INSERT") protected val INTERSECT = Keyword("INTERSECT") protected val INTO = Keyword("INTO") @@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser { | fixedDecimalType | DECIMAL ^^^ DecimalType.Unlimited | DATE ^^^ DateType + | INT ^^^ IntegerType ) protected lazy val fixedDecimalType: Parser[DataType] = http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc37b8c..e4e5425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,12 +52,6 @@ class Analyzer(catalog: Catalog, */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil - /** - * Override to provide additional rules for the "Check Analysis" batch. - * These rules will be evaluated after our built-in check rules. - */ - val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil - lazy val batches: Seq[Batch] = Seq( Batch("Resolution", fixedPoint, ResolveRelations :: @@ -71,88 +65,11 @@ class Analyzer(catalog: Catalog, TrimGroupingAliases :: typeCoercionRules ++ extendedResolutionRules : _*), - Batch("Check Analysis", Once, - CheckResolution +: - extendedCheckRules: _*), Batch("Remove SubQueries", fixedPoint, EliminateSubQueries) ) /** - * Makes sure all attributes and logical plans have been resolved. - */ - object CheckResolution extends Rule[LogicalPlan] { - def failAnalysis(msg: String) = { throw new AnalysisException(msg) } - - def apply(plan: LogicalPlan): LogicalPlan = { - // We transform up and order the rules so as to catch the first possible failure instead - // of the result of cascading resolution failures. - plan.foreachUp { - case operator: LogicalPlan => - operator transformExpressionsUp { - case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - - case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.simpleString} and ${b.right.simpleString}") - } - - operator match { - case f: Filter if f.condition.dataType != BooleanType => - failAnalysis( - s"filter expression '${f.condition.prettyString}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") - - case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => - def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => - failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + - s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK - case e if e.references.isEmpty => // OK - case e => e.children.foreach(checkValidAggregateExpression) - } - - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) - - case o if o.children.nonEmpty && - !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => - val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") - val input = o.inputSet.map(_.prettyString).mkString(",") - - failAnalysis(s"resolved attributes $missingAttributes missing from $input") - - // Catch all - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - - case _ => // Analysis successful! - } - } - - plan - } - } - - /** * Removes no-op Alias expressions from the plan. */ object TrimGroupingAliases extends Rule[LogicalPlan] { http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala new file mode 100644 index 0000000..4e8fc89 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +/** + * Throws user facing errors when passed invalid queries that fail to analyze. + */ +class CheckAnalysis { + + /** + * Override to provide additional checks for correct analysis. + * These rules will be evaluated after our built-in check rules. + */ + val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + + def failAnalysis(msg: String) = { + throw new AnalysisException(msg) + } + + def apply(plan: LogicalPlan): Unit = { + // We transform up and order the rules so as to catch the first possible failure instead + // of the result of cascading resolution failures. + plan.foreachUp { + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case b: BinaryExpression if !b.resolved => + failAnalysis( + s"invalid expression ${b.prettyString} " + + s"between ${b.left.simpleString} and ${b.right.simpleString}") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.contains(e) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.contains(e) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + val cleaned = aggregateExprs.map(_.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g, _) => g + }) + + cleaned.foreach(checkValidAggregateExpression) + + case o if o.children.nonEmpty && + !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => + val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") + val input = o.inputSet.map(_.prettyString).mkString(",") + + failAnalysis(s"resolved attributes $missingAttributes missing from $input") + + // Catch all + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case _ => // Analysis successful! + } + } + extendedCheckRules.foreach(_(plan)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index aec7847..c1dd5aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) - val caseSensitiveAnalyze = + + val caseSensitiveAnalyzer = new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) - val caseInsensitiveAnalyze = + val caseInsensitiveAnalyzer = new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + val checkAnalysis = new CheckAnalysis + + + def caseSensitiveAnalyze(plan: LogicalPlan) = + checkAnalysis(caseSensitiveAnalyzer(plan)) + + def caseInsensitiveAnalyze(plan: LogicalPlan) = + checkAnalysis(caseInsensitiveAnalyzer(plan)) + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( AttributeReference("a", StringType)(), @@ -55,7 +65,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyze(plan).resolved) + assert(caseInsensitiveAnalyzer(plan).resolved) } test("check project's resolved") { @@ -71,11 +81,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("analyze project") { assert( - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyze( + caseSensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -88,13 +98,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -107,16 +117,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } def errorTest( @@ -177,7 +184,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyze( + val plan = caseInsensitiveAnalyzer( testRelation2.select( 'a / Literal(2) as 'div1, 'a / 'b as 'div2, http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 69e5f6a..27ac398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -117,7 +117,7 @@ class DataFrame protected[sql]( this(sqlContext, { val qe = sqlContext.executePlan(logicalPlan) if (sqlContext.conf.dataFrameEagerAnalysis) { - qe.analyzed // This should force analysis and throw errors if there are any + qe.assertAnalyzed() // This should force analysis and throw errors if there are any } qe }) http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 39f6c2f..a08c0f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -52,8 +52,9 @@ private[spark] object SQLConf { // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" - // Whether to perform eager analysis on a DataFrame. - val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis" + // Whether to perform eager analysis when constructing a dataframe. + // Set to false when debugging requires the ability to look at invalid query plans. + val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bdaa02..ce800e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -114,7 +114,6 @@ class SQLContext(@transient val sparkContext: SparkContext) new Analyzer(catalog, functionRegistry, caseSensitive = true) { override val extendedResolutionRules = ExtractPythonUdfs :: - sources.PreWriteCheck(catalog) :: sources.PreInsertCastAndRename :: Nil } @@ -1057,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext) Batch("Add exchange", Once, AddExchange(self)) :: Nil } + @transient + protected[sql] lazy val checkAnalysis = new CheckAnalysis { + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } + /** * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -1064,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { + def assertAnalyzed(): Unit = checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = analyzer(logical) - lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed + cacheManager.useCachedData(analyzed) + } lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 36a9c0b..8440581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -78,10 +78,10 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan] { +private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { def failAnalysis(msg: String) = { throw new AnalysisException(msg) } - def apply(plan: LogicalPlan): LogicalPlan = { + def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => @@ -93,7 +93,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src } - if (srcRelations.exists(src => src == t)) { + if (srcRelations.contains(t)) { failAnalysis( "Cannot insert overwrite into table that is also being read from.") } else { @@ -119,7 +119,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src } - if (srcRelations.exists(src => src == dest)) { + if (srcRelations.contains(dest)) { failAnalysis( s"Cannot overwrite table $tableName that is also being read from.") } else { @@ -134,7 +134,5 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan case _ => // OK } - - plan } } http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 0ec6881..91c6367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -30,7 +30,6 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { override protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { override val extendedResolutionRules = - PreWriteCheck(catalog) :: PreInsertCastAndRename :: Nil } http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5682e5a..b5b16f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -205,7 +205,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { val message = intercept[AnalysisException] { sql( s""" - |INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt + |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt """.stripMargin) }.getMessage assert( http://git-wip-us.apache.org/repos/asf/spark/blob/2d7786ed/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e205e6..c439dfe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -268,7 +268,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.PreInsertionCasts :: ExtractPythonUdfs :: ResolveUdtfsAlias :: - sources.PreWriteCheck(catalog) :: sources.PreInsertCastAndRename :: Nil } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org