dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391840056


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total 
>= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth 
input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the 
two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, 
total=4)],

Review Comment:
   I thought so as well, but apparently the `range` function accepts the second 
argument for its upper bound (exclusive) :) 
   
   ```
   > SELECT id FROM range(1, 21)
   
   id
   1
   2
   3
   4
   5
   6
   7
   8
   9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to