This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fba4c8c20e52 [SPARK-48970][PYTHON][ML] Avoid using SparkSession.getActiveSession in spark ML reader/writer fba4c8c20e52 is described below commit fba4c8c20e523c9a441f007442efd616320e7be4 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Tue Jul 23 19:19:28 2024 +0800 [SPARK-48970][PYTHON][ML] Avoid using SparkSession.getActiveSession in spark ML reader/writer ### What changes were proposed in this pull request? `SparkSession.getActiveSession` is thread-local session, but spark ML reader / writer might be executed in different threads which causes `SparkSession.getActiveSession` returning None. ### Why are the changes needed? It fixes the bug like: ``` spark = SparkSession.getActiveSession() > spark.createDataFrame( # type: ignore[union-attr] [(metadataJson,)], schema=["value"] ).coalesce(1).write.text(metadataPath) E AttributeError: 'NoneType' object has no attribute 'createDataFrame' ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47453 from WeichenXu123/SPARK-48970. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- .../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +- python/pyspark/ml/util.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 021595f76c24..c127575e1470 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -588,7 +588,7 @@ private[ml] object DefaultParamsReader { */ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString - val spark = SparkSession.getActiveSession.get + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadataStr = spark.read.text(metadataPath).first().getString(0) parseMetadata(metadataStr, expectedClassName) } diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 5e7965554d82..89e2f9631564 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -464,10 +464,10 @@ class DefaultParamsWriter(MLWriter): metadataJson = DefaultParamsWriter._get_metadata_to_save( instance, sc, extraMetadata, paramMap ) - spark = SparkSession.getActiveSession() - spark.createDataFrame( # type: ignore[union-attr] - [(metadataJson,)], schema=["value"] - ).coalesce(1).write.text(metadataPath) + spark = SparkSession._getActiveSessionOrCreate() + spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text( + metadataPath + ) @staticmethod def _get_metadata_to_save( @@ -580,8 +580,8 @@ class DefaultParamsReader(MLReader[RL]): If non empty, this is checked against the loaded metadata. """ metadataPath = os.path.join(path, "metadata") - spark = SparkSession.getActiveSession() - metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[union-attr,index] + spark = SparkSession._getActiveSessionOrCreate() + metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org