This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new e6f0941ad [#2581] fix(spark): Use `SparkContext.getActive` instead of 
`getOrCreate` to align with method semantics (#2582)
e6f0941ad is described below

commit e6f0941ad9768beb83ec330c964d20a3ce2e55e3
Author: Zhen Wang <[email protected]>
AuthorDate: Thu Aug 14 11:15:12 2025 +0800

    [#2581] fix(spark): Use `SparkContext.getActive` instead of `getOrCreate` 
to align with method semantics (#2582)
    
    ### What changes were proposed in this pull request?
    
    Use `SparkContext.getActive` instead of `getOrCreate` to better align with 
the intended semantics for external invocation.
    
    ### Why are the changes needed?
    
    Fix #2581
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit test
---
 .../apache/spark/shuffle/RssSparkShuffleUtils.java | 15 ++++++++++----
 .../spark/shuffle/RssSpark2ShuffleUtilsTest.java   | 23 ++++++++++++++++++++++
 .../spark/shuffle/RssSpark3ShuffleUtilsTest.java   | 23 ++++++++++++++++++++++
 3 files changed, 57 insertions(+), 4 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index c37a93623..56ec29caa 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle;
 
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
@@ -262,13 +263,19 @@ public class RssSparkShuffleUtils {
    * Get current active {@link SparkContext}. It should be called inside 
Driver since we don't mean
    * to create any new {@link SparkContext} here.
    *
-   * <p>Note: We could use "SparkContext.getActive()" instead of 
"SparkContext.getOrCreate()" if the
-   * "getActive" method is not declared as package private in Scala.
-   *
    * @return Active SparkContext created by Driver.
    */
   public static SparkContext getActiveSparkContext() {
-    return SparkContext.getOrCreate();
+    try {
+      Class<?> clazz = Class.forName("org.apache.spark.SparkContext$");
+      Object module = clazz.getField("MODULE$").get(null);
+      Method getActiveMethod = clazz.getMethod("getActive");
+      Object scOpt = getActiveMethod.invoke(module);
+      return ((Option<SparkContext>) scOpt).get();
+    } catch (Exception e) {
+      LOG.error("Failed to get active SparkContext", e);
+      throw new RssException(e);
+    }
   }
 
   /**
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/RssSpark2ShuffleUtilsTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/RssSpark2ShuffleUtilsTest.java
index dac10d3db..38da72b4a 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/RssSpark2ShuffleUtilsTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/RssSpark2ShuffleUtilsTest.java
@@ -17,13 +17,36 @@
 
 package org.apache.spark.shuffle;
 
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssSpark2ShuffleUtilsTest {
+
+  @Test
+  public void testGetActiveSparkContext() {
+    assertThrows(RssException.class, 
RssSparkShuffleUtils::getActiveSparkContext);
+    SparkConf conf = new SparkConf();
+    conf.setMaster("local[1]");
+    conf.setAppName("test");
+    SparkContext sc = null;
+    try {
+      sc = SparkContext.getOrCreate(conf);
+      assertEquals(sc, RssSparkShuffleUtils.getActiveSparkContext());
+    } finally {
+      if (sc != null) {
+        sc.stop();
+      }
+    }
+  }
+
   @Test
   public void testCreateFetchFailedException() {
     FetchFailedException ffe = 
RssSparkShuffleUtils.createFetchFailedException(0, -1, 10, null);
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssSpark3ShuffleUtilsTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssSpark3ShuffleUtilsTest.java
index adbd57588..e09c941a4 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssSpark3ShuffleUtilsTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssSpark3ShuffleUtilsTest.java
@@ -17,13 +17,36 @@
 
 package org.apache.spark.shuffle;
 
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssSpark3ShuffleUtilsTest {
+
+  @Test
+  public void testGetActiveSparkContext() {
+    assertThrows(RssException.class, 
RssSparkShuffleUtils::getActiveSparkContext);
+    SparkConf conf = new SparkConf();
+    conf.setMaster("local[1]");
+    conf.setAppName("test");
+    SparkContext sc = null;
+    try {
+      sc = SparkContext.getOrCreate(conf);
+      assertEquals(sc, RssSparkShuffleUtils.getActiveSparkContext());
+    } finally {
+      if (sc != null) {
+        sc.stop();
+      }
+    }
+  }
+
   @Test
   public void testCreateFetchFailedException() {
     FetchFailedException ffe = 
RssSparkShuffleUtils.createFetchFailedException(0, -1, 10, null);

Reply via email to