Repository: spark Updated Branches: refs/heads/master e4bd50412 -> f77f11c67
[SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator ## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenc...@databricks.com> Closes #12131 from cloud-fan/separate. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f77f11c6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f77f11c6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f77f11c6 Branch: refs/heads/master Commit: f77f11c67125fdac2e6849a4d45d9286fc872ed9 Parents: e4bd504 Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Apr 5 10:53:54 2016 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Tue Apr 5 10:53:54 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 183 ++++++++++--------- .../sql/catalyst/analysis/unresolved.scala | 22 +++ .../catalyst/encoders/ExpressionEncoder.scala | 8 +- .../sql/catalyst/expressions/objects.scala | 14 +- .../sql/catalyst/plans/logical/object.scala | 52 ++---- 5 files changed, 153 insertions(+), 126 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f77f11c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a6e317e..3e0a6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier - import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer @@ -87,9 +85,11 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: - ResolveUpCast :: ResolveOrdinalInOrderByAndGroupBy :: ResolveSortReferences :: ResolveGenerate :: @@ -499,18 +499,9 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } - // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator - // should be resolved by their corresponding attributes instead of children's output. - case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => - val deserializerToAttributes = o.deserializers.map { - case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes - }.toMap - - o.transformExpressions { - case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => - resolveDeserializer(expr, attributes) - }.getOrElse(expr) - } + // Skips plan which contains deserializer expressions, as they should be resolved by another + // rule: ResolveDeserializer. + case plan if containsDeserializer(plan.expressions) => plan case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") @@ -526,38 +517,6 @@ class Analyzer( } } - private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { - exprs.exists { expr => - !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined - } - } - - def resolveDeserializer( - deserializer: Expression, - attributes: Seq[Attribute]): Expression = { - val unbound = deserializer transform { - case b: BoundReference => attributes(b.ordinal) - } - - resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance - // If this is an inner class of another class, register the outer object in `OuterScopes`. - // Note that static inner classes (e.g., inner classes within Scala objects) don't need - // outer pointer registration. - if n.outerPointer.isEmpty && - n.cls.isMemberClass && - !Modifier.isStatic(n.cls.getModifiers) => - val outer = OuterScopes.getOuterScope(n.cls) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") - } - n.copy(outerPointer = Some(outer)) - } - } - def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) @@ -623,6 +582,10 @@ class Analyzer( } } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -1475,7 +1438,94 @@ class Analyzer( Project(projectList, Join(left, right, joinType, newCondition)) } + /** + * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved + * to the given input attributes. + */ + object ResolveDeserializer extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + case p => p transformExpressions { + case UnresolvedDeserializer(deserializer, inputAttributes) => + val inputs = if (inputAttributes.isEmpty) { + p.children.flatMap(_.output) + } else { + inputAttributes + } + val unbound = deserializer transform { + case b: BoundReference => inputs(b.ordinal) + } + resolveExpression(unbound, LocalRelation(inputs), throws = true) + } + } + } + + /** + * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being + * constructed is an inner class. + */ + object ResolveNewInstance extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case n: NewInstance if n.childrenResolved && !n.resolved => + val outer = OuterScopes.getOuterScope(n.cls) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(outer)) + } + } + } + + /** + * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. + */ + object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType.asNullable) + } + } + } + } } /** @@ -1560,45 +1610,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } /** - * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. - */ -object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") - } - - private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) - toPrecedence > 0 && fromPrecedence > toPrecedence - } - - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case u @ UpCast(child, _, _) if !child.resolved => u - - case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => - fail(child, to, walkedTypePath) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => - fail(child, to, walkedTypePath) - case (from, to) if illegalNumericPrecedence(from, to) => - fail(child, to, walkedTypePath) - case (TimestampType, DateType) => - fail(child, DateType, walkedTypePath) - case (StringType, to: NumericType) => - fail(child, to, walkedTypePath) - case _ => Cast(child, dataType.asNullable) - } - } - } -} - -/** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and * filter out the rows where the time column is not inside the time window. http://git-wip-us.apache.org/repos/asf/spark/blob/f77f11c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e73d367..fbbf630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) override lazy val resolved = false } + +/** + * Holds the deserializer expression and the attributes that are available during the resolution + * for it. Deserializer expression is a special kind of expression that is not always resolved by + * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be + * resolved by `groupingAttributes` instead of children output. + * + * @param deserializer The unresolved deserializer expression + * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty + * if we want to resolve deserializer by children output. + */ +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) + extends UnaryExpression with Unevaluable with NonSQLExpression { + // The input attributes used to resolve deserializer expression must be all resolved. + require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") + + override def child: Expression = deserializer + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} http://git-wip-us.apache.org/repos/asf/spark/blob/f77f11c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1c712fd..56d29cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts @@ -317,11 +317,11 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) + val plan = Project( + Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, + LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) http://git-wip-us.apache.org/repos/asf/spark/blob/f77f11c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07b67a0..eebd43d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.reflect.Modifier + import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag @@ -112,7 +114,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = arguments.+:(targetObject) + override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -214,6 +216,16 @@ case class NewInstance( override def children: Seq[Expression] = arguments + override lazy val resolved: Boolean = { + // If the class to construct is an inner class, we need to get its outer pointer, or this + // expression should be regarded as unresolved. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + val needOuterPointer = + outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers) + childrenResolved && !needOuterPointer + } + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") http://git-wip-us.apache.org/repos/asf/spark/blob/f77f11c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 058fb6b..58313c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{ObjectType, StructType} @@ -33,13 +34,6 @@ trait ObjectOperator extends LogicalPlan { override def output: Seq[Attribute] = serializer.map(_.toAttribute) /** - * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. - * It must also provide the attributes that are available during the resolution of each - * deserializer. - */ - def deserializers: Seq[(Expression, Seq[Attribute])] - - /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. */ @@ -71,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -87,9 +81,7 @@ case class MapPartitions( func: Iterator[Any] => Iterator[Any], deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -98,7 +90,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -120,8 +112,6 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) - - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -133,8 +123,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].deserializer, - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -158,11 +148,7 @@ case class MapGroups( serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { @@ -170,22 +156,24 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], - leftData: Seq[Attribute], - rightData: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].deserializer, - encoderFor[Left].deserializer, - encoderFor[Right].deserializer, + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to + // resolve the `keyDeserializer` based on either of them, here we pick the left one. + UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), encoderFor[Result].namedExpressions, leftGroup, rightGroup, - leftData, - rightData, + leftAttr, + rightAttr, left, right) } @@ -206,10 +194,4 @@ case class CoGroup( leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve - // the `keyDeserializer` based on either of them, here we pick the left one. - Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) -} + right: LogicalPlan) extends BinaryNode with ObjectOperator --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org