Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21477#discussion_r195653451
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -1885,6 +1885,263 @@ def test_query_manager_await_termination(self):
                 q.stop()
                 shutil.rmtree(tmpPath)
     
    +    class ForeachWriterTester:
    +
    +        def __init__(self, spark):
    +            self.spark = spark
    +
    +        def write_open_event(self, partitionId, epochId):
    +            self._write_event(
    +                self.open_events_dir,
    +                {'partition': partitionId, 'epoch': epochId})
    +
    +        def write_process_event(self, row):
    +            self._write_event(self.process_events_dir, {'value': 'text'})
    +
    +        def write_close_event(self, error):
    +            self._write_event(self.close_events_dir, {'error': str(error)})
    +
    +        def write_input_file(self):
    +            self._write_event(self.input_dir, "text")
    +
    +        def open_events(self):
    +            return self._read_events(self.open_events_dir, 'partition INT, 
epoch INT')
    +
    +        def process_events(self):
    +            return self._read_events(self.process_events_dir, 'value 
STRING')
    +
    +        def close_events(self):
    +            return self._read_events(self.close_events_dir, 'error STRING')
    +
    +        def run_streaming_query_on_writer(self, writer, num_files):
    +            self._reset()
    +            try:
    +                sdf = 
self.spark.readStream.format('text').load(self.input_dir)
    +                sq = sdf.writeStream.foreach(writer).start()
    +                for i in range(num_files):
    +                    self.write_input_file()
    +                    sq.processAllAvailable()
    +            finally:
    +                self.stop_all()
    +
    +        def assert_invalid_writer(self, writer, msg=None):
    +            self._reset()
    +            try:
    +                sdf = 
self.spark.readStream.format('text').load(self.input_dir)
    +                sq = sdf.writeStream.foreach(writer).start()
    +                self.write_input_file()
    +                sq.processAllAvailable()
    +                self.fail("invalid writer %s did not fail the query" % 
str(writer))  # not expected
    +            except Exception as e:
    +                if msg:
    +                    assert(msg in str(e), "%s not in %s" % (msg, str(e)))
    +
    +            finally:
    +                self.stop_all()
    +
    +        def stop_all(self):
    +            for q in self.spark._wrapped.streams.active:
    +                q.stop()
    +
    +        def _reset(self):
    +            self.input_dir = tempfile.mkdtemp()
    +            self.open_events_dir = tempfile.mkdtemp()
    +            self.process_events_dir = tempfile.mkdtemp()
    +            self.close_events_dir = tempfile.mkdtemp()
    +
    +        def _read_events(self, dir, json):
    +            rows = self.spark.read.schema(json).json(dir).collect()
    +            dicts = [row.asDict() for row in rows]
    +            return dicts
    +
    +        def _write_event(self, dir, event):
    +            import uuid
    +            with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
    +                f.write("%s\n" % str(event))
    +
    +        def __getstate__(self):
    +            return (self.open_events_dir, self.process_events_dir, 
self.close_events_dir)
    +
    +        def __setstate__(self, state):
    +            self.open_events_dir, self.process_events_dir, 
self.close_events_dir = state
    +
    +    def test_streaming_foreach_with_simple_function(self):
    +        tester = self.ForeachWriterTester(self.spark)
    +
    +        def foreach_func(row):
    +            tester.write_process_event(row)
    +
    +        tester.run_streaming_query_on_writer(foreach_func, 2)
    +        self.assertEqual(len(tester.process_events()), 2)
    +
    +    def test_streaming_foreach_with_basic_open_process_close(self):
    +        tester = self.ForeachWriterTester(self.spark)
    +
    +        class ForeachWriter:
    +            def open(self, partitionId, epochId):
    +                tester.write_open_event(partitionId, epochId)
    +                return True
    +
    +            def process(self, row):
    +                tester.write_process_event(row)
    +
    +            def close(self, error):
    +                tester.write_close_event(error)
    +
    +        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
    +
    +        open_events = tester.open_events()
    +        self.assertEqual(len(open_events), 2)
    +        self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
    +
    +        self.assertEqual(len(tester.process_events()), 2)
    +
    +        close_events = tester.close_events()
    +        self.assertEqual(len(close_events), 2)
    +        self.assertSetEqual(set([e['error'] for e in close_events]), 
{'None'})
    +
    +    def test_streaming_foreach_with_open_returning_false(self):
    +        tester = self.ForeachWriterTester(self.spark)
    +
    +        class ForeachWriter:
    +            def open(self, partition_id, epoch_id):
    +                tester.write_open_event(partition_id, epoch_id)
    +                return False
    +
    +            def process(self, row):
    +                tester.write_process_event(row)
    +
    +            def close(self, error):
    +                tester.write_close_event(error)
    +
    +        tester.run_streaming_query_on_writer(ForeachWriter(), 2)
    +
    +        self.assertEqual(len(tester.open_events()), 2)
    +
    +        self.assertEqual(len(tester.process_events()), 0)   # no row was 
processed
    --- End diff --
    
    Is that a PEP 8 rule or Spark style?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to