 File path: sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
 @@ -0,0 +1,470 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#    http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module to instrument interactivity to the given pipeline.
+For internal use only; no backwards-compatibility guarantees.
+This module accesses current interactive environment and analyzes given 
+to transform original pipeline into a one-shot pipeline with interactivity.
+from __future__ import absolute_import
+import logging
+import apache_beam as beam
+from apache_beam.pipeline import PipelineVisitor
+from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive import interactive_environment as ie
+READ_CACHE = "_ReadCache_"
+WRITE_CACHE = "_WriteCache_"
+class PipelineInstrument(object):
+  """A pipeline instrument for pipeline to be executed by interactive runner.
+  This module should never depend on underlying runner that interactive runner
+  delegates. It instruments the original instance of pipeline directly by
+  appending or replacing transforms with help of cache. It provides
+  interfaces to recover states of original pipeline. It's the interactive
+  runner's responsibility to coordinate supported underlying runners to run
+  the pipeline instrumented and recover the original pipeline states if needed.
+  """
+  def __init__(self, pipeline, options=None):
+    self._pipeline = pipeline
+    # The cache manager should be initiated outside of this module and outside
+    # of run_pipeline() from interactive runner so that its lifespan could 
+    # multiple runs in the interactive environment. Owned by
+    # interactive_environment module. Not owned by this module.
+    # TODO(BEAM-7760): change the scope of cache to be owned by runner or
+    # pipeline result instances because a pipeline is not 1:1 correlated to a
+    # running job. Only complete and read-only cache is valid across multiple
+    # jobs. Other cache instances should have their own scopes. Some design
+    # change should support only runner.run(pipeline) pattern rather than
+    # pipeline.run([runner]) and a runner can only run at most one pipeline at 
+    # time. Otherwise, result returned by run() is the only 1:1 anchor.
+    self._cache_manager = ie.current_env().cache_manager()
+    # Invoke a round trip through the runner API. This makes sure the Pipeline
+    # proto is stable. The snapshot of pipeline will not be mutated within this
+    # module and can be used to recover original pipeline if needed.
+    self._pipeline_snap = beam.pipeline.Pipeline.from_runner_api(
+        pipeline.to_runner_api(use_fake_coders=True),
+        pipeline.runner,
+        options)
+    # Snapshot of original pipeline information.
+    (self._original_pipeline_proto,
+     self._original_context) = self._pipeline_snap.to_runner_api(
+         return_context=True, use_fake_coders=True)
+    # All compute-once-against-original-pipeline fields.
+    self._has_unbounded_source = has_unbounded_source(self._pipeline_snap)
+    # TODO(BEAM-7760): once cache scope changed, this is not needed to manage
+    # relationships across pipelines, runners, and jobs.
+    self._pcolls_to_pcoll_id = pcolls_to_pcoll_id(self._pipeline_snap,
+                                                  self._original_context)
+    # A mapping from PCollection id to python id() value in user defined
+    # pipeline instance.
+    (self._pcoll_version_map,
+     self._cacheables) = cacheables(self.pcolls_to_pcoll_id())
+    # A dict from cache key to PCollection that is read from cache.
+    # If exists, caller should reuse the PCollection read. If not, caller
+    # should create new transform and track the PCollection read from cache.
+    # (Dict[str, AppliedPTransform]).
+    self._cached_pcoll_read = {}
+  def instrumented_pipeline_proto(self):
+    """Always returns a new instance of portable instrumented proto."""
+    return self._pipeline.to_runner_api(use_fake_coders=True)
+  def has_unbounded_source(self):
+    """Checks if a given pipeline has any source that is unbounded.
+    The function directly checks the source transform definition instead
+    of pvalues in the pipeline. Thus manually setting is_bounded field of
+    a PCollection or switching streaming mode will not affect this
+    function's result. The result is always deterministic when the source
+    code of a pipeline is defined.
+    """
+    return self._has_unbounded_source
+  def cacheables(self):
+    """Finds cacheable PCollections from the pipeline.
+    The function only treats the result as cacheables since there is no
+    guarantee whether the cache desired PCollection has been cached or
+    not. A PCollection desires caching when it's bound to a user defined
+    variable in source code. Otherwise, the PCollection is not reusale
+    nor introspectable which nullifying the need of cache.
+    """
+    return self._cacheables
+  def pcolls_to_pcoll_id(self):
+    """Returns a dict mapping str(PCollection)s to IDs."""
+    return self._pcolls_to_pcoll_id
+  def original_pipeline_proto(self):
+    """Returns the portable proto representation of the pipeline before
+    instrumentation."""
+    return self._original_pipeline_proto
+  def original_pipeline(self):
+    """Returns a snapshot of the pipeline before instrumentation."""
+    return self._pipeline_snap
+  def instrument(self):
+    """Instruments original pipeline with cache.
+    For cacheable output PCollection, if cache for the key doesn't exist, do
+    _write_cache(); for cacheable input PCollection, if cache for the key
+    exists, do _read_cache(). No instrument in any other situation.
+    Modifies:
+      self._pipeline
+    """
+    self._preprocess()
+    cacheable_inputs = set()
+    class InstrumentVisitor(PipelineVisitor):
+      """Visitor utilizes cache to instrument the pipeline."""
+      def __init__(self, pin):
+        self._pin = pin
+      def enter_composite_transform(self, transform_node):
+        self.visit_transform(transform_node)
+      def visit_transform(self, transform_node):
+        cacheable_inputs.update(self._pin._cacheable_inputs(transform_node))
+    v = InstrumentVisitor(self)
+    self._pipeline.visit(v)
+    # Create ReadCache transforms.
+    for cacheable_input in cacheable_inputs:
+      self._read_cache(cacheable_input)
+    # Replace/wire inputs w/ cached PCollections from ReadCache transforms.
+    self._replace_with_cached_inputs()
+    # Write cache for all cacheables.
+    for _, cacheable in self.cacheables().items():
+      self._write_cache(cacheable['pcoll'])
+    # TODO(BEAM-7760): prune sub graphs that doesn't need to be executed.
+  def _preprocess(self):
+    """Pre-processes the pipeline.
+    Since the pipeline instance in the class might not be the same instance
+    defined in the user code, the pre-process will figure out the relationship
+    of cacheable PCollections between these 2 instances by replacing 'pcoll'
+    fields in the cacheable dictionary with ones from the running instance.
+    """
+    class PreprocessVisitor(PipelineVisitor):
+      def __init__(self, pin):
+        self._pin = pin
+      def enter_composite_transform(self, transform_node):
+        self.visit_transform(transform_node)
+      def visit_transform(self, transform_node):
+        for in_pcoll in transform_node.inputs:
+          self._process(in_pcoll)
+        for out_pcoll in transform_node.outputs.values():
+          self._process(out_pcoll)
+      def _process(self, pcoll):
+        pcoll_id = self._pin.pcolls_to_pcoll_id().get(str(pcoll), '')
+        if pcoll_id in self._pin._pcoll_version_map:
+          cacheable_key = self._pin._cacheable_key(pcoll)
+          if (cacheable_key in self._pin.cacheables() and
+              self._pin.cacheables()[cacheable_key]['pcoll'] != pcoll):
+            self._pin.cacheables()[cacheable_key]['pcoll'] = pcoll
+    v = PreprocessVisitor(self)
+    self._pipeline.visit(v)
+  def _write_cache(self, pcoll):
+    """Caches a cacheable PCollection.
+    For the given PCollection, by appending sub transform part that materialize
+    the PCollection through sink into cache implementation. The cache write is
+    not immediate. It happens when the runner runs the transformed pipeline
+    and thus not usable for this run as intended. It's the caller's
+    responsibility to make sure the PCollection is indeed cacheable. Otherwise,
+    cache resources might be wasted. If a cache with corresponding key exists,
+    noop since a cache write is only needed when the last cache is invalidated.
+    And if a cache is invalidated, the PCollection's new key is guaranteed to
+    not exist in current cache.
+    Modifies:
+      self._pipeline
+    """
+    if pcoll.pipeline is not self._pipeline:
+      return
+    key = self.cache_key(pcoll)
+    if not self._cache_manager.exists('full', key):
+      logging.debug('%s write cache: %s', pcoll.producer, key)
+      _ = pcoll | '{}{}'.format(WRITE_CACHE, key) >> cache.WriteCache(
+          self._cache_manager, key)
+  def _read_cache(self, pcoll):
+    """Reads a cached pvalue.
+    A noop will cause the pipeline to execute the transform as
+    it is and cache nothing from this transform for next run.
+    Modifies:
+      self._pipeline
+    """
+    if pcoll.pipeline is not self._pipeline:
+      return
+    key = self.cache_key(pcoll)
+    if self._cache_manager.exists('full', key):
+      if key not in self._cached_pcoll_read:
+        logging.debug('read cache: %s', key)
+        # Mutates the pipeline with cache read transform attached
+        # to root of the pipeline.
+        pcoll_from_cache = (
+            self._pipeline
+            | '{}{}'.format(READ_CACHE, key) >> cache.ReadCache(
+                self._cache_manager, key))
+        self._cached_pcoll_read[key] = pcoll_from_cache
+    # else: NOOP when cache doesn't exist.
+  def _replace_with_cached_inputs(self):
+    """Replace PCollection inputs in the pipeline with cache if possible.
+    For any input PCollection, find out whether there is valid cache. If so,
+    replace the input of the AppliedPTransform with output of the
+    AppliedPtransform that sources pvalue from the cache. If there is no valid
+    cache, noop.
+    """
+    class ReadCacheWireVisitor(PipelineVisitor):
+      """Visitor wires cache read as inputs to replace corresponding original
+      input PCollections in pipeline.
+      """
+      def __init__(self, pin):
+        """Initializes with a PipelineInstrument."""
+        self._pin = pin
+      def enter_composite_transform(self, transform_node):
+        self.visit_transform(transform_node)
+      def visit_transform(self, transform_node):
+        if transform_node.inputs:
+          input_list = list(transform_node.inputs)
+          for i in range(len(input_list)):
+            key = self._pin.cache_key(input_list[i])
+            if key in self._pin._cached_pcoll_read:
+              input_list[i] = self._pin._cached_pcoll_read[key]
+          transform_node.inputs = tuple(input_list)
+    v = ReadCacheWireVisitor(self)
+    self._pipeline.visit(v)
+  def _cacheable_inputs(self, transform):
+    inputs = set()
+    for in_pcoll in transform.inputs:
+      if self._cacheable_key(in_pcoll) in self.cacheables():
+        inputs.add(in_pcoll)
+    return inputs
+  def _cacheable_key(self, pcoll):
+    """Gets the key a cacheable PCollection is tracked within the 
+    return cacheable_key(pcoll, self.pcolls_to_pcoll_id(),
+                         self._pcoll_version_map)
+  def cache_key(self, pcoll):
+    """Gets the identifier of a cacheable PCollection in cache.
+    If the pcoll is not a cacheable, return ''.
+    The key is what the pcoll would use as identifier if it's materialized in
+    cache. It doesn't mean that there would definitely be such cache already.
+    Also, the pcoll can come from the original user defined pipeline object or
+    an equivalent pcoll from a transformed copy of the original pipeline.
+    """
+    cacheable = self.cacheables().get(self._cacheable_key(pcoll), None)
+    if cacheable:
+      return '_'.join((cacheable['var'],
+                       cacheable['version'],
+                       cacheable['pcoll_id'],
+                       cacheable['producer_version']))
+    return ''
+  def _debug_appliedptransform(self, transform_node):
+    """Debugs AppliedPTransform at debug level logging.
+    Logs structure of an AppliedPTransform instance. Used for testing,
+    debugging and dev purpose.
+    """
+    logging.debug('parent is: %s', transform_node.parent)
+    logging.debug('full label of transform_node: %s',
+                  transform_node.full_label)
+    logging.debug('inputs: ')
+    for inp in transform_node.inputs:
+      logging.debug('  %s', inp)
+    logging.debug('parts: ')
+    for part in transform_node.parts:
+      logging.debug('  %s', part)
+    logging.debug('outputs: ')
+    for output in transform_node.outputs:
+      logging.debug('  %s', output)
+  def _debug_pipeline_graph(self):
+    """Debugs the pipeline at debug level logging.
+    The pipeline within the class is being instrumented and mutates. This
+    function logs structure of the pipeline exactly when invoked and can be
+    invoked multiple times at different instrumenting stages for testing,
+    debugging and dev purposes.
+    """
+    class DebugVisitor(PipelineVisitor):
+      def __init__(self, pin):
+        self._pin = pin
+      def enter_composite(self, transform_node):
+        self.visit_transform(transform_node)
+      def visit_transform(self, transform_node):
+        self._pin._debug_appliedptransform(transform_node)
+    v = DebugVisitor(self)
+    self._pipeline.visit(v)
+def pin(pipeline, options=None):
 Review comment:
   It's short for `pipeline instrument`, also (verb.) `pin`s the `pipeline` at 
its current state, with original pipeline information snapshotted and caching 
(interactivity) transforms instrumented (once) by mutating the pipeline object 
from the state when it is received.
> Interactive Beam Caching PCollections bound to user defined vars in notebook
> ----------------------------------------------------------------------------
>                 Key: BEAM-7760
>                 URL: https://issues.apache.org/jira/browse/BEAM-7760
>             Project: Beam
>          Issue Type: New Feature
>          Components: examples-python
>            Reporter: Ning Kang
>            Assignee: Ning Kang
>            Priority: Major
>          Time Spent: 17h
>  Remaining Estimate: 0h
> Cache only PCollections bound to user defined variables in a pipeline when 
> running pipeline with interactive runner in jupyter notebooks.
> [Interactive 
> Beam|[https://github.com/apache/beam/tree/master/sdks/python/apache_beam/runners/interactive]]
>  has been caching and using caches of "leaf" PCollections for interactive 
> execution in jupyter notebooks.
> The interactive execution is currently supported so that when appending new 
> transforms to existing pipeline for a new run, executed part of the pipeline 
> doesn't need to be re-executed. 
> A PCollection is "leaf" when it is never used as input in any PTransform in 
> the pipeline.
> The problem with building caches and pipeline to execute around "leaf" is 
> that when a PCollection is consumed by a sink with no output, the pipeline to 
> execute built will miss the subgraph generating and consuming that 
> PCollection.
> An example, "ReadFromPubSub --> WirteToPubSub" will result in an empty 
> pipeline.
> Caching around PCollections bound to user defined variables and replacing 
> transforms with source and sink of caches could resolve the pipeline to 
> execute properly under the interactive execution scenario. Also, cached 
> PCollection now can trace back to user code and can be used for user data 
> visualization if user wants to do it.
> E.g.,
> {code:java}
> // ...
> p = beam.Pipeline(interactive_runner.InteractiveRunner(),
>                   options=pipeline_options)
> messages = p | "Read" >> beam.io.ReadFromPubSub(subscription='...')
> messages | "Write" >> beam.io.WriteToPubSub(topic_path)
> result = p.run()
> // ...
> visualize(messages){code}
>  The interactive runner automatically figures out that PCollection
> {code:java}
> messages{code}
> created by
> {code:java}
> p | "Read" >> beam.io.ReadFromPubSub(subscription='...'){code}
> should be cached and reused if the notebook user appends more transforms.
>  And once the pipeline gets executed, the user could use any 
> visualize(PCollection) module to visualize the data statically (batch) or 
> dynamically (stream)

