This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 12a576f  [SPARK-34892][SS] Introduce 
MergingSortWithSessionWindowStateIterator sorting input rows and rows in state 
efficiently
12a576f is described below

commit 12a576f1759aabdcf37c5c841638428e1e33a591
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Wed Jul 14 18:47:44 2021 +0900

    [SPARK-34892][SS] Introduce MergingSortWithSessionWindowStateIterator 
sorting input rows and rows in state efficiently
    
    Introduction: this PR is a part of SPARK-10816 (EventTime based 
sessionization (session window)). Please refer #31937 to see the overall view 
of the code change. (Note that code diff could be diverged a bit.)
    
    ### What changes were proposed in this pull request?
    
    This PR introduces MergingSortWithSessionWindowStateIterator, which does 
"merge sort" between input rows and sessions in state based on group key and 
session's start time.
    
    Note that the iterator does merge sort among input rows and sessions 
grouped by grouping key. The iterator doesn't provide sessions in state which 
keys don't exist in input rows. For input rows, the iterator will provide all 
rows regardless of the existence of matching sessions in state.
    
    MergingSortWithSessionWindowStateIterator works on the precondition that 
given iterator is sorted by "group keys + start time of session window", and 
the iterator still retains the characteristic of the sort.
    
    ### Why are the changes needed?
    
    This part is a one of required on implementing SPARK-10816.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT added.
    
    Closes #33077 from HeartSaVioR/SPARK-34892-SPARK-10816-PR-31570-part-4.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 ...MergingSortWithSessionWindowStateIterator.scala | 168 +++++++++++++++
 .../state/HDFSBackedStateStoreProvider.scala       |   2 +-
 ...ngSortWithSessionWindowStateIteratorSuite.scala | 231 +++++++++++++++++++++
 3 files changed, 400 insertions(+), 1 deletion(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala
new file mode 100644
index 0000000..a923ebd
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, 
UnsafeRow}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, 
StreamingSessionWindowStateManager}
+
+/**
+ * This class technically does the merge sort between input rows and existing 
sessions in state,
+ * to optimize the cost of sort on "input rows + existing sessions". This is 
based on the
+ * precondition that input rows are sorted by "group keys + start time of 
session window".
+ *
+ * This only materializes the existing sessions into memory, which are tend to 
be not many per
+ * group key. The cost of sorting existing sessions would be also minor based 
on the assumption.
+ *
+ * The output rows are sorted with "group keys + start time of session 
window", which is same as
+ * the sort condition on input rows.
+ */
+class MergingSortWithSessionWindowStateIterator(
+    iter: Iterator[InternalRow],
+    stateManager: StreamingSessionWindowStateManager,
+    store: ReadStateStore,
+    groupWithoutSessionExpressions: Seq[Attribute],
+    sessionExpression: Attribute,
+    inputSchema: Seq[Attribute]) extends Iterator[InternalRow] with Logging {
+
+  private val keysProjection: UnsafeProjection = 
GenerateUnsafeProjection.generate(
+    groupWithoutSessionExpressions, inputSchema)
+  private val sessionProjection: UnsafeProjection =
+    GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema)
+
+  private case class SessionRowInformation(
+      keys: UnsafeRow,
+      sessionStart: Long,
+      sessionEnd: Long,
+      row: InternalRow)
+
+  private object SessionRowInformation {
+    def of(row: InternalRow): SessionRowInformation = {
+      val keys = keysProjection(row).copy()
+      val session = sessionProjection(row).copy()
+      val sessionRow = session.getStruct(0, 2)
+      val sessionStart = sessionRow.getLong(0)
+      val sessionEnd = sessionRow.getLong(1)
+
+      SessionRowInformation(keys, sessionStart, sessionEnd, row)
+    }
+  }
+
+  // Holds the latest fetched row from input side iterator.
+  private var currentRowFromInput: SessionRowInformation = _
+
+  // Holds the latest fetched row from state side iterator.
+  private var currentRowFromState: SessionRowInformation = _
+
+  // Holds the iterator of rows (sessions) in state for the session key.
+  private var sessionIterFromState: Iterator[InternalRow] = _
+
+  // Holds the current session key.
+  private var currentSessionKey: UnsafeRow = _
+
+  override def hasNext: Boolean = {
+    currentRowFromInput != null || currentRowFromState != null ||
+      (sessionIterFromState != null && sessionIterFromState.hasNext) || 
iter.hasNext
+  }
+
+  override def next(): InternalRow = {
+    if (currentRowFromInput == null) {
+      mayFillCurrentRow()
+    }
+
+    if (currentRowFromState == null) {
+      mayFillCurrentStateRow()
+    }
+
+    if (currentRowFromInput == null && currentRowFromState == null) {
+      throw new IllegalStateException("No Row to provide in next() which 
should not happen!")
+    }
+
+    // return current row vs current state row, should return smaller key, 
earlier session start
+    val returnCurrentRow: Boolean = {
+      if (currentRowFromInput == null) {
+        false
+      } else if (currentRowFromState == null) {
+        true
+      } else {
+        // compare
+        if (currentRowFromInput.keys != currentRowFromState.keys) {
+          // state row cannot advance to row in input, so state row should be 
lower
+          false
+        } else {
+          currentRowFromInput.sessionStart < currentRowFromState.sessionStart
+        }
+      }
+    }
+
+    val ret: SessionRowInformation = {
+      if (returnCurrentRow) {
+        val toRet = currentRowFromInput
+        currentRowFromInput = null
+        toRet
+      } else {
+        val toRet = currentRowFromState
+        currentRowFromState = null
+        toRet
+      }
+    }
+
+    ret.row
+  }
+
+  private def mayFillCurrentRow(): Unit = {
+    if (iter.hasNext) {
+      currentRowFromInput = SessionRowInformation.of(iter.next())
+    }
+  }
+
+  private def mayFillCurrentStateRow(): Unit = {
+    if (sessionIterFromState != null && sessionIterFromState.hasNext) {
+      currentRowFromState = 
SessionRowInformation.of(sessionIterFromState.next())
+    } else {
+      sessionIterFromState = null
+
+      if (currentRowFromInput != null && currentRowFromInput.keys != 
currentSessionKey) {
+        // We expect a small number of sessions per group key, so 
materializing them
+        // and sorting wouldn't hurt much. The important thing is that we 
shouldn't buffer input
+        // rows to sort with existing sessions.
+        val unsortedIter = stateManager.getSessions(store, 
currentRowFromInput.keys)
+        val unsortedList = unsortedIter.map(_.copy()).toList
+
+        val sortedList = unsortedList.sortWith((row1, row2) => {
+          def getSessionStart(r: InternalRow): Long = {
+            val session = sessionProjection(r)
+            val sessionRow = session.getStruct(0, 2)
+            sessionRow.getLong(0)
+          }
+
+          // here sorting is based on the fact that keys are same
+          getSessionStart(row1).compareTo(getSessionStart(row2)) < 0
+        })
+        sessionIterFromState = sortedList.iterator
+
+        currentSessionKey = currentRowFromInput.keys
+        if (sessionIterFromState.hasNext) {
+          currentRowFromState = 
SessionRowInformation.of(sessionIterFromState.next())
+        }
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index c604021..75b7dae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -70,7 +70,7 @@ import org.apache.spark.util.{SizeEstimator, Utils}
  * to ensure re-executed RDD operations re-apply updates on the correct past 
version of the
  * store.
  */
-private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider 
with Logging {
+private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider 
with Logging {
 
   class HDFSBackedReadStateStore(val version: Long, map: 
HDFSBackedStateStoreMap)
     extends ReadStateStore {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
new file mode 100644
index 0000000..81f1a3f
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, 
StateStoreProviderId, StreamingSessionWindowStateManager}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType, 
StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with 
BeforeAndAfter {
+
+  private val rowSchema = new StructType().add("key1", StringType).add("key2", 
IntegerType)
+    .add("session", new StructType().add("start", LongType).add("end", 
LongType))
+    .add("value", LongType)
+  private val rowAttributes = rowSchema.toAttributes
+
+  private val keysWithoutSessionAttributes = rowAttributes.filter { attr =>
+    List("key1", "key2").contains(attr.name)
+  }
+
+  private val sessionAttribute = rowAttributes.filter(_.name == "session").head
+
+  private val inputValueGen = 
UnsafeProjection.create(rowAttributes.map(_.dataType).toArray)
+  private val inputKeyGen = UnsafeProjection.create(
+    keysWithoutSessionAttributes.map(_.dataType).toArray)
+
+  before {
+    SparkSession.setActiveSession(spark)
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+  private val providerOptions = Seq(
+    classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+    classOf[RocksDBStateStoreProvider].getCanonicalName).map { value =>
+    (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$"))
+  }
+
+  private val availableOptions = for (
+    opt1 <- providerOptions;
+    opt2 <- StreamingSessionWindowStateManager.supportedVersions
+  ) yield (opt1, opt2)
+
+  availableOptions.foreach { case (providerOpt, version) =>
+    withSQLConf(providerOpt) {
+      test(s"StreamingSessionWindowStateManager " +
+        s"provider ${providerOpt._2} state version v${version} - rows only in 
state") {
+        testRowsOnlyInState(version)
+      }
+
+      test(s"StreamingSessionWindowStateManager " +
+        s"provider ${providerOpt._2} state version v${version} - rows in both 
input and state") {
+        testRowsInBothInputAndState(version)
+      }
+
+      test(s"StreamingSessionWindowStateManager " +
+        s"provider ${providerOpt._2} state version v${version} - rows only in 
input") {
+        testRowsOnlyInInput(version)
+      }
+    }
+  }
+
+  private def testRowsOnlyInState(stateFormatVersion: Int): Unit = {
+    withStateManager(stateFormatVersion) { case (stateManager, store) =>
+      val key = createKeyRow("a", 1)
+      val values = Seq(
+        createRow("a", 1, 100, 110, 1),
+        createRow("a", 1, 120, 130, 2),
+        createRow("a", 1, 140, 150, 3))
+
+      stateManager.updateSessions(store, key, values)
+
+      val iter = new MergingSortWithSessionWindowStateIterator(
+        Iterator.empty,
+        stateManager,
+        store,
+        keysWithoutSessionAttributes,
+        sessionAttribute,
+        rowAttributes)
+
+      val actual = iter.map(_.copy()).toList
+      assert(actual.isEmpty)
+    }
+  }
+
+  private def testRowsInBothInputAndState(stateFormatVersion: Int): Unit = {
+    withStateManager(stateFormatVersion) { case (stateManager, store) =>
+      val key1 = createKeyRow("a", 1)
+      val key1Values = Seq(
+        createRow("a", 1, 100, 110, 1),
+        createRow("a", 1, 120, 130, 2),
+        createRow("a", 1, 140, 150, 3))
+
+      // This is to ensure sessions will not be populated if the input doesn't 
have such group key
+      val key2 = createKeyRow("a", 2)
+      val key2Values = Seq(
+        createRow("a", 2, 100, 110, 1),
+        createRow("a", 2, 120, 130, 2),
+        createRow("a", 2, 140, 150, 3))
+
+      val key3 = createKeyRow("b", 1)
+      val key3Values = Seq(
+        createRow("b", 1, 100, 110, 1),
+        createRow("b", 1, 120, 130, 2),
+        createRow("b", 1, 140, 150, 3))
+
+      stateManager.updateSessions(store, key1, key1Values)
+      stateManager.updateSessions(store, key2, key2Values)
+      stateManager.updateSessions(store, key3, key3Values)
+
+      val inputsForKey1 = Seq(
+        createRow("a", 1, 90, 100, 1),
+        createRow("a", 1, 125, 135, 2))
+      val inputsForKey3 = Seq(
+        createRow("b", 1, 150, 160, 3)
+      )
+      val inputs = inputsForKey1 ++ inputsForKey3
+
+      val iter = new MergingSortWithSessionWindowStateIterator(
+        inputs.iterator,
+        stateManager,
+        store,
+        keysWithoutSessionAttributes,
+        sessionAttribute,
+        rowAttributes)
+
+      val actual = iter.map(_.copy()).toList
+      val expected = (key1Values ++ inputsForKey1).sortBy(getSessionStart) ++
+        (key3Values ++ inputsForKey3).sortBy(getSessionStart)
+      assert(actual === expected.toList)
+    }
+  }
+
+  private def testRowsOnlyInInput(stateFormatVersion: Int): Unit = {
+    withStateManager(stateFormatVersion) { case (stateManager, store) =>
+      // This is to ensure sessions will not be populated if the input doesn't 
have such group key
+      val key1 = createKeyRow("a", 1)
+      val key1Values = Seq(
+        createRow("a", 1, 100, 110, 1),
+        createRow("a", 1, 120, 130, 2),
+        createRow("a", 1, 140, 150, 3))
+
+      stateManager.updateSessions(store, key1, key1Values)
+
+      val inputs = Seq(
+        createRow("b", 1, 100, 110, 1),
+        createRow("b", 1, 120, 130, 2),
+        createRow("b", 1, 140, 150, 3))
+
+      val iter = new MergingSortWithSessionWindowStateIterator(
+        inputs.iterator,
+        stateManager,
+        store,
+        keysWithoutSessionAttributes,
+        sessionAttribute,
+        rowAttributes)
+
+      val actual = iter.map(_.copy()).toList
+      assert(actual === inputs.toList)
+    }
+  }
+
+  private def createRow(
+      key1: String,
+      key2: Int,
+      sessionStart: Long,
+      sessionEnd: Long,
+      value: Long): UnsafeRow = {
+    val sessionRow = new GenericInternalRow(Array[Any](sessionStart, 
sessionEnd))
+    val row = new GenericInternalRow(
+      Array[Any](UTF8String.fromString(key1), key2, sessionRow, value))
+    inputValueGen.apply(row).copy()
+  }
+
+  private def createKeyRow(key1: String, key2: Int): UnsafeRow = {
+    val row = new GenericInternalRow(Array[Any](UTF8String.fromString(key1), 
key2))
+    inputKeyGen.apply(row).copy()
+  }
+
+  private def getSessionStart(row: UnsafeRow): Long = {
+    row.getStruct(2, 2).getLong(0)
+  }
+
+  private def withStateManager(
+      stateFormatVersion: Int)(
+      f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = {
+    withTempDir { file =>
+      val storeConf = new StateStoreConf()
+      val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, 
UUID.randomUUID, 0, 0, 5)
+
+      val manager = StreamingSessionWindowStateManager.createStateManager(
+        keysWithoutSessionAttributes,
+        sessionAttribute,
+        rowAttributes,
+        stateFormatVersion)
+
+      val storeProviderId = StateStoreProviderId(stateInfo, 0, 
StateStoreId.DEFAULT_STORE_NAME)
+      val store = StateStore.get(
+        storeProviderId, manager.getStateKeySchema, 
manager.getStateValueSchema,
+        manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new 
Configuration)
+
+      try {
+        f(manager, store)
+      } finally {
+        manager.abortIfNeeded(store)
+      }
+    }
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to