Github user brkyvz commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19271#discussion_r139850114
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 ---
    @@ -0,0 +1,405 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.state
    +
    +import scala.reflect.ClassTag
    +
    +import org.apache.hadoop.conf.Configuration
    +
    +import org.apache.spark.{Partition, SparkContext, TaskContext}
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
    +import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal, 
SpecificInternalRow, UnsafeProjection, UnsafeRow}
    +import org.apache.spark.sql.catalyst.expressions.codegen.Predicate
    +import 
org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, 
StreamingSymmetricHashJoinExec}
    +import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
    +import org.apache.spark.sql.types.{LongType, StructField, StructType}
    +import org.apache.spark.util.NextIterator
    +
    +/**
    + * Helper class to manage state required by a single side of 
[[StreamingSymmetricHashJoinExec]].
    + * The interface of this class is basically that of a multi-map:
    + * - Get: Returns an iterator of multiple values for given key
    + * - Append: Append a new value to the given key
    + * - Remove Data by predicate: Drop any state using a predicate condition 
on keys or values
    + *
    + * @param joinSide          Defines the join side
    + * @param inputValueAttributes   Attributes of the input row which will be 
stored as value
    + * @param joinKeys          Expressions to find the join key that will be 
used to key the value rows
    + * @param stateInfo           Information about how to retrieve the 
correct version of state
    + * @param storeConf           Configuration for the state store.
    + * @param hadoopConf          Hadoop configuration for reading state data 
from storage
    + *
    + * Internally, the key -> multiple values is stored in two [[StateStore]]s.
    + * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> 
number of values
    + * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between 
(key, index) -> value
    + * - Put:   update count in KeyToNumValuesStore,
    + *          insert new (key, count) -> value in KeyWithIndexToValueStore
    + * - Get:   read count from KeyToNumValuesStore,
    + *          read each of the n values in KeyWithIndexToValueStore
    + * - Remove state by predicate on keys:
    + *          scan all keys in KeyToNumValuesStore to find keys that do 
match the predicate,
    + *          delete from key from KeyToNumValuesStore, delete values in 
KeyWithIndexToValueStore
    + * - Remove state by condition on values:
    + *          scan all [(key, index) -> value] in KeyWithIndexToValueStore 
to find values that match
    + *          the predicate, delete corresponding (key, indexToDelete) from 
KeyWithIndexToValueStore
    + *          by overwriting with the value of (key, maxIndex), and removing 
[(key, maxIndex),
    + *          decrement corresponding num values in KeyToNumValuesStore
    + */
    +class SymmetricHashJoinStateManager(
    +    val joinSide: JoinSide,
    +    val inputValueAttributes: Seq[Attribute],
    +    joinKeys: Seq[Expression],
    +    stateInfo: Option[StatefulOperatorStateInfo],
    +    storeConf: StateStoreConf,
    +    hadoopConf: Configuration) extends Logging {
    +
    +  import SymmetricHashJoinStateManager._
    +
    +  // Clean up any state store resources if necessary at the end of the task
    +  Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => 
abortIfNeeded() } }
    +
    +  /*
    +  =====================================================
    +                  Public methods
    +  =====================================================
    +   */
    +
    +  /** Get all the values of a key */
    +  def get(key: UnsafeRow): Iterator[UnsafeRow] = {
    +    val numValues = keyToNumValues.get(key)
    +    keyWithIndexToValue.getAll(key, numValues)
    +  }
    +
    +  /** Append a new value to the key */
    +  def append(key: UnsafeRow, value: UnsafeRow): Unit = {
    +    val numExistingValues = keyToNumValues.get(key)
    +    keyWithIndexToValue.put(key, numExistingValues, value)
    +    keyToNumValues.put(key, numExistingValues + 1)
    +  }
    +
    +  /**
    +   * Remove using a predicate on keys. See class docs for more context and 
implement details.
    +   */
    +  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
    +    val allKeyToNumValues = keyToNumValues.iterator
    +
    +    while (allKeyToNumValues.hasNext) {
    +      val keyToNumValue = allKeyToNumValues.next
    +      if (condition(keyToNumValue.key)) {
    +        keyToNumValues.remove(keyToNumValue.key)
    +        keyWithIndexToValue.removeAllValues(keyToNumValue.key, 
keyToNumValue.numValue)
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Remove using a predicate on values. See class docs for more context 
and implementation details.
    +   */
    +  def removeByPredicateOnValues(condition: UnsafeRow => Boolean): Unit = {
    +    val allKeyToNumValues = keyToNumValues.iterator
    +
    +    var numValues: Long = 0L
    +    var index: Long = 0L
    +    var valueRemoved = false
    +    var valueForIndex: UnsafeRow = null
    +
    +
    +    while (allKeyToNumValues.hasNext) {
    +      val keyToNumValue = allKeyToNumValues.next
    +      val key = keyToNumValue.key
    +
    +      numValues = keyToNumValue.numValue
    +      index = 0L
    +      valueRemoved = false
    +      valueForIndex = null
    +
    +      while (index < numValues) {
    +        if (valueForIndex == null) {
    +          valueForIndex = keyWithIndexToValue.get(key, index)
    +        }
    +        if (condition(valueForIndex)) {
    +          if (numValues > 1) {
    +            val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 
1)
    +            keyWithIndexToValue.put(key, index, valueAtMaxIndex)
    +            keyWithIndexToValue.remove(key, numValues - 1)
    +            valueForIndex = valueAtMaxIndex
    +          } else {
    +            keyWithIndexToValue.remove(key, 0)
    +            valueForIndex = null
    +          }
    +          numValues -= 1
    +          valueRemoved = true
    +        } else {
    +          valueForIndex = null
    +          index += 1
    +        }
    +      }
    +      if (valueRemoved) {
    +        if (numValues >= 1) {
    +          keyToNumValues.put(key, numValues)
    +        } else {
    +          keyToNumValues.remove(key)
    +        }
    +      }
    +    }
    +  }
    +
    +  def iterator(): Iterator[UnsafeRowPair] = {
    +    val pair = new UnsafeRowPair()
    +    keyWithIndexToValue.iterator.map { x =>
    +      pair.withRows(x.key, x.value)
    +    }
    +  }
    +
    +  /** Commit all the changes to all the state stores */
    +  def commit(): Unit = {
    +    keyToNumValues.commit()
    +    keyWithIndexToValue.commit()
    +  }
    +
    +  /** Abort any changes to the state stores if needed */
    +  def abortIfNeeded(): Unit = {
    +    keyWithIndexToValue.abortIfNeeded()
    +    keyWithIndexToValue.abortIfNeeded()
    --- End diff --
    
    you're aborting twice on the same state store. `keyToNumValues`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to