Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19271#discussion_r139834356 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExecHelper.scala --- @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for [[StreamingSymmetricHashJoinExec]]. + */ +object StreamingSymmetricHashJoinExecHelper extends PredicateHelper with Logging { + + sealed trait JoinSide + case object LeftSide extends JoinSide { override def toString(): String = "left" } + case object RightSide extends JoinSide { override def toString(): String = "right" } + + sealed trait JoinStateWatermarkPredicate + case class JoinStateKeyWatermarkPredicate(expr: Expression) extends JoinStateWatermarkPredicate + case class JoinStateValueWatermarkPredicate(expr: Expression) extends JoinStateWatermarkPredicate + + case class JoinStateWatermarkPredicates( + left: Option[JoinStateWatermarkPredicate] = None, + right: Option[JoinStateWatermarkPredicate] = None) + + def getStateWatermarkPredicates( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = { + val joinKeyOrdinalForWatermark: Option[Int] = { + leftKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } orElse { + rightKeys.zipWithIndex.collectFirst { + case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index + } + } + } + + def getOneSideStateWatermarkPredicate( + oneSideInputAttributes: Seq[Attribute], + oneSideJoinKeys: Seq[Expression], + otherSideInputAttributes: Seq[Attribute]): Option[JoinStateWatermarkPredicate] = { + val isWatermarkDefinedOnInput = oneSideInputAttributes.exists(_.metadata.contains(delayKey)) + val isWatermarkDefinedOnJoinKey = joinKeyOrdinalForWatermark.isDefined + + if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 explained in the class docs + val keyExprWithWatermark = BoundReference( + joinKeyOrdinalForWatermark.get, + oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType, + oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable) + val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermark) + expr.map(JoinStateKeyWatermarkPredicate) + + } else if (isWatermarkDefinedOnInput) { // case 2 explained in the class docs + val stateValueWatermark = getStateValueWatermark( + attributesToFindStateWatemarkFor = oneSideInputAttributes, + attributesWithEventWatermark = otherSideInputAttributes, + condition, + eventTimeWatermark) + val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) + val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) + expr.map(JoinStateValueWatermarkPredicate) + + } else { + None + + } + } + + val leftStateWatermarkPredicate = + getOneSideStateWatermarkPredicate(leftAttributes, leftKeys, rightAttributes) + val rightStateWatermarkPredicate = + getOneSideStateWatermarkPredicate(rightAttributes, rightKeys, leftAttributes) + JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * + * @param attributesToFindStateWatemarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds + */ + def getStateValueWatermark( + attributesToFindStateWatemarkFor: Seq[Attribute], + attributesWithEventWatermark: Seq[Attribute], + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatemarkFromLessThenPredicate( + l, r, attributesToFindStateWatemarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) --- End diff -- the generated the state watermark cleanup expression is inclusive of the state watermark. That is if state watermark is W, all state where timestamp <= W will be cleaned up. Now when the canonicalized join condition solves to leftTime >= W, then I dont want to clean up <= W. Rather I choose to cleanup <= W-1.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org