Github user brkyvz commented on a diff in the pull request: https://github.com/apache/spark/pull/19271#discussion_r139815599 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExecHelper.scala --- @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for [[StreamingSymmetricHashJoinExec]]. + */ +object StreamingSymmetricHashJoinExecHelper extends PredicateHelper with Logging { + + sealed trait JoinSide + case object LeftSide extends JoinSide { override def toString(): String = "left" } + case object RightSide extends JoinSide { override def toString(): String = "right" } + + sealed trait JoinStateWatermarkPredicate + case class JoinStateKeyWatermarkPredicate(expr: Expression) extends JoinStateWatermarkPredicate + case class JoinStateValueWatermarkPredicate(expr: Expression) extends JoinStateWatermarkPredicate + + case class JoinStateWatermarkPredicates( + left: Option[JoinStateWatermarkPredicate] = None, + right: Option[JoinStateWatermarkPredicate] = None) + + def getStateWatermarkPredicates( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = { + val joinKeyOrdinalForWatermark: Option[Int] = { + leftKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } orElse { + rightKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } + } + } + + def getOneSideStateWatermarkPredicate( + oneSideInputAttributes: Seq[Attribute], + oneSideJoinKeys: Seq[Expression], + otherSideInputAttributes: Seq[Attribute]): Option[JoinStateWatermarkPredicate] = { + val isWatermarkDefinedOnInput = oneSideInputAttributes.exists(_.metadata.contains(delayKey)) + val isWatermarkDefinedOnJoinKey = joinKeyOrdinalForWatermark.isDefined + + if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 explained in the class docs + val keyExprWithWatermark = BoundReference( + joinKeyOrdinalForWatermark.get, + oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType, + oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable) + val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermark) + expr.map(JoinStateKeyWatermarkPredicate) + + } else if (isWatermarkDefinedOnInput) { // case 2 explained in the class docs + val stateValueWatermark = getStateValueWatermark( + attributesToFindStateWatemarkFor = oneSideInputAttributes, + attributesWithEventWatermark = otherSideInputAttributes, + condition, + eventTimeWatermark) + val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) + val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) + expr.map(JoinStateValueWatermarkPredicate) + + } else { + None + + } + } + + val leftStateWatermarkPredicate = + getOneSideStateWatermarkPredicate(leftAttributes, leftKeys, rightAttributes) + val rightStateWatermarkPredicate = + getOneSideStateWatermarkPredicate(rightAttributes, rightKeys, leftAttributes) + JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * + * @param attributesToFindStateWatemarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds + */ + def getStateValueWatermark( + attributesToFindStateWatemarkFor: Seq[Attribute], + attributesWithEventWatermark: Seq[Attribute], + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatemarkFromLessThenPredicate( + l, r, attributesToFindStateWatemarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract constraint from conditions. For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatemarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: Seq[Attribute], + attributesWithEventWatermark: Seq[Attribute], + eventWatermark: Option[Long]): Option[Long] = { + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a@AttributeReference(_, TimestampType, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a@AttributeReference(_, TimestampType, _, metadata) + if attributesWithEventWatermark.contains(a) && a.metadata.contains(delayKey) => + Literal(eventWatermark.get) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some(Double2double(constraintValue).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Csat(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in milliseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { --- End diff -- fancy stuff!
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org