This is an automated email from the ASF dual-hosted git repository.

ningk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b7cb10  [BEAM-10708] Added beam_sql magics
     new 5a627eb  Merge pull request #15368 from KevinGG/sql
3b7cb10 is described below

commit 3b7cb10789ab13e2b1ec458b5e0f5cb6e4f867b5
Author: KevinGG <kawai...@gmail.com>
AuthorDate: Fri Aug 20 14:16:23 2021 -0700

    [BEAM-10708] Added beam_sql magics
    
    Added a beam_sql cell magic that applies a SqlTransform based on a
    given Beam SQL query in the notebook.
---
 .../runners/interactive/sql/__init__.py            |  16 ++
 .../runners/interactive/sql/beam_sql_magics.py     | 293 +++++++++++++++++++++
 .../interactive/sql/beam_sql_magics_test.py        | 121 +++++++++
 .../apache_beam/runners/interactive/sql/utils.py   | 125 +++++++++
 .../runners/interactive/sql/utils_test.py          |  90 +++++++
 .../apache_beam/runners/interactive/utils.py       |  10 +
 sdks/python/scripts/generate_pydoc.sh              |   3 +
 sdks/python/setup.py                               |   2 +-
 8 files changed, 659 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/runners/interactive/sql/__init__.py 
b/sdks/python/apache_beam/runners/interactive/sql/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py 
b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
new file mode 100644
index 0000000..cee3d34
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
@@ -0,0 +1,293 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module of beam_sql cell magic that executes a Beam SQL.
+
+Only works within an IPython kernel.
+"""
+
+import importlib
+import keyword
+import logging
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.pvalue import PValue
+from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+from apache_beam.runners.interactive import pipeline_instrument as inst
+from apache_beam.runners.interactive.cache_manager import FileBasedCacheManager
+from apache_beam.runners.interactive.caching.streaming_cache import 
StreamingCache
+from apache_beam.runners.interactive.sql.utils import find_pcolls
+from apache_beam.runners.interactive.sql.utils import is_namedtuple
+from apache_beam.runners.interactive.sql.utils import pcolls_by_name
+from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
+from apache_beam.runners.interactive.sql.utils import 
replace_single_pcoll_token
+from apache_beam.runners.interactive.utils import obfuscate
+from apache_beam.runners.interactive.utils import progress_indicated
+from apache_beam.testing import test_stream
+from apache_beam.testing.test_stream_service import TestStreamServiceController
+from apache_beam.transforms.sql import SqlTransform
+from IPython.core.magic import Magics
+from IPython.core.magic import cell_magic
+from IPython.core.magic import magics_class
+
+_LOGGER = logging.getLogger(__name__)
+
+_EXAMPLE_USAGE = """Usage:
+    %%%%beam_sql [output_name]
+    Calcite SQL statement
+    Syntax: 
https://beam.apache.org/documentation/dsls/sql/calcite/query-syntax/
+    Please make sure that there is no conflicts between your variable names and
+    the SQL keywords, such as "SELECT", "FROM", "WHERE" and etc.
+
+    output_name is optional. If not supplied, a variable name is automatically
+    assigned to the output of the magic.
+
+    The output of the magic is usually a PCollection or similar PValue,
+    depending on the SQL statement executed.
+"""
+
+
+def on_error(error_msg, *args):
+  """Logs the error and the usage example."""
+  _LOGGER.error(error_msg, *args)
+  _LOGGER.info(_EXAMPLE_USAGE)
+
+
+@magics_class
+class BeamSqlMagics(Magics):
+  @cell_magic
+  def beam_sql(self, line: str, cell: str) -> Union[None, PValue]:
+    """The beam_sql cell magic that executes a Beam SQL.
+
+    Args:
+      line: (optional) the string on the same line after the beam_sql magic.
+          Used as the output variable name in the __main__ module.
+      cell: everything else in the same notebook cell as a string. Used as a
+          Beam SQL query.
+
+    Returns None if running into an error, otherwise a PValue as if a
+    SqlTransform is applied.
+    """
+    if line and not line.strip().isidentifier() or keyword.iskeyword(
+        line.strip()):
+      on_error(
+          'The output_name "%s" is not a valid identifier. Please supply a '
+          'valid identifier that is not a Python keyword.',
+          line)
+      return
+    if not cell or cell.isspace():
+      on_error('Please supply the sql to be executed.')
+      return
+    found = find_pcolls(cell, pcolls_by_name())
+    for _, pcoll in found.items():
+      if not is_namedtuple(pcoll.element_type):
+        on_error(
+            'PCollection %s of type %s is not a NamedTuple. See '
+            'https://beam.apache.org/documentation/programming-guide/#schemas '
+            'for more details.',
+            pcoll,
+            pcoll.element_type)
+        return
+      register_coder_for_schema(pcoll.element_type)
+
+    # TODO(BEAM-10708): implicitly execute the pipeline and write output into
+    # cache.
+    return apply_sql(cell, line, found)
+
+
+@progress_indicated
+def apply_sql(
+    query: str, output_name: Optional[str],
+    found: Dict[str, beam.PCollection]) -> PValue:
+  """Applies a SqlTransform with the given sql and queried PCollections.
+
+  Args:
+    query: The SQL query executed in the magic.
+    output_name: (optional) The output variable name in __main__ module.
+    found: The PCollections with variable names found to be used in the query.
+
+  Returns:
+    A PValue, mostly a PCollection, depending on the query.
+  """
+  output_name = _generate_output_name(output_name, query, found)
+  query, sql_source = _build_query_components(query, found)
+  try:
+    output = sql_source | SqlTransform(query)
+    # Declare a variable with the output_name and output value in the
+    # __main__ module so that the user can use the output smoothly.
+    setattr(importlib.import_module('__main__'), output_name, output)
+    ib.watch({output_name: output})
+    _LOGGER.info(
+        "The output PCollection variable is %s: %s", output_name, output)
+    return output
+  except (KeyboardInterrupt, SystemExit):
+    raise
+  except Exception as e:
+    on_error('Error when applying the Beam SQL: %s', e)
+
+
+def pcoll_from_file_cache(
+    query_pipeline: beam.Pipeline,
+    pcoll: beam.PCollection,
+    cache_manager: FileBasedCacheManager,
+    key: str) -> beam.PCollection:
+  """Reads PCollection cache from files.
+
+  Args:
+    query_pipeline: The beam.Pipeline object built by the magic to execute the
+        SQL query.
+    pcoll: The PCollection to read cache for.
+    cache_manager: The file based cache manager that holds the PCollection
+        cache.
+    key: The key of the PCollection cache.
+
+  Returns:
+    A PCollection read from the cache.
+  """
+  schema = pcoll.element_type
+
+  class Unreify(beam.DoFn):
+    def process(self, e):
+      if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'):
+        yield e.windowed_value
+
+  return (
+      query_pipeline
+      |
+      '{}{}'.format('QuerySource', key) >> cache.ReadCache(cache_manager, key)
+      | '{}{}'.format('Unreify', key) >> beam.ParDo(
+          Unreify()).with_output_types(schema))
+
+
+def pcolls_from_streaming_cache(
+    user_pipeline: beam.Pipeline,
+    query_pipeline: beam.Pipeline,
+    name_to_pcoll: Dict[str, beam.PCollection],
+    instrumentation: inst.PipelineInstrument,
+    cache_manager: StreamingCache) -> Dict[str, beam.PCollection]:
+  """Reads PCollection cache through the TestStream.
+
+  Args:
+    user_pipeline: The beam.Pipeline object defined by the user in the
+        notebook.
+    query_pipeline: The beam.Pipeline object built by the magic to execute the
+        SQL query.
+    name_to_pcoll: PCollections with variable names used in the SQL query.
+    instrumentation: A pipeline_instrument.PipelineInstrument that helps
+        calculate the cache key of a given PCollection.
+    cache_manager: The streaming cache manager that holds the PCollection 
cache.
+
+  Returns:
+    A Dict[str, beam.PCollection], where each PCollection is tagged with
+    their PCollection variable names, read from the cache.
+
+  When the user_pipeline has unbounded sources, we force all cache reads to go
+  through the TestStream even if they are bounded sources.
+  """
+  def exception_handler(e):
+    _LOGGER.error(str(e))
+    return True
+
+  test_stream_service = ie.current_env().get_test_stream_service_controller(
+      user_pipeline)
+  if not test_stream_service:
+    test_stream_service = TestStreamServiceController(
+        cache_manager, exception_handler=exception_handler)
+    test_stream_service.start()
+    ie.current_env().set_test_stream_service_controller(
+        user_pipeline, test_stream_service)
+
+  tag_to_name = {}
+  for name, pcoll in name_to_pcoll.items():
+    key = instrumentation.cache_key(pcoll)
+    tag_to_name[key] = name
+  output_pcolls = query_pipeline | test_stream.TestStream(
+      output_tags=set(tag_to_name.keys()),
+      coder=cache_manager._default_pcoder,
+      endpoint=test_stream_service.endpoint)
+  sql_source = {}
+  for tag, output in output_pcolls.items():
+    sql_source[tag_to_name[tag]] = output
+  return sql_source
+
+
+def _generate_output_name(
+    output_name: Optional[str], query: str,
+    found: Dict[str, beam.PCollection]) -> str:
+  """Generates a unique output name if None is provided.
+
+  Otherwise, returns the given output name directly.
+  The generated output name is sql_output_{uuid} where uuid is an obfuscated
+  value from the query and PCollections found to be used in the query.
+  """
+  if not output_name:
+    execution_id = obfuscate(query, found)[:12]
+    output_name = 'sql_output_' + execution_id
+  return output_name
+
+
+def _build_query_components(
+    query: str, found: Dict[str, beam.PCollection]
+) -> Tuple[str,
+           Union[Dict[str, beam.PCollection], beam.PCollection, 
beam.Pipeline]]:
+  """Builds necessary components needed to apply the SqlTransform.
+
+  Args:
+    query: The SQL query to be executed by the magic.
+    found: The PCollections with variable names found to be used by the query.
+
+  Returns:
+    The processed query to be executed by the magic and a source to apply the
+    SqlTransform to: a dictionary of tagged PCollections, or a single
+    PCollection, or the pipeline to execute the query.
+  """
+  if found:
+    user_pipeline = next(iter(found.values())).pipeline
+    cache_manager = ie.current_env().get_cache_manager(user_pipeline)
+    instrumentation = inst.build_pipeline_instrument(user_pipeline)
+    sql_pipeline = beam.Pipeline(options=user_pipeline._options)
+    ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
+    sql_source = {}
+    if instrumentation.has_unbounded_sources:
+      sql_source = pcolls_from_streaming_cache(
+          user_pipeline, sql_pipeline, found, instrumentation, cache_manager)
+    else:
+      for pcoll_name, pcoll in found.items():
+        cache_key = instrumentation.cache_key(pcoll)
+        sql_source[pcoll_name] = pcoll_from_file_cache(
+            sql_pipeline, pcoll, cache_manager, cache_key)
+    if len(sql_source) == 1:
+      query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
+      sql_source = next(iter(sql_source.values()))
+  else:
+    sql_source = beam.Pipeline()
+  return query, sql_source
+
+
+def load_ipython_extension(ipython):
+  """Marks this module as an IPython extension.
+
+  To load this magic in an IPython environment, execute:
+  %load_ext apache_beam.runners.interactive.sql.beam_sql_magics.
+  """
+  ipython.register_magics(BeamSqlMagics)
diff --git 
a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py 
b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
new file mode 100644
index 0000000..d35bd46
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for beam_sql_magics module."""
+
+# pytype: skip-file
+
+import unittest
+from unittest.mock import patch
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+
+try:
+  from apache_beam.runners.interactive.sql.beam_sql_magics import 
_build_query_components
+  from apache_beam.runners.interactive.sql.beam_sql_magics import 
_generate_output_name
+except (ImportError, NameError):
+  pass  # The test is to be skipped because [interactive] dep not installed.
+
+
+@unittest.skipIf(
+    not ie.current_env().is_interactive_ready,
+    '[interactive] dependency is not installed.')
+@pytest.mark.skipif(
+    not ie.current_env().is_interactive_ready,
+    reason='[interactive] dependency is not installed.')
+class BeamSqlMagicsTest(unittest.TestCase):
+  def test_generate_output_name_when_not_provided(self):
+    output_name = None
+    self.assertTrue(
+        _generate_output_name(output_name, '', {}).startswith('sql_output_'))
+
+  def test_use_given_output_name_when_provided(self):
+    output_name = 'output'
+    self.assertEqual(_generate_output_name(output_name, '', {}), output_name)
+
+  def test_build_query_components_when_no_pcoll_queried(self):
+    query = """SELECT CAST(1 AS INT) AS `id`,
+                      CAST('foo' AS VARCHAR) AS `str`,
+                      CAST(3.14  AS DOUBLE) AS `flt`"""
+    processed_query, sql_source = _build_query_components(query, {})
+    self.assertEqual(processed_query, query)
+    self.assertIsInstance(sql_source, beam.Pipeline)
+
+  def test_build_query_components_when_single_pcoll_queried(self):
+    p = beam.Pipeline()
+    target = p | beam.Create([1, 2, 3])
+    ib.watch(locals())
+    query = 'SELECT * FROM target where a=1'
+    found = {'target': target}
+
+    with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+               'pcoll_from_file_cache',
+               lambda a,
+               b,
+               c,
+               d: target):
+      processed_query, sql_source = _build_query_components(query, found)
+
+      self.assertEqual(processed_query, 'SELECT * FROM PCOLLECTION where a=1')
+      self.assertIsInstance(sql_source, beam.PCollection)
+
+  def test_build_query_components_when_multiple_pcolls_queried(self):
+    p = beam.Pipeline()
+    pcoll_1 = p | 'Create 1' >> beam.Create([1, 2, 3])
+    pcoll_2 = p | 'Create 2' >> beam.Create([4, 5, 6])
+    ib.watch(locals())
+    query = 'SELECT * FROM pcoll_1 JOIN pcoll_2 USING (a)'
+    found = {'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2}
+
+    with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+               'pcoll_from_file_cache',
+               lambda a,
+               b,
+               c,
+               d: pcoll_1):
+      processed_query, sql_source = _build_query_components(query, found)
+
+      self.assertEqual(processed_query, query)
+      self.assertIsInstance(sql_source, dict)
+      self.assertIn('pcoll_1', sql_source)
+      self.assertIn('pcoll_2', sql_source)
+
+  def test_build_query_components_when_unbounded_pcolls_queried(self):
+    p = beam.Pipeline()
+    pcoll = p | beam.io.ReadFromPubSub(
+        subscription='projects/fake-project/subscriptions/fake_sub')
+    ib.watch(locals())
+    query = 'SELECT * FROM pcoll'
+    found = {'pcoll': pcoll}
+
+    with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+               'pcolls_from_streaming_cache',
+               lambda a,
+               b,
+               c,
+               d,
+               e: found):
+      _, sql_source = _build_query_components(query, found)
+      self.assertIs(sql_source, pcoll)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils.py 
b/sdks/python/apache_beam/runners/interactive/sql/utils.py
new file mode 100644
index 0000000..355b6e6
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils.py
@@ -0,0 +1,125 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module of utilities for SQL magics.
+
+For internal use only; no backward-compatibility guarantees.
+"""
+
+# pytype: skip-file
+
+import logging
+from typing import Dict
+from typing import NamedTuple
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def is_namedtuple(cls: type) -> bool:
+  """Determines if a class is built from typing.NamedTuple."""
+  return (
+      isinstance(cls, type) and issubclass(cls, tuple) and
+      hasattr(cls, '_fields') and hasattr(cls, '_field_types'))
+
+
+def register_coder_for_schema(schema: NamedTuple) -> None:
+  """Registers a RowCoder for the given schema if hasn't.
+
+  Notifies the user of what code has been implicitly executed.
+  """
+  assert is_namedtuple(schema), (
+      'Schema %s is not a typing.NamedTuple.' % schema)
+  coder = beam.coders.registry.get_coder(schema)
+  if not isinstance(coder, beam.coders.RowCoder):
+    _LOGGER.warning(
+        'Schema %s has not been registered to use a RowCoder. '
+        'Automatically registering it by running: '
+        'beam.coders.registry.register_coder(%s, '
+        'beam.coders.RowCoder)',
+        schema.__name__,
+        schema.__name__)
+    beam.coders.registry.register_coder(schema, beam.coders.RowCoder)
+
+
+def pcolls_by_name() -> Dict[str, beam.PCollection]:
+  """Finds all PCollections by their variable names defined in the notebook."""
+  inspectables = ie.current_env().inspector.inspectables
+  pcolls = {}
+  for _, inspectable in inspectables.items():
+    metadata = inspectable['metadata']
+    if metadata['type'] == 'pcollection':
+      pcolls[metadata['name']] = inspectable['value']
+  return pcolls
+
+
+def find_pcolls(
+    sql: str, pcolls: Dict[str,
+                           beam.PCollection]) -> Dict[str, beam.PCollection]:
+  """Finds all PCollections used in the given sql query.
+
+  It does a simple word by word match and calls ib.collect for each PCollection
+  found.
+  """
+  found = {}
+  for word in sql.split():
+    if word in pcolls:
+      found[word] = pcolls[word]
+  if found:
+    _LOGGER.info('Found PCollections used in the magic: %s.', found)
+    _LOGGER.info('Collecting data...')
+    for name, pcoll in found.items():
+      try:
+        _ = ib.collect(pcoll)
+      except (KeyboardInterrupt, SystemExit):
+        raise
+      except:
+        _LOGGER.error(
+            'Cannot collect data for PCollection %s. Please make sure the '
+            'PCollections queried in the sql "%s" are all from a single '
+            'pipeline using an InteractiveRunner. Make sure there is no '
+            'ambiguity, for example, same named PCollections from multiple '
+            'pipelines or notebook re-executions.',
+            name,
+            sql)
+        raise
+    _LOGGER.info('Done collecting data.')
+  return found
+
+
+def replace_single_pcoll_token(sql: str, pcoll_name: str) -> str:
+  """Replaces the pcoll_name used in the sql with 'PCOLLECTION'.
+
+  For sql query using only a single PCollection, the PCollection needs to be
+  referred to as 'PCOLLECTION' instead of its variable/tag name.
+  """
+  words = sql.split()
+  token_locations = []
+  i = 0
+  for word in words:
+    if word.lower() == 'from':
+      token_locations.append(i + 1)
+      i += 2
+      continue
+    i += 1
+  for token_location in token_locations:
+    if token_location < len(words) and words[token_location] == pcoll_name:
+      words[token_location] = 'PCOLLECTION'
+  return ' '.join(words)
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py 
b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
new file mode 100644
index 0000000..ed52cad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for utils module."""
+
+# pytype: skip-file
+
+import unittest
+from typing import NamedTuple
+from unittest.mock import patch
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive.sql.utils import find_pcolls
+from apache_beam.runners.interactive.sql.utils import is_namedtuple
+from apache_beam.runners.interactive.sql.utils import pcolls_by_name
+from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
+from apache_beam.runners.interactive.sql.utils import 
replace_single_pcoll_token
+
+
+class ANamedTuple(NamedTuple):
+  a: int
+  b: str
+
+
+class UtilsTest(unittest.TestCase):
+  def test_is_namedtuple(self):
+    class AType:
+      pass
+
+    a_type = AType
+    a_tuple = type((1, 2, 3))
+
+    a_namedtuple = ANamedTuple
+
+    self.assertTrue(is_namedtuple(a_namedtuple))
+    self.assertFalse(is_namedtuple(a_type))
+    self.assertFalse(is_namedtuple(a_tuple))
+
+  def test_register_coder_for_schema(self):
+    self.assertNotIsInstance(
+        beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
+    register_coder_for_schema(ANamedTuple)
+    self.assertIsInstance(
+        beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
+
+  def test_pcolls_by_name(self):
+    p = beam.Pipeline()
+    pcoll = p | beam.Create([1])
+    ib.watch({'p': p, 'pcoll': pcoll})
+
+    name_to_pcoll = pcolls_by_name()
+    self.assertIn('pcoll', name_to_pcoll)
+
+  def test_find_pcolls(self):
+    with patch('apache_beam.runners.interactive.interactive_beam.collect',
+               lambda _: None):
+      found = find_pcolls(
+          """SELECT * FROM pcoll_1 JOIN pcoll_2
+          USING (common_column)""", {
+              'pcoll_1': None, 'pcoll_2': None
+          })
+      self.assertIn('pcoll_1', found)
+      self.assertIn('pcoll_2', found)
+
+  def test_replace_single_pcoll_token(self):
+    sql = 'SELECT * FROM abc WHERE a=1 AND b=2'
+    replaced_sql = replace_single_pcoll_token(sql, 'wow')
+    self.assertEqual(replaced_sql, sql)
+    replaced_sql = replace_single_pcoll_token(sql, 'abc')
+    self.assertEqual(
+        replaced_sql, 'SELECT * FROM PCOLLECTION WHERE a=1 AND b=2')
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py 
b/sdks/python/apache_beam/runners/interactive/utils.py
index 3e85145..cb0b7db 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -34,6 +34,15 @@ from apache_beam.typehints.schemas import 
named_fields_from_element_type
 
 _LOGGER = logging.getLogger(__name__)
 
+# Add line breaks to the IPythonLogHandler's HTML output.
+_INTERACTIVE_LOG_STYLE = """
+  <style>
+    div.alert {
+      white-space: pre-line;
+    }
+  </style>
+"""
+
 
 def to_element_list(
     reader,  # type: Generator[Union[TestStreamPayload.Event, 
WindowedValueHolder]]
@@ -169,6 +178,7 @@ class IPythonLogHandler(logging.Handler):
       from html import escape
       from IPython.core.display import HTML
       from IPython.core.display import display
+      display(HTML(_INTERACTIVE_LOG_STYLE))
       display(
           HTML(
               self.log_template.format(
diff --git a/sdks/python/scripts/generate_pydoc.sh 
b/sdks/python/scripts/generate_pydoc.sh
index 6b4b344..fb0c415 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -218,6 +218,9 @@ ignore_identifiers = [
   'google.cloud.datastore.batch.Batch',
   'is_in_ipython',
   'doctest.TestResults',
+
+  # IPython Magics py:class reference target not found
+  'IPython.core.magic.Magics',
 ]
 ignore_references = [
   'BeamIOError',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 170fbff..83826e1 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -201,7 +201,7 @@ GCP_REQUIREMENTS = [
 
 INTERACTIVE_BEAM = [
     'facets-overview>=1.0.0,<2',
-    'ipython>=5.8.0,<8',
+    'ipython>=7,<8',
     'ipykernel>=5.2.0,<6',
     # Skip version 6.1.13 due to
     # https://github.com/jupyter/jupyter_client/issues/637

Reply via email to