This is an automated email from the ASF dual-hosted git repository. ningk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 8c0601c [BEAM-10708] Enable submit beam_sql built jobs to Dataflow new bed6bee Merge pull request #15647 from KevinGG/beam_sql_on_df 8c0601c is described below commit 8c0601cca39f8350be51084bb98ea53f90c5466a Author: KevinGG <kawai...@gmail.com> AuthorDate: Fri Oct 1 11:51:14 2021 -0700 [BEAM-10708] Enable submit beam_sql built jobs to Dataflow 1. Added an additional beam_sql option to specify a runner. 2. Added a sql_chain module to track chained beam_sql magics and produce pipelines for execution when a non-direct runner is specified. 3. Added logic to load schemas defined in main session without relying on save_main_session that might fail. 4. Added a OptionsForm class and DataflowOptionsForm subclass to guide users through pipeline options configuration in notebooks. 5. Removed is_namedtuple utility and honor the Beam common utility match_is_named_tuple. Note dill does not preserve __annotations__ across multiple main sessions. Added a workaround until cloudpickle replaces dill in Beam. --- .../runners/interactive/interactive_environment.py | 17 + .../interactive/interactive_environment_test.py | 32 ++ .../runners/interactive/sql/beam_sql_magics.py | 187 ++++++++--- .../interactive/sql/beam_sql_magics_test.py | 28 +- .../runners/interactive/sql/sql_chain.py | 226 +++++++++++++ .../runners/interactive/sql/sql_chain_test.py | 109 +++++++ .../apache_beam/runners/interactive/sql/utils.py | 354 +++++++++++++++++++-- .../runners/interactive/sql/utils_test.py | 50 ++- .../apache_beam/runners/interactive/utils.py | 22 ++ .../apache_beam/runners/interactive/utils_test.py | 8 + sdks/python/setup.py | 1 + 11 files changed, 941 insertions(+), 93 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py index fe10ab1..9f3d66e 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py @@ -36,6 +36,7 @@ from apache_beam.runners import runner from apache_beam.runners.interactive import cache_manager as cache from apache_beam.runners.interactive.messaging.interactive_environment_inspector import InteractiveEnvironmentInspector from apache_beam.runners.interactive.recording_manager import RecordingManager +from apache_beam.runners.interactive.sql.sql_chain import SqlChain from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker from apache_beam.runners.interactive.utils import register_ipython_log_handler from apache_beam.utils.interactive_utils import is_in_ipython @@ -206,6 +207,8 @@ class InteractiveEnvironment(object): self._inspector_with_synthetic = InteractiveEnvironmentInspector( ignore_synthetic=False) + self.sql_chain = {} + @property def options(self): """A reference to the global interactive options. @@ -651,3 +654,17 @@ class InteractiveEnvironment(object): Javascript(_HTML_IMPORT_TEMPLATE.format(hrefs=html_hrefs))) except ImportError: pass # NOOP if dependencies are not available. + + def get_sql_chain(self, pipeline, set_user_pipeline=False): + if pipeline not in self.sql_chain: + self.sql_chain[pipeline] = SqlChain() + chain = self.sql_chain[pipeline] + if set_user_pipeline: + if chain.user_pipeline and chain.user_pipeline is not pipeline: + raise ValueError( + 'The beam_sql magic tries to query PCollections from multiple ' + 'pipelines: %s and %s', + chain.user_pipeline, + pipeline) + chain.user_pipeline = pipeline + return chain diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py index f08db01..4e0293d 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py @@ -27,6 +27,7 @@ from apache_beam.runners import runner from apache_beam.runners.interactive import cache_manager as cache from apache_beam.runners.interactive import interactive_environment as ie from apache_beam.runners.interactive.recording_manager import RecordingManager +from apache_beam.runners.interactive.sql.sql_chain import SqlNode # The module name is also a variable in module. _module_name = 'apache_beam.runners.interactive.interactive_environment_test' @@ -303,6 +304,37 @@ class InteractiveEnvironmentTest(unittest.TestCase): expected_description = {p1: rm1.describe(), p2: rm2.describe()} self.assertDictEqual(description, expected_description) + def test_get_empty_sql_chain(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + chain = env.get_sql_chain(p) + self.assertIsNotNone(chain) + self.assertEqual(chain.nodes, {}) + + def test_get_sql_chain_with_nodes(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + chain_with_node = env.get_sql_chain(p).append( + SqlNode(output_name='name', source=p, query="query")) + chain_got = env.get_sql_chain(p) + self.assertIs(chain_with_node, chain_got) + + def test_get_sql_chain_setting_user_pipeline(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + chain = env.get_sql_chain(p, set_user_pipeline=True) + self.assertIs(chain.user_pipeline, p) + + def test_get_sql_chain_None_when_setting_multiple_user_pipelines(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + chain = env.get_sql_chain(p, set_user_pipeline=True) + p2 = beam.Pipeline() + # Set the chain for a different pipeline. + env.sql_chain[p2] = chain + with self.assertRaises(ValueError): + env.get_sql_chain(p2, set_user_pipeline=True) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py index bd40f13..d27fc61 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py +++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py @@ -32,24 +32,27 @@ from typing import Union import apache_beam as beam from apache_beam.pvalue import PValue -from apache_beam.runners.interactive import interactive_beam as ib from apache_beam.runners.interactive import interactive_environment as ie from apache_beam.runners.interactive.background_caching_job import has_source_to_cache from apache_beam.runners.interactive.caching.cacheable import CacheKey from apache_beam.runners.interactive.caching.reify import reify_to_cache from apache_beam.runners.interactive.caching.reify import unreify_from_cache from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll +from apache_beam.runners.interactive.sql.sql_chain import SqlChain +from apache_beam.runners.interactive.sql.sql_chain import SqlNode +from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm from apache_beam.runners.interactive.sql.utils import find_pcolls -from apache_beam.runners.interactive.sql.utils import is_namedtuple from apache_beam.runners.interactive.sql.utils import pformat_namedtuple from apache_beam.runners.interactive.sql.utils import register_coder_for_schema from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token +from apache_beam.runners.interactive.utils import create_var_in_main from apache_beam.runners.interactive.utils import obfuscate from apache_beam.runners.interactive.utils import pcoll_by_name from apache_beam.runners.interactive.utils import progress_indicated from apache_beam.testing import test_stream from apache_beam.testing.test_stream_service import TestStreamServiceController from apache_beam.transforms.sql import SqlTransform +from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from IPython.core.magic import Magics from IPython.core.magic import line_cell_magic from IPython.core.magic import magics_class @@ -58,11 +61,11 @@ _LOGGER = logging.getLogger(__name__) _EXAMPLE_USAGE = """beam_sql magic to execute Beam SQL in notebooks --------------------------------------------------------- -%%beam_sql [-o OUTPUT_NAME] query +%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query --------------------------------------------------------- Or --------------------------------------------------------- -%%%%beam_sql [-o OUTPUT_NAME] query-line#1 +%%%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query-line#1 query-line#2 ... query-line#N @@ -82,6 +85,8 @@ _NOT_SUPPORTED_MSG = """The query was valid and successfully applied. to build Beam pipelines in a non-interactive manner. """ +_SUPPORTED_RUNNERS = ['DirectRunner', 'DataflowRunner'] + class BeamSqlParser: """A parser to parse beam_sql inputs.""" @@ -100,6 +105,14 @@ class BeamSqlParser: action='store_true', help='Display more details about the magic execution.') self._parser.add_argument( + '-r', + '--runner', + dest='runner', + help=( + 'The runner to run the query. Supported runners are %s. If not ' + 'provided, DirectRunner is used and results can be inspected ' + 'locally.' % _SUPPORTED_RUNNERS)) + self._parser.add_argument( 'query', type=str, nargs='*', @@ -157,8 +170,9 @@ class BeamSqlMagics(Magics): cell: everything else in the same notebook cell as a string. If None, beam_sql is used as line magic. Otherwise, cell magic. - Returns None if running into an error, otherwise a PValue as if a - SqlTransform is applied. + Returns None if running into an error or waiting for user input (running on + a selected runner remotely), otherwise a PValue as if a SqlTransform is + applied. """ input_str = line if cell: @@ -170,6 +184,7 @@ class BeamSqlMagics(Magics): output_name = parsed.output_name verbose = parsed.verbose query = parsed.query + runner = parsed.runner if output_name and not output_name.isidentifier() or keyword.iskeyword( output_name): @@ -181,11 +196,18 @@ class BeamSqlMagics(Magics): if not query: on_error('Please supply the SQL query to be executed.') return + if runner and runner not in _SUPPORTED_RUNNERS: + on_error( + 'Runner "%s" is not supported. Supported runners are %s.', + runner, + _SUPPORTED_RUNNERS) query = ' '.join(query) found = find_pcolls(query, pcoll_by_name(), verbose=verbose) + schemas = set() + main_session = importlib.import_module('__main__') for _, pcoll in found.items(): - if not is_namedtuple(pcoll.element_type): + if not match_is_named_tuple(pcoll.element_type): on_error( 'PCollection %s of type %s is not a NamedTuple. See ' 'https://beam.apache.org/documentation/programming-guide/#schemas ' @@ -194,45 +216,93 @@ class BeamSqlMagics(Magics): pcoll.element_type) return register_coder_for_schema(pcoll.element_type, verbose=verbose) + # Only care about schemas defined by the user in the main module. + if hasattr(main_session, pcoll.element_type.__name__): + schemas.add(pcoll.element_type) + + if runner in ('DirectRunner', None): + collect_data_for_local_run(query, found) + output_name, output, chain = apply_sql(query, output_name, found) + chain.current.schemas = schemas + cache_output(output_name, output) + return output + + output_name, current_node, chain = apply_sql( + query, output_name, found, False) + current_node.schemas = schemas + # TODO(BEAM-10708): Move the options setup and result handling to a + # separate module when more runners are supported. + if runner == 'DataflowRunner': + _ = chain.to_pipeline() + _ = DataflowOptionsForm( + output_name, pcoll_by_name()[output_name], + verbose).display_for_input() + return None + else: + raise ValueError('Unsupported runner %s.', runner) - output_name, output = apply_sql(query, output_name, found) - cache_output(output_name, output) - return output + +@progress_indicated +def collect_data_for_local_run(query: str, found: Dict[str, beam.PCollection]): + from apache_beam.runners.interactive import interactive_beam as ib + for name, pcoll in found.items(): + try: + _ = ib.collect(pcoll) + except (KeyboardInterrupt, SystemExit): + raise + except: + _LOGGER.error( + 'Cannot collect data for PCollection %s. Please make sure the ' + 'PCollections queried in the sql "%s" are all from a single ' + 'pipeline using an InteractiveRunner. Make sure there is no ' + 'ambiguity, for example, same named PCollections from multiple ' + 'pipelines or notebook re-executions.', + name, + query) + raise @progress_indicated def apply_sql( - query: str, output_name: Optional[str], - found: Dict[str, beam.PCollection]) -> Tuple[str, PValue]: + query: str, + output_name: Optional[str], + found: Dict[str, beam.PCollection], + run: bool = True) -> Tuple[str, Union[PValue, SqlNode], SqlChain]: """Applies a SqlTransform with the given sql and queried PCollections. Args: query: The SQL query executed in the magic. output_name: (optional) The output variable name in __main__ module. found: The PCollections with variable names found to be used in the query. + run: Whether to prepare the SQL pipeline for a local run or not. Returns: - A Tuple[str, PValue]. First str value is the output variable name in - __main__ module (auto-generated if not provided). Second PValue is - most likely a PCollection, depending on the query. + A tuple of values. First str value is the output variable name in + __main__ module, auto-generated if not provided. Second value: if run, + it's a PValue; otherwise, a SqlNode tracks the SQL without applying it or + executing it. Third value: SqlChain is a chain of SqlNodes that have been + applied. """ output_name = _generate_output_name(output_name, query, found) - query, sql_source = _build_query_components(query, found) - try: - output = sql_source | SqlTransform(query) - # Declare a variable with the output_name and output value in the - # __main__ module so that the user can use the output smoothly. - setattr(importlib.import_module('__main__'), output_name, output) - ib.watch({output_name: output}) - _LOGGER.info( - "The output PCollection variable is %s with element_type %s", - output_name, - pformat_namedtuple(output.element_type)) - return output_name, output - except (KeyboardInterrupt, SystemExit): - raise - except Exception as e: - on_error('Error when applying the Beam SQL: %s', e) + query, sql_source, chain = _build_query_components( + query, found, output_name, run) + if run: + try: + output = sql_source | SqlTransform(query) + # Declare a variable with the output_name and output value in the + # __main__ module so that the user can use the output smoothly. + output_name, output = create_var_in_main(output_name, output) + _LOGGER.info( + "The output PCollection variable is %s with element_type %s", + output_name, + pformat_namedtuple(output.element_type)) + return output_name, output, chain + except (KeyboardInterrupt, SystemExit): + raise + except Exception as e: + on_error('Error when applying the Beam SQL: %s', e) + else: + return output_name, chain.current, chain def pcolls_from_streaming_cache( @@ -304,19 +374,26 @@ def _generate_output_name( def _build_query_components( - query: str, found: Dict[str, beam.PCollection] + query: str, + found: Dict[str, beam.PCollection], + output_name: str, + run: bool = True ) -> Tuple[str, - Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline]]: + Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline], + SqlChain]: """Builds necessary components needed to apply the SqlTransform. Args: query: The SQL query to be executed by the magic. found: The PCollections with variable names found to be used by the query. + output_name: The output variable name in __main__ module. + run: Whether to prepare components for a local run or not. Returns: - The processed query to be executed by the magic and a source to apply the + The processed query to be executed by the magic; a source to apply the SqlTransform to: a dictionary of tagged PCollections, or a single - PCollection, or the pipeline to execute the query. + PCollection, or the pipeline to execute the query; the chain of applied + beam_sql magics this one belongs to. """ if found: user_pipeline = ie.current_env().user_pipeline( @@ -324,26 +401,38 @@ def _build_query_components( sql_pipeline = beam.Pipeline(options=user_pipeline._options) ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline) sql_source = {} - if has_source_to_cache(user_pipeline): - sql_source = pcolls_from_streaming_cache( - user_pipeline, sql_pipeline, found) + if run: + if has_source_to_cache(user_pipeline): + sql_source = pcolls_from_streaming_cache( + user_pipeline, sql_pipeline, found) + else: + cache_manager = ie.current_env().get_cache_manager( + user_pipeline, create_if_absent=True) + for pcoll_name, pcoll in found.items(): + cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() + sql_source[pcoll_name] = unreify_from_cache( + pipeline=sql_pipeline, + cache_key=cache_key, + cache_manager=cache_manager, + element_type=pcoll.element_type) else: - cache_manager = ie.current_env().get_cache_manager( - user_pipeline, create_if_absent=True) - for pcoll_name, pcoll in found.items(): - cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() - sql_source[pcoll_name] = unreify_from_cache( - pipeline=sql_pipeline, - cache_key=cache_key, - cache_manager=cache_manager, - element_type=pcoll.element_type) + sql_source = found if len(sql_source) == 1: query = replace_single_pcoll_token(query, next(iter(sql_source.keys()))) sql_source = next(iter(sql_source.values())) - else: + + node = SqlNode( + output_name=output_name, source=set(found.keys()), query=query) + chain = ie.current_env().get_sql_chain( + user_pipeline, set_user_pipeline=True).append(node) + else: # does not query any existing PCollection sql_source = beam.Pipeline() ie.current_env().add_user_pipeline(sql_source) - return query, sql_source + + # The node should be the root node of the chain created below. + node = SqlNode(output_name=output_name, source=sql_source, query=query) + chain = ie.current_env().get_sql_chain(sql_source).append(node) + return query, sql_source, chain @progress_indicated diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py index 538abbb..3d843a0 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py +++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py @@ -59,9 +59,13 @@ class BeamSqlMagicsTest(unittest.TestCase): query = """SELECT CAST(1 AS INT) AS `id`, CAST('foo' AS VARCHAR) AS `str`, CAST(3.14 AS DOUBLE) AS `flt`""" - processed_query, sql_source = _build_query_components(query, {}) + processed_query, sql_source, chain = _build_query_components( + query, {}, 'output') self.assertEqual(processed_query, query) self.assertIsInstance(sql_source, beam.Pipeline) + self.assertIsInstance(chain.current.source, beam.Pipeline) + self.assertEqual('output', chain.current.output_name) + self.assertEqual(query, chain.current.query) def test_build_query_components_when_single_pcoll_queried(self): p = beam.Pipeline() @@ -76,10 +80,14 @@ class BeamSqlMagicsTest(unittest.TestCase): cache_key, cache_manager, element_type: target): - processed_query, sql_source = _build_query_components(query, found) - - self.assertEqual(processed_query, 'SELECT * FROM PCOLLECTION where a=1') + processed_query, sql_source, chain = _build_query_components( + query, found, 'output') + expected_query = 'SELECT * FROM PCOLLECTION where a=1' + self.assertEqual(expected_query, processed_query) self.assertIsInstance(sql_source, beam.PCollection) + self.assertIn('target', chain.current.source) + self.assertEqual(expected_query, chain.current.query) + self.assertEqual('output', chain.current.output_name) def test_build_query_components_when_multiple_pcolls_queried(self): p = beam.Pipeline() @@ -95,12 +103,17 @@ class BeamSqlMagicsTest(unittest.TestCase): cache_key, cache_manager, element_type: pcoll_1): - processed_query, sql_source = _build_query_components(query, found) + processed_query, sql_source, chain = _build_query_components( + query, found, 'output') self.assertEqual(processed_query, query) self.assertIsInstance(sql_source, dict) self.assertIn('pcoll_1', sql_source) self.assertIn('pcoll_2', sql_source) + self.assertIn('pcoll_1', chain.current.source) + self.assertIn('pcoll_2', chain.current.source) + self.assertEqual(query, chain.current.query) + self.assertEqual('output', chain.current.output_name) def test_build_query_components_when_unbounded_pcolls_queried(self): p = beam.Pipeline() @@ -115,8 +128,11 @@ class BeamSqlMagicsTest(unittest.TestCase): lambda a, b, c: found): - _, sql_source = _build_query_components(query, found) + _, sql_source, chain = _build_query_components(query, found, 'output') self.assertIs(sql_source, pcoll) + self.assertIn('pcoll', chain.current.source) + self.assertEqual('SELECT * FROM PCOLLECTION', chain.current.query) + self.assertEqual('output', chain.current.output_name) def test_cache_output(self): p_cache_output = beam.Pipeline() diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py new file mode 100644 index 0000000..a6f4866 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py @@ -0,0 +1,226 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Module for tracking a chain of beam_sql magics applied. + +For internal use only; no backwards-compatibility guarantees. +""" + +# pytype: skip-file + +import importlib +import logging +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import Optional +from typing import Set +from typing import Union + +import apache_beam as beam +from apache_beam.internal import pickler +from apache_beam.runners.interactive.sql.utils import register_coder_for_schema +from apache_beam.runners.interactive.utils import create_var_in_main +from apache_beam.runners.interactive.utils import pcoll_by_name +from apache_beam.runners.interactive.utils import progress_indicated +from apache_beam.transforms.sql import SqlTransform +from apache_beam.utils.interactive_utils import is_in_ipython + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class SqlNode: + """Each SqlNode represents a beam_sql magic applied. + + Attributes: + output_name: the watched unique name of the beam_sql output. Can be used as + an identifier. + source: the inputs consumed by this node. Can be a pipeline or a set of + PCollections represented by their variable names watched. When it's a + pipeline, the node computes from raw values in the query, so the output + can be consumed by any SqlNode in any SqlChain. + query: the SQL query applied by this node. + schemas: the schemas (NamedTuple classes) used by this node. + evaluated: the pipelines this node has been evaluated for. + next: the next SqlNode applied chronologically. + execution_count: the execution count if in an IPython env. + """ + output_name: str + source: Union[beam.Pipeline, Set[str]] + query: str + schemas: Set[Any] = None + evaluated: Set[beam.Pipeline] = None + next: Optional['SqlNode'] = None + execution_count: int = 0 + + def __post_init__(self): + if not self.schemas: + self.schemas = set() + if not self.evaluated: + self.evaluated = set() + if is_in_ipython(): + from IPython import get_ipython + self.execution_count = get_ipython().execution_count + + def __hash__(self): + return hash( + (self.output_name, self.source, self.query, self.execution_count)) + + def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline: + """Converts the chain into an executable pipeline.""" + if pipeline not in self.evaluated: + # The whole chain should form a single pipeline. + source = self.source + if isinstance(self.source, beam.Pipeline): + if pipeline: # use the known pipeline + source = pipeline + else: # use the source pipeline + pipeline = self.source + else: + name_to_pcoll = pcoll_by_name() + if len(self.source) == 1: + source = name_to_pcoll.get(next(iter(self.source))) + else: + source = {s: name_to_pcoll.get(s) for s in self.source} + if isinstance(source, beam.Pipeline): + output = source | 'beam_sql_{}_{}'.format( + self.output_name, self.execution_count) >> SqlTransform(self.query) + else: + output = source | 'schema_loaded_beam_sql_{}_{}'.format( + self.output_name, self.execution_count + ) >> SchemaLoadedSqlTransform( + self.output_name, self.query, self.schemas, self.execution_count) + _ = create_var_in_main(self.output_name, output) + self.evaluated.add(pipeline) + if self.next: + return self.next.to_pipeline(pipeline) + else: + return pipeline + + +class SchemaLoadedSqlTransform(beam.PTransform): + """PTransform that loads schema before executing SQL. + + When submitting a pipeline to remote runner for execution, schemas defined in + the main module are not available without save_main_session. However, + save_main_session might fail when there is anything unpicklable. This DoFn + makes sure only the schemas needed are pickled locally and restored later on + workers. + """ + def __init__(self, output_name, query, schemas, execution_count): + self.output_name = output_name + self.query = query + self.schemas = schemas + self.execution_count = execution_count + # TODO(BEAM-8123): clean up this attribute or the whole wrapper PTransform. + # Dill does not preserve everything. On the other hand, save_main_session + # is not stable. Until cloudpickle replaces dill in Beam, we work around + # it by explicitly pickling annotations and load schemas in remote main + # sessions. + self.schema_annotations = [s.__annotations__ for s in self.schemas] + + class _SqlTransformDoFn(beam.DoFn): + """The DoFn yields all its input without any transform but a setup to + configure the main session.""" + def __init__(self, schemas, annotations): + self.pickled_schemas = [pickler.dumps(s) for s in schemas] + self.pickled_annotations = [pickler.dumps(a) for a in annotations] + + def setup(self): + main_session = importlib.import_module('__main__') + for pickled_schema, pickled_annotation in zip( + self.pickled_schemas, self.pickled_annotations): + schema = pickler.loads(pickled_schema) + schema.__annotations__ = pickler.loads(pickled_annotation) + if not hasattr(main_session, schema.__name__) or not hasattr( + getattr(main_session, schema.__name__), '__annotations__'): + # Restore the schema in the main session on the [remote] worker. + setattr(main_session, schema.__name__, schema) + register_coder_for_schema(schema) + + def process(self, e): + yield e + + def expand(self, source): + """Applies the SQL transform. If a PCollection uses a schema defined in + the main session, use the additional DoFn to restore it on the worker.""" + if isinstance(source, dict): + schema_loaded = { + tag: pcoll | 'load_schemas_{}_tag_{}_{}'.format( + self.output_name, tag, self.execution_count) >> beam.ParDo( + self._SqlTransformDoFn(self.schemas, self.schema_annotations)) + if pcoll.element_type in self.schemas else pcoll + for tag, + pcoll in source.items() + } + elif isinstance(source, beam.pvalue.PCollection): + schema_loaded = source | 'load_schemas_{}_{}'.format( + self.output_name, self.execution_count) >> beam.ParDo( + self._SqlTransformDoFn(self.schemas, self.schema_annotations) + ) if source.element_type in self.schemas else source + else: + raise ValueError( + '{} should be either a single PCollection or a dict of named ' + 'PCollections.'.format(source)) + return schema_loaded | 'beam_sql_{}_{}'.format( + self.output_name, self.execution_count) >> SqlTransform(self.query) + + +@dataclass +class SqlChain: + """A chain of SqlNodes. + + Attributes: + nodes: all nodes by their output_names. + root: the first SqlNode applied chronologically. + current: the last node applied. + user_pipeline: the user defined pipeline this chain originates from. If + None, the whole chain just computes from raw values in queries. + Otherwise, at least some of the nodes in chain has queried against + PCollections. + """ + nodes: Dict[str, SqlNode] = None + root: Optional[SqlNode] = None + current: Optional[SqlNode] = None + user_pipeline: Optional[beam.Pipeline] = None + + def __post_init__(self): + if not self.nodes: + self.nodes = {} + + @progress_indicated + def to_pipeline(self) -> beam.Pipeline: + """Converts the chain into a beam pipeline.""" + pipeline_to_execute = self.root.to_pipeline(self.user_pipeline) + # The pipeline definitely contains external transform: SqlTransform. + pipeline_to_execute.contains_external_transforms = True + return pipeline_to_execute + + def append(self, node: SqlNode) -> 'SqlChain': + """Appends a node to the chain.""" + if self.current: + self.current.next = node + else: + self.root = node + self.current = node + self.nodes[node.output_name] = node + return self + + def get(self, output_name: str) -> Optional[SqlNode]: + """Gets a node from the chain based on the given output_name.""" + return self.nodes.get(output_name, None) diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py new file mode 100644 index 0000000..42d0804 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for sql_chain module.""" + +# pytype: skip-file + +import unittest +from unittest.mock import patch + +import pytest + +import apache_beam as beam +from apache_beam.runners.interactive import interactive_environment as ie +from apache_beam.runners.interactive.sql.sql_chain import SqlChain +from apache_beam.runners.interactive.sql.sql_chain import SqlNode +from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython + + +class SqlChainTest(unittest.TestCase): + def test_init(self): + chain = SqlChain() + self.assertEqual({}, chain.nodes) + self.assertIsNone(chain.root) + self.assertIsNone(chain.current) + self.assertIsNone(chain.user_pipeline) + + def test_append_first_node(self): + node = SqlNode(output_name='first', source='a', query='q1') + chain = SqlChain().append(node) + self.assertIs(node, chain.get(node.output_name)) + self.assertIs(node, chain.root) + self.assertIs(node, chain.current) + + def test_append_non_root_node(self): + chain = SqlChain().append( + SqlNode(output_name='root', source='root', query='q1')) + self.assertIsNone(chain.root.next) + node = SqlNode(output_name='next_node', source='root', query='q2') + chain.append(node) + self.assertIs(node, chain.root.next) + self.assertIs(node, chain.get(node.output_name)) + + @patch( + 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.' + '__rrshift__') + def test_to_pipeline_only_evaluate_once_per_pipeline_and_node( + self, mocked_sql_transform): + p = beam.Pipeline() + ie.current_env().watch({'p': p}) + pcoll_1 = p | 'create pcoll_1' >> beam.Create([1, 2, 3]) + pcoll_2 = p | 'create pcoll_2' >> beam.Create([4, 5, 6]) + ie.current_env().watch({'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2}) + node = SqlNode( + output_name='root', source={'pcoll_1', 'pcoll_2'}, query='q1') + chain = SqlChain(user_pipeline=p).append(node) + _ = chain.to_pipeline() + mocked_sql_transform.assert_called_once() + _ = chain.to_pipeline() + mocked_sql_transform.assert_called_once() + + @unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') + @pytest.mark.skipif( + not ie.current_env().is_interactive_ready, + reason='[interactive] dependency is not installed.') + @patch( + 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.' + '__rrshift__') + def test_nodes_with_same_outputs(self, mocked_sql_transform): + p = beam.Pipeline() + ie.current_env().watch({'p_nodes_with_same_output': p}) + pcoll = p | 'create pcoll' >> beam.Create([1, 2, 3]) + ie.current_env().watch({'pcoll': pcoll}) + chain = SqlChain(user_pipeline=p) + output_name = 'output' + + with patch('IPython.get_ipython', new_callable=mock_get_ipython) as cell: + with cell: + node_cell_1 = SqlNode(output_name, source='pcoll', query='q1') + chain.append(node_cell_1) + _ = chain.to_pipeline() + mocked_sql_transform.assert_called_with( + 'schema_loaded_beam_sql_output_1') + with cell: + node_cell_2 = SqlNode(output_name, source='pcoll', query='q2') + chain.append(node_cell_2) + _ = chain.to_pipeline() + mocked_sql_transform.assert_called_with( + 'schema_loaded_beam_sql_output_2') + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils.py b/sdks/python/apache_beam/runners/interactive/sql/utils.py index fb4e57d..b2e75c8 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/utils.py +++ b/sdks/python/apache_beam/runners/interactive/sql/utils.py @@ -23,29 +23,39 @@ For internal use only; no backward-compatibility guarantees. # pytype: skip-file import logging +import os +import tempfile +from dataclasses import dataclass +from typing import Any +from typing import Callable from typing import Dict from typing import NamedTuple +from typing import Optional +from typing import Type +from typing import Union import apache_beam as beam -from apache_beam.runners.interactive import interactive_beam as ib +from apache_beam.io import WriteToText +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import WorkerOptions +from apache_beam.runners.interactive.utils import create_var_in_main +from apache_beam.runners.interactive.utils import progress_indicated +from apache_beam.runners.runner import create_runner +from apache_beam.typehints.native_type_compatibility import match_is_named_tuple +from apache_beam.utils.interactive_utils import is_in_ipython _LOGGER = logging.getLogger(__name__) -def is_namedtuple(cls: type) -> bool: - """Determines if a class is built from typing.NamedTuple.""" - return ( - isinstance(cls, type) and issubclass(cls, tuple) and - hasattr(cls, '_fields') and hasattr(cls, '__annotations__')) - - def register_coder_for_schema( schema: NamedTuple, verbose: bool = False) -> None: """Registers a RowCoder for the given schema if hasn't. Notifies the user of what code has been implicitly executed. """ - assert is_namedtuple(schema), ( + assert match_is_named_tuple(schema), ( 'Schema %s is not a typing.NamedTuple.' % schema) coder = beam.coders.registry.get_coder(schema) if not isinstance(coder, beam.coders.RowCoder): @@ -77,21 +87,6 @@ def find_pcolls( if verbose: _LOGGER.info('Found PCollections used in the magic: %s.', found) _LOGGER.info('Collecting data...') - for name, pcoll in found.items(): - try: - _ = ib.collect(pcoll) - except (KeyboardInterrupt, SystemExit): - raise - except: - _LOGGER.error( - 'Cannot collect data for PCollection %s. Please make sure the ' - 'PCollections queried in the sql "%s" are all from a single ' - 'pipeline using an InteractiveRunner. Make sure there is no ' - 'ambiguity, for example, same named PCollections from multiple ' - 'pipelines or notebook re-executions.', - name, - sql) - raise return found @@ -123,3 +118,314 @@ def pformat_namedtuple(schema: NamedTuple) -> str: '{}: {}'.format(k, v.__name__) for k, v in schema.__annotations__.items() ])) + + +def pformat_dict(raw_input: Dict[str, Any]) -> str: + return '{{\n{}\n}}'.format( + ',\n'.join(['{}: {}'.format(k, v) for k, v in raw_input.items()])) + + +@dataclass +class OptionsEntry: + """An entry of PipelineOptions that can be visualized through ipywidgets to + take inputs in IPython notebooks interactively. + + Attributes: + label: The value of the Label widget. + help: The help message of the entry, usually the same to the help in + PipelineOptions. + cls: The PipelineOptions class/subclass the options belong to. + arg_builder: Builds the argument/option. If it's a str, this entry + assigns the input ipywidget's value directly to the argument. If it's a + Dict, use the corresponding Callable to assign the input value to each + argument. If Callable is None, fallback to assign the input value + directly. This allows building multiple similar PipelineOptions + arguments from a single input, such as staging_location and + temp_location in GoogleCloudOptions. + default: The default value of the entry, None if absent. + """ + label: str + help: str + cls: Type[PipelineOptions] + arg_builder: Union[str, Dict[str, Optional[Callable]]] + default: Optional[str] = None + + def __post_init__(self): + # The attribute holds an ipywidget, currently only supports Text. + # The str value can be accessed by self.input.value. + self.input = None + + +class OptionsForm: + """A form visualized to take inputs from users in IPython Notebooks and + generate PipelineOptions to run pipelines. + """ + def __init__(self): + self.options = PipelineOptions() + self.entries = [] + + def add(self, entry: OptionsEntry) -> 'OptionsForm': + """Adds an OptionsEntry to the form. + """ + self.entries.append(entry) + return self + + def to_options(self) -> PipelineOptions: + """Builds the PipelineOptions based on user inputs. + + Can only be invoked after display_for_input. + """ + for entry in self.entries: + assert entry.input, ( + 'to_options invoked before display_for_input. ' + 'Wrong usage.') + view = self.options.view_as(entry.cls) + if isinstance(entry.arg_builder, str): + setattr(view, entry.arg_builder, entry.input.value) + else: + for arg, builder in entry.arg_builder.items(): + if builder: + setattr(view, arg, builder(entry.input.value)) + else: + setattr(view, arg, entry.input.value) + self.additional_options() + return self.options + + def additional_options(self): + """Alters the self.options with additional config.""" + pass + + def display_for_input(self) -> 'OptionsForm': + """Displays the widgets to take user inputs.""" + from IPython.display import display + from ipywidgets import GridBox + from ipywidgets import Label + from ipywidgets import Layout + from ipywidgets import Text + widgets = [] + for entry in self.entries: + text_label = Label(value=entry.label) + text_input = entry.input if entry.input else Text( + value=entry.default if entry.default else '') + text_help = Label(value=entry.help) + entry.input = text_input + widgets.append(text_label) + widgets.append(text_input) + widgets.append(text_help) + grid = GridBox(widgets, layout=Layout(grid_template_columns='1fr 2fr 6fr')) + display(grid) + self.display_actions() + return self + + def display_actions(self): + """Displays actionable widgets to utilize the options, run pipelines and + etc.""" + pass + + +class DataflowOptionsForm(OptionsForm): + """A form to take inputs from users in IPython Notebooks to build + PipelineOptions to run pipelines on Dataflow. + + Only contains minimum fields needed. + """ + @staticmethod + def _build_default_project() -> str: + """Builds a default project id.""" + try: + # pylint: disable=c-extension-no-member + import google.auth + return google.auth.default()[1] + except (KeyboardInterrupt, SystemExit): + raise + except Exception as e: + _LOGGER.warning('There is some issue with your gcloud auth: %s', e) + return 'your-project-id' + + @staticmethod + def _build_req_file_from_pkgs(pkgs) -> Optional[str]: + """Builds a requirements file that contains all additional PYPI packages + needed.""" + if pkgs: + deps = pkgs.split(',') + req_file = os.path.join( + tempfile.mkdtemp(prefix='beam-sql-dataflow-'), 'req.txt') + with open(req_file, 'a') as f: + for dep in deps: + f.write(dep.strip() + '\n') + return req_file + return None + + def __init__( + self, + output_name: str, + output_pcoll: beam.PCollection, + verbose: bool = False): + """Inits the OptionsForm for setting up Dataflow jobs.""" + super().__init__() + self.p = output_pcoll.pipeline + self.output_name = output_name + self.output_pcoll = output_pcoll + self.verbose = verbose + self.notice_shown = False + self.add( + OptionsEntry( + label='Project Id', + help='Name of the Cloud project owning the Dataflow job.', + cls=GoogleCloudOptions, + arg_builder='project', + default=DataflowOptionsForm._build_default_project()) + ).add( + OptionsEntry( + label='Region', + help='The Google Compute Engine region for creating Dataflow job.', + cls=GoogleCloudOptions, + arg_builder='region', + default='us-central1') + ).add( + OptionsEntry( + label='GCS Bucket', + help=( + 'GCS path to stage code packages needed by workers and save ' + 'temporary workflow jobs.'), + cls=GoogleCloudOptions, + arg_builder={ + 'staging_location': lambda x: x + '/staging', + 'temp_location': lambda x: x + '/temp' + }, + default='gs://YOUR_GCS_BUCKET_HERE') + ).add( + OptionsEntry( + label='Additional Packages', + help=( + 'PYPI packages installed, comma-separated. If None, leave ' + 'this field empty.'), + cls=SetupOptions, + arg_builder={ + 'requirements_file': lambda x: DataflowOptionsForm. + _build_req_file_from_pkgs(x) + }, + default='')) + + def additional_options(self): + # Use the latest Java SDK by default. + sdk_overrides = self.options.view_as( + WorkerOptions).sdk_harness_container_image_overrides + override = '.*java.*,apache/beam_java11_sdk:latest' + if sdk_overrides and override not in sdk_overrides: + sdk_overrides.append(override) + else: + self.options.view_as( + WorkerOptions).sdk_harness_container_image_overrides = [override] + + def display_actions(self): + from IPython.display import HTML + from IPython.display import display + from ipywidgets import Button + from ipywidgets import GridBox + from ipywidgets import Layout + from ipywidgets import Output + options_output_area = Output() + run_output_area = Output() + run_btn = Button( + description='Run on Dataflow', + button_style='success', + tooltip=( + 'Submit to Dataflow for execution with the configured options. The ' + 'output PCollection\'s data will be written to the GCS bucket you ' + 'configure.')) + show_options_btn = Button( + description='Show Options', + button_style='info', + tooltip='Show current pipeline options configured.') + + def _run_on_dataflow(btn): + with run_output_area: + run_output_area.clear_output() + + @progress_indicated + def _inner(): + options = self.to_options() + # Caches the output_pcoll to a GCS bucket. + try: + execution_count = 0 + if is_in_ipython(): + from IPython import get_ipython + execution_count = get_ipython().execution_count + output_location = '{}/{}'.format( + options.view_as(GoogleCloudOptions).staging_location, + self.output_name) + _ = self.output_pcoll | 'WriteOuput{}_{}ToGCS'.format( + self.output_name, + execution_count) >> WriteToText(output_location) + _LOGGER.info( + 'Data of output PCollection %s will be written to %s', + self.output_name, + output_location) + except (KeyboardInterrupt, SystemExit): + raise + except: # pylint: disable=bare-except + # The transform has been added before, noop. + pass + if self.verbose: + _LOGGER.info( + 'Running the pipeline on Dataflow with pipeline options %s.', + pformat_dict(options.display_data())) + result = create_runner('DataflowRunner').run_pipeline(self.p, options) + cloud_options = options.view_as(GoogleCloudOptions) + url = ( + 'https://console.cloud.google.com/dataflow/jobs/%s/%s?project=%s' + % (cloud_options.region, result.job_id(), cloud_options.project)) + display( + HTML( + 'Click <a href="%s" target="_new">here</a> for the details ' + 'of your Dataflow job.' % url)) + result_name = 'result_{}'.format(self.output_name) + create_var_in_main(result_name, result) + if self.verbose: + _LOGGER.info( + 'The pipeline result of the run can be accessed from variable ' + '%s. The current status is %s.', + result_name, + result) + + try: + btn.disabled = True + _inner() + finally: + btn.disabled = False + + run_btn.on_click(_run_on_dataflow) + + def _show_options(btn): + with options_output_area: + options_output_area.clear_output() + options = self.to_options() + options_name = 'options_{}'.format(self.output_name) + create_var_in_main(options_name, options) + _LOGGER.info( + 'The pipeline options configured is: %s.', + pformat_dict(options.display_data())) + + show_options_btn.on_click(_show_options) + grid = GridBox([run_btn, show_options_btn], + layout=Layout(grid_template_columns='repeat(2, 200px)')) + display(grid) + + # Implicitly initializes the options variable before 1st time showing + # options. + options_name_inited, _ = create_var_in_main('options_{}'.format( + self.output_name), self.to_options()) + if not self.notice_shown: + _LOGGER.info( + 'The pipeline options can be configured through variable %s. You ' + 'may also add additional options or sink transforms such as write ' + 'to BigQuery in other notebook cells. Come back to click "Run on ' + 'Dataflow" button once you complete additional configurations. ' + 'Optionally, you can chain more beam_sql magics with DataflowRunner ' + 'and click "Run on Dataflow" in their outputs.', + options_name_inited) + self.notice_shown = True + + display(options_output_area) + display(run_output_area) diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py index 01a54c3..16d03f5 100644 --- a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py @@ -23,9 +23,15 @@ import unittest from typing import NamedTuple from unittest.mock import patch +import pytest + import apache_beam as beam +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.interactive import interactive_environment as ie +from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm from apache_beam.runners.interactive.sql.utils import find_pcolls -from apache_beam.runners.interactive.sql.utils import is_namedtuple +from apache_beam.runners.interactive.sql.utils import pformat_dict from apache_beam.runners.interactive.sql.utils import pformat_namedtuple from apache_beam.runners.interactive.sql.utils import register_coder_for_schema from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token @@ -37,19 +43,6 @@ class ANamedTuple(NamedTuple): class UtilsTest(unittest.TestCase): - def test_is_namedtuple(self): - class AType: - pass - - a_type = AType - a_tuple = type((1, 2, 3)) - - a_namedtuple = ANamedTuple - - self.assertTrue(is_namedtuple(a_namedtuple)) - self.assertFalse(is_namedtuple(a_type)) - self.assertFalse(is_namedtuple(a_tuple)) - def test_register_coder_for_schema(self): self.assertNotIsInstance( beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder) @@ -80,6 +73,35 @@ class UtilsTest(unittest.TestCase): self.assertEqual( 'ANamedTuple(a: int, b: str)', pformat_namedtuple(ANamedTuple)) + def test_pformat_dict(self): + self.assertEqual('{\na: 1,\nb: 2\n}', pformat_dict({'a': 1, 'b': '2'})) + + +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') +@pytest.mark.skipif( + not ie.current_env().is_interactive_ready, + reason='[interactive] dependency is not installed.') +class OptionsFormTest(unittest.TestCase): + def test_dataflow_options_form(self): + p = beam.Pipeline() + pcoll = p | beam.Create([1, 2, 3]) + with patch('google.auth') as ga: + ga.default = lambda: ['', 'default_project_id'] + df_form = DataflowOptionsForm('pcoll', pcoll) + df_form.display_for_input() + df_form.entries[2].input.value = 'gs://test-bucket' + df_form.entries[3].input.value = 'a-pkg' + options = df_form.to_options() + cloud_options = options.view_as(GoogleCloudOptions) + self.assertEqual(cloud_options.project, 'default_project_id') + self.assertEqual(cloud_options.region, 'us-central1') + self.assertEqual( + cloud_options.staging_location, 'gs://test-bucket/staging') + self.assertEqual(cloud_options.temp_location, 'gs://test-bucket/temp') + self.assertIsNotNone(options.view_as(SetupOptions).requirements_file) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py index 49b87ba..957b65c 100644 --- a/sdks/python/apache_beam/runners/interactive/utils.py +++ b/sdks/python/apache_beam/runners/interactive/utils.py @@ -20,9 +20,12 @@ import functools import hashlib +import importlib import json import logging +from typing import Any from typing import Dict +from typing import Tuple import pandas as pd @@ -405,3 +408,22 @@ def unbounded_sources(pipeline): v = CheckUnboundednessVisitor() pipeline.visit(v) return v.unbounded_sources + + +def create_var_in_main(name: str, + value: Any, + watch: bool = True) -> Tuple[str, Any]: + """Declares a variable in the main module. + + Args: + name: the variable name in the main module. + value: the value of the variable. + watch: whether to watch it in the interactive environment. + Returns: + A 2-entry tuple of the variable name and value. + """ + setattr(importlib.import_module('__main__'), name, value) + if watch: + from apache_beam.runners.interactive import interactive_environment as ie + ie.current_env().watch({name: value}) + return name, value diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py b/sdks/python/apache_beam/runners/interactive/utils_test.py index 784081e..0915ff2 100644 --- a/sdks/python/apache_beam/runners/interactive/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/utils_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +import importlib import json import logging import tempfile @@ -318,6 +319,13 @@ class GeneralUtilTest(unittest.TestCase): }) self.assertEqual('pcoll_test_find_pcoll_name', utils.find_pcoll_name(pcoll)) + def test_create_var_in_main(self): + name = 'test_create_var_in_main' + value = Record(0, 0, 0) + _ = utils.create_var_in_main(name, value) + main_session = importlib.import_module('__main__') + self.assertIs(getattr(main_session, name, None), value) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 514f4e7..d16618d 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -206,6 +206,7 @@ INTERACTIVE_BEAM = [ 'facets-overview>=1.0.0,<2', 'ipython>=7,<8', 'ipykernel>=5.2.0,<6', + 'ipywidgets>=7.6.5,<8', # Skip version 6.1.13 due to # https://github.com/jupyter/jupyter_client/issues/637 'jupyter-client>=6.1.11,<6.1.13',