Github user brkyvz commented on a diff in the pull request: https://github.com/apache/spark/pull/19271#discussion_r140051116 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala --- @@ -0,0 +1,585 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} +import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreProviderId, SymmetricHashJoinStateManager} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + import testImplicits._ + + test("SymmetricHashJoinStateManager - all operations") { + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() + val inputValueSchema = new StructType() + .add(StructField("time", IntegerType, metadata = watermarkMetadata)) + .add(StructField("value", BooleanType)) + val inputValueAttribs = inputValueSchema.toAttributes + val inputValueAttribWithWatermark = inputValueAttribs(0) + val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + + val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) + val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray) + + def toInputValue(i: Int): UnsafeRow = { + inputValueGen.apply(new GenericInternalRow(Array[Any](i, false))) + } + + def toJoinKeyRow(i: Int): UnsafeRow = { + joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + } + + def toKeyInt(joinKeyRow: UnsafeRow): Int = joinKeyRow.getInt(1) + + def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) + + withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => + def append(key: Int, value: Int): Unit = { + manager.append(toJoinKeyRow(key), toInputValue(value)) + } + + def get(key: Int): Seq[Int] = manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted + + /** Remove keys (and corresponding values) where `time <= threshold` */ + def removeByKey(threshold: Long): Unit = { + val expr = + LessThanOrEqual( + BoundReference( + 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), + Literal(threshold)) + manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + } + + /** Remove values where `time <= threshold` */ + def removeByValue(watermark: Long): Unit = { + val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) + manager.removeByValueCondition( + GeneratePredicate.generate(expr, inputValueAttribs).eval _) + } + + def numRows: Long = { + manager.metrics.numKeys + } + + assert(get(20) === Seq.empty) // initially empty + append(20, 2) + assert(get(20) === Seq(2)) // should first value correctly + assert(numRows === 1) + + append(20, 3) + assert(get(20) === Seq(2, 3)) // should append new values + append(20, 3) + assert(get(20) === Seq(2, 3, 3)) // should append another copy if same value added again + assert(numRows === 3) + + assert(get(30) === Seq.empty) + append(30, 1) + assert(get(30) === Seq(1)) + assert(get(20) === Seq(2, 3, 3)) // add another key-value should not affect existing ones + assert(numRows === 4) + + removeByKey(25) + assert(get(20) === Seq.empty) + assert(get(30) === Seq(1)) // should remove 20, not 30 + assert(numRows === 1) + + removeByKey(30) + assert(get(30) === Seq.empty) // should remove 30 + assert(numRows === 0) + + def appendAndTest(key: Int, values: Int*): Unit = { + values.foreach { value => append(key, value)} + require(get(key) === values) + } + + appendAndTest(40, 100, 200, 300) + appendAndTest(50, 125) + appendAndTest(60, 275) // prepare for testing removeByValue + assert(numRows === 5) + + removeByValue(125) + assert(get(40) === Seq(200, 300)) + assert(get(50) === Seq.empty) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 3) + + append(40, 50) + assert(get(40) === Seq(50, 200, 300)) + assert(numRows === 4) + + removeByValue(200) + assert(get(40) === Seq(300)) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 2) + + removeByValue(300) + assert(get(40) === Seq.empty) + assert(get(60) === Seq.empty) // should remove all values now + assert(numRows === 0) + } + } + + test("stream stream inner join on non-time column") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue") + val joined = df1.join(df2, "key") + + testStream(joined)( + AddData(input1, 1), + CheckAnswer(), + AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join + CheckLastBatch((1, 2, 3)), + AddData(input1, 10), // 10 arrived on input2 first, then input1, should join + CheckLastBatch((10, 20, 30)), + AddData(input2, 1), // another 1 in input2 should join with 1 input1 + CheckLastBatch((1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) + CheckLastBatch((1, 2, 3), (1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 100), + AddData(input2, 100), + CheckLastBatch((100, 200, 300)) + ) + } + + + test("stream stream inner join on windows - without watermark") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", ('value * 2) as "leftValue") + .select('key, window('timestamp, "10 second"), 'leftValue) + + val df2 = input2.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", + ('value * 3) as "rightValue") + .select('key, window('timestamp, "10 second"), 'rightValue) + + val joined = df1.join(df2, Seq("key", "window")) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(input1, 1), + CheckLastBatch(), + AddData(input2, 1), + CheckLastBatch((1, 10, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 25), + CheckLastBatch(), + StopStream, + StartStream(), + AddData(input2, 25), + CheckLastBatch((25, 30, 50, 75)), + AddData(input1, 1), + CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is not watermark --- End diff -- nit: `there is no watermark`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org