This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 0c4ad508476 [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized 0c4ad508476 is described below commit 0c4ad508476b54d7d3acd303ff686310dd198a3d Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Tue Mar 28 14:35:38 2023 +0900 [SPARK-42908][PYTHON] Raise RuntimeError when SparkContext is required but not initialized ### What changes were proposed in this pull request? Raise RuntimeError when SparkContext is required but not initialized. ### Why are the changes needed? Error improvement. ### Does this PR introduce _any_ user-facing change? Error type and message change. Raise a RuntimeError with a clear message (rather than an AssertionError) when SparkContext is required but not initialized yet. ### How was this patch tested? Unit test. Closes #40534 from xinrong-meng/err_msg. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 70f6206dbcd3c5ff0f4618cf179b7fcf75ae672c) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/avro/functions.py | 21 ++++++----- python/pyspark/sql/column.py | 32 +++++++--------- python/pyspark/sql/dataframe.py | 10 +++-- python/pyspark/sql/functions.py | 65 +++++++++++++++----------------- python/pyspark/sql/protobuf/functions.py | 21 ++++++----- python/pyspark/sql/tests/test_udf.py | 7 ++++ python/pyspark/sql/types.py | 17 ++++----- python/pyspark/sql/udf.py | 4 +- python/pyspark/sql/utils.py | 9 +++++ python/pyspark/sql/window.py | 40 +++++++++++--------- 10 files changed, 120 insertions(+), 106 deletions(-) diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index cf6676c8ab1..080e45934e6 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -20,9 +20,12 @@ A collections of builtin avro functions """ -from typing import Dict, Optional, TYPE_CHECKING -from pyspark import SparkContext +from typing import Dict, Optional, TYPE_CHECKING, cast + +from py4j.java_gateway import JVMView + from pyspark.sql.column import Column, _to_java_column +from pyspark.sql.utils import get_active_spark_context from pyspark.util import _print_missing_jar if TYPE_CHECKING: @@ -73,10 +76,9 @@ def from_avro( [Row(value=Row(avro=Row(age=2, name='Alice')))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() try: - jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.from_avro( _to_java_column(data), jsonFormatSchema, options or {} ) except TypeError as e: @@ -119,13 +121,14 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: [Row(suite=bytearray(b'\\x02\\x00'))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() try: if jsonFormatSchema == "": - jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data)) + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro( + _to_java_column(data) + ) else: - jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.to_avro( _to_java_column(data), jsonFormatSchema ) except TypeError as e: diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index d13d3954bca..0a18930b8eb 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,12 +31,13 @@ from typing import ( Union, ) -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject, JVMView from pyspark import copy_func from pyspark.context import SparkContext from pyspark.errors import PySparkTypeError from pyspark.sql.types import DataType +from pyspark.sql.utils import get_active_spark_context if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral @@ -46,15 +47,13 @@ __all__ = ["Column"] def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column": - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return sc._jvm.functions.lit(literal) + sc = get_active_spark_context() + return cast(JVMView, sc._jvm).functions.lit(literal) def _create_column_from_name(name: str) -> "Column": - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return sc._jvm.functions.col(name) + sc = get_active_spark_context() + return cast(JVMView, sc._jvm).functions.col(name) def _to_java_column(col: "ColumnOrName") -> JavaObject: @@ -122,9 +121,8 @@ def _unary_op( def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]: def _(self: "Column") -> "Column": - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jc = getattr(sc._jvm.functions, name)(self._jc) + sc = get_active_spark_context() + jc = getattr(cast(JVMView, sc._jvm).functions, name)(self._jc) return Column(jc) _.__doc__ = doc @@ -137,9 +135,8 @@ def _bin_func_op( doc: str = "binary function", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]: def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column": - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - fn = getattr(sc._jvm.functions, name) + sc = get_active_spark_context() + fn = getattr(cast(JVMView, sc._jvm).functions, name) jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) return Column(njc) @@ -633,8 +630,7 @@ class Column: +--------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None + sc = get_active_spark_context() jc = self._jc.dropFields(_to_seq(sc, fieldNames)) return Column(jc) @@ -962,8 +958,7 @@ class Column: Tuple, [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols], ) - sc = SparkContext._active_spark_context - assert sc is not None + sc = get_active_spark_context() jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) return Column(jc) @@ -1144,8 +1139,7 @@ class Column: metadata = kwargs.pop("metadata", None) assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs - sc = SparkContext._active_spark_context - assert sc is not None + sc = get_active_spark_context() if len(alias) == 1: if metadata: assert sc._jvm is not None diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cc5d264bd34..518bc9867d7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -39,7 +39,7 @@ from typing import ( TYPE_CHECKING, ) -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject, JVMView from pyspark import copy_func, _NoValue from pyspark._globals import _NoValueType @@ -61,6 +61,7 @@ from pyspark.sql.types import ( Row, _parse_datatype_json_string, ) +from pyspark.sql.utils import get_active_spark_context from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin @@ -4899,9 +4900,10 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): error_class="NOT_DICT", message_parameters={"arg_name": "metadata", "arg_type": type(metadata).__name__}, ) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata)) + sc = get_active_spark_context() + jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson( + json.dumps(metadata) + ) return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession) @overload diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bb5a1a559be..ab099554293 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -37,6 +37,8 @@ from typing import ( ValuesView, ) +from py4j.java_gateway import JVMView + from pyspark import SparkContext from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.rdd import PythonEvalType @@ -49,7 +51,12 @@ from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401 # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401 -from pyspark.sql.utils import to_str, has_numpy, try_remote_functions +from pyspark.sql.utils import ( + to_str, + has_numpy, + try_remote_functions, + get_active_spark_context, +) if TYPE_CHECKING: from pyspark.sql._typing import ( @@ -101,8 +108,7 @@ def _invoke_function_over_seq_of_columns(name: str, cols: "Iterable[ColumnOrName Invokes unary JVM function identified by name with and wraps the result with :class:`~pyspark.sql.Column`. """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function(name, _to_seq(sc, cols, _to_java_column)) @@ -2676,9 +2682,8 @@ def broadcast(df: DataFrame) -> DataFrame: +-----+---+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession) + sc = get_active_spark_context() + return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) @try_remote_functions @@ -2891,8 +2896,7 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: | 4| +----------------------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function( "count_distinct", _to_java_column(col), _to_seq(sc, cols, _to_java_column) ) @@ -3304,8 +3308,7 @@ def percentile_approx( |-- key: long (nullable = true) |-- median: double (nullable = true) """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() if isinstance(percentage, (list, tuple)): # A local list @@ -6226,8 +6229,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() [Row(s='abcd-123')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function("concat_ws", sep, _to_seq(sc, cols, _to_java_column)) @@ -6360,8 +6362,7 @@ def format_string(format: str, *cols: "ColumnOrName") -> Column: >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() [Row(v='5 hello')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function("format_string", format, _to_seq(sc, cols, _to_java_column)) @@ -7419,8 +7420,7 @@ def array_join( >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a,NULL')] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + get_active_spark_context() if null_replacement is None: return _invoke_function("array_join", _to_java_column(col), delimiter) else: @@ -8229,8 +8229,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column: >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function("json_tuple", _to_java_column(col), _to_seq(sc, fields)) @@ -9182,8 +9181,7 @@ def from_csv( [Row(csv=Row(s='abc'))] """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + get_active_spark_context() if isinstance(schema, str): schema = _create_column_from_literal(schema) elif isinstance(schema, Column): @@ -9209,11 +9207,12 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: ---------- name_parts : str """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() name_parts_seq = _to_seq(sc, name_parts) - expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions - return Column(sc._jvm.Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq))) + expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions + return Column( + cast(JVMView, sc._jvm).Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq)) + ) def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]: @@ -9258,9 +9257,8 @@ def _create_lambda(f: Callable) -> Callable: """ parameters = _get_lambda_parameters(f) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions + sc = get_active_spark_context() + expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions argnames = ["x", "y", "z"] args = [ @@ -9300,15 +9298,14 @@ def _invoke_higher_order_function( :return: a Column """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions + sc = get_active_spark_context() + expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions expr = getattr(expressions, name) jcols = [_to_java_column(col).expr() for col in cols] jfuns = [_create_lambda(f) for f in funs] - return Column(sc._jvm.Column(expr(*jcols + jfuns))) + return Column(cast(JVMView, sc._jvm).Column(expr(*jcols + jfuns))) @overload @@ -10017,8 +10014,7 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: message_parameters={"arg_name": "numBuckets", "arg_type": type(numBuckets).__name__}, ) - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + get_active_spark_context() numBuckets = ( _create_column_from_literal(numBuckets) if isinstance(numBuckets, int) @@ -10070,8 +10066,7 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: | cc| +-----------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column)) diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 1fed9cfda66..a303cf91493 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -20,9 +20,12 @@ A collections of builtin protobuf functions """ -from typing import Dict, Optional, TYPE_CHECKING -from pyspark import SparkContext +from typing import Dict, Optional, TYPE_CHECKING, cast + +from py4j.java_gateway import JVMView + from pyspark.sql.column import Column, _to_java_column +from pyspark.sql.utils import get_active_spark_context from pyspark.util import _print_missing_jar if TYPE_CHECKING: @@ -117,15 +120,14 @@ def from_protobuf( +------------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() try: if descFilePath is not None: - jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf( _to_java_column(data), messageName, descFilePath, options or {} ) else: - jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.from_protobuf( _to_java_column(data), messageName, options or {} ) except TypeError as e: @@ -212,15 +214,14 @@ def to_protobuf( +----------------------------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None + sc = get_active_spark_context() try: if descFilePath is not None: - jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf( _to_java_column(data), messageName, descFilePath, options or {} ) else: - jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf( + jc = cast(JVMView, sc._jvm).org.apache.spark.sql.protobuf.functions.to_protobuf( _to_java_column(data), messageName, options or {} ) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 1b414baeec3..d8a464b006f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -858,6 +858,13 @@ class UDFInitializationTests(unittest.TestCase): "SparkSession shouldn't be initialized when UserDefinedFunction is created.", ) + def test_err_parse_type_when_no_sc(self): + with self.assertRaisesRegex( + RuntimeError, + "SparkContext or SparkSession should be created first", + ): + udf(lambda x: x, "integer") + if __name__ == "__main__": from pyspark.sql.tests.test_udf import * # noqa: F401 diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9cb17e85540..ff43e4b00e9 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -46,10 +46,10 @@ from typing import ( ) from py4j.protocol import register_input_converter -from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject +from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject, JVMView from pyspark.serializers import CloudPickleSerializer -from pyspark.sql.utils import has_numpy +from pyspark.sql.utils import has_numpy, get_active_spark_context if has_numpy: import numpy as np @@ -1208,21 +1208,18 @@ def _parse_datatype_string(s: str) -> DataType: ... ParseException:... """ - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - assert sc is not None + sc = get_active_spark_context() def from_ddl_schema(type_str: str) -> DataType: - assert sc is not None and sc._jvm is not None return _parse_datatype_json_string( - sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json() + cast(JVMView, sc._jvm).org.apache.spark.sql.types.StructType.fromDDL(type_str).json() ) def from_ddl_datatype(type_str: str) -> DataType: - assert sc is not None and sc._jvm is not None return _parse_datatype_json_string( - sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json() + cast(JVMView, sc._jvm) + .org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str) + .json() ) try: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e79d04141ae..c1fa3d187fe 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -40,6 +40,7 @@ from pyspark.sql.types import ( StructType, _parse_datatype_string, ) +from pyspark.sql.utils import get_active_spark_context from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -334,8 +335,7 @@ class UserDefinedFunction: return judf def __call__(self, *cols: "ColumnOrName") -> Column: - sc = SparkContext._active_spark_context - assert sc is not None + sc = get_active_spark_context() profiler: Optional[Profiler] = None memory_profiler: Optional[Profiler] = None if sc.profiler_collector: diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b9b045541a6..b5d17e38b87 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -193,6 +193,15 @@ def try_remote_windowspec(f: FuncT) -> FuncT: return cast(FuncT, wrapped) +def get_active_spark_context() -> SparkContext: + """Raise RuntimeError if SparkContext is not initialized, + otherwise, returns the active SparkContext.""" + sc = SparkContext._active_spark_context + if sc is None or sc._jvm is None: + raise RuntimeError("SparkContext or SparkSession should be created first.") + return sc + + def try_remote_observation(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 92b251ba63f..ca05cb0cc7f 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -17,11 +17,14 @@ import sys from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject, JVMView -from pyspark import SparkContext from pyspark.sql.column import _to_seq, _to_java_column -from pyspark.sql.utils import try_remote_window, try_remote_windowspec +from pyspark.sql.utils import ( + try_remote_window, + try_remote_windowspec, + get_active_spark_context, +) if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName, ColumnOrName_ @@ -30,10 +33,9 @@ __all__ = ["Window", "WindowSpec"] def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> JavaObject: - sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] - assert sc is not None + sc = get_active_spark_context() return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column) @@ -123,9 +125,10 @@ class Window: | 3| b| 3| +---+--------+----------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + sc = get_active_spark_context() + jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.partitionBy( + _to_java_cols(cols) + ) return WindowSpec(jspec) @staticmethod @@ -179,9 +182,10 @@ class Window: | 3| b| 1| +---+--------+----------+ """ - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) + sc = get_active_spark_context() + jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.orderBy( + _to_java_cols(cols) + ) return WindowSpec(jspec) @staticmethod @@ -263,9 +267,10 @@ class Window: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end) + sc = get_active_spark_context() + jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rowsBetween( + start, end + ) return WindowSpec(jspec) @staticmethod @@ -350,9 +355,10 @@ class Window: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing - sc = SparkContext._active_spark_context - assert sc is not None and sc._jvm is not None - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) + sc = get_active_spark_context() + jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rangeBetween( + start, end + ) return WindowSpec(jspec) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org