This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a100e11936bc [SPARK-41811][PYTHON][CONNECT] Implement `SQLStringFormatter` with `WithRelations` a100e11936bc is described below commit a100e11936bcd92ac091abe94221c1b669811efa Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Apr 11 09:06:23 2024 +0900 [SPARK-41811][PYTHON][CONNECT] Implement `SQLStringFormatter` with `WithRelations` ### What changes were proposed in this pull request? Implement `SQLStringFormatter` for Python Client ### Why are the changes needed? for parity ### Does this PR introduce _any_ user-facing change? yes, new feature ``` In [1]: mydf = spark.range(10) In [2]: spark.sql("SELECT {col} FROM {mydf} WHERE id IN {x}", col=mydf.id, mydf=mydf, x=tuple(range(4))).show() +---+ | id| +---+ | 0| | 1| | 2| | 3| +---+ ``` ### How was this patch tested? enabled doc tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45614 from zhengruifeng/connect_sql_str_fmt_with_relations. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/commands.proto | 9 +- .../main/protobuf/spark/connect/relations.proto | 18 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 149 ++++++++--- python/pyspark/sql/connect/plan.py | 56 +++- python/pyspark/sql/connect/proto/commands_pb2.py | 190 ++++++------- python/pyspark/sql/connect/proto/commands_pb2.pyi | 10 + python/pyspark/sql/connect/proto/relations_pb2.py | 298 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 48 ++++ python/pyspark/sql/connect/session.py | 23 +- python/pyspark/sql/{ => connect}/sql_formatter.py | 45 ++-- python/pyspark/sql/session.py | 4 +- python/pyspark/sql/sql_formatter.py | 4 +- .../pyspark/sql/tests/connect/test_connect_plan.py | 2 +- python/pyspark/sql/utils.py | 7 + 14 files changed, 539 insertions(+), 324 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index e0ccf01fe92e..acff0a2089e9 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -61,7 +61,7 @@ message Command { // almost oblivious to the server-side behavior. message SqlCommand { // (Required) SQL Query. - string sql = 1; + string sql = 1 [deprecated=true]; // (Optional) A map of parameter names to literal expressions. map<string, Expression.Literal> args = 2 [deprecated=true]; @@ -71,11 +71,14 @@ message SqlCommand { // (Optional) A map of parameter names to expressions. // It cannot coexist with `pos_arguments`. - map<string, Expression> named_arguments = 4; + map<string, Expression> named_arguments = 4 [deprecated=true]; // (Optional) A sequence of expressions for positional parameters in the SQL query text. // It cannot coexist with `named_arguments`. - repeated Expression pos_arguments = 5; + repeated Expression pos_arguments = 5 [deprecated=true]; + + // (Optional) The relation that this SQL command will be built on. + Relation input = 6; } // A command that can create DataFrame global temp view or local temp view. diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 4d4324ed340b..5cbe6459d226 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -75,6 +75,7 @@ message Relation { CommonInlineUserDefinedTableFunction common_inline_user_defined_table_function = 38; AsOfJoin as_of_join = 39; CommonInlineUserDefinedDataSource common_inline_user_defined_data_source = 40; + WithRelations with_relations = 41; // NA functions NAFill fill_na = 90; @@ -133,6 +134,23 @@ message SQL { repeated Expression pos_arguments = 5; } +// Relation of type [[WithRelations]]. +// +// This relation contains a root plan, and one or more references that are used by the root plan. +// There are two ways of referencing a relation, by name (through a subquery alias), or by plan_id +// (using RelationCommon.plan_id). +// +// This relation can be used to implement CTEs, describe DAGs, or to reduce tree depth. +message WithRelations { + // (Required) Plan at the root of the query tree. This plan is expected to contain one or more + // references. Those references get expanded later on by the engine. + Relation root = 1; + + // (Required) Plans referenced by the root plan. Relations in this list are also allowed to + // contain references to other relations in this list, as long they do not form cycles. + repeated Relation references = 2; +} + // Relation that reads from a file / table or other data source. Does not have additional // inputs. message Read { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 0813b0a57671..690f2bfded3b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -43,7 +43,7 @@ import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} -import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} +import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder @@ -135,6 +135,9 @@ class SparkConnectPlanner( case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) + case proto.Relation.RelTypeCase.WITH_RELATIONS + if isValidSQLWithRefs(rel.getWithRelations) => + transformSqlWithRefs(rel.getWithRelations) case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) @@ -308,6 +311,13 @@ class SparkConnectPlanner( } } + private def transformSqlWithRefs(query: proto.WithRelations): LogicalPlan = { + if (!isValidSQLWithRefs(query)) { + throw InvalidPlanInput(s"$query is not a valid relation for SQL with references") + } + executeSQLWithRefs(query).logicalPlan + } + private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = { val aliasIdentifier = if (alias.getQualifierCount > 0) { @@ -2553,34 +2563,38 @@ class SparkConnectPlanner( } private def handleSqlCommand( - getSqlCommand: SqlCommand, + command: SqlCommand, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - // Eagerly execute commands of the provided SQL string. - val args = getSqlCommand.getArgsMap - val namedArguments = getSqlCommand.getNamedArgumentsMap - val posArgs = getSqlCommand.getPosArgsList - val posArguments = getSqlCommand.getPosArgumentsList val tracker = executeHolder.eventsManager.createQueryPlanningTracker() - val df = if (!namedArguments.isEmpty) { - session.sql( - getSqlCommand.getSql, - namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))), - tracker) - } else if (!posArguments.isEmpty) { - session.sql( - getSqlCommand.getSql, - posArguments.asScala.map(e => Column(transformExpression(e))).toArray, - tracker) - } else if (!args.isEmpty) { - session.sql( - getSqlCommand.getSql, - args.asScala.toMap.transform((_, v) => transformLiteral(v)), - tracker) - } else if (!posArgs.isEmpty) { - session.sql(getSqlCommand.getSql, posArgs.asScala.map(transformLiteral).toArray, tracker) + + val relation = if (command.hasInput) { + command.getInput } else { - session.sql(getSqlCommand.getSql, Map.empty[String, Any], tracker) + // for backward compatibility + proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery(command.getSql) + .putAllArgs(command.getArgsMap) + .putAllNamedArguments(command.getNamedArgumentsMap) + .addAllPosArgs(command.getPosArgsList) + .addAllPosArguments(command.getPosArgumentsList) + .build()) + .build() } + + val df = relation.getRelTypeCase match { + case proto.Relation.RelTypeCase.SQL => + executeSQL(relation.getSql, tracker) + case proto.Relation.RelTypeCase.WITH_RELATIONS => + executeSQLWithRefs(relation.getWithRelations, tracker) + case other => + throw InvalidPlanInput( + s"SQL command expects either a SQL or a WithRelations, but got $other") + } + // Check if commands have been executed. val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult] val rows = df.logicalPlan match { @@ -2631,17 +2645,7 @@ class SparkConnectPlanner( } else { // No execution triggered for relations. Manually set ready tracker.setReadyForExecution() - result.setRelation( - proto.Relation - .newBuilder() - .setSql( - proto.SQL - .newBuilder() - .setQuery(getSqlCommand.getSql) - .putAllNamedArguments(getSqlCommand.getNamedArgumentsMap) - .addAllPosArguments(getSqlCommand.getPosArgumentsList) - .putAllArgs(getSqlCommand.getArgsMap) - .addAllPosArgs(getSqlCommand.getPosArgsList))) + result.setRelation(relation) } executeHolder.eventsManager.postFinished(Some(rows.size)) // Exactly one SQL Command Result Batch @@ -2666,6 +2670,79 @@ class SparkConnectPlanner( } } + private def isValidSQLWithRefs(query: proto.WithRelations): Boolean = { + query.getRoot.getRelTypeCase match { + case proto.Relation.RelTypeCase.SQL => + case _ => return false + } + if (query.getReferencesCount == 0) { + return false + } + query.getReferencesList.iterator().asScala.foreach { ref => + ref.getRelTypeCase match { + case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => + case _ => return false + } + } + true + } + + private def executeSQLWithRefs( + query: proto.WithRelations, + tracker: QueryPlanningTracker = new QueryPlanningTracker) = { + if (!isValidSQLWithRefs(query)) { + throw InvalidPlanInput(s"$query is not a valid relation for SQL with references") + } + + // Eagerly execute commands of the provided SQL string, with given references. + val sql = query.getRoot.getSql + this.synchronized { + try { + query.getReferencesList.asScala.foreach { ref => + Dataset + .ofRows(session, transformRelation(ref.getSubqueryAlias.getInput)) + .createOrReplaceTempView(ref.getSubqueryAlias.getAlias) + } + executeSQL(sql, tracker) + } finally { + // drop all temporary views + query.getReferencesList.asScala.foreach { ref => + session.catalog.dropTempView(ref.getSubqueryAlias.getAlias) + } + } + } + } + + private def executeSQL( + sql: proto.SQL, + tracker: QueryPlanningTracker = new QueryPlanningTracker) = { + // Eagerly execute commands of the provided SQL string. + val args = sql.getArgsMap + val namedArguments = sql.getNamedArgumentsMap + val posArgs = sql.getPosArgsList + val posArguments = sql.getPosArgumentsList + if (!namedArguments.isEmpty) { + session.sql( + sql.getQuery, + namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))), + tracker) + } else if (!posArguments.isEmpty) { + session.sql( + sql.getQuery, + posArguments.asScala.map(e => Column(transformExpression(e))).toArray, + tracker) + } else if (!args.isEmpty) { + session.sql( + sql.getQuery, + args.asScala.toMap.transform((_, v) => transformLiteral(v)), + tracker) + } else if (!posArgs.isEmpty) { + session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker) + } else { + session.sql(sql.getQuery, Map.empty[String, Any], tracker) + } + } + private def handleRegisterUserDefinedFunction( fun: proto.CommonInlineUserDefinedFunction): Unit = { fun.getFunctionCase match { diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7751e42466aa..72b8372c8039 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -69,14 +69,17 @@ class LogicalPlan: def __init__(self, child: Optional["LogicalPlan"]) -> None: self._child = child + self._plan_id = LogicalPlan._fresh_plan_id() + @staticmethod + def _fresh_plan_id() -> int: plan_id: Optional[int] = None with LogicalPlan._lock: plan_id = LogicalPlan._nextPlanId LogicalPlan._nextPlanId += 1 assert plan_id is not None - self._plan_id = plan_id + return plan_id def _create_proto_relation(self) -> proto.Relation: plan = proto.Relation() @@ -1115,12 +1118,33 @@ class SubqueryAlias(LogicalPlan): return plan +class WithRelations(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + references: Sequence["LogicalPlan"], + ) -> None: + super().__init__(child) + assert references is not None and len(references) > 0 + assert all(isinstance(ref, LogicalPlan) for ref in references) + self._references = references + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() + if self._child is not None: + plan.with_relations.root.CopyFrom(self._child.plan(session)) + for ref in self._references: + plan.with_relations.references.append(ref.plan(session)) + return plan + + class SQL(LogicalPlan): def __init__( self, query: str, args: Optional[List[Column]] = None, named_args: Optional[Dict[str, Column]] = None, + views: Optional[Sequence[SubqueryAlias]] = None, ) -> None: super().__init__(None) @@ -1134,9 +1158,17 @@ class SQL(LogicalPlan): assert isinstance(k, str) assert isinstance(arg, Column) + if views is not None: + assert isinstance(views, List) + assert all(isinstance(v, SubqueryAlias) for v in views) + if len(views) > 0: + # reserved plan id for WithRelations + self._plan_id_with_rel = LogicalPlan._fresh_plan_id() + self._query = query self._args = args self._named_args = named_args + self._views = views def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() @@ -1147,17 +1179,25 @@ class SQL(LogicalPlan): if self._named_args is not None and len(self._named_args) > 0: for k, arg in self._named_args.items(): plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session)) + + if self._views is not None and len(self._views) > 0: + # build new plan like + # with_relations [id 10] + # root: sql [id 9] + # reference: + # view#1: [id 8] + # view#2: [id 5] + sql_plan = plan + plan = proto.Relation() + plan.common.plan_id = self._plan_id_with_rel + plan.with_relations.root.CopyFrom(sql_plan) + plan.with_relations.references.extend([v.plan(session) for v in self._views]) + return plan def command(self, session: "SparkConnectClient") -> proto.Command: cmd = proto.Command() - cmd.sql_command.sql = self._query - - if self._args is not None and len(self._args) > 0: - cmd.sql_command.pos_arguments.extend([arg.to_plan(session) for arg in self._args]) - if self._named_args is not None and len(self._named_args) > 0: - for k, arg in self._named_args.items(): - cmd.sql_command.named_arguments[k].CopyFrom(arg.to_plan(session)) + cmd.sql_command.input.CopyFrom(self.plan(session)) return cmd diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index eba96d28eb68..118d8200393a 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xd5\n\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18 [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xd5\n\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -49,10 +49,16 @@ if _descriptor._USE_C_DESCRIPTORS == False: _SQLCOMMAND_ARGSENTRY._serialized_options = b"8\001" _SQLCOMMAND_NAMEDARGUMENTSENTRY._options = None _SQLCOMMAND_NAMEDARGUMENTSENTRY._serialized_options = b"8\001" + _SQLCOMMAND.fields_by_name["sql"]._options = None + _SQLCOMMAND.fields_by_name["sql"]._serialized_options = b"\030\001" _SQLCOMMAND.fields_by_name["args"]._options = None _SQLCOMMAND.fields_by_name["args"]._serialized_options = b"\030\001" _SQLCOMMAND.fields_by_name["pos_args"]._options = None _SQLCOMMAND.fields_by_name["pos_args"]._serialized_options = b"\030\001" + _SQLCOMMAND.fields_by_name["named_arguments"]._options = None + _SQLCOMMAND.fields_by_name["named_arguments"]._serialized_options = b"\030\001" + _SQLCOMMAND.fields_by_name["pos_arguments"]._options = None + _SQLCOMMAND.fields_by_name["pos_arguments"]._serialized_options = b"\030\001" _WRITEOPERATION_OPTIONSENTRY._options = None _WRITEOPERATION_OPTIONSENTRY._serialized_options = b"8\001" _WRITEOPERATIONV2_OPTIONSENTRY._options = None @@ -63,98 +69,98 @@ if _descriptor._USE_C_DESCRIPTORS == False: _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001" _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" - _STREAMINGQUERYEVENTTYPE._serialized_start = 10021 - _STREAMINGQUERYEVENTTYPE._serialized_end = 10154 + _STREAMINGQUERYEVENTTYPE._serialized_start = 10080 + _STREAMINGQUERYEVENTTYPE._serialized_end = 10213 _COMMAND._serialized_start = 167 _COMMAND._serialized_end = 1532 _SQLCOMMAND._serialized_start = 1535 - _SQLCOMMAND._serialized_end = 2030 - _SQLCOMMAND_ARGSENTRY._serialized_start = 1846 - _SQLCOMMAND_ARGSENTRY._serialized_end = 1936 - _SQLCOMMAND_NAMEDARGUMENTSENTRY._serialized_start = 1938 - _SQLCOMMAND_NAMEDARGUMENTSENTRY._serialized_end = 2030 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 2033 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 2183 - _WRITEOPERATION._serialized_start = 2186 - _WRITEOPERATION._serialized_end = 3284 - _WRITEOPERATION_OPTIONSENTRY._serialized_start = 2708 - _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2766 - _WRITEOPERATION_SAVETABLE._serialized_start = 2769 - _WRITEOPERATION_SAVETABLE._serialized_end = 3027 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2903 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 3027 - _WRITEOPERATION_BUCKETBY._serialized_start = 3029 - _WRITEOPERATION_BUCKETBY._serialized_end = 3120 - _WRITEOPERATION_SAVEMODE._serialized_start = 3123 - _WRITEOPERATION_SAVEMODE._serialized_end = 3260 - _WRITEOPERATIONV2._serialized_start = 3287 - _WRITEOPERATIONV2._serialized_end = 4147 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 2708 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2766 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3906 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 3972 - _WRITEOPERATIONV2_MODE._serialized_start = 3975 - _WRITEOPERATIONV2_MODE._serialized_end = 4134 - _WRITESTREAMOPERATIONSTART._serialized_start = 4150 - _WRITESTREAMOPERATIONSTART._serialized_end = 4950 - _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 2708 - _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2766 - _STREAMINGFOREACHFUNCTION._serialized_start = 4953 - _STREAMINGFOREACHFUNCTION._serialized_end = 5132 - _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 5135 - _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 5347 - _STREAMINGQUERYINSTANCEID._serialized_start = 5349 - _STREAMINGQUERYINSTANCEID._serialized_end = 5414 - _STREAMINGQUERYCOMMAND._serialized_start = 5417 - _STREAMINGQUERYCOMMAND._serialized_end = 6049 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5916 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5960 - _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5962 - _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 6038 - _STREAMINGQUERYCOMMANDRESULT._serialized_start = 6052 - _STREAMINGQUERYCOMMANDRESULT._serialized_end = 7193 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 6635 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 6805 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 6807 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 6879 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 6881 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6920 - _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6923 - _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 7120 - _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 7122 - _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 7178 - _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 7196 - _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 8025 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 7727 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 7806 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7809 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 8014 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 8028 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 9104 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 8636 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 8763 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 8765 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 8880 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 8882 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8941 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8943 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 9018 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 9020 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 9089 - _STREAMINGQUERYLISTENERBUSCOMMAND._serialized_start = 9107 - _STREAMINGQUERYLISTENERBUSCOMMAND._serialized_end = 9280 - _STREAMINGQUERYLISTENEREVENT._serialized_start = 9283 - _STREAMINGQUERYLISTENEREVENT._serialized_end = 9414 - _STREAMINGQUERYLISTENEREVENTSRESULT._serialized_start = 9417 - _STREAMINGQUERYLISTENEREVENTSRESULT._serialized_end = 9621 - _GETRESOURCESCOMMAND._serialized_start = 9623 - _GETRESOURCESCOMMAND._serialized_end = 9644 - _GETRESOURCESCOMMANDRESULT._serialized_start = 9647 - _GETRESOURCESCOMMANDRESULT._serialized_end = 9859 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 9763 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 9859 - _CREATERESOURCEPROFILECOMMAND._serialized_start = 9861 - _CREATERESOURCEPROFILECOMMAND._serialized_end = 9949 - _CREATERESOURCEPROFILECOMMANDRESULT._serialized_start = 9951 - _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10018 + _SQLCOMMAND._serialized_end = 2089 + _SQLCOMMAND_ARGSENTRY._serialized_start = 1905 + _SQLCOMMAND_ARGSENTRY._serialized_end = 1995 + _SQLCOMMAND_NAMEDARGUMENTSENTRY._serialized_start = 1997 + _SQLCOMMAND_NAMEDARGUMENTSENTRY._serialized_end = 2089 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 2092 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 2242 + _WRITEOPERATION._serialized_start = 2245 + _WRITEOPERATION._serialized_end = 3343 + _WRITEOPERATION_OPTIONSENTRY._serialized_start = 2767 + _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2825 + _WRITEOPERATION_SAVETABLE._serialized_start = 2828 + _WRITEOPERATION_SAVETABLE._serialized_end = 3086 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2962 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 3086 + _WRITEOPERATION_BUCKETBY._serialized_start = 3088 + _WRITEOPERATION_BUCKETBY._serialized_end = 3179 + _WRITEOPERATION_SAVEMODE._serialized_start = 3182 + _WRITEOPERATION_SAVEMODE._serialized_end = 3319 + _WRITEOPERATIONV2._serialized_start = 3346 + _WRITEOPERATIONV2._serialized_end = 4206 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 2767 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2825 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3965 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 4031 + _WRITEOPERATIONV2_MODE._serialized_start = 4034 + _WRITEOPERATIONV2_MODE._serialized_end = 4193 + _WRITESTREAMOPERATIONSTART._serialized_start = 4209 + _WRITESTREAMOPERATIONSTART._serialized_end = 5009 + _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 2767 + _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2825 + _STREAMINGFOREACHFUNCTION._serialized_start = 5012 + _STREAMINGFOREACHFUNCTION._serialized_end = 5191 + _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 5194 + _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 5406 + _STREAMINGQUERYINSTANCEID._serialized_start = 5408 + _STREAMINGQUERYINSTANCEID._serialized_end = 5473 + _STREAMINGQUERYCOMMAND._serialized_start = 5476 + _STREAMINGQUERYCOMMAND._serialized_end = 6108 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5975 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 6019 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 6021 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 6097 + _STREAMINGQUERYCOMMANDRESULT._serialized_start = 6111 + _STREAMINGQUERYCOMMANDRESULT._serialized_end = 7252 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 6694 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 6864 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 6866 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 6938 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 6940 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6979 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6982 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 7179 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 7181 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 7237 + _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 7255 + _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 8084 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 7786 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 7865 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7868 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 8073 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 8087 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 9163 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 8695 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 8822 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 8824 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 8939 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 8941 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 9000 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 9002 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 9077 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 9079 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 9148 + _STREAMINGQUERYLISTENERBUSCOMMAND._serialized_start = 9166 + _STREAMINGQUERYLISTENERBUSCOMMAND._serialized_end = 9339 + _STREAMINGQUERYLISTENEREVENT._serialized_start = 9342 + _STREAMINGQUERYLISTENEREVENT._serialized_end = 9473 + _STREAMINGQUERYLISTENEREVENTSRESULT._serialized_start = 9476 + _STREAMINGQUERYLISTENEREVENTSRESULT._serialized_end = 9680 + _GETRESOURCESCOMMAND._serialized_start = 9682 + _GETRESOURCESCOMMAND._serialized_end = 9703 + _GETRESOURCESCOMMANDRESULT._serialized_start = 9706 + _GETRESOURCESCOMMANDRESULT._serialized_end = 9918 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 9822 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 9918 + _CREATERESOURCEPROFILECOMMAND._serialized_start = 9920 + _CREATERESOURCEPROFILECOMMAND._serialized_end = 10008 + _CREATERESOURCEPROFILECOMMANDRESULT._serialized_start = 10010 + _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10077 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index b57a2a6c4d68..f86ae653508e 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -313,6 +313,7 @@ class SqlCommand(google.protobuf.message.Message): POS_ARGS_FIELD_NUMBER: builtins.int NAMED_ARGUMENTS_FIELD_NUMBER: builtins.int POS_ARGUMENTS_FIELD_NUMBER: builtins.int + INPUT_FIELD_NUMBER: builtins.int sql: builtins.str """(Required) SQL Query.""" @property @@ -347,6 +348,9 @@ class SqlCommand(google.protobuf.message.Message): """(Optional) A sequence of expressions for positional parameters in the SQL query text. It cannot coexist with `named_arguments`. """ + @property + def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: + """(Optional) The relation that this SQL command will be built on.""" def __init__( self, *, @@ -367,12 +371,18 @@ class SqlCommand(google.protobuf.message.Message): pyspark.sql.connect.proto.expressions_pb2.Expression ] | None = ..., + input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ "args", b"args", + "input", + b"input", "named_arguments", b"named_arguments", "pos_args", diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index ff01d5fe346a..82208d00485b 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import common_pb2 as spark_dot_connect_dot_common DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xa2\x1a\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.Project [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xe9\x1a\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.Project [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -65,151 +65,153 @@ if _descriptor._USE_C_DESCRIPTORS == False: _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 193 - _RELATION._serialized_end = 3555 - _UNKNOWN._serialized_start = 3557 - _UNKNOWN._serialized_end = 3566 - _RELATIONCOMMON._serialized_start = 3568 - _RELATIONCOMMON._serialized_end = 3659 - _SQL._serialized_start = 3662 - _SQL._serialized_end = 4140 - _SQL_ARGSENTRY._serialized_start = 3956 - _SQL_ARGSENTRY._serialized_end = 4046 - _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4048 - _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4140 - _READ._serialized_start = 4143 - _READ._serialized_end = 4806 - _READ_NAMEDTABLE._serialized_start = 4321 - _READ_NAMEDTABLE._serialized_end = 4513 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4455 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4513 - _READ_DATASOURCE._serialized_start = 4516 - _READ_DATASOURCE._serialized_end = 4793 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4455 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4513 - _PROJECT._serialized_start = 4808 - _PROJECT._serialized_end = 4925 - _FILTER._serialized_start = 4927 - _FILTER._serialized_end = 5039 - _JOIN._serialized_start = 5042 - _JOIN._serialized_end = 5703 - _JOIN_JOINDATATYPE._serialized_start = 5381 - _JOIN_JOINDATATYPE._serialized_end = 5473 - _JOIN_JOINTYPE._serialized_start = 5476 - _JOIN_JOINTYPE._serialized_end = 5684 - _SETOPERATION._serialized_start = 5706 - _SETOPERATION._serialized_end = 6185 - _SETOPERATION_SETOPTYPE._serialized_start = 6022 - _SETOPERATION_SETOPTYPE._serialized_end = 6136 - _LIMIT._serialized_start = 6187 - _LIMIT._serialized_end = 6263 - _OFFSET._serialized_start = 6265 - _OFFSET._serialized_end = 6344 - _TAIL._serialized_start = 6346 - _TAIL._serialized_end = 6421 - _AGGREGATE._serialized_start = 6424 - _AGGREGATE._serialized_end = 7190 - _AGGREGATE_PIVOT._serialized_start = 6839 - _AGGREGATE_PIVOT._serialized_end = 6950 - _AGGREGATE_GROUPINGSETS._serialized_start = 6952 - _AGGREGATE_GROUPINGSETS._serialized_end = 7028 - _AGGREGATE_GROUPTYPE._serialized_start = 7031 - _AGGREGATE_GROUPTYPE._serialized_end = 7190 - _SORT._serialized_start = 7193 - _SORT._serialized_end = 7353 - _DROP._serialized_start = 7356 - _DROP._serialized_end = 7497 - _DEDUPLICATE._serialized_start = 7500 - _DEDUPLICATE._serialized_end = 7740 - _LOCALRELATION._serialized_start = 7742 - _LOCALRELATION._serialized_end = 7831 - _CACHEDLOCALRELATION._serialized_start = 7833 - _CACHEDLOCALRELATION._serialized_end = 7905 - _CACHEDREMOTERELATION._serialized_start = 7907 - _CACHEDREMOTERELATION._serialized_end = 7962 - _SAMPLE._serialized_start = 7965 - _SAMPLE._serialized_end = 8238 - _RANGE._serialized_start = 8241 - _RANGE._serialized_end = 8386 - _SUBQUERYALIAS._serialized_start = 8388 - _SUBQUERYALIAS._serialized_end = 8502 - _REPARTITION._serialized_start = 8505 - _REPARTITION._serialized_end = 8647 - _SHOWSTRING._serialized_start = 8650 - _SHOWSTRING._serialized_end = 8792 - _HTMLSTRING._serialized_start = 8794 - _HTMLSTRING._serialized_end = 8908 - _STATSUMMARY._serialized_start = 8910 - _STATSUMMARY._serialized_end = 9002 - _STATDESCRIBE._serialized_start = 9004 - _STATDESCRIBE._serialized_end = 9085 - _STATCROSSTAB._serialized_start = 9087 - _STATCROSSTAB._serialized_end = 9188 - _STATCOV._serialized_start = 9190 - _STATCOV._serialized_end = 9286 - _STATCORR._serialized_start = 9289 - _STATCORR._serialized_end = 9426 - _STATAPPROXQUANTILE._serialized_start = 9429 - _STATAPPROXQUANTILE._serialized_end = 9593 - _STATFREQITEMS._serialized_start = 9595 - _STATFREQITEMS._serialized_end = 9720 - _STATSAMPLEBY._serialized_start = 9723 - _STATSAMPLEBY._serialized_end = 10032 - _STATSAMPLEBY_FRACTION._serialized_start = 9924 - _STATSAMPLEBY_FRACTION._serialized_end = 10023 - _NAFILL._serialized_start = 10035 - _NAFILL._serialized_end = 10169 - _NADROP._serialized_start = 10172 - _NADROP._serialized_end = 10306 - _NAREPLACE._serialized_start = 10309 - _NAREPLACE._serialized_end = 10605 - _NAREPLACE_REPLACEMENT._serialized_start = 10464 - _NAREPLACE_REPLACEMENT._serialized_end = 10605 - _TODF._serialized_start = 10607 - _TODF._serialized_end = 10695 - _WITHCOLUMNSRENAMED._serialized_start = 10698 - _WITHCOLUMNSRENAMED._serialized_end = 11080 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10942 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11009 - _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11011 - _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11080 - _WITHCOLUMNS._serialized_start = 11082 - _WITHCOLUMNS._serialized_end = 11201 - _WITHWATERMARK._serialized_start = 11204 - _WITHWATERMARK._serialized_end = 11338 - _HINT._serialized_start = 11341 - _HINT._serialized_end = 11473 - _UNPIVOT._serialized_start = 11476 - _UNPIVOT._serialized_end = 11803 - _UNPIVOT_VALUES._serialized_start = 11733 - _UNPIVOT_VALUES._serialized_end = 11792 - _TOSCHEMA._serialized_start = 11805 - _TOSCHEMA._serialized_end = 11911 - _REPARTITIONBYEXPRESSION._serialized_start = 11914 - _REPARTITIONBYEXPRESSION._serialized_end = 12117 - _MAPPARTITIONS._serialized_start = 12120 - _MAPPARTITIONS._serialized_end = 12352 - _GROUPMAP._serialized_start = 12355 - _GROUPMAP._serialized_end = 12990 - _COGROUPMAP._serialized_start = 12993 - _COGROUPMAP._serialized_end = 13519 - _APPLYINPANDASWITHSTATE._serialized_start = 13522 - _APPLYINPANDASWITHSTATE._serialized_end = 13879 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13882 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14126 - _PYTHONUDTF._serialized_start = 14129 - _PYTHONUDTF._serialized_end = 14306 - _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14309 - _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14460 - _PYTHONDATASOURCE._serialized_start = 14462 - _PYTHONDATASOURCE._serialized_end = 14537 - _COLLECTMETRICS._serialized_start = 14540 - _COLLECTMETRICS._serialized_end = 14676 - _PARSE._serialized_start = 14679 - _PARSE._serialized_end = 15067 - _PARSE_OPTIONSENTRY._serialized_start = 4455 - _PARSE_OPTIONSENTRY._serialized_end = 4513 - _PARSE_PARSEFORMAT._serialized_start = 14968 - _PARSE_PARSEFORMAT._serialized_end = 15056 - _ASOFJOIN._serialized_start = 15070 - _ASOFJOIN._serialized_end = 15545 + _RELATION._serialized_end = 3626 + _UNKNOWN._serialized_start = 3628 + _UNKNOWN._serialized_end = 3637 + _RELATIONCOMMON._serialized_start = 3639 + _RELATIONCOMMON._serialized_end = 3730 + _SQL._serialized_start = 3733 + _SQL._serialized_end = 4211 + _SQL_ARGSENTRY._serialized_start = 4027 + _SQL_ARGSENTRY._serialized_end = 4117 + _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4119 + _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4211 + _WITHRELATIONS._serialized_start = 4213 + _WITHRELATIONS._serialized_end = 4330 + _READ._serialized_start = 4333 + _READ._serialized_end = 4996 + _READ_NAMEDTABLE._serialized_start = 4511 + _READ_NAMEDTABLE._serialized_end = 4703 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4645 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4703 + _READ_DATASOURCE._serialized_start = 4706 + _READ_DATASOURCE._serialized_end = 4983 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4645 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4703 + _PROJECT._serialized_start = 4998 + _PROJECT._serialized_end = 5115 + _FILTER._serialized_start = 5117 + _FILTER._serialized_end = 5229 + _JOIN._serialized_start = 5232 + _JOIN._serialized_end = 5893 + _JOIN_JOINDATATYPE._serialized_start = 5571 + _JOIN_JOINDATATYPE._serialized_end = 5663 + _JOIN_JOINTYPE._serialized_start = 5666 + _JOIN_JOINTYPE._serialized_end = 5874 + _SETOPERATION._serialized_start = 5896 + _SETOPERATION._serialized_end = 6375 + _SETOPERATION_SETOPTYPE._serialized_start = 6212 + _SETOPERATION_SETOPTYPE._serialized_end = 6326 + _LIMIT._serialized_start = 6377 + _LIMIT._serialized_end = 6453 + _OFFSET._serialized_start = 6455 + _OFFSET._serialized_end = 6534 + _TAIL._serialized_start = 6536 + _TAIL._serialized_end = 6611 + _AGGREGATE._serialized_start = 6614 + _AGGREGATE._serialized_end = 7380 + _AGGREGATE_PIVOT._serialized_start = 7029 + _AGGREGATE_PIVOT._serialized_end = 7140 + _AGGREGATE_GROUPINGSETS._serialized_start = 7142 + _AGGREGATE_GROUPINGSETS._serialized_end = 7218 + _AGGREGATE_GROUPTYPE._serialized_start = 7221 + _AGGREGATE_GROUPTYPE._serialized_end = 7380 + _SORT._serialized_start = 7383 + _SORT._serialized_end = 7543 + _DROP._serialized_start = 7546 + _DROP._serialized_end = 7687 + _DEDUPLICATE._serialized_start = 7690 + _DEDUPLICATE._serialized_end = 7930 + _LOCALRELATION._serialized_start = 7932 + _LOCALRELATION._serialized_end = 8021 + _CACHEDLOCALRELATION._serialized_start = 8023 + _CACHEDLOCALRELATION._serialized_end = 8095 + _CACHEDREMOTERELATION._serialized_start = 8097 + _CACHEDREMOTERELATION._serialized_end = 8152 + _SAMPLE._serialized_start = 8155 + _SAMPLE._serialized_end = 8428 + _RANGE._serialized_start = 8431 + _RANGE._serialized_end = 8576 + _SUBQUERYALIAS._serialized_start = 8578 + _SUBQUERYALIAS._serialized_end = 8692 + _REPARTITION._serialized_start = 8695 + _REPARTITION._serialized_end = 8837 + _SHOWSTRING._serialized_start = 8840 + _SHOWSTRING._serialized_end = 8982 + _HTMLSTRING._serialized_start = 8984 + _HTMLSTRING._serialized_end = 9098 + _STATSUMMARY._serialized_start = 9100 + _STATSUMMARY._serialized_end = 9192 + _STATDESCRIBE._serialized_start = 9194 + _STATDESCRIBE._serialized_end = 9275 + _STATCROSSTAB._serialized_start = 9277 + _STATCROSSTAB._serialized_end = 9378 + _STATCOV._serialized_start = 9380 + _STATCOV._serialized_end = 9476 + _STATCORR._serialized_start = 9479 + _STATCORR._serialized_end = 9616 + _STATAPPROXQUANTILE._serialized_start = 9619 + _STATAPPROXQUANTILE._serialized_end = 9783 + _STATFREQITEMS._serialized_start = 9785 + _STATFREQITEMS._serialized_end = 9910 + _STATSAMPLEBY._serialized_start = 9913 + _STATSAMPLEBY._serialized_end = 10222 + _STATSAMPLEBY_FRACTION._serialized_start = 10114 + _STATSAMPLEBY_FRACTION._serialized_end = 10213 + _NAFILL._serialized_start = 10225 + _NAFILL._serialized_end = 10359 + _NADROP._serialized_start = 10362 + _NADROP._serialized_end = 10496 + _NAREPLACE._serialized_start = 10499 + _NAREPLACE._serialized_end = 10795 + _NAREPLACE_REPLACEMENT._serialized_start = 10654 + _NAREPLACE_REPLACEMENT._serialized_end = 10795 + _TODF._serialized_start = 10797 + _TODF._serialized_end = 10885 + _WITHCOLUMNSRENAMED._serialized_start = 10888 + _WITHCOLUMNSRENAMED._serialized_end = 11270 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11132 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11199 + _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11201 + _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11270 + _WITHCOLUMNS._serialized_start = 11272 + _WITHCOLUMNS._serialized_end = 11391 + _WITHWATERMARK._serialized_start = 11394 + _WITHWATERMARK._serialized_end = 11528 + _HINT._serialized_start = 11531 + _HINT._serialized_end = 11663 + _UNPIVOT._serialized_start = 11666 + _UNPIVOT._serialized_end = 11993 + _UNPIVOT_VALUES._serialized_start = 11923 + _UNPIVOT_VALUES._serialized_end = 11982 + _TOSCHEMA._serialized_start = 11995 + _TOSCHEMA._serialized_end = 12101 + _REPARTITIONBYEXPRESSION._serialized_start = 12104 + _REPARTITIONBYEXPRESSION._serialized_end = 12307 + _MAPPARTITIONS._serialized_start = 12310 + _MAPPARTITIONS._serialized_end = 12542 + _GROUPMAP._serialized_start = 12545 + _GROUPMAP._serialized_end = 13180 + _COGROUPMAP._serialized_start = 13183 + _COGROUPMAP._serialized_end = 13709 + _APPLYINPANDASWITHSTATE._serialized_start = 13712 + _APPLYINPANDASWITHSTATE._serialized_end = 14069 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14072 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14316 + _PYTHONUDTF._serialized_start = 14319 + _PYTHONUDTF._serialized_end = 14496 + _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14499 + _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14650 + _PYTHONDATASOURCE._serialized_start = 14652 + _PYTHONDATASOURCE._serialized_end = 14727 + _COLLECTMETRICS._serialized_start = 14730 + _COLLECTMETRICS._serialized_end = 14866 + _PARSE._serialized_start = 14869 + _PARSE._serialized_end = 15257 + _PARSE_OPTIONSENTRY._serialized_start = 4645 + _PARSE_OPTIONSENTRY._serialized_end = 4703 + _PARSE_PARSEFORMAT._serialized_start = 15158 + _PARSE_PARSEFORMAT._serialized_end = 15246 + _ASOFJOIN._serialized_start = 15260 + _ASOFJOIN._serialized_end = 15735 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index db9609eebb85..5dfb47da67a9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -102,6 +102,7 @@ class Relation(google.protobuf.message.Message): COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION_FIELD_NUMBER: builtins.int AS_OF_JOIN_FIELD_NUMBER: builtins.int COMMON_INLINE_USER_DEFINED_DATA_SOURCE_FIELD_NUMBER: builtins.int + WITH_RELATIONS_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -201,6 +202,8 @@ class Relation(google.protobuf.message.Message): self, ) -> global___CommonInlineUserDefinedDataSource: ... @property + def with_relations(self) -> global___WithRelations: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -279,6 +282,7 @@ class Relation(google.protobuf.message.Message): as_of_join: global___AsOfJoin | None = ..., common_inline_user_defined_data_source: global___CommonInlineUserDefinedDataSource | None = ..., + with_relations: global___WithRelations | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -405,6 +409,8 @@ class Relation(google.protobuf.message.Message): b"with_columns", "with_columns_renamed", b"with_columns_renamed", + "with_relations", + b"with_relations", "with_watermark", b"with_watermark", ], @@ -520,6 +526,8 @@ class Relation(google.protobuf.message.Message): b"with_columns", "with_columns_renamed", b"with_columns_renamed", + "with_relations", + b"with_relations", "with_watermark", b"with_watermark", ], @@ -567,6 +575,7 @@ class Relation(google.protobuf.message.Message): "common_inline_user_defined_table_function", "as_of_join", "common_inline_user_defined_data_source", + "with_relations", "fill_na", "drop_na", "replace", @@ -755,6 +764,45 @@ class SQL(google.protobuf.message.Message): global___SQL = SQL +class WithRelations(google.protobuf.message.Message): + """Relation of type [[WithRelations]]. + + This relation contains a root plan, and one or more references that are used by the root plan. + There are two ways of referencing a relation, by name (through a subquery alias), or by plan_id + (using RelationCommon.plan_id). + + This relation can be used to implement CTEs, describe DAGs, or to reduce tree depth. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ROOT_FIELD_NUMBER: builtins.int + REFERENCES_FIELD_NUMBER: builtins.int + @property + def root(self) -> global___Relation: + """(Required) Plan at the root of the query tree. This plan is expected to contain one or more + references. Those references get expanded later on by the engine. + """ + @property + def references( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Relation]: + """(Required) Plans referenced by the root plan. Relations in this list are also allowed to + contain references to other relations in this list, as long they do not form cycles. + """ + def __init__( + self, + *, + root: global___Relation | None = ..., + references: collections.abc.Iterable[global___Relation] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["root", b"root"]) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["references", b"references", "root", b"root"] + ) -> None: ... + +global___WithRelations = WithRelations + class Read(google.protobuf.message.Message): """Relation that reads from a file / table or other data source. Does not have additional inputs. diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 40a8076698bf..07fe8a62f082 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -60,6 +60,7 @@ from pyspark.sql.connect.plan import ( CachedLocalRelation, CachedRelation, CachedRemoteRelation, + SubqueryAlias, ) from pyspark.sql.connect.functions import builtin as F from pyspark.sql.connect.profiler import ProfilerCollector @@ -619,7 +620,12 @@ class SparkSession: createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ - def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame": + def sql( + self, + sqlQuery: str, + args: Optional[Union[Dict[str, Any], List]] = None, + **kwargs: Any, + ) -> "DataFrame": _args = [] _named_args = {} if args is not None: @@ -635,7 +641,17 @@ class SparkSession: message_parameters={"arg_name": "args", "arg_type": type(args).__name__}, ) - cmd = SQL(sqlQuery, _args, _named_args) + _views: List[SubqueryAlias] = [] + if len(kwargs) > 0: + from pyspark.sql.connect.sql_formatter import SQLStringFormatter + + formatter = SQLStringFormatter(self) + sqlQuery = formatter.format(sqlQuery, **kwargs) + + for df, name in formatter._temp_views: + _views.append(SubqueryAlias(df._plan, name)) + + cmd = SQL(sqlQuery, _args, _named_args, _views) data, properties = self.client.execute_command(cmd.command(self._client)) if "sql_command_result" in properties: return DataFrame(CachedRelation(properties["sql_command_result"]), self) @@ -1042,9 +1058,6 @@ def _test() -> None: pyspark.sql.connect.session.SparkSession.__doc__ = None del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__ - # TODO(SPARK-41811): Implement SparkSession.sql's string formatter - del pyspark.sql.connect.session.SparkSession.sql.__doc__ - (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.session, globs=globs, diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/connect/sql_formatter.py similarity index 64% copy from python/pyspark/sql/sql_formatter.py copy to python/pyspark/sql/connect/sql_formatter.py index 6d37821e5374..38b94bbaf205 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/connect/sql_formatter.py @@ -20,11 +20,12 @@ import typing from typing import Any, Optional, List, Tuple, Sequence, Mapping import uuid -if typing.TYPE_CHECKING: - from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.functions import lit from pyspark.errors import PySparkValueError +if typing.TYPE_CHECKING: + from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect.dataframe import DataFrame + class SQLStringFormatter(string.Formatter): """ @@ -45,41 +46,31 @@ class SQLStringFormatter(string.Formatter): """ Converts the given value into a SQL string. """ - from py4j.java_gateway import is_instance_of - - from pyspark import SparkContext - from pyspark.sql import Column, DataFrame + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.column import Column + from pyspark.sql.connect.expressions import ColumnReference + from pyspark.sql.utils import get_lit_sql_str if isinstance(val, Column): - assert SparkContext._gateway is not None - - gw = SparkContext._gateway - jexpr = val._jc.expr() - if is_instance_of( - gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute" - ) or is_instance_of( - gw, jexpr, "org.apache.spark.sql.catalyst.expressions.AttributeReference" - ): - return jexpr.sql() + expr = val._expr + if isinstance(expr, ColumnReference): + return expr._unparsed_identifier else: raise PySparkValueError( - error_class="VALUE_NOT_PLAIN_COLUMN_REFERENCE", - message_parameters={"val": str(val), "field_name": field_name}, + "%s in %s should be a plain column reference such as `df.col` " + "or `col('column')`" % (val, field_name) ) elif isinstance(val, DataFrame): for df, n in self._temp_views: if df is val: return n - df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "") - self._temp_views.append((val, df_name)) - val.createOrReplaceTempView(df_name) - return df_name + name = "_pyspark_connect_temp_view_%s" % str(uuid.uuid4()).replace("-", "") + self._temp_views.append((val, name)) + return name elif isinstance(val, str): - return lit(val)._jc.expr().sql() # for escaped characters. + return get_lit_sql_str(val) else: return val def clear(self) -> None: - for _, n in self._temp_views: - self._session.catalog.dropTempView(n) - self._temp_views = [] + pass diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c187122cdb40..f1666a9f575c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1695,7 +1695,7 @@ class SparkSession(SparkConversionMixin): And substitute named parameters with the `:` prefix by SQL literals. - >>> from pyspark.sql.functions import create_map + >>> from pyspark.sql.functions import create_map, lit >>> spark.sql( ... "SELECT *, element_at(:m, 'a') AS C FROM {df} WHERE {df[B]} > :minB", ... {"minB" : 5, "m" : create_map(lit('a'), lit(1))}, df=mydf).show() @@ -1707,7 +1707,7 @@ class SparkSession(SparkConversionMixin): Or positional parameters marked by `?` in the SQL query by SQL literals. - >>> from pyspark.sql.functions import array + >>> from pyspark.sql.functions import array, lit >>> spark.sql( ... "SELECT *, element_at(?, 1) AS C FROM {df} WHERE {df[B]} > ? and ? < {df[A]}", ... args=[array(lit(1), lit(2), lit(3)), 5, 2], df=mydf).show() diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index 6d37821e5374..abb75f88f385 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -22,7 +22,7 @@ import uuid if typing.TYPE_CHECKING: from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.functions import lit +from pyspark.sql.utils import get_lit_sql_str from pyspark.errors import PySparkValueError @@ -75,7 +75,7 @@ class SQLStringFormatter(string.Formatter): val.createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - return lit(val)._jc.expr().sql() # for escaped characters. + return get_lit_sql_str(val) else: return val diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 52911a506c83..3a221cacedb2 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -756,7 +756,7 @@ class SparkConnectPlanTests(PlanOnlyTestFixture): # SPARK-41717: test print self.assertEqual( self.connect.sql("SELECT 1")._plan.print().strip(), - "<SQL query='SELECT 1', args='None', named_args='None'>", + "<SQL query='SELECT 1', args='None', named_args='None', views='None'>", ) self.assertEqual( self.connect.range(1, 10)._plan.print().strip(), diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 09ad959e2b8e..be4620366571 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -357,3 +357,10 @@ def get_window_class() -> Type["Window"]: return ConnectWindow # type: ignore[return-value] else: return PySparkWindow + + +def get_lit_sql_str(val: str) -> str: + # Equivalent to `lit(val)._jc.expr().sql()` for string typed val + # See `sql` definition in `sql/catalyst/src/main/scala/org/apache/spark/ + # sql/catalyst/expressions/literals.scala` + return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org