Repository: spark Updated Branches: refs/heads/branch-2.0 b3f145442 -> 68617e1ad
[SPARK-15094][SPARK-14803][SQL] Remove extra Project added in EliminateSerialization ## What changes were proposed in this pull request? We will eliminate the pair of `DeserializeToObject` and `SerializeFromObject` in `Optimizer` and add extra `Project`. However, when DeserializeToObject's outputObjectType is ObjectType and its cls can't be processed by unsafe project, it will be failed. To fix it, we can simply remove the extra `Project` and replace the output attribute of `DeserializeToObject` in another rule. ## How was this patch tested? `DatasetSuite`. Author: Liang-Chi Hsieh <sim...@tw.ibm.com> Closes #12926 from viirya/fix-eliminate-serialization-projection. (cherry picked from commit 470de743ecf3617babd86f50ab203e85aa975d69) Signed-off-by: Yin Huai <yh...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/68617e1a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/68617e1a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/68617e1a Branch: refs/heads/branch-2.0 Commit: 68617e1addc81805d6c27d37a84f5b50644c6a75 Parents: b3f1454 Author: Liang-Chi Hsieh <sim...@tw.ibm.com> Authored: Thu May 12 10:11:12 2016 -0700 Committer: Yin Huai <yh...@databricks.com> Committed: Thu May 12 10:11:26 2016 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 60 ++++++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++ 2 files changed, 62 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/68617e1a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 350b601..928ba21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -102,7 +102,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCasts, SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, - EliminateSerialization) :: + EliminateSerialization, + RemoveAliasOnlyProject) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -156,6 +157,49 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** + * Removes the Project only conducting Alias of its child node. + * It is created mainly for removing extra Project added in EliminateSerialization rule, + * but can also benefit other operators. + */ +object RemoveAliasOnlyProject extends Rule[LogicalPlan] { + // Check if projectList in the Project node has the same attribute names and ordering + // as its child node. + private def isAliasOnly( + projectList: Seq[NamedExpression], + childOutput: Seq[Attribute]): Boolean = { + if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { + return false + } else { + projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => + a.child match { + case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true + case _ => false + } + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + val aliasOnlyProject = plan.find { p => + p match { + case Project(pList, child) if isAliasOnly(pList, child.output) => true + case _ => false + } + } + + aliasOnlyProject.map { case p: Project => + val aliases = p.projectList.map(_.asInstanceOf[Alias]) + val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child))) + plan.transformAllExpressions { + case a: Attribute if attrMap.contains(a) => attrMap(a) + }.transform { + case op: Project if op.eq(p) => op.child + } + }.getOrElse(plan) + } +} + +/** * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) * representation of data item. For example back to back map operations. */ @@ -163,15 +207,11 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // A workaround for SPARK-14803. Remove this after it is fixed. - if (d.outputObjectType.isInstanceOf[ObjectType] && - d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { - s.child - } else { - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - } + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttr = + Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) http://git-wip-us.apache.org/repos/asf/spark/blob/68617e1a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0784041..3b9feae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -661,6 +661,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } + test("dataset.rdd with generic case class") { + val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS + val ds2 = ds.map(g => Generic(g.id, g.value)) + assert(ds.rdd.map(r => r.id).count === 2) + assert(ds2.rdd.map(r => r.id).count === 2) + + val ds3 = ds.map(g => new java.lang.Long(g.id)) + assert(ds3.rdd.map(r => r).count === 2) + } + test("runtime null check for RowEncoder") { val schema = new StructType().add("i", IntegerType, nullable = false) val df = sqlContext.range(10).map(l => { @@ -694,6 +704,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } +case class Generic[T](id: T, value: Double) + case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org