[ https://issues.apache.org/jira/browse/SPARK-16418?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15375680#comment-15375680 ]
Dongjoon Hyun commented on SPARK-16418: --------------------------------------- Hi, [~erikwright]. I think the following is what you trying. If I missed something, please comment me. (The following is the result on current master, 2.1.0.) {code} >>> from pyspark.sql import types >>> from pyspark.sql import Window >>> from pyspark.sql import functions >>> >>> schema = types.StructType([types.StructField('id', types.IntegerType(), >>> False),types.StructField('state', types.StringType(), >>> True),types.StructField('seq', types.IntegerType(), False)]) >>> df = spark.createDataFrame([(1, 'hello', 1),(1, 'world', 2),(1,'world',3)], >>> schema) >>> df2 = df.withColumn("c", >>> functions.lag('state').over(Window.partitionBy('id').orderBy('seq').rowsBetween(-1,-1))) >>> df2.filter(df2["c"] != 'world').show() +---+-----+---+-----+ | id|state|seq| c| +---+-----+---+-----+ | 1|world| 2|hello| +---+-----+---+-----+ {code} > DataFrame.filter fails if it references a window function > --------------------------------------------------------- > > Key: SPARK-16418 > URL: https://issues.apache.org/jira/browse/SPARK-16418 > Project: Spark > Issue Type: Bug > Components: SQL > Affects Versions: 2.0.0 > Reporter: Erik Wright > > I'm using Data Frames in Python. If I build up a column expression that > includes a window function, then filter on it, the resulting Data Frame > cannot be evaluated. > If I first add that column expression to the Data Frame as a column (or add > the sub-expression that includes the window function as a column), the filter > works. This works even if I later drop the added column. > It seems like this shouldn't be required. In the worst case, the platform > should be able to do this for me under the hood when/if necessary. > {code:none} > In [1]: from pyspark.sql import types > In [2]: from pyspark.sql import Window > In [3]: from pyspark.sql import functions > In [4]: schema = types.StructType([types.StructField('id', > types.IntegerType(), False), > ...: types.StructField('state', > types.StringType(), True), > ...: types.StructField('seq', > types.IntegerType(), False)]) > In [5]: original_data_frame = sc.sql.createDataFrame([(1, 'hello', 1),(1, > 'world', 2),(1,'world',3)], schema) > In [6]: previous_state = > functions.lag('state').over(Window.partitionBy('id').orderBy('seq').rowsBetween(-1,-1)) > In [7]: filter_condition = (original_data_frame['state'] == 'world') & > (previous_state != 'world') > In [8]: data_frame = original_data_frame.withColumn('filter_condition', > filter_condition) > In [9]: data_frame.show() > +---+-----+---+----------------+ > | id|state|seq|filter_condition| > +---+-----+---+----------------+ > | 1|hello| 1| false| > | 1|world| 2| true| > | 1|world| 3| false| > +---+-----+---+----------------+ > In [10]: data_frame = > data_frame.filter(data_frame['filter_condition']).drop('filter_condition') > In [11]: data_frame.explain() > == Physical Plan == > WholeStageCodegen > : +- Project [id#0,state#1,seq#2] > : +- Filter (((isnotnull(state#1) && isnotnull(_we0#6)) && (state#1 = > world)) && NOT (_we0#6 = world)) > : +- INPUT > +- Window [lag(state#1, 1, null) windowspecdefinition(id#0, seq#2 ASC, ROWS > BETWEEN 1 PRECEDING AND 1 PRECEDING) AS _we0#6], [id#0], [seq#2 ASC] > +- WholeStageCodegen > : +- Sort [id#0 ASC,seq#2 ASC], false, 0 > : +- INPUT > +- Exchange hashpartitioning(id#0, 200), None > +- Scan ExistingRDD[id#0,state#1,seq#2] > In [12]: data_frame.show() > +---+-----+---+ > | id|state|seq| > +---+-----+---+ > | 1|world| 2| > +---+-----+---+ > In [13]: data_frame = original_data_frame.withColumn('previous_state', > previous_state) > In [14]: data_frame.show() > +---+-----+---+--------------+ > | id|state|seq|previous_state| > +---+-----+---+--------------+ > | 1|hello| 1| null| > | 1|world| 2| hello| > | 1|world| 3| world| > +---+-----+---+--------------+ > In [15]: filter_condition = (data_frame['state'] == 'world') & > (data_frame['previous_state'] != 'world') > In [16]: data_frame = > data_frame.filter(filter_condition).drop('previous_state') > In [17]: data_frame.explain() > == Physical Plan == > WholeStageCodegen > : +- Project [id#0,state#1,seq#2] > : +- Filter (((isnotnull(state#1) && isnotnull(previous_state#12)) && > (state#1 = world)) && NOT (previous_state#12 = world)) > : +- INPUT > +- Window [lag(state#1, 1, null) windowspecdefinition(id#0, seq#2 ASC, ROWS > BETWEEN 1 PRECEDING AND 1 PRECEDING) AS previous_state#12], [id#0], [seq#2 > ASC] > +- WholeStageCodegen > : +- Sort [id#0 ASC,seq#2 ASC], false, 0 > : +- INPUT > +- Exchange hashpartitioning(id#0, 200), None > +- Scan ExistingRDD[id#0,state#1,seq#2] > In [18]: data_frame.show() > +---+-----+---+ > | id|state|seq| > +---+-----+---+ > | 1|world| 2| > +---+-----+---+ > In [19]: filter_condition = (original_data_frame['state'] == 'world') & > (previous_state != 'world') > In [20]: data_frame = original_data_frame.filter(filter_condition) > In [21]: data_frame.explain() > == Physical Plan == > WholeStageCodegen > : +- Filter ((isnotnull(state#1) && (state#1 = world)) && NOT (lag(state#1, > 1, null) windowspecdefinition(id#0, seq#2 ASC, ROWS BETWEEN 1 PRECEDING AND 1 > PRECEDING) = world)) > : +- INPUT > +- Scan ExistingRDD[id#0,state#1,seq#2] > In [22]: data_frame.show() > --------------------------------------------------------------------------- > Py4JJavaError Traceback (most recent call last) > /Users/erikwright/src/starscream/bin/starscream in <module>() > ----> 1 data_frame.show() > /Users/erikwright/spark-2.0.0-preview-bin-hadoop2.7/python/pyspark/sql/dataframe.pyc > in show(self, n, truncate) > 271 +---+-----+ > 272 """ > --> 273 print(self._jdf.showString(n, truncate)) > 274 > 275 def __repr__(self): > /Users/erikwright/spark-2.0.0-preview-bin-hadoop2.7/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py > in __call__(self, *args) > 931 answer = self.gateway_client.send_command(command) > 932 return_value = get_return_value( > --> 933 answer, self.gateway_client, self.target_id, self.name) > 934 > 935 for temp_arg in temp_args: > /Users/erikwright/spark-2.0.0-preview-bin-hadoop2.7/python/pyspark/sql/utils.pyc > in deco(*a, **kw) > 55 def deco(*a, **kw): > 56 try: > ---> 57 return f(*a, **kw) > 58 except py4j.protocol.Py4JJavaError as e: > 59 s = e.java_exception.toString() > /Users/erikwright/spark-2.0.0-preview-bin-hadoop2.7/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py > in get_return_value(answer, gateway_client, target_id, name) > 310 raise Py4JJavaError( > 311 "An error occurred while calling {0}{1}{2}.\n". > --> 312 format(target_id, ".", name), value) > 313 else: > 314 raise Py4JError( > Py4JJavaError: An error occurred while calling o128.showString. > : java.lang.UnsupportedOperationException: Cannot evaluate expression: > lag(input[1, string], 1, null) windowspecdefinition(input[0, int], input[2, > int] ASC, ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) > at > org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:220) > at > org.apache.spark.sql.catalyst.expressions.WindowExpression.doGenCode(windowExpressions.scala:288) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:102) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:102) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:452) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:435) > at > org.apache.spark.sql.catalyst.expressions.EqualTo.doGenCode(predicates.scala:429) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:102) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:102) > at > org.apache.spark.sql.catalyst.expressions.UnaryExpression.nullSafeCodeGen(Expression.scala:363) > at > org.apache.spark.sql.catalyst.expressions.UnaryExpression.defineCodeGen(Expression.scala:347) > at > org.apache.spark.sql.catalyst.expressions.Not.doGenCode(predicates.scala:103) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:102) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:102) > at > org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:127) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$12.apply(basicPhysicalOperators.scala:169) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$12.apply(basicPhysicalOperators.scala:153) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at scala.collection.immutable.List.foreach(List.scala:381) > at > scala.collection.TraversableLike$class.map(TraversableLike.scala:234) > at scala.collection.immutable.List.map(List.scala:285) > at > org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:153) > at > org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:153) > at > org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:218) > at > org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:244) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:78) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:136) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:133) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:78) > at > org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:218) > at > org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:113) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:78) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:136) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:133) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:78) > at > org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:79) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:304) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:343) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:115) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:115) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:136) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:133) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:114) > at > org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:240) > at > org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:323) > at > org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) > at > org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2122) > at > org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57) > at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2436) > at > org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2121) > at > org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2128) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:1862) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:1861) > at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2466) > at org.apache.spark.sql.Dataset.head(Dataset.scala:1861) > at org.apache.spark.sql.Dataset.take(Dataset.scala:2078) > at org.apache.spark.sql.Dataset.showString(Dataset.scala:240) > at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) > at > sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) > at > sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) > at java.lang.reflect.Method.invoke(Method.java:497) > at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237) > at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) > at py4j.Gateway.invoke(Gateway.java:280) > at > py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128) > at py4j.commands.CallCommand.execute(CallCommand.java:79) > at py4j.GatewayConnection.run(GatewayConnection.java:211) > at java.lang.Thread.run(Thread.java:745) > In [23]: > {code} -- This message was sent by Atlassian JIRA (v6.3.4#6332) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org