This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 7432057  [SPARK-35659][SS] Avoid write null to StateStore
7432057 is described below

commit 74320575e6a5761815425fa0862e3c1848a504ca
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Tue Jun 8 09:10:19 2021 -0700

    [SPARK-35659][SS] Avoid write null to StateStore
    
    ### What changes were proposed in this pull request?
    
    This patch removes the usage of putting null into StateStore.
    
    ### Why are the changes needed?
    
    According to `get` method doc in `StateStore` API, it returns non-null row 
if the key exists. So basically we should avoid write null to `StateStore`. You 
cannot distinguish if the returned null row is because the key doesn't exist, 
or the value is actually null. And due to the defined behavior of `get`, it is 
quite easy to cause NPE error if the caller doesn't expect to get a null if the 
caller believes the key exists.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added test.
    
    Closes #32796 from viirya/fix-ss-joinstatemanager.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
    (cherry picked from commit 1226b9badd2bc6681e4c533e0dfbc09443a86167)
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../streaming/state/HDFSBackedStateStoreProvider.scala      |  1 +
 .../spark/sql/execution/streaming/state/StateStore.scala    |  4 ++--
 .../streaming/state/SymmetricHashJoinStateManager.scala     |  8 ++------
 .../sql/execution/streaming/state/StateStoreSuite.scala     | 13 +++++++++++++
 4 files changed, 18 insertions(+), 8 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 5c55034..eb0e7ce 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -122,6 +122,7 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
     }
 
     override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+      require(value != null, "Cannot put a null value")
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
       val keyCopy = key.copy()
       val valueCopy = value.copy()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 7c69e6f..ee4e2ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -99,8 +99,8 @@ trait ReadStateStore {
 trait StateStore extends ReadStateStore {
 
   /**
-   * Put a new value for a non-null key. Implementations must be aware that 
the UnsafeRows in
-   * the params can be reused, and must make copies of the data as needed for 
persistence.
+   * Put a new non-null value for a non-null key. Implementations must be 
aware that the UnsafeRows
+   * in the params can be reused, and must make copies of the data as needed 
for persistence.
    */
   def put(key: UnsafeRow, value: UnsafeRow): Unit
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index dae771c..915b0ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -269,18 +269,14 @@ class SymmetricHashJoinStateManager(
         // The backing store is arraylike - we as the caller are responsible 
for filling back in
         // any hole. So we swap the last element into the hole and decrement 
numValues to shorten.
         // clean
-        if (numValues > 1) {
+        if (index != numValues - 1) {
           val valuePairAtMaxIndex = keyWithIndexToValue.get(currentKey, 
numValues - 1)
           if (valuePairAtMaxIndex != null) {
             keyWithIndexToValue.put(currentKey, index, 
valuePairAtMaxIndex.value,
               valuePairAtMaxIndex.matched)
-          } else {
-            keyWithIndexToValue.put(currentKey, index, null, false)
           }
-          keyWithIndexToValue.remove(currentKey, numValues - 1)
-        } else {
-          keyWithIndexToValue.remove(currentKey, 0)
         }
+        keyWithIndexToValue.remove(currentKey, numValues - 1)
         numValues -= 1
         valueRemoved = true
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 7e8f955..b82d32e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -1012,6 +1012,19 @@ abstract class StateStoreSuiteBase[ProviderClass <: 
StateStoreProvider]
     assert(combinedMetrics.customMetrics(customTimingMetric) == 400L)
   }
 
+  test("SPARK-35659: StateStore.put cannot put null value") {
+    val provider = newStoreProvider()
+
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+
+    val store = provider.getStore(0)
+    val err = intercept[IllegalArgumentException] {
+      store.put(stringToRow("key"), null)
+    }
+    assert(err.getMessage.contains("Cannot put a null value"))
+  }
+
   /** Return a new provider with a random id */
   def newStoreProvider(): ProviderClass
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to