This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 764edaf8b2e [SPARK-41528][CONNECT] Merge namespace of Spark Connect and PySpark API 764edaf8b2e is described below commit 764edaf8b2e1c42a32e7bfa058cf8ee26ce02a9e Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Wed Dec 21 09:13:07 2022 +0900 [SPARK-41528][CONNECT] Merge namespace of Spark Connect and PySpark API ### What changes were proposed in this pull request? This PR proposes to merge namespaces between Spark Connect and PySpark with adding an CLI option `--remote` and `spark.remote` configuration as a symmetry of `--master` and `spark.master`. ### Why are the changes needed? In order to provide the same user experience to the end users, see also the design document attached ([here](https://docs.google.com/document/d/10XJFHnzH8a1cQq9iDf9KORAveK6uy6mBvWA8zZP7rjc/edit?usp=sharing)). ### Does this PR introduce _any_ user-facing change? Yes, users now can use Spark Connect as below: ``` $ ./bin/pyspark --remote ... $ ./bin/pyspark --conf spark.remote ... ... >>> # **Same as regular PySpark from here** ... # Do something with `spark` that is a remote client ... spark.range(1) ``` ``` $ ./bin/spark-submit --remote ... app.py $ ./bin/spark-submit --conf spark.remote ... app.py ... # **Same as regular PySpark from here** # app.py from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() # Do something with `spark` that is a remote client ``` See the design document attached ([here](https://docs.google.com/document/d/10XJFHnzH8a1cQq9iDf9KORAveK6uy6mBvWA8zZP7rjc/edit?usp=sharing)). ### How was this patch tested? Reusing PySpark unittests of DataFrame and functions. Closes #39041 from HyukjinKwon/prototype_merged_pyspark. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/deploy/PythonRunner.scala | 1 + .../org/apache/spark/deploy/SparkSubmit.scala | 31 ++- .../apache/spark/deploy/SparkSubmitArguments.scala | 30 ++- dev/sparktestsupport/modules.py | 2 + .../spark/launcher/AbstractCommandBuilder.java | 1 + .../apache/spark/launcher/AbstractLauncher.java | 15 ++ .../spark/launcher/SparkSubmitCommandBuilder.java | 15 ++ .../spark/launcher/SparkSubmitOptionParser.java | 2 + .../source/reference/pyspark.sql/spark_session.rst | 1 + python/pyspark/context.py | 4 + python/pyspark/shell.py | 65 +++-- python/pyspark/sql/connect/session.py | 4 +- python/pyspark/sql/functions.py | 264 ++++++++++++++++++- python/pyspark/sql/observation.py | 6 +- python/pyspark/sql/session.py | 83 +++++- .../sql/tests/connect/test_parity_dataframe.py | 237 +++++++++++++++++ .../sql/tests/connect/test_parity_functions.py | 292 +++++++++++++++++++++ python/pyspark/sql/tests/test_dataframe.py | 6 +- python/pyspark/sql/tests/test_functions.py | 25 +- python/pyspark/sql/utils.py | 61 ++++- python/pyspark/sql/window.py | 14 +- python/pyspark/testing/connectutils.py | 4 + 22 files changed, 1095 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index c3f73ed745d..c3cb6831e39 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -74,6 +74,7 @@ object PythonRunner { // Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() + sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url)) env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 73acfedd8bc..745836dfbef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -229,15 +229,20 @@ private[spark] class SparkSubmit extends Logging { var childMainClass = "" // Set the cluster manager - val clusterManager: Int = args.master match { - case "yarn" => YARN - case m if m.startsWith("spark") => STANDALONE - case m if m.startsWith("mesos") => MESOS - case m if m.startsWith("k8s") => KUBERNETES - case m if m.startsWith("local") => LOCAL - case _ => - error("Master must either be yarn or start with spark, mesos, k8s, or local") - -1 + val clusterManager: Int = args.maybeMaster match { + case Some(v) => + assert(args.maybeRemote.isEmpty) + v match { + case "yarn" => YARN + case m if m.startsWith("spark") => STANDALONE + case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("k8s") => KUBERNETES + case m if m.startsWith("local") => LOCAL + case _ => + error("Master must either be yarn or start with spark, mesos, k8s, or local") + -1 + } + case None => LOCAL // default master or remote mode. } // Set the deploy mode; default is client mode @@ -259,7 +264,7 @@ private[spark] class SparkSubmit extends Logging { } if (clusterManager == KUBERNETES) { - args.master = Utils.checkAndGetK8sMasterUrl(args.master) + args.maybeMaster = Option(Utils.checkAndGetK8sMasterUrl(args.master)) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { error( @@ -597,7 +602,11 @@ private[spark] class SparkSubmit extends Logging { val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), + OptionAssigner( + if (args.maybeRemote.isDefined) args.maybeMaster.orNull else args.master, + ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), + OptionAssigner( + args.maybeRemote.orNull, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.remote"), OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = SUBMIT_DEPLOY_MODE.key), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9a5123f218a..92c67d5156f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -41,7 +41,10 @@ import org.apache.spark.util.Utils */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) extends SparkSubmitArgumentsParser with Logging { - var master: String = null + var maybeMaster: Option[String] = None + // Global defaults. These should be keep to minimum to avoid confusing behavior. + def master: String = maybeMaster.getOrElse("local[*]") + var maybeRemote: Option[String] = None var deployMode: String = null var executorMemory: String = null var executorCores: String = null @@ -149,10 +152,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S * Load arguments from environment variables, Spark properties etc. */ private def loadEnvironmentArguments(): Unit = { - master = Option(master) + maybeMaster = maybeMaster .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) - .orNull + maybeRemote = maybeRemote + .orElse(sparkProperties.get("spark.remote")) + .orElse(env.get("SPARK_REMOTE")) + driverExtraClassPath = Option(driverExtraClassPath) .orElse(sparkProperties.get(config.DRIVER_CLASS_PATH.key)) .orNull @@ -210,9 +216,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S dynamicAllocationEnabled = sparkProperties.get(DYN_ALLOCATION_ENABLED.key).exists("true".equalsIgnoreCase) - // Global defaults. These should be keep to minimum to avoid confusing behavior. - master = Option(master).getOrElse("local[*]") - // In YARN mode, app name can be set via SPARK_YARN_APP_NAME (see SPARK-5222) if (master.startsWith("yarn")) { name = Option(name).orElse(env.get("SPARK_YARN_APP_NAME")).orNull @@ -242,6 +245,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (args.length == 0) { printUsageAndExit(-1) } + if (maybeRemote.isDefined && (maybeMaster.isDefined || deployMode != null)) { + error("Remote cannot be specified with master and/or deploy mode.") + } if (primaryResource == null) { error("Must specify a primary resource (JAR or Python or R file)") } @@ -299,6 +305,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S override def toString: String = { s"""Parsed arguments: | master $master + | remote ${maybeRemote.orNull} | deployMode $deployMode | executorMemory $executorMemory | executorCores $executorCores @@ -338,7 +345,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = value case MASTER => - master = value + maybeMaster = Option(value) + + case REMOTE => + maybeRemote = Option(value) case CLASS => mainClass = value @@ -539,6 +549,12 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --verbose, -v Print additional debug output. | --version, Print the version of current Spark. | + | Experimental options: + | --remote CONNECT_URL URL to connect to the server for Spark Connect, e.g., + | sc://host:port. --master and --deploy-mode cannot be set + | together with this option. This option is experimental, and + | might change between minor releases. + | | Cluster deploy mode only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1f1270c93fe..51f5246d741 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -512,6 +512,8 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_connect_basic", "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_column", + "pyspark.sql.tests.connect.test_parity_functions", + "pyspark.sql.tests.connect.test_parity_dataframe", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index c434941a585..b75410e11a5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -47,6 +47,7 @@ abstract class AbstractCommandBuilder { String javaHome; String mainClass; String master; + String remote; protected String propertiesFile; final List<String> appArgs; final List<String> jars; diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index a944950cf15..c085d2dae5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -87,6 +87,19 @@ public abstract class AbstractLauncher<T extends AbstractLauncher<T>> { return self(); } + /** + * Set the Spark master for the application. + * + * @param remote Spark remote url. + * @return This launcher. + */ + public T setRemote(String remote) { + checkNotNull(remote, "remote"); + builder.remote = remote; + return self(); + } + + /** * Set the deploy mode for the application. * @@ -163,6 +176,8 @@ public abstract class AbstractLauncher<T extends AbstractLauncher<T>> { SparkSubmitOptionParser validator = new ArgumentValidator(true); if (validator.MASTER.equals(name)) { setMaster(value); + } else if (validator.REMOTE.equals(name)) { + setRemote(value); } else if (validator.PROPERTIES_FILE.equals(name)) { setPropertiesFile(value); } else if (validator.CONF.equals(name)) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 25237da47ce..520a147751d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -192,6 +192,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { args.add(master); } + if (remote != null) { + args.add(parser.REMOTE); + args.add(remote); + } + if (deployMode != null) { args.add(parser.DEPLOY_MODE); args.add(deployMode); @@ -344,6 +349,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { // pass conf spark.pyspark.python to python by environment variable. env.put("PYSPARK_PYTHON", conf.get(SparkLauncher.PYSPARK_PYTHON)); } + env.put("SPARK_REMOTE", remote); if (!isEmpty(pyOpts)) { pyargs.addAll(parseOptionString(pyOpts)); } @@ -457,9 +463,18 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { protected boolean handle(String opt, String value) { switch (opt) { case MASTER: + checkArgument(remote == null, + "Both master (%s) and remote (%s) cannot be set together.", master, remote); master = value; break; + case REMOTE: + checkArgument(remote == null, + "Both master (%s) and remote (%s) cannot be set together.", master, remote); + remote = value; + break; case DEPLOY_MODE: + checkArgument(remote == null, + "Both deploy-mode (%s) and remote (%s) cannot be set together.", deployMode, remote); deployMode = value; break; case PROPERTIES_FILE: diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index c57af920294..ea54986daab 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -49,6 +49,7 @@ class SparkSubmitOptionParser { protected final String JARS = "--jars"; protected final String KILL_SUBMISSION = "--kill"; protected final String MASTER = "--master"; + protected final String REMOTE = "--remote"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; protected final String PACKAGES_EXCLUDE = "--exclude-packages"; @@ -103,6 +104,7 @@ class SparkSubmitOptionParser { { KEYTAB }, { KILL_SUBMISSION }, { MASTER }, + { REMOTE }, { NAME }, { NUM_EXECUTORS }, { PACKAGES }, diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index d4fb7270a77..15724306d75 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -36,6 +36,7 @@ See also :class:`SparkSession`. SparkSession.builder.enableHiveSupport SparkSession.builder.getOrCreate SparkSession.builder.master + SparkSession.builder.remote SparkSession.catalog SparkSession.conf SparkSession.createDataFrame diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9a7a8f46e84..f6e74493105 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -179,6 +179,10 @@ class SparkContext: udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler, memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler, ): + if "SPARK_REMOTE" in os.environ and "SPARK_TESTING" not in os.environ: + raise RuntimeError( + "Remote client cannot create a SparkContext. Create SparkSession instead." + ) if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true": # In order to prevent SparkContext from being created in executors. diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 9004a94e340..c1c2d4faacd 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -26,32 +26,50 @@ import os import platform import warnings +import pyspark from pyspark.context import SparkContext from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext +from pyspark.sql.utils import is_remote -if os.environ.get("SPARK_EXECUTOR_URI"): - SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) +if is_remote(): + try: + # Creates pyspark.sql.connect.SparkSession. + spark = SparkSession.builder.getOrCreate() + except Exception: + import sys + import traceback -SparkContext._ensure_initialized() + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) + version = pyspark.__version__ + sc = None +else: + if os.environ.get("SPARK_EXECUTOR_URI"): + SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -try: - spark = SparkSession._create_shell_session() -except Exception: - import sys - import traceback + SparkContext._ensure_initialized() - warnings.warn("Failed to initialize Spark session.") - traceback.print_exc(file=sys.stderr) - sys.exit(1) + try: + spark = SparkSession._create_shell_session() + except Exception: + import sys + import traceback -sc = spark.sparkContext -sql = spark.sql -atexit.register((lambda sc: lambda: sc.stop())(sc)) + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) + + sc = spark.sparkContext + atexit.register((lambda sc: lambda: sc.stop())(sc)) -# for compatibility -sqlContext = SQLContext._get_or_create(sc) -sqlCtx = sqlContext + # for compatibility + sqlContext = SQLContext._get_or_create(sc) + sqlCtx = sqlContext + version = sc.version + +sql = spark.sql print( r"""Welcome to @@ -61,14 +79,21 @@ print( /__ / .__/\_,_/_/ /_/\_\ version %s /_/ """ - % sc.version + % version ) print( "Using Python version %s (%s, %s)" % (platform.python_version(), platform.python_build()[0], platform.python_build()[1]) ) -print("Spark context Web UI available at %s" % (sc.uiWebUrl)) -print("Spark context available as 'sc' (master = %s, app id = %s)." % (sc.master, sc.applicationId)) +if is_remote(): + print("Client connected to the Spark Connect server at %s" % (os.environ["SPARK_REMOTE"])) +else: + print("Spark context Web UI available at %s" % (sc.uiWebUrl)) # type: ignore[union-attr] + print( + "Spark context available as 'sc' (master = %s, app id = %s)." + % (sc.master, sc.applicationId) # type: ignore[union-attr] + ) + print("SparkSession available as 'spark'.") # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 0a3d03110fc..672464cce0e 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -180,14 +180,14 @@ class SparkSession(object): return self.config("spark.app.name", name) def remote(self, location: str = "sc://localhost") -> "SparkSession.Builder": - return self.config("spark.connect.location", location) + return self.config("spark.remote", location) def enableHiveSupport(self) -> "SparkSession.Builder": raise NotImplementedError("enableHiveSupport not implemented for Spark Connect") def getOrCreate(self) -> "SparkSession": """Creates a new instance.""" - return SparkSession(connectionString=self._options["spark.connect.location"]) + return SparkSession(connectionString=self._options["spark.remote"]) _client: SparkConnectClient diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 958595d9188..f04b0a61438 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -48,10 +48,7 @@ from pyspark.sql.udf import UserDefinedFunction, _create_udf # noqa: F401 # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401 -from pyspark.sql.utils import to_str, has_numpy - -if has_numpy: - import numpy as np +from pyspark.sql.utils import to_str, has_numpy, try_remote_functions if TYPE_CHECKING: from pyspark.sql._typing import ( @@ -61,6 +58,8 @@ if TYPE_CHECKING: UserDefinedFunctionLike, ) +if has_numpy: + import numpy as np # Note to developers: all of PySpark functions here take string as column names whenever possible. # Namely, if columns are referred as arguments, they can always be both Column or string, @@ -126,6 +125,7 @@ def _options_to_str(options: Optional[Dict[str, Any]] = None) -> Dict[str, Optio return {} +@try_remote_functions def lit(col: Any) -> Column: """ Creates a :class:`~pyspark.sql.Column` of literal value. @@ -179,6 +179,7 @@ def lit(col: Any) -> Column: return _invoke_function("lit", col) +@try_remote_functions def col(col: str) -> Column: """ Returns a :class:`~pyspark.sql.Column` based on the given column name. @@ -208,6 +209,7 @@ def col(col: str) -> Column: column = col +@try_remote_functions def asc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the ascending order of the given column name. @@ -257,6 +259,7 @@ def asc(col: "ColumnOrName") -> Column: return col.asc() if isinstance(col, Column) else _invoke_function("asc", col) +@try_remote_functions def desc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the descending order of the given column name. @@ -291,6 +294,7 @@ def desc(col: "ColumnOrName") -> Column: return col.desc() if isinstance(col, Column) else _invoke_function("desc", col) +@try_remote_functions def sqrt(col: "ColumnOrName") -> Column: """ Computes the square root of the specified float value. @@ -320,6 +324,7 @@ def sqrt(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sqrt", col) +@try_remote_functions def abs(col: "ColumnOrName") -> Column: """ Computes the absolute value. @@ -349,6 +354,7 @@ def abs(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("abs", col) +@try_remote_functions def mode(col: "ColumnOrName") -> Column: """ Returns the most frequent value in a group. @@ -383,6 +389,7 @@ def mode(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("mode", col) +@try_remote_functions def max(col: "ColumnOrName") -> Column: """ Aggregate function: returns the maximum value of the expression in a group. @@ -412,6 +419,7 @@ def max(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("max", col) +@try_remote_functions def min(col: "ColumnOrName") -> Column: """ Aggregate function: returns the minimum value of the expression in a group. @@ -441,6 +449,7 @@ def min(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("min", col) +@try_remote_functions def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: """ Returns the value associated with the maximum value of ord. @@ -476,6 +485,7 @@ def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: return _invoke_function_over_columns("max_by", col, ord) +@try_remote_functions def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: """ Returns the value associated with the minimum value of ord. @@ -511,6 +521,7 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: return _invoke_function_over_columns("min_by", col, ord) +@try_remote_functions def count(col: "ColumnOrName") -> Column: """ Aggregate function: returns the number of items in a group. @@ -542,6 +553,7 @@ def count(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("count", col) +@try_remote_functions def sum(col: "ColumnOrName") -> Column: """ Aggregate function: returns the sum of all values in the expression. @@ -571,6 +583,7 @@ def sum(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sum", col) +@try_remote_functions def avg(col: "ColumnOrName") -> Column: """ Aggregate function: returns the average of the values in a group. @@ -600,6 +613,7 @@ def avg(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("avg", col) +@try_remote_functions def mean(col: "ColumnOrName") -> Column: """ Aggregate function: returns the average of the values in a group. @@ -630,6 +644,7 @@ def mean(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("mean", col) +@try_remote_functions def median(col: "ColumnOrName") -> Column: """ Returns the median of the values in a group. @@ -664,6 +679,7 @@ def median(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("median", col) +@try_remote_functions def sumDistinct(col: "ColumnOrName") -> Column: """ Aggregate function: returns the sum of distinct values in the expression. @@ -677,6 +693,7 @@ def sumDistinct(col: "ColumnOrName") -> Column: return sum_distinct(col) +@try_remote_functions def sum_distinct(col: "ColumnOrName") -> Column: """ Aggregate function: returns the sum of distinct values in the expression. @@ -706,6 +723,7 @@ def sum_distinct(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sum_distinct", col) +@try_remote_functions def product(col: "ColumnOrName") -> Column: """ Aggregate function: returns the product of the values in a group. @@ -738,6 +756,7 @@ def product(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("product", col) +@try_remote_functions def acos(col: "ColumnOrName") -> Column: """ Computes inverse cosine of the input column. @@ -768,6 +787,7 @@ def acos(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("acos", col) +@try_remote_functions def acosh(col: "ColumnOrName") -> Column: """ Computes inverse hyperbolic cosine of the input column. @@ -798,6 +818,7 @@ def acosh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("acosh", col) +@try_remote_functions def asin(col: "ColumnOrName") -> Column: """ Computes inverse sine of the input column. @@ -828,6 +849,7 @@ def asin(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("asin", col) +@try_remote_functions def asinh(col: "ColumnOrName") -> Column: """ Computes inverse hyperbolic sine of the input column. @@ -857,6 +879,7 @@ def asinh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("asinh", col) +@try_remote_functions def atan(col: "ColumnOrName") -> Column: """ Compute inverse tangent of the input column. @@ -886,6 +909,7 @@ def atan(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("atan", col) +@try_remote_functions def atanh(col: "ColumnOrName") -> Column: """ Computes inverse hyperbolic tangent of the input column. @@ -916,6 +940,7 @@ def atanh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("atanh", col) +@try_remote_functions def cbrt(col: "ColumnOrName") -> Column: """ Computes the cube-root of the given value. @@ -945,6 +970,7 @@ def cbrt(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("cbrt", col) +@try_remote_functions def ceil(col: "ColumnOrName") -> Column: """ Computes the ceiling of the given value. @@ -974,6 +1000,7 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) +@try_remote_functions def cos(col: "ColumnOrName") -> Column: """ Computes cosine of the input column. @@ -1000,6 +1027,7 @@ def cos(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("cos", col) +@try_remote_functions def cosh(col: "ColumnOrName") -> Column: """ Computes hyperbolic cosine of the input column. @@ -1025,6 +1053,7 @@ def cosh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("cosh", col) +@try_remote_functions def cot(col: "ColumnOrName") -> Column: """ Computes cotangent of the input column. @@ -1051,6 +1080,7 @@ def cot(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("cot", col) +@try_remote_functions def csc(col: "ColumnOrName") -> Column: """ Computes cosecant of the input column. @@ -1077,6 +1107,7 @@ def csc(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("csc", col) +@try_remote_functions def exp(col: "ColumnOrName") -> Column: """ Computes the exponential of the given value. @@ -1106,6 +1137,7 @@ def exp(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("exp", col) +@try_remote_functions def expm1(col: "ColumnOrName") -> Column: """ Computes the exponential of the given value minus one. @@ -1131,6 +1163,7 @@ def expm1(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("expm1", col) +@try_remote_functions def floor(col: "ColumnOrName") -> Column: """ Computes the floor of the given value. @@ -1160,6 +1193,7 @@ def floor(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("floor", col) +@try_remote_functions def log(col: "ColumnOrName") -> Column: """ Computes the natural logarithm of the given value. @@ -1186,6 +1220,7 @@ def log(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("log", col) +@try_remote_functions def log10(col: "ColumnOrName") -> Column: """ Computes the logarithm of the given value in Base 10. @@ -1215,6 +1250,7 @@ def log10(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("log10", col) +@try_remote_functions def log1p(col: "ColumnOrName") -> Column: """ Computes the natural logarithm of the "given value plus one". @@ -1246,6 +1282,7 @@ def log1p(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("log1p", col) +@try_remote_functions def rint(col: "ColumnOrName") -> Column: """ Returns the double value that is closest in value to the argument and @@ -1283,6 +1320,7 @@ def rint(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("rint", col) +@try_remote_functions def sec(col: "ColumnOrName") -> Column: """ Computes secant of the input column. @@ -1308,6 +1346,7 @@ def sec(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sec", col) +@try_remote_functions def signum(col: "ColumnOrName") -> Column: """ Computes the signum of the given value. @@ -1344,6 +1383,7 @@ def signum(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("signum", col) +@try_remote_functions def sin(col: "ColumnOrName") -> Column: """ Computes sine of the input column. @@ -1370,6 +1410,7 @@ def sin(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sin", col) +@try_remote_functions def sinh(col: "ColumnOrName") -> Column: """ Computes hyperbolic sine of the input column. @@ -1396,6 +1437,7 @@ def sinh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sinh", col) +@try_remote_functions def tan(col: "ColumnOrName") -> Column: """ Computes tangent of the input column. @@ -1422,6 +1464,7 @@ def tan(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("tan", col) +@try_remote_functions def tanh(col: "ColumnOrName") -> Column: """ Computes hyperbolic tangent of the input column. @@ -1449,6 +1492,7 @@ def tanh(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("tanh", col) +@try_remote_functions def toDegrees(col: "ColumnOrName") -> Column: """ .. versionadded:: 1.4.0 @@ -1460,6 +1504,7 @@ def toDegrees(col: "ColumnOrName") -> Column: return degrees(col) +@try_remote_functions def toRadians(col: "ColumnOrName") -> Column: """ .. versionadded:: 1.4.0 @@ -1471,6 +1516,7 @@ def toRadians(col: "ColumnOrName") -> Column: return radians(col) +@try_remote_functions def bitwiseNOT(col: "ColumnOrName") -> Column: """ Computes bitwise not. @@ -1484,6 +1530,7 @@ def bitwiseNOT(col: "ColumnOrName") -> Column: return bitwise_not(col) +@try_remote_functions def bitwise_not(col: "ColumnOrName") -> Column: """ Computes bitwise not. @@ -1519,6 +1566,7 @@ def bitwise_not(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("bitwise_not", col) +@try_remote_functions def asc_nulls_first(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the ascending order of the given @@ -1558,6 +1606,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: ) +@try_remote_functions def asc_nulls_last(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the ascending order of the given @@ -1595,6 +1644,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: ) +@try_remote_functions def desc_nulls_first(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the descending order of the given @@ -1634,6 +1684,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: ) +@try_remote_functions def desc_nulls_last(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the descending order of the given @@ -1673,6 +1724,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: ) +@try_remote_functions def stddev(col: "ColumnOrName") -> Column: """ Aggregate function: alias for stddev_samp. @@ -1698,6 +1750,7 @@ def stddev(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("stddev", col) +@try_remote_functions def stddev_samp(col: "ColumnOrName") -> Column: """ Aggregate function: returns the unbiased sample standard deviation of @@ -1724,6 +1777,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("stddev_samp", col) +@try_remote_functions def stddev_pop(col: "ColumnOrName") -> Column: """ Aggregate function: returns population standard deviation of @@ -1750,6 +1804,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("stddev_pop", col) +@try_remote_functions def variance(col: "ColumnOrName") -> Column: """ Aggregate function: alias for var_samp @@ -1779,6 +1834,7 @@ def variance(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("variance", col) +@try_remote_functions def var_samp(col: "ColumnOrName") -> Column: """ Aggregate function: returns the unbiased sample variance of @@ -1809,6 +1865,7 @@ def var_samp(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("var_samp", col) +@try_remote_functions def var_pop(col: "ColumnOrName") -> Column: """ Aggregate function: returns the population variance of the values in a group. @@ -1834,6 +1891,7 @@ def var_pop(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("var_pop", col) +@try_remote_functions def skewness(col: "ColumnOrName") -> Column: """ Aggregate function: returns the skewness of the values in a group. @@ -1859,6 +1917,7 @@ def skewness(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("skewness", col) +@try_remote_functions def kurtosis(col: "ColumnOrName") -> Column: """ Aggregate function: returns the kurtosis of the values in a group. @@ -1888,6 +1947,7 @@ def kurtosis(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("kurtosis", col) +@try_remote_functions def collect_list(col: "ColumnOrName") -> Column: """ Aggregate function: returns a list of objects with duplicates. @@ -1918,6 +1978,7 @@ def collect_list(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("collect_list", col) +@try_remote_functions def collect_set(col: "ColumnOrName") -> Column: """ Aggregate function: returns a set of objects with duplicate elements eliminated. @@ -1948,6 +2009,7 @@ def collect_set(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("collect_set", col) +@try_remote_functions def degrees(col: "ColumnOrName") -> Column: """ Converts an angle measured in radians to an approximately equivalent angle @@ -1975,6 +2037,7 @@ def degrees(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("degrees", col) +@try_remote_functions def radians(col: "ColumnOrName") -> Column: """ Converts an angle measured in degrees to an approximately equivalent angle @@ -2001,6 +2064,7 @@ def radians(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("radians", col) +@try_remote_functions def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: """ .. versionadded:: 1.4.0 @@ -2030,6 +2094,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] return _invoke_binary_math_function("atan2", col1, col2) +@try_remote_functions def hypot(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: """ Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow. @@ -2057,6 +2122,7 @@ def hypot(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] return _invoke_binary_math_function("hypot", col1, col2) +@try_remote_functions def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: """ Returns the value of the first argument raised to the power of the second argument. @@ -2084,6 +2150,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) return _invoke_binary_math_function("pow", col1, col2) +@try_remote_functions def pmod(dividend: Union["ColumnOrName", float], divisor: Union["ColumnOrName", float]) -> Column: """ Returns the positive value of dividend mod divisor. @@ -2128,6 +2195,7 @@ def pmod(dividend: Union["ColumnOrName", float], divisor: Union["ColumnOrName", return _invoke_binary_math_function("pmod", dividend, divisor) +@try_remote_functions def row_number() -> Column: """ Window function: returns a sequential number starting at 1 within a window partition. @@ -2156,6 +2224,7 @@ def row_number() -> Column: return _invoke_function("row_number") +@try_remote_functions def dense_rank() -> Column: """ Window function: returns the rank of rows within a window partition, without any gaps. @@ -2195,6 +2264,7 @@ def dense_rank() -> Column: return _invoke_function("dense_rank") +@try_remote_functions def rank() -> Column: """ Window function: returns the rank of rows within a window partition. @@ -2234,6 +2304,7 @@ def rank() -> Column: return _invoke_function("rank") +@try_remote_functions def cume_dist() -> Column: """ Window function: returns the cumulative distribution of values within a window partition, @@ -2265,6 +2336,7 @@ def cume_dist() -> Column: return _invoke_function("cume_dist") +@try_remote_functions def percent_rank() -> Column: """ Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -2296,6 +2368,7 @@ def percent_rank() -> Column: return _invoke_function("percent_rank") +@try_remote_functions def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: """ .. versionadded:: 1.3.0 @@ -2307,6 +2380,7 @@ def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Col return approx_count_distinct(col, rsd) +@try_remote_functions def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: """Aggregate function: returns a new :class:`~pyspark.sql.Column` for approximate distinct count of column `col`. @@ -2341,6 +2415,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C return _invoke_function("approx_count_distinct", _to_java_column(col), rsd) +@try_remote_functions def broadcast(df: DataFrame) -> DataFrame: """ Marks a DataFrame as small enough for use in broadcast joins. @@ -2372,6 +2447,7 @@ def broadcast(df: DataFrame) -> DataFrame: return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession) +@try_remote_functions def coalesce(*cols: "ColumnOrName") -> Column: """Returns the first column that is not null. @@ -2420,6 +2496,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("coalesce", cols) +@try_remote_functions def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """Returns a new :class:`~pyspark.sql.Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. @@ -2449,6 +2526,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("corr", col1, col2) +@try_remote_functions def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """Returns a new :class:`~pyspark.sql.Column` for the population covariance of ``col1`` and ``col2``. @@ -2478,6 +2556,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("covar_pop", col1, col2) +@try_remote_functions def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """Returns a new :class:`~pyspark.sql.Column` for the sample covariance of ``col1`` and ``col2``. @@ -2507,6 +2586,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("covar_samp", col1, col2) +@try_remote_functions def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: """Returns a new :class:`~pyspark.sql.Column` for distinct count of ``col`` or ``cols``. @@ -2518,6 +2598,7 @@ def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: return count_distinct(col, *cols) +@try_remote_functions def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -2565,6 +2646,7 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: ) +@try_remote_functions def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: """Aggregate function: returns the first value in a group. @@ -2615,6 +2697,7 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: return _invoke_function("first", _to_java_column(col), ignorenulls) +@try_remote_functions def grouping(col: "ColumnOrName") -> Column: """ Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -2647,6 +2730,7 @@ def grouping(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("grouping", col) +@try_remote_functions def grouping_id(*cols: "ColumnOrName") -> Column: """ Aggregate function: returns the level of grouping, equals to @@ -2691,6 +2775,7 @@ def grouping_id(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("grouping_id", cols) +@try_remote_functions def input_file_name() -> Column: """ Creates a string column for the file name of the current Spark task. @@ -2713,6 +2798,7 @@ def input_file_name() -> Column: return _invoke_function("input_file_name") +@try_remote_functions def isnan(col: "ColumnOrName") -> Column: """An expression that returns true if the column is NaN. @@ -2742,6 +2828,7 @@ def isnan(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("isnan", col) +@try_remote_functions def isnull(col: "ColumnOrName") -> Column: """An expression that returns true if the column is null. @@ -2771,6 +2858,7 @@ def isnull(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("isnull", col) +@try_remote_functions def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: """Aggregate function: returns the last value in a group. @@ -2821,6 +2909,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: return _invoke_function("last", _to_java_column(col), ignorenulls) +@try_remote_functions def monotonically_increasing_id() -> Column: """A column that generates monotonically increasing 64-bit integers. @@ -2853,6 +2942,7 @@ def monotonically_increasing_id() -> Column: return _invoke_function("monotonically_increasing_id") +@try_remote_functions def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -2881,6 +2971,7 @@ def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("nanvl", col1, col2) +@try_remote_functions def percentile_approx( col: "ColumnOrName", percentage: Union[Column, float, List[float], Tuple[float]], @@ -2955,6 +3046,7 @@ def percentile_approx( return _invoke_function("percentile_approx", _to_java_column(col), percentage, accuracy) +@try_remote_functions def rand(seed: Optional[int] = None) -> Column: """Generates a random column with independent and identically distributed (i.i.d.) samples uniformly distributed in [0.0, 1.0). @@ -2992,6 +3084,7 @@ def rand(seed: Optional[int] = None) -> Column: return _invoke_function("rand") +@try_remote_functions def randn(seed: Optional[int] = None) -> Column: """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. @@ -3029,6 +3122,7 @@ def randn(seed: Optional[int] = None) -> Column: return _invoke_function("randn") +@try_remote_functions def round(col: "ColumnOrName", scale: int = 0) -> Column: """ Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 @@ -3056,6 +3150,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: return _invoke_function("round", _to_java_column(col), scale) +@try_remote_functions def bround(col: "ColumnOrName", scale: int = 0) -> Column: """ Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 @@ -3083,6 +3178,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: return _invoke_function("bround", _to_java_column(col), scale) +@try_remote_functions def shiftLeft(col: "ColumnOrName", numBits: int) -> Column: """Shift the given value numBits left. @@ -3095,6 +3191,7 @@ def shiftLeft(col: "ColumnOrName", numBits: int) -> Column: return shiftleft(col, numBits) +@try_remote_functions def shiftleft(col: "ColumnOrName", numBits: int) -> Column: """Shift the given value numBits left. @@ -3120,6 +3217,7 @@ def shiftleft(col: "ColumnOrName", numBits: int) -> Column: return _invoke_function("shiftleft", _to_java_column(col), numBits) +@try_remote_functions def shiftRight(col: "ColumnOrName", numBits: int) -> Column: """(Signed) shift the given value numBits right. @@ -3132,6 +3230,7 @@ def shiftRight(col: "ColumnOrName", numBits: int) -> Column: return shiftright(col, numBits) +@try_remote_functions def shiftright(col: "ColumnOrName", numBits: int) -> Column: """(Signed) shift the given value numBits right. @@ -3157,6 +3256,7 @@ def shiftright(col: "ColumnOrName", numBits: int) -> Column: return _invoke_function("shiftright", _to_java_column(col), numBits) +@try_remote_functions def shiftRightUnsigned(col: "ColumnOrName", numBits: int) -> Column: """Unsigned shift the given value numBits right. @@ -3169,6 +3269,7 @@ def shiftRightUnsigned(col: "ColumnOrName", numBits: int) -> Column: return shiftrightunsigned(col, numBits) +@try_remote_functions def shiftrightunsigned(col: "ColumnOrName", numBits: int) -> Column: """Unsigned shift the given value numBits right. @@ -3195,6 +3296,7 @@ def shiftrightunsigned(col: "ColumnOrName", numBits: int) -> Column: return _invoke_function("shiftrightunsigned", _to_java_column(col), numBits) +@try_remote_functions def spark_partition_id() -> Column: """A column for partition ID. @@ -3218,6 +3320,7 @@ def spark_partition_id() -> Column: return _invoke_function("spark_partition_id") +@try_remote_functions def expr(str: str) -> Column: """Parses the expression string into the column that it represents @@ -3257,6 +3360,7 @@ def struct(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> ... +@try_remote_functions def struct( *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] ) -> Column: @@ -3287,6 +3391,7 @@ def struct( return _invoke_function_over_seq_of_columns("struct", cols) # type: ignore[arg-type] +@try_remote_functions def greatest(*cols: "ColumnOrName") -> Column: """ Returns the greatest value of the list of column names, skipping null values. @@ -3315,6 +3420,7 @@ def greatest(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("greatest", cols) +@try_remote_functions def least(*cols: "ColumnOrName") -> Column: """ Returns the least value of the list of column names, skipping null values. @@ -3343,6 +3449,7 @@ def least(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("least", cols) +@try_remote_functions def when(condition: Column, value: Any) -> Column: """Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched @@ -3401,6 +3508,7 @@ def log(arg1: float, arg2: "ColumnOrName") -> Column: ... +@try_remote_functions def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = None) -> Column: """Returns the first argument-based logarithm of the second argument. @@ -3449,6 +3557,7 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non return _invoke_function("log", arg1, _to_java_column(arg2)) +@try_remote_functions def log2(col: "ColumnOrName") -> Column: """Returns the base-2 logarithm of the argument. @@ -3477,6 +3586,7 @@ def log2(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("log2", col) +@try_remote_functions def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: """ Convert a number in a string column from one base to another. @@ -3506,6 +3616,7 @@ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: return _invoke_function("conv", _to_java_column(col), fromBase, toBase) +@try_remote_functions def factorial(col: "ColumnOrName") -> Column: """ Computes the factorial of the given value. @@ -3534,6 +3645,7 @@ def factorial(col: "ColumnOrName") -> Column: # --------------- Window functions ------------------------ +@try_remote_functions def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: """ Window function: returns the value that is `offset` rows before the current row, and @@ -3611,6 +3723,7 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> return _invoke_function("lag", _to_java_column(col), offset, default) +@try_remote_functions def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: """ Window function: returns the value that is `offset` rows after the current row, and @@ -3688,6 +3801,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> return _invoke_function("lead", _to_java_column(col), offset, default) +@try_remote_functions def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: """ Window function: returns the value that is the `offset`\\th row of the window frame @@ -3758,6 +3872,7 @@ def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = Fa return _invoke_function("nth_value", _to_java_column(col), offset, ignoreNulls) +@try_remote_functions def ntile(n: int) -> Column: """ Window function: returns the ntile group id (from 1 to `n` inclusive) @@ -3815,6 +3930,7 @@ def ntile(n: int) -> Column: # ---------------------- Date/Timestamp functions ------------------------------ +@try_remote_functions def current_date() -> Column: """ Returns the current date at the start of query evaluation as a :class:`DateType` column. @@ -3840,6 +3956,7 @@ def current_date() -> Column: return _invoke_function("current_date") +@try_remote_functions def current_timestamp() -> Column: """ Returns the current timestamp at the start of query evaluation as a :class:`TimestampType` @@ -3865,6 +3982,7 @@ def current_timestamp() -> Column: return _invoke_function("current_timestamp") +@try_remote_functions def localtimestamp() -> Column: """ Returns the current timestamp without time zone at the start of query evaluation @@ -3891,6 +4009,7 @@ def localtimestamp() -> Column: return _invoke_function("localtimestamp") +@try_remote_functions def date_format(date: "ColumnOrName", format: str) -> Column: """ Converts a date/timestamp/string to a value of string in the format specified by the date @@ -3928,6 +4047,7 @@ def date_format(date: "ColumnOrName", format: str) -> Column: return _invoke_function("date_format", _to_java_column(date), format) +@try_remote_functions def year(col: "ColumnOrName") -> Column: """ Extract the year of a given date/timestamp as integer. @@ -3953,6 +4073,7 @@ def year(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("year", col) +@try_remote_functions def quarter(col: "ColumnOrName") -> Column: """ Extract the quarter of a given date/timestamp as integer. @@ -3978,6 +4099,7 @@ def quarter(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("quarter", col) +@try_remote_functions def month(col: "ColumnOrName") -> Column: """ Extract the month of a given date/timestamp as integer. @@ -4003,6 +4125,7 @@ def month(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("month", col) +@try_remote_functions def dayofweek(col: "ColumnOrName") -> Column: """ Extract the day of the week of a given date/timestamp as integer. @@ -4029,6 +4152,7 @@ def dayofweek(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("dayofweek", col) +@try_remote_functions def dayofmonth(col: "ColumnOrName") -> Column: """ Extract the day of the month of a given date/timestamp as integer. @@ -4054,6 +4178,7 @@ def dayofmonth(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("dayofmonth", col) +@try_remote_functions def dayofyear(col: "ColumnOrName") -> Column: """ Extract the day of the year of a given date/timestamp as integer. @@ -4079,6 +4204,7 @@ def dayofyear(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("dayofyear", col) +@try_remote_functions def hour(col: "ColumnOrName") -> Column: """ Extract the hours of a given timestamp as integer. @@ -4105,6 +4231,7 @@ def hour(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("hour", col) +@try_remote_functions def minute(col: "ColumnOrName") -> Column: """ Extract the minutes of a given timestamp as integer. @@ -4131,6 +4258,7 @@ def minute(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("minute", col) +@try_remote_functions def second(col: "ColumnOrName") -> Column: """ Extract the seconds of a given date as integer. @@ -4157,6 +4285,7 @@ def second(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("second", col) +@try_remote_functions def weekofyear(col: "ColumnOrName") -> Column: """ Extract the week number of a given date as integer. @@ -4184,6 +4313,7 @@ def weekofyear(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("weekofyear", col) +@try_remote_functions def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") -> Column: """ Returns a column with a date built from the year, month and day columns. @@ -4213,6 +4343,7 @@ def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") return _invoke_function_over_columns("make_date", year, month, day) +@try_remote_functions def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: """ Returns the date that is `days` days after `start`. If `days` is a negative value @@ -4247,6 +4378,7 @@ def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("date_add", start, days) +@try_remote_functions def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: """ Returns the date that is `days` days before `start`. If `days` is a negative value @@ -4281,6 +4413,7 @@ def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("date_sub", start, days) +@try_remote_functions def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column: """ Returns the number of days from `start` to `end`. @@ -4308,6 +4441,7 @@ def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column: return _invoke_function_over_columns("datediff", end, start) +@try_remote_functions def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: """ Returns the date that is `months` months after `start`. If `months` is a negative value @@ -4342,6 +4476,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col return _invoke_function_over_columns("add_months", start, months) +@try_remote_functions def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool = True) -> Column: """ Returns number of months between dates date1 and date2. @@ -4379,6 +4514,7 @@ def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool ) +@try_remote_functions def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to `datetime pattern`_. @@ -4427,6 +4563,7 @@ def to_timestamp(col: "ColumnOrName", format: str) -> Column: ... +@try_remote_functions def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.TimestampType` using the optionally specified format. Specify formats according to `datetime pattern`_. @@ -4465,6 +4602,7 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: return _invoke_function("to_timestamp", _to_java_column(col), format) +@try_remote_functions def trunc(date: "ColumnOrName", format: str) -> Column: """ Returns date truncated to the unit specified by the format. @@ -4496,6 +4634,7 @@ def trunc(date: "ColumnOrName", format: str) -> Column: return _invoke_function("trunc", _to_java_column(date), format) +@try_remote_functions def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: """ Returns timestamp truncated to the unit specified by the format. @@ -4529,6 +4668,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: return _invoke_function("date_trunc", format, _to_java_column(timestamp)) +@try_remote_functions def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column: """ Returns the first date which is later than the value of the date column @@ -4558,6 +4698,7 @@ def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column: return _invoke_function("next_day", _to_java_column(date), dayOfWeek) +@try_remote_functions def last_day(date: "ColumnOrName") -> Column: """ Returns the last day of the month which the given date belongs to. @@ -4583,6 +4724,7 @@ def last_day(date: "ColumnOrName") -> Column: return _invoke_function("last_day", _to_java_column(date)) +@try_remote_functions def from_unixtime(timestamp: "ColumnOrName", format: str = "yyyy-MM-dd HH:mm:ss") -> Column: """ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -4624,6 +4766,7 @@ def unix_timestamp() -> Column: ... +@try_remote_functions def unix_timestamp( timestamp: Optional["ColumnOrName"] = None, format: str = "yyyy-MM-dd HH:mm:ss" ) -> Column: @@ -4661,6 +4804,7 @@ def unix_timestamp( return _invoke_function("unix_timestamp", _to_java_column(timestamp), format) +@try_remote_functions def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: """ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function @@ -4711,6 +4855,7 @@ def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: return _invoke_function("from_utc_timestamp", _to_java_column(timestamp), tz) +@try_remote_functions def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: """ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function @@ -4761,6 +4906,7 @@ def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: return _invoke_function("to_utc_timestamp", _to_java_column(timestamp), tz) +@try_remote_functions def timestamp_seconds(col: "ColumnOrName") -> Column: """ Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) @@ -4798,6 +4944,7 @@ def timestamp_seconds(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("timestamp_seconds", col) +@try_remote_functions def window( timeColumn: "ColumnOrName", windowDuration: str, @@ -4884,6 +5031,7 @@ def window( return _invoke_function("window", time_col, windowDuration) +@try_remote_functions def window_time( windowColumn: "ColumnOrName", ) -> Column: @@ -4930,6 +5078,7 @@ def window_time( return _invoke_function("window_time", window_col) +@try_remote_functions def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) -> Column: """ Generates session window given a timestamp specifying column. @@ -4991,6 +5140,7 @@ def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) # ---------------------------- misc functions ---------------------------------- +@try_remote_functions def crc32(col: "ColumnOrName") -> Column: """ Calculates the cyclic redundancy check value (CRC32) of a binary column and @@ -5016,6 +5166,7 @@ def crc32(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("crc32", col) +@try_remote_functions def md5(col: "ColumnOrName") -> Column: """Calculates the MD5 digest and returns the value as a 32 character hex string. @@ -5039,6 +5190,7 @@ def md5(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("md5", col) +@try_remote_functions def sha1(col: "ColumnOrName") -> Column: """Returns the hex string result of SHA-1. @@ -5062,6 +5214,7 @@ def sha1(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("sha1", col) +@try_remote_functions def sha2(col: "ColumnOrName", numBits: int) -> Column: """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, and SHA-512). The numBits indicates the desired bit length of the result, which must have a @@ -5096,6 +5249,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: return _invoke_function("sha2", _to_java_column(col), numBits) +@try_remote_functions def hash(*cols: "ColumnOrName") -> Column: """Calculates the hash code of given columns, and returns the result as an int column. @@ -5136,6 +5290,7 @@ def hash(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("hash", cols) +@try_remote_functions def xxhash64(*cols: "ColumnOrName") -> Column: """Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm, and returns the result as a long column. The hash computation uses an initial seed of 42. @@ -5177,6 +5332,7 @@ def xxhash64(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("xxhash64", cols) +@try_remote_functions def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column: """ Returns `null` if the input column is `true`; throws an exception @@ -5221,6 +5377,7 @@ def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None return _invoke_function("assert_true", _to_java_column(col), errMsg) +@try_remote_functions def raise_error(errMsg: Union[Column, str]) -> Column: """ Throws an exception with the provided error message. @@ -5257,6 +5414,7 @@ def raise_error(errMsg: Union[Column, str]) -> Column: # ---------------------- String/Binary functions ------------------------------ +@try_remote_functions def upper(col: "ColumnOrName") -> Column: """ Converts a string expression to upper case. @@ -5288,6 +5446,7 @@ def upper(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("upper", col) +@try_remote_functions def lower(col: "ColumnOrName") -> Column: """ Converts a string expression to lower case. @@ -5319,6 +5478,7 @@ def lower(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("lower", col) +@try_remote_functions def ascii(col: "ColumnOrName") -> Column: """ Computes the numeric value of the first character of the string column. @@ -5350,6 +5510,7 @@ def ascii(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ascii", col) +@try_remote_functions def base64(col: "ColumnOrName") -> Column: """ Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -5381,6 +5542,7 @@ def base64(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("base64", col) +@try_remote_functions def unbase64(col: "ColumnOrName") -> Column: """ Decodes a BASE64 encoded string column and returns it as a binary column. @@ -5414,6 +5576,7 @@ def unbase64(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unbase64", col) +@try_remote_functions def ltrim(col: "ColumnOrName") -> Column: """ Trim the spaces from left end for the specified string value. @@ -5445,6 +5608,7 @@ def ltrim(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ltrim", col) +@try_remote_functions def rtrim(col: "ColumnOrName") -> Column: """ Trim the spaces from right end for the specified string value. @@ -5476,6 +5640,7 @@ def rtrim(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("rtrim", col) +@try_remote_functions def trim(col: "ColumnOrName") -> Column: """ Trim the spaces from both ends for the specified string column. @@ -5507,6 +5672,7 @@ def trim(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("trim", col) +@try_remote_functions def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: """ Concatenates multiple input string columns together into a single string column, @@ -5537,6 +5703,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: return _invoke_function("concat_ws", sep, _to_seq(sc, cols, _to_java_column)) +@try_remote_functions def decode(col: "ColumnOrName", charset: str) -> Column: """ Computes the first argument into a string from a binary using the provided character set @@ -5569,6 +5736,7 @@ def decode(col: "ColumnOrName", charset: str) -> Column: return _invoke_function("decode", _to_java_column(col), charset) +@try_remote_functions def encode(col: "ColumnOrName", charset: str) -> Column: """ Computes the first argument into a binary from a string using the provided character set @@ -5601,6 +5769,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: return _invoke_function("encode", _to_java_column(col), charset) +@try_remote_functions def format_number(col: "ColumnOrName", d: int) -> Column: """ Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places @@ -5626,6 +5795,7 @@ def format_number(col: "ColumnOrName", d: int) -> Column: return _invoke_function("format_number", _to_java_column(col), d) +@try_remote_functions def format_string(format: str, *cols: "ColumnOrName") -> Column: """ Formats the arguments in printf-style and returns the result as a string column. @@ -5655,6 +5825,7 @@ def format_string(format: str, *cols: "ColumnOrName") -> Column: return _invoke_function("format_string", format, _to_seq(sc, cols, _to_java_column)) +@try_remote_functions def instr(str: "ColumnOrName", substr: str) -> Column: """ Locate the position of the first occurrence of substr column in the given string. @@ -5688,6 +5859,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: return _invoke_function("instr", _to_java_column(str), substr) +@try_remote_functions def overlay( src: "ColumnOrName", replace: "ColumnOrName", @@ -5742,6 +5914,7 @@ def overlay( return _invoke_function("overlay", _to_java_column(src), _to_java_column(replace), pos, len) +@try_remote_functions def sentences( string: "ColumnOrName", language: Optional["ColumnOrName"] = None, @@ -5792,6 +5965,7 @@ def sentences( return _invoke_function_over_columns("sentences", string, language, country) +@try_remote_functions def substring(str: "ColumnOrName", pos: int, len: int) -> Column: """ Substring starts at `pos` and is of length `len` when str is String type or @@ -5827,6 +6001,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: return _invoke_function("substring", _to_java_column(str), pos, len) +@try_remote_functions def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: """ Returns the substring from string str before count occurrences of the delimiter delim. @@ -5861,6 +6036,7 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: return _invoke_function("substring_index", _to_java_column(str), delim, count) +@try_remote_functions def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: """Computes the Levenshtein distance of the two given strings. @@ -5887,6 +6063,7 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: return _invoke_function_over_columns("levenshtein", left, right) +@try_remote_functions def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: """ Locate the position of the first occurrence of substr in a string column, after position pos. @@ -5921,6 +6098,7 @@ def locate(substr: str, str: "ColumnOrName", pos: int = 1) -> Column: return _invoke_function("locate", substr, _to_java_column(str), pos) +@try_remote_functions def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: """ Left-pad the string column to width `len` with `pad`. @@ -5950,6 +6128,7 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: return _invoke_function("lpad", _to_java_column(col), len, pad) +@try_remote_functions def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: """ Right-pad the string column to width `len` with `pad`. @@ -5979,6 +6158,7 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: return _invoke_function("rpad", _to_java_column(col), len, pad) +@try_remote_functions def repeat(col: "ColumnOrName", n: int) -> Column: """ Repeats a string column n times, and returns it as a new string column. @@ -6006,6 +6186,7 @@ def repeat(col: "ColumnOrName", n: int) -> Column: return _invoke_function("repeat", _to_java_column(col), n) +@try_remote_functions def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: """ Splits str around matches of the given pattern. @@ -6047,6 +6228,7 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: return _invoke_function("split", _to_java_column(str), pattern, limit) +@try_remote_functions def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: r"""Extract a specific group matched by a Java regex, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. @@ -6082,6 +6264,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: return _invoke_function("regexp_extract", _to_java_column(str), pattern, idx) +@try_remote_functions def regexp_replace( string: "ColumnOrName", pattern: Union[str, Column], replacement: Union[str, Column] ) -> Column: @@ -6122,6 +6305,7 @@ def regexp_replace( return _invoke_function("regexp_replace", _to_java_column(string), pattern_col, replacement_col) +@try_remote_functions def initcap(col: "ColumnOrName") -> Column: """Translate the first letter of each word to upper case in the sentence. @@ -6145,6 +6329,7 @@ def initcap(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("initcap", col) +@try_remote_functions def soundex(col: "ColumnOrName") -> Column: """ Returns the SoundEx encoding for a string @@ -6170,6 +6355,7 @@ def soundex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("soundex", col) +@try_remote_functions def bin(col: "ColumnOrName") -> Column: """Returns the string representation of the binary value of the given column. @@ -6194,6 +6380,7 @@ def bin(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("bin", col) +@try_remote_functions def hex(col: "ColumnOrName") -> Column: """Computes hex value of the given column, which could be :class:`pyspark.sql.types.StringType`, :class:`pyspark.sql.types.BinaryType`, :class:`pyspark.sql.types.IntegerType` or @@ -6219,6 +6406,7 @@ def hex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("hex", col) +@try_remote_functions def unhex(col: "ColumnOrName") -> Column: """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. @@ -6243,6 +6431,7 @@ def unhex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unhex", col) +@try_remote_functions def length(col: "ColumnOrName") -> Column: """Computes the character length of string data or number of bytes of binary data. The length of character data includes the trailing spaces. The length of binary data @@ -6268,6 +6457,7 @@ def length(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("length", col) +@try_remote_functions def octet_length(col: "ColumnOrName") -> Column: """ Calculates the byte length for the specified string column. @@ -6294,6 +6484,7 @@ def octet_length(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("octet_length", col) +@try_remote_functions def bit_length(col: "ColumnOrName") -> Column: """ Calculates the bit length for the specified string column. @@ -6320,6 +6511,7 @@ def bit_length(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("bit_length", col) +@try_remote_functions def translate(srcCol: "ColumnOrName", matching: str, replace: str) -> Column: """A function translate any character in the `srcCol` by a character in `matching`. The characters in `replace` is corresponding to the characters in `matching`. @@ -6365,6 +6557,7 @@ def create_map(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]] ... +@try_remote_functions def create_map( *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] ) -> Column: @@ -6391,6 +6584,7 @@ def create_map( return _invoke_function_over_seq_of_columns("map", cols) # type: ignore[arg-type] +@try_remote_functions def map_from_arrays(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """Creates a new map from two arrays. @@ -6437,6 +6631,7 @@ def array(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> ... +@try_remote_functions def array( *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] ) -> Column: @@ -6472,6 +6667,7 @@ def array( return _invoke_function_over_seq_of_columns("array", cols) # type: ignore[arg-type] +@try_remote_functions def array_contains(col: "ColumnOrName", value: Any) -> Column: """ Collection function: returns null if the array is null, true if the array contains the @@ -6503,6 +6699,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: return _invoke_function("array_contains", _to_java_column(col), value) +@try_remote_functions def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: """ Collection function: returns true if the arrays contain any common non-null element; if not, @@ -6525,6 +6722,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: return _invoke_function_over_columns("arrays_overlap", a1, a2) +@try_remote_functions def slice( x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] ) -> Column: @@ -6560,6 +6758,7 @@ def slice( return _invoke_function_over_columns("slice", x, start, length) +@try_remote_functions def array_join( col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None ) -> Column: @@ -6599,6 +6798,7 @@ def array_join( return _invoke_function("array_join", _to_java_column(col), delimiter, null_replacement) +@try_remote_functions def concat(*cols: "ColumnOrName") -> Column: """ Concatenates multiple input columns together into a single column. @@ -6639,6 +6839,7 @@ def concat(*cols: "ColumnOrName") -> Column: return _invoke_function_over_seq_of_columns("concat", cols) +@try_remote_functions def array_position(col: "ColumnOrName", value: Any) -> Column: """ Collection function: Locates the position of the first occurrence of the given value @@ -6672,6 +6873,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: return _invoke_function("array_position", _to_java_column(col), value) +@try_remote_functions def element_at(col: "ColumnOrName", extraction: Any) -> Column: """ Collection function: Returns element of array at given index in `extraction` if col is array. @@ -6716,6 +6918,7 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column: return _invoke_function_over_columns("element_at", col, lit(extraction)) +@try_remote_functions def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: """ Collection function: Returns element of array at given (0-based) index. @@ -6787,6 +6990,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ Collection function: Remove all elements that equal to element from the given array. @@ -6814,6 +7018,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: return _invoke_function("array_remove", _to_java_column(col), element) +@try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ Collection function: removes duplicate values from the array. @@ -6839,6 +7044,7 @@ def array_distinct(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_distinct", col) +@try_remote_functions def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ Collection function: returns an array of the elements in the intersection of col1 and col2, @@ -6868,6 +7074,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_intersect", col1, col2) +@try_remote_functions def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ Collection function: returns an array of the elements in the union of col1 and col2, @@ -6897,6 +7104,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_union", col1, col2) +@try_remote_functions def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ Collection function: returns an array of the elements in col1 but not in col2, @@ -6926,6 +7134,7 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_except", col1, col2) +@try_remote_functions def explode(col: "ColumnOrName") -> Column: """ Returns a new row for each element in the given array or map. @@ -6967,6 +7176,7 @@ def explode(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("explode", col) +@try_remote_functions def posexplode(col: "ColumnOrName") -> Column: """ Returns a new row for each element with position in the given array or map. @@ -7002,6 +7212,7 @@ def posexplode(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("posexplode", col) +@try_remote_functions def inline(col: "ColumnOrName") -> Column: """ Explodes an array of structs into a table. @@ -7037,6 +7248,7 @@ def inline(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("inline", col) +@try_remote_functions def explode_outer(col: "ColumnOrName") -> Column: """ Returns a new row for each element in the given array or map. @@ -7084,6 +7296,7 @@ def explode_outer(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("explode_outer", col) +@try_remote_functions def posexplode_outer(col: "ColumnOrName") -> Column: """ Returns a new row for each element with position in the given array or map. @@ -7130,6 +7343,7 @@ def posexplode_outer(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("posexplode_outer", col) +@try_remote_functions def inline_outer(col: "ColumnOrName") -> Column: """ Explodes an array of structs into a table. @@ -7171,6 +7385,7 @@ def inline_outer(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("inline_outer", col) +@try_remote_functions def get_json_object(col: "ColumnOrName", path: str) -> Column: """ Extracts json object from a json string based on json `path` specified, and returns json string @@ -7201,6 +7416,7 @@ def get_json_object(col: "ColumnOrName", path: str) -> Column: return _invoke_function("get_json_object", _to_java_column(col), path) +@try_remote_functions def json_tuple(col: "ColumnOrName", *fields: str) -> Column: """Creates a new row for a json column according to the given field names. @@ -7230,6 +7446,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column: return _invoke_function("json_tuple", _to_java_column(col), _to_seq(sc, fields)) +@try_remote_functions def from_json( col: "ColumnOrName", schema: Union[ArrayType, StructType, Column, str], @@ -7295,6 +7512,7 @@ def from_json( return _invoke_function("from_json", _to_java_column(col), schema, _options_to_str(options)) +@try_remote_functions def to_json(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: """ Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType` @@ -7349,6 +7567,7 @@ def to_json(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Co return _invoke_function("to_json", _to_java_column(col), _options_to_str(options)) +@try_remote_functions def schema_of_json(json: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: """ Parses a JSON string and infers its schema in DDL format. @@ -7393,6 +7612,7 @@ def schema_of_json(json: "ColumnOrName", options: Optional[Dict[str, str]] = Non return _invoke_function("schema_of_json", col, _options_to_str(options)) +@try_remote_functions def schema_of_csv(csv: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: """ Parses a CSV string and infers its schema in DDL format. @@ -7433,6 +7653,7 @@ def schema_of_csv(csv: "ColumnOrName", options: Optional[Dict[str, str]] = None) return _invoke_function("schema_of_csv", col, _options_to_str(options)) +@try_remote_functions def to_csv(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Column: """ Converts a column containing a :class:`StructType` into a CSV string. @@ -7468,6 +7689,7 @@ def to_csv(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Col return _invoke_function("to_csv", _to_java_column(col), _options_to_str(options)) +@try_remote_functions def size(col: "ColumnOrName") -> Column: """ Collection function: returns the length of the array or map stored in the column. @@ -7493,6 +7715,7 @@ def size(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("size", col) +@try_remote_functions def array_min(col: "ColumnOrName") -> Column: """ Collection function: returns the minimum value of the array. @@ -7518,6 +7741,7 @@ def array_min(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_min", col) +@try_remote_functions def array_max(col: "ColumnOrName") -> Column: """ Collection function: returns the maximum value of the array. @@ -7543,6 +7767,7 @@ def array_max(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_max", col) +@try_remote_functions def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: """ Collection function: sorts the input array in ascending or descending order according @@ -7576,6 +7801,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: return _invoke_function("sort_array", _to_java_column(col), asc) +@try_remote_functions def array_sort( col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None ) -> Column: @@ -7621,6 +7847,7 @@ def array_sort( return _invoke_higher_order_function("ArraySort", [col], [comparator]) +@try_remote_functions def shuffle(col: "ColumnOrName") -> Column: """ Collection function: Generates a random permutation of the given array. @@ -7650,6 +7877,7 @@ def shuffle(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("shuffle", col) +@try_remote_functions def reverse(col: "ColumnOrName") -> Column: """ Collection function: returns a reversed string or an array with reverse order of elements. @@ -7678,6 +7906,7 @@ def reverse(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("reverse", col) +@try_remote_functions def flatten(col: "ColumnOrName") -> Column: """ Collection function: creates a single array from an array of arrays. @@ -7717,6 +7946,7 @@ def flatten(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("flatten", col) +@try_remote_functions def map_contains_key(col: "ColumnOrName", value: Any) -> Column: """ Returns true if the map contains the key. @@ -7755,6 +7985,7 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column: return _invoke_function("map_contains_key", _to_java_column(col), value) +@try_remote_functions def map_keys(col: "ColumnOrName") -> Column: """ Collection function: Returns an unordered array containing the keys of the map. @@ -7785,6 +8016,7 @@ def map_keys(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("map_keys", col) +@try_remote_functions def map_values(col: "ColumnOrName") -> Column: """ Collection function: Returns an unordered array containing the values of the map. @@ -7815,6 +8047,7 @@ def map_values(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("map_values", col) +@try_remote_functions def map_entries(col: "ColumnOrName") -> Column: """ Collection function: Returns an unordered array of all entries in the given map. @@ -7852,6 +8085,7 @@ def map_entries(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("map_entries", col) +@try_remote_functions def map_from_entries(col: "ColumnOrName") -> Column: """ Collection function: Converts an array of entries (key value struct types) to a map @@ -7883,6 +8117,7 @@ def map_from_entries(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("map_from_entries", col) +@try_remote_functions def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: """ Collection function: creates an array containing a column repeated count times. @@ -7912,6 +8147,7 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu return _invoke_function_over_columns("array_repeat", col, count) +@try_remote_functions def arrays_zip(*cols: "ColumnOrName") -> Column: """ Collection function: Returns a merged array of structs in which the N-th struct contains all @@ -7962,6 +8198,7 @@ def map_concat(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]] ... +@try_remote_functions def map_concat( *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] ) -> Column: @@ -7995,6 +8232,7 @@ def map_concat( return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] +@try_remote_functions def sequence( start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None ) -> Column: @@ -8034,6 +8272,7 @@ def sequence( return _invoke_function_over_columns("sequence", start, stop, step) +@try_remote_functions def from_csv( col: "ColumnOrName", schema: Union[Column, str], @@ -8207,6 +8446,7 @@ def transform(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Col ... +@try_remote_functions def transform( col: "ColumnOrName", f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], @@ -8260,6 +8500,7 @@ def transform( return _invoke_higher_order_function("ArrayTransform", [col], [f]) +@try_remote_functions def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: """ Returns whether a predicate holds for one or more elements in the array. @@ -8297,6 +8538,7 @@ def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: return _invoke_higher_order_function("ArrayExists", [col], [f]) +@try_remote_functions def forall(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: """ Returns whether a predicate holds for every element in the array. @@ -8348,6 +8590,7 @@ def filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column ... +@try_remote_functions def filter( col: "ColumnOrName", f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], @@ -8400,6 +8643,7 @@ def filter( return _invoke_higher_order_function("ArrayFilter", [col], [f]) +@try_remote_functions def aggregate( col: "ColumnOrName", initialValue: "ColumnOrName", @@ -8471,6 +8715,7 @@ def aggregate( return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge]) +@try_remote_functions def zip_with( left: "ColumnOrName", right: "ColumnOrName", @@ -8522,6 +8767,7 @@ def zip_with( return _invoke_higher_order_function("ZipWith", [left, right], [f]) +@try_remote_functions def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: """ Applies a function to every key-value pair in a map and returns @@ -8561,6 +8807,7 @@ def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) - return _invoke_higher_order_function("TransformKeys", [col], [f]) +@try_remote_functions def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: """ Applies a function to every key-value pair in a map and returns @@ -8600,6 +8847,7 @@ def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) return _invoke_higher_order_function("TransformValues", [col], [f]) +@try_remote_functions def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: """ Returns a map whose key-value pairs satisfy a predicate. @@ -8637,6 +8885,7 @@ def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Co return _invoke_higher_order_function("MapFilter", [col], [f]) +@try_remote_functions def map_zip_with( col1: "ColumnOrName", col2: "ColumnOrName", @@ -8687,6 +8936,7 @@ def map_zip_with( # ---------------------- Partition transform functions -------------------------------- +@try_remote_functions def years(col: "ColumnOrName") -> Column: """ Partition transform function: A transform for timestamps and dates @@ -8720,6 +8970,7 @@ def years(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("years", col) +@try_remote_functions def months(col: "ColumnOrName") -> Column: """ Partition transform function: A transform for timestamps and dates @@ -8753,6 +9004,7 @@ def months(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("months", col) +@try_remote_functions def days(col: "ColumnOrName") -> Column: """ Partition transform function: A transform for timestamps and dates @@ -8786,6 +9038,7 @@ def days(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("days", col) +@try_remote_functions def hours(col: "ColumnOrName") -> Column: """ Partition transform function: A transform for timestamps @@ -8819,6 +9072,7 @@ def hours(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("hours", col) +@try_remote_functions def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: """ Partition transform function: A transform for any type that partitions @@ -8862,6 +9116,7 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column: return _invoke_function("bucket", numBuckets, _to_java_column(col)) +@try_remote_functions def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: """ Call an user-defined function. @@ -8909,6 +9164,7 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column)) +@try_remote_functions def unwrap_udt(col: "ColumnOrName") -> Column: """ Unwrap UDT data type column into its underlying type. diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index 5031fd2f1c1..67bb1f36305 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -21,6 +21,7 @@ from py4j.java_gateway import JavaObject, JVMView from pyspark.sql import column from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame +from pyspark.sql.utils import try_remote_observation __all__ = ["Observation"] @@ -83,6 +84,7 @@ class Observation: self._jvm: Optional[JVMView] = None self._jo: Optional[JavaObject] = None + @try_remote_observation def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: """Attaches this observation to the given :class:`DataFrame` to observe aggregations. @@ -109,7 +111,9 @@ class Observation: ) return DataFrame(observed_df, df.sparkSession) - @property + # Note that decorated property only works with Python 3.9+ which Spark Connect requires. + @property # type: ignore[misc] + @try_remote_observation def get(self) -> Dict[str, Any]: """Get the observed metrics. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ebad3224f02..e96eef79b87 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import sys import warnings from collections.abc import Sized @@ -250,15 +250,43 @@ class SparkSession(SparkConversionMixin): ... map={"spark.some.config.number": 123, "spark.some.config.float": 0.123}) <pyspark.sql.session.SparkSession.Builder... """ + + def check_startup_urls(k: str, v: str) -> None: + if k == "spark.master": + if "spark.remote" in self._options or "SPARK_REMOTE" in os.environ: + raise RuntimeError( + "Spark master cannot be configured with Spark Connect server; " + "however, found URL for Spark Connect [%s]" + % self._options.get("spark.remote", os.environ.get("SPARK_REMOTE")) + ) + elif k == "spark.remote": + if "spark.master" in self._options or "MASTER" in os.environ: + raise RuntimeError( + "Spark Connect server cannot be configured with Spark master; " + "however, found URL for Spark master [%s]" + % self._options.get("spark.master", os.environ.get("MASTER")) + ) + + if "SPARK_REMOTE" in os.environ and os.environ["SPARK_REMOTE"] != v: + raise RuntimeError( + "Only one Spark Connect client URL can be set; however, got a different" + "URL [%s] from the existing [%s]" % (os.environ["SPARK_REMOTE"], v) + ) + with self._lock: if conf is not None: for (k, v) in conf.getAll(): + check_startup_urls(k, v) self._options[k] = v elif map is not None: for k, v in map.items(): # type: ignore[assignment] - self._options[k] = to_str(v) + v = to_str(v) # type: ignore[assignment] + check_startup_urls(k, v) + self._options[k] = v else: - self._options[cast(str, key)] = to_str(value) + value = to_str(value) + check_startup_urls(key, value) # type: ignore[arg-type] + self._options[cast(str, key)] = value return self def master(self, master: str) -> "SparkSession.Builder": @@ -284,6 +312,28 @@ class SparkSession(SparkConversionMixin): """ return self.config("spark.master", master) + def remote(self, url: str) -> "SparkSession.Builder": + """Sets the Spark remote URL to connect to, such as "sc://host:port" to run + it via Spark Connect server. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + url : str + URL to Spark Connect server + + Returns + ------- + :class:`SparkSession.Builder` + + Examples + -------- + >>> SparkSession.builder.remote("sc://localhost") # doctest: +SKIP + <pyspark.sql.session.SparkSession.Builder... + """ + return self.config("spark.remote", url) + def appName(self, name: str) -> "SparkSession.Builder": """Sets a name for the application, which will be shown in the Spark web UI. @@ -361,8 +411,33 @@ class SparkSession(SparkConversionMixin): >>> s1.conf.get("k2") == s2.conf.get("k2") == "v2" True """ + from pyspark.context import SparkContext + + opts = dict(self._options) + if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: + with SparkContext._lock: + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + + if ( + SparkContext._active_spark_context is None + and SparkSession._instantiatedSession is None + ): + url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) + os.environ["SPARK_REMOTE"] = url + opts["spark.remote"] = url + return cast( + SparkSession, RemoteSparkSession.builder.config(map=opts).getOrCreate() + ) + elif "SPARK_TESTING" not in os.environ: + raise RuntimeError( + "Cannot start a remote Spark session because there " + "is a regular Spark Connect already running." + ) + + # Cannot reach here in production. Test-only. + return cast(SparkSession, RemoteSparkSession.builder.config(map=opts).getOrCreate()) + with self._lock: - from pyspark.context import SparkContext from pyspark.conf import SparkConf session = SparkSession._instantiatedSession diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py new file mode 100644 index 00000000000..ccb5dd45b54 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import os + +from pyspark.sql import SparkSession +from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class DataFrameParityTests(DataFrameTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + super(DataFrameParityTests, cls).setUpClass() + cls._spark = cls.spark # Assign existing Spark session to run the server + # Sets the remote address. Now, we create a remote Spark Session. + # Note that this is only allowed in testing. + os.environ["SPARK_REMOTE"] = "sc://localhost" + cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate() + + @classmethod + def tearDownClass(cls): + # TODO(SPARK-41529): Implement stop in RemoteSparkSession. + # Stop the regular Spark session (server) too. + cls.spark = cls._spark + super(DataFrameParityTests, cls).tearDownClass() + del os.environ["SPARK_REMOTE"] + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_cache(self): + super().test_cache() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_dataframe_from_array_of_long(self): + super().test_create_dataframe_from_array_of_long() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_dataframe_from_pandas_with_day_time_interval(self): + super().test_create_dataframe_from_pandas_with_day_time_interval() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_dataframe_from_pandas_with_dst(self): + super().test_create_dataframe_from_pandas_with_dst() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_dataframe_from_pandas_with_timestamp(self): + super().test_create_dataframe_from_pandas_with_timestamp() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_dataframe_required_pandas_not_found(self): + super().test_create_dataframe_required_pandas_not_found() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_create_nan_decimal_dataframe(self): + super().test_create_nan_decimal_dataframe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_df_show(self): + super().test_df_show() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_drop(self): + super().test_drop() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_drop_duplicates(self): + super().test_drop_duplicates() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_dropna(self): + super().test_dropna() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_duplicated_column_names(self): + super().test_duplicated_column_names() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_extended_hint_types(self): + super().test_extended_hint_types() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_fillna(self): + super().test_fillna() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_freqItems(self): + super().test_freqItems() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_generic_hints(self): + super().test_generic_hints() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_help_command(self): + super().test_help_command() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_input_files(self): + super().test_input_files() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_invalid_join_method(self): + super().test_invalid_join_method() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_join_without_on(self): + super().test_join_without_on() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_observe(self): + super().test_observe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_observe_str(self): + super().test_observe_str() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_pandas_api(self): + super().test_pandas_api() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_repartitionByRange_dataframe(self): + super().test_repartitionByRange_dataframe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_replace(self): + super().test_replace() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_repr_behaviors(self): + super().test_repr_behaviors() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_require_cross(self): + super().test_require_cross() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_same_semantics_error(self): + super().test_same_semantics_error() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sample(self): + super().test_sample() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to(self): + super().test_to() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_toDF_with_schema_string(self): + super().test_toDF_with_schema_string() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_local_iterator(self): + super().test_to_local_iterator() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_local_iterator_not_fully_consumed(self): + super().test_to_local_iterator_not_fully_consumed() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_local_iterator_prefetch(self): + super().test_to_local_iterator_prefetch() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas(self): + super().test_to_pandas() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_avoid_astype(self): + super().test_to_pandas_avoid_astype() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_for_array_of_struct(self): + super().test_to_pandas_for_array_of_struct() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_from_empty_dataframe(self): + super().test_to_pandas_from_empty_dataframe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_from_mixed_dataframe(self): + super().test_to_pandas_from_mixed_dataframe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_from_null_dataframe(self): + super().test_to_pandas_from_null_dataframe() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_on_cross_join(self): + super().test_to_pandas_on_cross_join() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_required_pandas_not_found(self): + super().test_to_pandas_required_pandas_not_found() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_to_pandas_with_duplicated_column_names(self): + super().test_to_pandas_with_duplicated_column_names() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_unpivot(self): + super().test_unpivot() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_with_columns_renamed(self): + super().test_with_columns_renamed() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_dataframe import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py new file mode 100644 index 00000000000..0b8c2bd036b --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -0,0 +1,292 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import os + +from pyspark.sql import SparkSession +from pyspark.sql.tests.test_functions import FunctionsTestsMixin +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class FunctionsParityTests(ReusedSQLTestCase, FunctionsTestsMixin): + @classmethod + def setUpClass(cls): + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + + super(FunctionsParityTests, cls).setUpClass() + cls._spark = cls.spark # Assign existing Spark session to run the server + # Sets the remote address. Now, we create a remote Spark Session. + # Note that this is only allowed in testing. + os.environ["SPARK_REMOTE"] = "sc://localhost" + cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate() + assert isinstance(cls.spark, RemoteSparkSession) + + @classmethod + def tearDownClass(cls): + # TODO(SPARK-41529): Implement stop in RemoteSparkSession. + # Stop the regular Spark session (server) too. + cls.spark = cls._spark + super(FunctionsParityTests, cls).tearDownClass() + del os.environ["SPARK_REMOTE"] + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_add_months_function(self): + super().test_add_months_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_array_contains_function(self): + super().test_array_contains_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_array_repeat(self): + super().test_array_repeat() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_assert_true(self): + super().test_assert_true() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_basic_functions(self): + super().test_basic_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_between_function(self): + super().test_between_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_bit_length_function(self): + super().test_bit_length_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_collect_functions(self): + super().test_collect_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_date_add_function(self): + super().test_date_add_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_date_sub_function(self): + super().test_date_sub_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_datetime_functions(self): + super().test_datetime_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_dayofweek(self): + super().test_dayofweek() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_explode(self): + super().test_explode() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_expr(self): + super().test_expr() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_first_last_ignorenulls(self): + super().test_first_last_ignorenulls() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_function_parity(self): + super().test_function_parity() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_functions_broadcast(self): + super().test_functions_broadcast() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_higher_order_function_failures(self): + super().test_higher_order_function_failures() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_inline(self): + super().test_inline() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_input_file_name_reset_for_rdd(self): + super().test_input_file_name_reset_for_rdd() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_input_file_name_udf(self): + super().test_input_file_name_udf() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_inverse_trig_functions(self): + super().test_inverse_trig_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_least(self): + super().test_least() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_lit_list(self): + super().test_lit_list() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_lit_np_scalar(self): + super().test_lit_np_scalar() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_make_date(self): + super().test_make_date() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_map_functions(self): + super().test_map_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_math_functions(self): + super().test_math_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_ndarray_input(self): + super().test_ndarray_input() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_nested_higher_order_function(self): + super().test_nested_higher_order_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_np_scalar_input(self): + super().test_np_scalar_input() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_nth_value(self): + super().test_nth_value() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_octet_length_function(self): + super().test_octet_length_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_overlay(self): + super().test_overlay() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_percentile_approx(self): + super().test_percentile_approx() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_raise_error(self): + super().test_raise_error() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_regexp_replace(self): + super().test_regexp_replace() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_shiftleft(self): + super().test_shiftleft() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_shiftright(self): + super().test_shiftright() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_shiftrightunsigned(self): + super().test_shiftrightunsigned() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_slice(self): + super().test_slice() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sort_with_nulls_order(self): + super().test_sort_with_nulls_order() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sorting_functions_with_column(self): + super().test_sorting_functions_with_column() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_string_functions(self): + super().test_string_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sum_distinct(self): + super().test_sum_distinct() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_window_functions(self): + super().test_window_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_window_functions_cumulative_sum(self): + super().test_window_functions_cumulative_sum() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_window_functions_without_partitionBy(self): + super().test_window_functions_without_partitionBy() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_window_time(self): + super().test_window_time() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_binary_math_function(self): + super().test_binary_math_function() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_corr(self): + super().test_corr() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_cov(self): + super().test_cov() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_crosstab(self): + super().test_crosstab() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_lit_day_time_interval(self): + super().test_lit_day_time_interval() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_rand_functions(self): + super().test_rand_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_reciprocal_trig_functions(self): + super().test_reciprocal_trig_functions() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sampleby(self): + super().test_sampleby() + + @unittest.skip("Fails in Spark Connect, should enable.") + def test_approxQuantile(self): + super().test_approxQuantile() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index a8321c58050..a9011ec1f95 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -53,7 +53,7 @@ from pyspark.testing.sqlutils import ( from pyspark.testing.utils import QuietTest -class DataFrameTests(ReusedSQLTestCase): +class DataFrameTestsMixin: def test_range(self): self.assertEqual(self.spark.range(1, 1).count(), 0) self.assertEqual(self.spark.range(1, 0, -1).count(), 1) @@ -1568,6 +1568,10 @@ class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): ) +class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": from pyspark.sql.tests.test_dataframe import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 94cb3c4f1ee..7eb05034839 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -72,7 +72,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy -class FunctionsTests(ReusedSQLTestCase): +class FunctionsTestsMixin: def test_function_parity(self): # This test compares the available list of functions in pyspark.sql.functions with those # available in the Scala/Java DataFrame API in org.apache.spark.sql.functions. @@ -130,8 +130,7 @@ class FunctionsTests(ReusedSQLTestCase): Row(a=1, intlist=[], mapfield={}), Row(a=1, intlist=None, mapfield=None), ] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) + data = self.spark.createDataFrame(d) result = data.select(explode(data.intlist).alias("a")).select("a").collect() self.assertEqual(result[0][0], 1) @@ -194,22 +193,22 @@ class FunctionsTests(ReusedSQLTestCase): def test_corr(self): import math - df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() + df = self.spark.createDataFrame([Row(a=i, b=math.sqrt(i)) for i in range(10)]) corr = df.stat.corr("a", "b") self.assertTrue(abs(corr - 0.95734012) < 1e-6) def test_sampleby(self): - df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF() + df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in range(100)]) sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0) self.assertTrue(sampled.count() == 35) def test_cov(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)]) cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) def test_crosstab(self): - df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + df = self.spark.createDataFrame([Row(a=i % 3, b=i % 2) for i in range(1, 7)]) ct = df.stat.crosstab("a", "b").collect() ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): @@ -218,7 +217,7 @@ class FunctionsTests(ReusedSQLTestCase): self.assertTrue(row[2], 1) def test_math_functions(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)]) from pyspark.sql import functions SQLTestUtils.assert_close( @@ -361,9 +360,9 @@ class FunctionsTests(ReusedSQLTestCase): self.assertEqual([Row(b=True), Row(b=False)], actual) def test_between_function(self): - df = self.sc.parallelize( + df = self.spark.createDataFrame( [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)] - ).toDF() + ) self.assertEqual( [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect() ) @@ -478,7 +477,7 @@ class FunctionsTests(ReusedSQLTestCase): self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) def test_approxQuantile(self): - df = self.sc.parallelize([Row(a=i, b=i + 10) for i in range(10)]).toDF() + df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in range(10)]) for f in ["a", "a"]: aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aq, list)) @@ -1151,6 +1150,10 @@ class FunctionsTests(ReusedSQLTestCase): self.assertEqual(expected, actual["from_items"]) +class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.test_functions import * # noqa: F401 diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 6176ceedd99..05d6f7ebea9 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast +import functools +import os +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar import py4j from py4j.java_collections import JavaArray @@ -29,6 +31,10 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkContext from pyspark.find_spark_home import _find_spark_home +if TYPE_CHECKING: + from pyspark.sql.session import SparkSession + from pyspark.sql.dataframe import DataFrame + has_numpy = False try: import numpy as np # noqa: F401 @@ -38,9 +44,7 @@ except ImportError: pass -if TYPE_CHECKING: - from pyspark.sql.session import SparkSession - from pyspark.sql.dataframe import DataFrame +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) class CapturedException(Exception): @@ -307,3 +311,52 @@ def is_timestamp_ntz_preferred() -> bool: """ jvm = SparkContext._jvm return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() + + +def is_remote() -> bool: + """ + Returns if the current running environment is for Spark Connect. + """ + return "SPARK_REMOTE" in os.environ + + +def try_remote_functions(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote(): + from pyspark.sql.connect import functions + + return getattr(functions, f.__name__)(*args, **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + +def try_remote_window(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + # TODO(SPARK-41292): Support Window functions + if is_remote(): + raise NotImplementedError() + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + +def try_remote_observation(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + # TODO(SPARK-41527): Add the support of Observation. + if is_remote(): + raise NotImplementedError() + return f(*args, **kwargs) + + return cast(FuncT, wrapped) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index ce91d6d4514..1b920a2c2fb 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -14,14 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import sys from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union +from py4j.java_gateway import JavaObject + from pyspark import SparkContext from pyspark.sql.column import _to_seq, _to_java_column - -from py4j.java_gateway import JavaObject +from pyspark.sql.utils import try_remote_window if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName, ColumnOrName_ @@ -70,6 +70,7 @@ class Window: currentRow: int = 0 @staticmethod + @try_remote_window def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the partitioning defined. @@ -125,6 +126,7 @@ class Window: return WindowSpec(jspec) @staticmethod + @try_remote_window def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the ordering defined. @@ -180,6 +182,7 @@ class Window: return WindowSpec(jspec) @staticmethod + @try_remote_window def rowsBetween(start: int, end: int) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the frame boundaries defined, @@ -263,6 +266,7 @@ class Window: return WindowSpec(jspec) @staticmethod + @try_remote_window def rangeBetween(start: int, end: int) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the frame boundaries defined, @@ -362,6 +366,7 @@ class WindowSpec: def __init__(self, jspec: JavaObject) -> None: self._jspec = jspec + @try_remote_window def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": """ Defines the partitioning columns in a :class:`WindowSpec`. @@ -375,6 +380,7 @@ class WindowSpec: """ return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) + @try_remote_window def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": """ Defines the ordering columns in a :class:`WindowSpec`. @@ -388,6 +394,7 @@ class WindowSpec: """ return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) + @try_remote_window def rowsBetween(self, start: int, end: int) -> "WindowSpec": """ Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). @@ -419,6 +426,7 @@ class WindowSpec: end = Window.unboundedFollowing return WindowSpec(self._jspec.rowsBetween(start, end)) + @try_remote_window def rangeBetween(self, start: int, end: int) -> "WindowSpec": """ Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index efc118b572d..21bdf35c6f3 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -43,9 +43,13 @@ if have_pandas and have_pyarrow and have_grpc: connect_jar = search_jar("connector/connect/server", "spark-connect-assembly-", "spark-connect") existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + connect_url = "--remote sc://localhost" jars_args = "--jars %s" % connect_jar plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin" os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join( + [connect_url, jars_args, plugin_args, existing_args] + ) else: connect_not_compiled_message = ( "Skipping all Spark Connect Python tests as the optional Spark Connect project was " --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org