Repository: spark
Updated Branches:
  refs/heads/master 4cdd8ecd6 -> 687c8c371


[SPARK-9372] [SQL] Filter nulls in join keys

This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` 
before join operators to filter out rows having null values for join keys.

This optimization is guarded by a new SQL conf, 
`spark.sql.advancedOptimization`.

The code in this PR was authored by yhuai; I'm opening this PR to factor out 
this change from #7685, a larger pull request which contains two other 
optimizations.

Author: Yin Huai <yh...@databricks.com>
Author: Josh Rosen <joshro...@databricks.com>

Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following 
commits:

c02fc3f [Yin Huai] Address Josh's comments.
0a8e096 [Yin Huai] Update comments.
ea7d5a6 [Yin Huai] Make sure we do not keep adding filters.
be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used 
to test FilterNullsInJoinKey.
8bb39ad [Yin Huai] Fix non-deterministic tests.
303236b [Josh Rosen] Revert changes that are unrelated to null join key 
filtering
40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
filter-nulls-in-join-key
c57a954 [Yin Huai] Bug fix.
d3d2e64 [Yin Huai] First round of cleanup.
f9516b0 [Yin Huai] Style
c6667e7 [Yin Huai] Add PartitioningCollection.
e616d3b [Yin Huai] wip
7c2d2d8 [Yin Huai] Bug fix and refactoring.
69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and 
NullUnsafePartitioning.
d5b84c3 [Yin Huai] Do not add unnessary filters.
2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/687c8c37
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/687c8c37
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/687c8c37

Branch: refs/heads/master
Commit: 687c8c37150f4c93f8e57d86bb56321a4891286b
Parents: 4cdd8ec
Author: Yin Huai <yh...@databricks.com>
Authored: Sun Aug 2 23:32:09 2015 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Sun Aug 2 23:32:09 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/nullFunctions.scala    |  48 +++-
 .../sql/catalyst/optimizer/Optimizer.scala      |  64 +++--
 .../catalyst/plans/logical/basicOperators.scala |  32 ++-
 .../expressions/ExpressionEvalHelper.scala      |   4 +-
 .../expressions/MathFunctionsSuite.scala        |   3 +-
 .../expressions/NullFunctionsSuite.scala        |  49 +++-
 .../apache/spark/sql/DataFrameNaFunctions.scala |   2 +-
 .../scala/org/apache/spark/sql/SQLConf.scala    |   6 +
 .../scala/org/apache/spark/sql/SQLContext.scala |   5 +-
 .../extendedOperatorOptimizations.scala         | 160 +++++++++++++
 .../optimizer/FilterNullsInJoinKeySuite.scala   | 236 +++++++++++++++++++
 11 files changed, 572 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 287718f..d58c475 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends 
UnaryExpression with Predicate {
   }
 }
 
+/**
+ * A predicate that is evaluated to be true if there are at least `n` null 
values.
+ */
+case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate {
+  override def nullable: Boolean = false
+  override def foldable: Boolean = children.forall(_.foldable)
+  override def toString: String = s"AtLeastNNulls($n, 
${children.mkString(",")})"
+
+  private[this] val childrenArray = children.toArray
+
+  override def eval(input: InternalRow): Boolean = {
+    var numNulls = 0
+    var i = 0
+    while (i < childrenArray.length && numNulls < n) {
+      val evalC = childrenArray(i).eval(input)
+      if (evalC == null) {
+        numNulls += 1
+      }
+      i += 1
+    }
+    numNulls >= n
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val numNulls = ctx.freshName("numNulls")
+    val code = children.map { e =>
+      val eval = e.gen(ctx)
+      s"""
+        if ($numNulls < $n) {
+          ${eval.code}
+          if (${eval.isNull}) {
+            $numNulls += 1;
+          }
+        }
+      """
+    }.mkString("\n")
+    s"""
+      int $numNulls = 0;
+      $code
+      boolean ${ev.isNull} = false;
+      boolean ${ev.primitive} = $numNulls >= $n;
+     """
+  }
+}
 
 /**
  * A predicate that is evaluated to be true if there are at least `n` non-null 
and non-NaN values.
  */
-case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends 
Predicate {
+case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends 
Predicate {
   override def nullable: Boolean = false
   override def foldable: Boolean = children.forall(_.foldable)
-  override def toString: String = s"AtLeastNNulls(n, 
${children.mkString(",")})"
+  override def toString: String = s"AtLeastNNonNullNans($n, 
${children.mkString(",")})"
 
   private[this] val childrenArray = children.toArray
 

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706d..e4b6294 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,8 +31,14 @@ import org.apache.spark.sql.types._
 
 abstract class Optimizer extends RuleExecutor[LogicalPlan]
 
-object DefaultOptimizer extends Optimizer {
-  val batches =
+class DefaultOptimizer extends Optimizer {
+
+  /**
+   * Override to provide additional rules for the "Operator Optimizations" 
batch.
+   */
+  val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+
+  lazy val batches =
     // SubQueries are only needed for analysis and can be removed before 
execution.
     Batch("Remove SubQueries", FixedPoint(100),
       EliminateSubQueries) ::
@@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer {
       RemoveLiteralFromGroupExpressions) ::
     Batch("Operator Optimizations", FixedPoint(100),
       // Operator push down
-      SetOperationPushDown,
-      SamplePushDown,
-      PushPredicateThroughJoin,
-      PushPredicateThroughProject,
-      PushPredicateThroughGenerate,
-      ColumnPruning,
+      SetOperationPushDown ::
+      SamplePushDown ::
+      PushPredicateThroughJoin ::
+      PushPredicateThroughProject ::
+      PushPredicateThroughGenerate ::
+      ColumnPruning ::
       // Operator combine
-      ProjectCollapsing,
-      CombineFilters,
-      CombineLimits,
+      ProjectCollapsing ::
+      CombineFilters ::
+      CombineLimits ::
       // Constant folding
-      NullPropagation,
-      OptimizeIn,
-      ConstantFolding,
-      LikeSimplification,
-      BooleanSimplification,
-      RemovePositive,
-      SimplifyFilters,
-      SimplifyCasts,
-      SimplifyCaseConversionExpressions) ::
+      NullPropagation ::
+      OptimizeIn ::
+      ConstantFolding ::
+      LikeSimplification ::
+      BooleanSimplification ::
+      RemovePositive ::
+      SimplifyFilters ::
+      SimplifyCasts ::
+      SimplifyCaseConversionExpressions ::
+      extendedOperatorOptimizationRules.toList : _*) ::
     Batch("Decimal Optimizations", FixedPoint(100),
       DecimalAggregates) ::
     Batch("LocalRelation", FixedPoint(100),
@@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
   }
 
   /** Applies a projection only when the child is producing unnecessary 
attributes */
-  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = {
     if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
-      Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+      // We need to preserve the nullability of c's output.
+      // So, we first create a outputMap and if a reference is from the output 
of
+      // c, we use that output attribute from c.
+      val outputMap = AttributeMap(c.output.map(attr => (attr, attr)))
+      val projectList = 
allReferences.filter(outputMap.contains).map(outputMap).toSeq
+      Project(projectList, c)
     } else {
       c
     }
+  }
 }
 
 /**
@@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] 
with PredicateHelper {
  */
 object CombineFilters extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), 
grandChild)) =>
+      // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and
+      // Not(AtLeastNNulls(1, e2))
+      // (this is used to make sure there is no null in the result of e1 and 
e2 and
+      // they are added by FilterNullsInJoinKey optimziation rule), we can
+      // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)).
+      Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild)
     case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), 
grandChild)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index aacfc86..54b5f49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -86,7 +86,37 @@ case class Generate(
 }
 
 case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode 
{
-  override def output: Seq[Attribute] = child.output
+  /**
+   * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children
+   * have at least one null value and atLeastNNulls.children are all 
attributes.
+   */
+  private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): 
Boolean = {
+    val expressions = atLeastNNulls.children
+    val n = atLeastNNulls.n
+    if (n != 1) {
+      // AtLeastNNulls is not used to check if atLeastNNulls.children have
+      // at least one null value.
+      false
+    } else {
+      // AtLeastNNulls is used to check if atLeastNNulls.children have
+      // at least one null value. We need to make sure all 
atLeastNNulls.children
+      // are attributes.
+      expressions.forall(_.isInstanceOf[Attribute])
+    }
+  }
+
+  override def output: Seq[Attribute] = condition match {
+    case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) =>
+      // The condition is used to make sure that there is no null value in
+      // a.children.
+      val nonNullableAttributes = 
AttributeSet(a.children.asInstanceOf[Seq[Attribute]])
+      child.output.map {
+        case attr if nonNullableAttributes.contains(attr) =>
+          attr.withNullability(false)
+        case attr => attr
+      }
+    case _ => child.output
+  }
 }
 
 case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a41185b..3e55151 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -31,6 +31,8 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
 trait ExpressionEvalHelper {
   self: SparkFunSuite =>
 
+  protected val defaultOptimizer = new DefaultOptimizer
+
   protected def create_row(values: Any*): InternalRow = {
     InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
   }
@@ -186,7 +188,7 @@ trait ExpressionEvalHelper {
       expected: Any,
       inputRow: InternalRow = EmptyRow): Unit = {
     val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, 
OneRowRelation)
-    val optimizedPlan = DefaultOptimizer.execute(plan)
+    val optimizedPlan = defaultOptimizer.execute(plan)
     checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, 
inputRow)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 9fcb548..649a5b4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, 
GenerateMutableProjection}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
 import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
 import org.apache.spark.sql.types._
 
@@ -149,7 +148,7 @@ class MathFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     expression: Expression,
     inputRow: InternalRow = EmptyRow): Unit = {
     val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, 
OneRowRelation)
-    val optimizedPlan = DefaultOptimizer.execute(plan)
+    val optimizedPlan = defaultOptimizer.execute(plan)
     checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index ace6c15..bf19712 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
-  test("AtLeastNNonNulls") {
+  test("AtLeastNNonNullNans") {
     val mix = Seq(Literal("x"),
       Literal.create(null, StringType),
       Literal.create(null, DoubleType),
@@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       Literal(Float.MaxValue),
       Literal(false))
 
-    checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
-    checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
-    checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
-    checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
-    checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
-    checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow)
+  }
+
+  test("AtLeastNNull") {
+    val mix = Seq(Literal("x"),
+      Literal.create(null, StringType),
+      Literal.create(null, DoubleType),
+      Literal(Double.NaN),
+      Literal(5f))
+
+    val nanOnly = Seq(Literal("x"),
+      Literal(10.0),
+      Literal(Float.NaN),
+      Literal(math.log(-2)),
+      Literal(Double.MaxValue))
+
+    val nullOnly = Seq(Literal("x"),
+      Literal.create(null, DoubleType),
+      Literal.create(null, DecimalType.USER_DEFAULT),
+      Literal(Float.MaxValue),
+      Literal(false))
+
+    checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow)
+    checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow)
+    checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow)
+    checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow)
+    checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index a4fd4cf..ea85f06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
   def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
     // Filtering condition:
     // only keep the row if it has at least `minNonNulls` non-null and non-NaN 
values.
-    val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => 
df.resolve(name)))
+    val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => 
df.resolve(name)))
     df.filter(Column(predicate))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 6644e85..387960c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -413,6 +413,10 @@ private[spark] object SQLConf {
     "spark.sql.useSerializer2",
     defaultValue = Some(true), isPublic = false)
 
+  val ADVANCED_SQL_OPTIMIZATION = booleanConf(
+    "spark.sql.advancedOptimization",
+    defaultValue = Some(true), isPublic = false)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -484,6 +488,8 @@ private[sql] class SQLConf extends Serializable with 
CatalystConf {
 
   private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
 
+  private[spark] def advancedSqlOptimizations: Boolean = 
getConf(ADVANCED_SQL_OPTIMIZATION)
+
   private[spark] def autoBroadcastJoinThreshold: Int = 
getConf(AUTO_BROADCASTJOIN_THRESHOLD)
 
   private[spark] def defaultSizeInBytes: Long =

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09..31e2b50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.optimizer.FilterNullsInJoinKey
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
     }
 
   @transient
-  protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
+  protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer {
+    override val extendedOperatorOptimizationRules = 
FilterNullsInJoinKey(self) :: Nil
+  }
 
   @transient
   protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
new file mode 100644
index 0000000..5a4dde5
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, 
LeftSemi}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * An optimization rule used to insert Filters to filter out rows whose equal 
join keys
+ * have at least one null values. For this kind of rows, they will not 
contribute to
+ * the join results of equal joins because a null does not equal another null. 
We can
+ * filter them out before shuffling join input rows. For example, we have two 
tables
+ *
+ * table1(key String, value Int)
+ * "str1"|1
+ * null  |2
+ *
+ * table2(key String, value Int)
+ * "str1"|3
+ * null  |4
+ *
+ * For a inner equal join, the result will be
+ * "str1"|1|"str1"|3
+ *
+ * those two rows having null as the value of key will not contribute to the 
result.
+ * So, we can filter them out early.
+ *
+ * This optimization rule can be disabled by setting 
spark.sql.advancedOptimization to false.
+ *
+ */
+case class FilterNullsInJoinKey(
+    sqlContext: SQLContext)
+  extends Rule[LogicalPlan] {
+
+  /**
+   * Checks if we need to add a Filter operator. We will add a Filter when
+   * there is any attribute in `keys` whose corresponding attribute of `keys`
+   * in `plan.output` is still nullable (`nullable` field is `true`).
+   */
+  private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = 
{
+    val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute]))
+    plan.output.filter(keyAttributeSet.contains).exists(_.nullable)
+  }
+
+  /**
+   * Adds a Filter operator to make sure that every attribute in `keys` is 
non-nullable.
+   */
+  private def addFilterIfNecessary(
+      keys: Seq[Expression],
+      child: LogicalPlan): LogicalPlan = {
+    // We get all attributes from keys.
+    val attributes = keys.filter(_.isInstanceOf[Attribute])
+
+    // Then, we create a Filter to make sure these attributes are non-nullable.
+    val filter =
+      if (attributes.nonEmpty) {
+        Filter(Not(AtLeastNNulls(1, attributes)), child)
+      } else {
+        child
+      }
+
+    filter
+  }
+
+  /**
+   * We reconstruct the join condition.
+   */
+  private def reconstructJoinCondition(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      otherPredicate: Option[Expression]): Expression = {
+    // First, we rewrite the equal condition part. When we extract those keys,
+    // we use splitConjunctivePredicates. So, it is safe to use .reduce(And).
+    val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map {
+      case (l, r) => EqualTo(l, r)
+    }.reduce(And)
+
+    // Then, we add otherPredicate. When we extract those equal condition part,
+    // we use splitConjunctivePredicates. So, it is safe to use
+    // And(rewrittenEqualJoinCondition, c).
+    val rewrittenJoinCondition = otherPredicate
+      .map(c => And(rewrittenEqualJoinCondition, c))
+      .getOrElse(rewrittenEqualJoinCondition)
+
+    rewrittenJoinCondition
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    if (!sqlContext.conf.advancedSqlOptimizations) {
+      plan
+    } else {
+      plan transform {
+        case join: Join => join match {
+          // For a inner join having equal join condition part, we can add 
filters
+          // to both sides of the join operator.
+          case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, 
left, right)
+            if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+            val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+            val withRightFilter = addFilterIfNecessary(rightKeys, right)
+            val rewrittenJoinCondition =
+              reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+            Join(withLeftFilter, withRightFilter, Inner, 
Some(rewrittenJoinCondition))
+
+          // For a left outer join having equal join condition part, we can 
add a filter
+          // to the right side of the join operator.
+          case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, 
left, right)
+            if needsFilter(rightKeys, right) =>
+            val withRightFilter = addFilterIfNecessary(rightKeys, right)
+            val rewrittenJoinCondition =
+              reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+            Join(left, withRightFilter, LeftOuter, 
Some(rewrittenJoinCondition))
+
+          // For a right outer join having equal join condition part, we can 
add a filter
+          // to the left side of the join operator.
+          case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, 
left, right)
+            if needsFilter(leftKeys, left) =>
+            val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+            val rewrittenJoinCondition =
+              reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+            Join(withLeftFilter, right, RightOuter, 
Some(rewrittenJoinCondition))
+
+          // For a left semi join having equal join condition part, we can add 
filters
+          // to both sides of the join operator.
+          case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, 
left, right)
+            if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) =>
+            val withLeftFilter = addFilterIfNecessary(leftKeys, left)
+            val withRightFilter = addFilterIfNecessary(rightKeys, right)
+            val rewrittenJoinCondition =
+              reconstructJoinCondition(leftKeys, rightKeys, condition)
+
+            Join(withLeftFilter, withRightFilter, LeftSemi, 
Some(rewrittenJoinCondition))
+
+          case other => other
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/687c8c37/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
new file mode 100644
index 0000000..f98e4ac
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls}
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.test.TestSQLContext
+
+/** This is the test suite for FilterNullsInJoinKey optimization rule. */
+class FilterNullsInJoinKeySuite extends PlanTest {
+
+  // We add predicate pushdown rules at here to make sure we do not
+  // create redundant Filter operators. Also, because the attribute ordering of
+  // the Project operator added by ColumnPruning may be not deterministic
+  // (the ordering may depend on the testing environment),
+  // we first construct the plan with expected Filter operators and then
+  // run the optimizer to add the the Project for column pruning.
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Subqueries", Once,
+        EliminateSubQueries) ::
+      Batch("Operator Optimizations", FixedPoint(100),
+        FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in 
this suite.
+        CombineFilters,
+        PushPredicateThroughProject,
+        BooleanSimplification,
+        PushPredicateThroughJoin,
+        PushPredicateThroughGenerate,
+        ColumnPruning,
+        ProjectCollapsing) :: Nil
+  }
+
+  val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int)
+
+  val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int)
+
+  test("inner join") {
+    val joinCondition =
+      ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, Inner, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    // For an inner join, FilterNullsInJoinKey add filter to both side.
+    val correctLeft =
+      leftRelation
+        .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+    val correctRight =
+      rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+    val correctAnswer =
+      correctLeft
+        .join(correctRight, Inner, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+  }
+
+  test("make sure we do not keep adding filters") {
+    val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int)
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, Inner, Some('a === 'e))
+        .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j))
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+    val conditions = optimized.collect {
+      case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs
+    }
+
+    // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those 
three tables.
+    assert(conditions.length === 3)
+
+    // Make sure attribtues are indeed a, b, e, i, and j.
+    assert(
+      conditions.flatMap(exprs => exprs).toSet ===
+        joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet)
+  }
+
+  test("inner join (partially optimized)") {
+    val joinCondition =
+      ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, Inner, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    // We cannot extract attribute from the left join key.
+    val correctRight =
+      rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+    val correctAnswer =
+      leftRelation
+        .join(correctRight, Inner, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+  }
+
+  test("inner join (not optimized)") {
+    val nonOptimizedJoinConditions =
+      Some('c - 100 + 'd === 'g + 1 - 'h) ::
+        Some('d > 'h || 'c === 'g) ::
+        Some('d + 'g + 'c > 'd - 'h) :: Nil
+
+    nonOptimizedJoinConditions.foreach { joinCondition =>
+      val joinedPlan =
+        leftRelation
+          .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition)
+          .select('a, 'c, 'f, 'd, 'h, 'g)
+
+      val optimized = Optimize.execute(joinedPlan.analyze)
+
+      comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+    }
+  }
+
+  test("left outer join") {
+    val joinCondition =
+      ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, LeftOuter, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    // For a left outer join, FilterNullsInJoinKey add filter to the right 
side.
+    val correctRight =
+      rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+    val correctAnswer =
+      leftRelation
+        .join(correctRight, LeftOuter, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+  }
+
+  test("right outer join") {
+    val joinCondition =
+      ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, RightOuter, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    // For a right outer join, FilterNullsInJoinKey add filter to the left 
side.
+    val correctLeft =
+      leftRelation
+        .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+    val correctAnswer =
+      correctLeft
+        .join(rightRelation, RightOuter, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+
+    comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+  }
+
+  test("full outer join") {
+    val joinCondition =
+      ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, FullOuter, Some(joinCondition))
+        .select('a, 'f, 'd, 'h)
+
+    // FilterNullsInJoinKey does not fire for a full outer join.
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    comparePlans(optimized, Optimize.execute(joinedPlan.analyze))
+  }
+
+  test("left semi join") {
+    val joinCondition =
+      ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g)
+
+    val joinedPlan =
+      leftRelation
+        .join(rightRelation, LeftSemi, Some(joinCondition))
+        .select('a, 'd)
+
+    val optimized = Optimize.execute(joinedPlan.analyze)
+
+    // For a left semi join, FilterNullsInJoinKey add filter to both side.
+    val correctLeft =
+      leftRelation
+        .where(!(AtLeastNNulls(1, 'a.expr :: Nil)))
+
+    val correctRight =
+      rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil)))
+
+    val correctAnswer =
+      correctLeft
+        .join(correctRight, LeftSemi, Some(joinCondition))
+        .select('a, 'd)
+
+    comparePlans(optimized, Optimize.execute(correctAnswer.analyze))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to