Github user brkyvz commented on a diff in the pull request: https://github.com/apache/spark/pull/19271#discussion_r139850114 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala --- @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.Predicate +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.util.NextIterator + +/** + * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. + * The interface of this class is basically that of a multi-map: + * - Get: Returns an iterator of multiple values for given key + * - Append: Append a new value to the given key + * - Remove Data by predicate: Drop any state using a predicate condition on keys or values + * + * @param joinSide Defines the join side + * @param inputValueAttributes Attributes of the input row which will be stored as value + * @param joinKeys Expressions to find the join key that will be used to key the value rows + * @param stateInfo Information about how to retrieve the correct version of state + * @param storeConf Configuration for the state store. + * @param hadoopConf Hadoop configuration for reading state data from storage + * + * Internally, the key -> multiple values is stored in two [[StateStore]]s. + * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values + * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between (key, index) -> value + * - Put: update count in KeyToNumValuesStore, + * insert new (key, count) -> value in KeyWithIndexToValueStore + * - Get: read count from KeyToNumValuesStore, + * read each of the n values in KeyWithIndexToValueStore + * - Remove state by predicate on keys: + * scan all keys in KeyToNumValuesStore to find keys that do match the predicate, + * delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore + * - Remove state by condition on values: + * scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match + * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore + * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), + * decrement corresponding num values in KeyToNumValuesStore + */ +class SymmetricHashJoinStateManager( + val joinSide: JoinSide, + val inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) extends Logging { + + import SymmetricHashJoinStateManager._ + + // Clean up any state store resources if necessary at the end of the task + Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + + /* + ===================================================== + Public methods + ===================================================== + */ + + /** Get all the values of a key */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues) + } + + /** Append a new value to the key */ + def append(key: UnsafeRow, value: UnsafeRow): Unit = { + val numExistingValues = keyToNumValues.get(key) + keyWithIndexToValue.put(key, numExistingValues, value) + keyToNumValues.put(key, numExistingValues + 1) + } + + /** + * Remove using a predicate on keys. See class docs for more context and implement details. + */ + def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { + val allKeyToNumValues = keyToNumValues.iterator + + while (allKeyToNumValues.hasNext) { + val keyToNumValue = allKeyToNumValues.next + if (condition(keyToNumValue.key)) { + keyToNumValues.remove(keyToNumValue.key) + keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + } + } + } + + /** + * Remove using a predicate on values. See class docs for more context and implementation details. + */ + def removeByPredicateOnValues(condition: UnsafeRow => Boolean): Unit = { + val allKeyToNumValues = keyToNumValues.iterator + + var numValues: Long = 0L + var index: Long = 0L + var valueRemoved = false + var valueForIndex: UnsafeRow = null + + + while (allKeyToNumValues.hasNext) { + val keyToNumValue = allKeyToNumValues.next + val key = keyToNumValue.key + + numValues = keyToNumValue.numValue + index = 0L + valueRemoved = false + valueForIndex = null + + while (index < numValues) { + if (valueForIndex == null) { + valueForIndex = keyWithIndexToValue.get(key, index) + } + if (condition(valueForIndex)) { + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) + keyWithIndexToValue.put(key, index, valueAtMaxIndex) + keyWithIndexToValue.remove(key, numValues - 1) + valueForIndex = valueAtMaxIndex + } else { + keyWithIndexToValue.remove(key, 0) + valueForIndex = null + } + numValues -= 1 + valueRemoved = true + } else { + valueForIndex = null + index += 1 + } + } + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(key, numValues) + } else { + keyToNumValues.remove(key) + } + } + } + } + + def iterator(): Iterator[UnsafeRowPair] = { + val pair = new UnsafeRowPair() + keyWithIndexToValue.iterator.map { x => + pair.withRows(x.key, x.value) + } + } + + /** Commit all the changes to all the state stores */ + def commit(): Unit = { + keyToNumValues.commit() + keyWithIndexToValue.commit() + } + + /** Abort any changes to the state stores if needed */ + def abortIfNeeded(): Unit = { + keyWithIndexToValue.abortIfNeeded() + keyWithIndexToValue.abortIfNeeded() --- End diff -- you're aborting twice on the same state store. `keyToNumValues`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org