This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f22d2bc1f16 Save Logical Type and Coder Registry on cloudpickle save 
main session (#36271)
f22d2bc1f16 is described below

commit f22d2bc1f1638a19639483ef18b56b4cf9fb8504
Author: Yi Hu <[email protected]>
AuthorDate: Fri Dec 5 16:16:42 2025 -0500

    Save Logical Type and Coder Registry on cloudpickle save main session 
(#36271)
    
    * Save Logical Type Registry and Coder Registry on cloudpickle save main 
session
    
    fix naming
    
    * Track custom_urn set in logical type registry
    
    * Fix, add tests
    
    * Set save_main_session default to true for cloudpickle and introduce 
overwrite flag
    
    * Fix test; trigger postcommits
    
    * Fix test as Dataflow runner submission now staging a main session file
---
 .github/trigger_files/beam_PostCommit_Python.json  |  1 +
 .../beam_PostCommit_Python_Xlang_Gcp_Dataflow.json |  2 +-
 .../beam_PostCommit_Python_Xlang_IO_Dataflow.json  |  2 +-
 CHANGES.md                                         |  1 +
 sdks/python/apache_beam/coders/typecoders.py       | 25 +++++---
 .../apache_beam/internal/cloudpickle_pickler.py    | 35 +++++++++--
 .../internal/cloudpickle_pickler_test.py           | 20 +++++++
 sdks/python/apache_beam/internal/pickler.py        | 11 +++-
 .../python/apache_beam/options/pipeline_options.py | 29 ++++++++-
 .../runners/dataflow/dataflow_runner_test.py       | 63 ++++++++++++-------
 .../apache_beam/runners/portability/stager.py      |  1 -
 .../apache_beam/runners/portability/stager_test.py |  7 ++-
 .../apache_beam/runners/worker/sdk_worker_main.py  | 70 ++++++++++++----------
 sdks/python/apache_beam/typehints/schemas.py       | 59 ++++++++++++++----
 14 files changed, 239 insertions(+), 87 deletions(-)

diff --git a/.github/trigger_files/beam_PostCommit_Python.json 
b/.github/trigger_files/beam_PostCommit_Python.json
index c6ec17f4841..d6818d275f1 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,5 +1,6 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run.",
+  "pr": "36271",
   "modification": 35
 }
 
diff --git 
a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json 
b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
index 99a8fc8ff6d..b60f5c4cc3c 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
@@ -1,4 +1,4 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run",
-  "modification": 14
+  "modification": 0
 }
diff --git 
a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json 
b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
index f1ba03a243e..b60f5c4cc3c 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
@@ -1,4 +1,4 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run",
-  "modification": 5
+  "modification": 0
 }
diff --git a/CHANGES.md b/CHANGES.md
index bab7182539e..dfad320a694 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -85,6 +85,7 @@
 ## Bugfixes
 
 * Fixed FirestoreV1 Beam connectors allow configuring inconsistent 
project/database IDs between RPC requests and routing headers #36895 (Java) 
([#36895](https://github.com/apache/beam/issues/36895)).
+ Logical type and coder registry are saved for pipelines in the case of 
default pickler. This fixes a side effect of switching to cloudpickle as 
default pickler in Beam 2.65.0 (Python) 
([#35738](https://github.com/apache/beam/issues/35738)).
 
 ## Known Issues
 
diff --git a/sdks/python/apache_beam/coders/typecoders.py 
b/sdks/python/apache_beam/coders/typecoders.py
index 779c65dc772..ef75a21ce9e 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -114,6 +114,14 @@ class CoderRegistry(object):
       typehint_coder_class: Type[coders.Coder]) -> None:
     self._coders[typehint_type] = typehint_coder_class
 
+  @staticmethod
+  def _normalize_typehint_type(typehint_type):
+    if typehint_type.__module__ == '__main__':
+      # See https://github.com/apache/beam/issues/21541
+      # TODO(robertwb): Remove once all runners are portable.
+      return getattr(typehint_type, '__name__', str(typehint_type))
+    return typehint_type
+
   def register_coder(
       self, typehint_type: Any,
       typehint_coder_class: Type[coders.Coder]) -> None:
@@ -123,11 +131,8 @@ class CoderRegistry(object):
           'Received %r instead.' % typehint_coder_class)
     if typehint_type not in self.custom_types:
       self.custom_types.append(typehint_type)
-    if typehint_type.__module__ == '__main__':
-      # See https://github.com/apache/beam/issues/21541
-      # TODO(robertwb): Remove once all runners are portable.
-      typehint_type = getattr(typehint_type, '__name__', str(typehint_type))
-    self._register_coder_internal(typehint_type, typehint_coder_class)
+    self._register_coder_internal(
+        self._normalize_typehint_type(typehint_type), typehint_coder_class)
 
   def get_coder(self, typehint: Any) -> coders.Coder:
     if typehint and typehint.__module__ == '__main__':
@@ -170,9 +175,15 @@ class CoderRegistry(object):
       coder = self._fallback_coder
     return coder.from_type_hint(typehint, self)
 
-  def get_custom_type_coder_tuples(self, types):
+  def get_custom_type_coder_tuples(self, types=None):
     """Returns type/coder tuples for all custom types passed in."""
-    return [(t, self._coders[t]) for t in types if t in self.custom_types]
+    return [(t, self._coders[self._normalize_typehint_type(t)])
+            for t in self.custom_types if (types is None or t in types)]
+
+  def load_custom_type_coder_tuples(self, type_coder):
+    """Load type/coder tuples into coder registry."""
+    for t, c in type_coder:
+      self.register_coder(t, c)
 
   def verify_deterministic(self, key_coder, op_name, silent=True):
     if not key_coder.is_deterministic():
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py 
b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
index 199294f1731..acdcc46cd40 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
@@ -252,12 +252,35 @@ def _lock_reducer(obj):
 
 
 def dump_session(file_path):
-  # It is possible to dump session with cloudpickle. However, since references
-  # are saved it should not be necessary. See 
https://s.apache.org/beam-picklers
-  pass
+  # Since References are saved (https://s.apache.org/beam-picklers), we only
+  # dump supported Beam Registries (currently only logical type registry)
+  from apache_beam.coders import typecoders
+  from apache_beam.typehints import schemas
+
+  with _pickle_lock, open(file_path, 'wb') as file:
+    coder_reg = typecoders.registry.get_custom_type_coder_tuples()
+    logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
+
+    pickler = cloudpickle.CloudPickler(file)
+    # TODO(https://github.com/apache/beam/issues/18500) add file system 
registry
+    # once implemented
+    pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg})
 
 
 def load_session(file_path):
-  # It is possible to load_session with cloudpickle. However, since references
-  # are saved it should not be necessary. See 
https://s.apache.org/beam-picklers
-  pass
+  from apache_beam.coders import typecoders
+  from apache_beam.typehints import schemas
+
+  with _pickle_lock, open(file_path, 'rb') as file:
+    registries = cloudpickle.load(file)
+    if type(registries) != dict:
+      raise ValueError(
+          "Faled loading session: expected dict, got {}", type(registries))
+    if "coder" in registries:
+      typecoders.registry.load_custom_type_coder_tuples(registries["coder"])
+    else:
+      _LOGGER.warning('No coder registry found in saved session')
+    if "logical_type" in registries:
+      schemas.LogicalType._known_logical_types.load(registries["logical_type"])
+    else:
+      _LOGGER.warning('No logical type registry found in saved session')
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py 
b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
index 4a51c56c24b..99fbb03ac2e 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
@@ -20,6 +20,7 @@
 # pytype: skip-file
 
 import os
+import tempfile
 import threading
 import types
 import unittest
@@ -31,6 +32,7 @@ from apache_beam.internal import code_object_pickler
 from apache_beam.internal import module_test
 from apache_beam.internal.cloudpickle_pickler import dumps
 from apache_beam.internal.cloudpickle_pickler import loads
+from apache_beam.typehints.schemas import LogicalTypeRegistry
 from apache_beam.utils import shared
 
 GLOBAL_DICT_REF = module_test.GLOBAL_DICT
@@ -244,6 +246,24 @@ self.assertEqual(DataClass(datum='abc'), 
loads(dumps(DataClass(datum='abc'))))
     unpickled_filename = os.path.abspath(unpickled_code.co_filename)
     self.assertEqual(unpickled_filename, original_filename)
 
+  @mock.patch(
+      "apache_beam.coders.typecoders.registry.load_custom_type_coder_tuples")
+  @mock.patch(
+      "apache_beam.typehints.schemas.LogicalType._known_logical_types.load")
+  def test_dump_load_session(self, logicaltype_mock, coder_mock):
+    session_file = 'pickled'
+
+    with tempfile.TemporaryDirectory() as tmp_dirname:
+      pickled_session_file = os.path.join(tmp_dirname, session_file)
+      beam_cloudpickle.dump_session(pickled_session_file)
+      beam_cloudpickle.load_session(pickled_session_file)
+    load_logical_types = logicaltype_mock.call_args.args
+    load_coders = coder_mock.call_args.args
+    self.assertEqual(len(load_logical_types), 1)
+    self.assertEqual(len(load_coders), 1)
+    self.assertTrue(isinstance(load_logical_types[0], LogicalTypeRegistry))
+    self.assertTrue(isinstance(load_coders[0], list))
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/internal/pickler.py 
b/sdks/python/apache_beam/internal/pickler.py
index 0af3b16ec05..3626b599a5c 100644
--- a/sdks/python/apache_beam/internal/pickler.py
+++ b/sdks/python/apache_beam/internal/pickler.py
@@ -91,6 +91,14 @@ def load_session(file_path):
   return desired_pickle_lib.load_session(file_path)
 
 
+def is_currently_dill():
+  return desired_pickle_lib == dill_pickler
+
+
+def is_currently_cloudpickle():
+  return desired_pickle_lib == cloudpickle_pickler
+
+
 def set_library(selected_library=DEFAULT_PICKLE_LIB):
   """ Sets pickle library that will be used. """
   global desired_pickle_lib
@@ -108,12 +116,11 @@ def set_library(selected_library=DEFAULT_PICKLE_LIB):
         "Pipeline option pickle_library=dill_unsafe is set, but dill is not "
         "installed. Install dill in job submission and runtime environments.")
 
-  is_currently_dill = (desired_pickle_lib == dill_pickler)
   dill_is_requested = (
       selected_library == USE_DILL or selected_library == USE_DILL_UNSAFE)
 
   # If switching to or from dill, update the pickler hook overrides.
-  if is_currently_dill != dill_is_requested:
+  if is_currently_dill() != dill_is_requested:
     dill_pickler.override_pickler_hooks(selected_library == USE_DILL)
 
   if dill_is_requested:
diff --git a/sdks/python/apache_beam/options/pipeline_options.py 
b/sdks/python/apache_beam/options/pipeline_options.py
index f2addf6f9d5..170ade224c1 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -64,7 +64,10 @@ _LOGGER = logging.getLogger(__name__)
 # Map defined with option names to flag names for boolean options
 # that have a destination(dest) in parser.add_argument() different
 # from the flag name and whose default value is `None`.
-_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
+_FLAG_THAT_SETS_FALSE_VALUE = {
+    'use_public_ips': 'no_use_public_ips',
+    'save_main_session': 'no_save_main_session'
+}
 # Set of options which should not be overriden when applying options from a
 # different language. This is relevant when using x-lang transforms where the
 # expansion service is started up with some pipeline options, and will
@@ -1672,7 +1675,7 @@ class SetupOptions(PipelineOptions):
         choices=['cloudpickle', 'default', 'dill', 'dill_unsafe'])
     parser.add_argument(
         '--save_main_session',
-        default=False,
+        default=None,
         action='store_true',
         help=(
             'Save the main session state so that pickled functions and classes 
'
@@ -1680,6 +1683,15 @@ class SetupOptions(PipelineOptions):
             'Some workflows do not need the session state if for instance all '
             'their functions/classes are defined in proper modules '
             '(not __main__) and the modules are importable in the worker. '))
+    parser.add_argument(
+        '--no_save_main_session',
+        default=None,
+        action='store_false',
+        dest='save_main_session',
+        help=(
+            'Disable saving the main session state. It is enabled/disabled by'
+            'default for cloudpickle/dill pickler. See "save_main_session".'))
+
     parser.add_argument(
         '--sdk_location',
         default='default',
@@ -1780,10 +1792,23 @@ class SetupOptions(PipelineOptions):
             'If not specified, the default Maven Central repository will be '
             'used.'))
 
+  def _handle_load_main_session(self, validator):
+    save_main_session = getattr(self, 'save_main_session')
+    if save_main_session is None:
+      # save_main_session default to False for dill, while default to true
+      # for cloudpickle
+      pickle_library = getattr(self, 'pickle_library')
+      if pickle_library in ['default', 'cloudpickle']:
+        setattr(self, 'save_main_session', True)
+      else:
+        setattr(self, 'save_main_session', False)
+    return []
+
   def validate(self, validator):
     errors = []
     errors.extend(validator.validate_container_prebuilding_options(self))
     errors.extend(validator.validate_pickle_library(self))
+    errors.extend(self._handle_load_main_session(validator))
     return errors
 
 
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index 178a75ec41d..d5d8ba662f0 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -42,6 +42,7 @@ from apache_beam.runners.dataflow.dataflow_runner import 
DataflowRuntimeExceptio
 from apache_beam.runners.dataflow.dataflow_runner import 
_check_and_add_missing_options
 from apache_beam.runners.dataflow.dataflow_runner import 
_check_and_add_missing_streaming_options
 from apache_beam.runners.dataflow.internal.clients import dataflow as 
dataflow_api
+from apache_beam.runners.internal import names
 from apache_beam.runners.runner import PipelineState
 from apache_beam.testing.extra_assertions import ExtraAssertionsMixin
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -243,6 +244,18 @@ class DataflowRunnerTest(unittest.TestCase, 
ExtraAssertionsMixin):
     self.assertTrue(
         isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner))
 
+  @staticmethod
+  def dependency_proto_from_main_session_file(serialized_path):
+    return [
+        beam_runner_api_pb2.ArtifactInformation(
+            type_urn=common_urns.artifact_types.FILE.urn,
+            type_payload=serialized_path,
+            role_urn=common_urns.artifact_roles.STAGING_TO.urn,
+            role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
+                staged_name=names.PICKLED_MAIN_SESSION_FILE).SerializeToString(
+                ))
+    ]
+
   def test_environment_override_translation_legacy_worker_harness_image(self):
     self.default_properties.append('--experiments=beam_fn_api')
     self.default_properties.append('--worker_harness_container_image=LEGACY')
@@ -256,17 +269,22 @@ class DataflowRunnerTest(unittest.TestCase, 
ExtraAssertionsMixin):
           | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
           | ptransform.GroupByKey())
 
+    actual = 
list(remote_runner.proto_pipeline.components.environments.values())
+    self.assertEqual(len(actual), 1)
+    actual = actual[0]
+    file_path = actual.dependencies[0].type_payload
+    # Dependency payload contains main_session from a transient temp directory
+    # Use actual for expected value.
+    main_session_dep = self.dependency_proto_from_main_session_file(file_path)
     self.assertEqual(
-        list(remote_runner.proto_pipeline.components.environments.values()),
-        [
-            beam_runner_api_pb2.Environment(
-                urn=common_urns.environments.DOCKER.urn,
-                payload=beam_runner_api_pb2.DockerPayload(
-                    container_image='LEGACY').SerializeToString(),
-                capabilities=environments.python_sdk_docker_capabilities(),
-                dependencies=environments.python_sdk_dependencies(
-                    options=options))
-        ])
+        actual,
+        beam_runner_api_pb2.Environment(
+            urn=common_urns.environments.DOCKER.urn,
+            payload=beam_runner_api_pb2.DockerPayload(
+                container_image='LEGACY').SerializeToString(),
+            capabilities=environments.python_sdk_docker_capabilities(),
+            dependencies=environments.python_sdk_dependencies(options=options) 
+
+            main_session_dep))
 
   def test_environment_override_translation_sdk_container_image(self):
     self.default_properties.append('--experiments=beam_fn_api')
@@ -281,17 +299,22 @@ class DataflowRunnerTest(unittest.TestCase, 
ExtraAssertionsMixin):
           | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
           | ptransform.GroupByKey())
 
+    actual = 
list(remote_runner.proto_pipeline.components.environments.values())
+    self.assertEqual(len(actual), 1)
+    actual = actual[0]
+    file_path = actual.dependencies[0].type_payload
+    # Dependency payload contains main_session from a transient temp directory
+    # Use actual for expected value.
+    main_session_dep = self.dependency_proto_from_main_session_file(file_path)
     self.assertEqual(
-        list(remote_runner.proto_pipeline.components.environments.values()),
-        [
-            beam_runner_api_pb2.Environment(
-                urn=common_urns.environments.DOCKER.urn,
-                payload=beam_runner_api_pb2.DockerPayload(
-                    container_image='FOO').SerializeToString(),
-                capabilities=environments.python_sdk_docker_capabilities(),
-                dependencies=environments.python_sdk_dependencies(
-                    options=options))
-        ])
+        actual,
+        beam_runner_api_pb2.Environment(
+            urn=common_urns.environments.DOCKER.urn,
+            payload=beam_runner_api_pb2.DockerPayload(
+                container_image='FOO').SerializeToString(),
+            capabilities=environments.python_sdk_docker_capabilities(),
+            dependencies=environments.python_sdk_dependencies(options=options) 
+
+            main_session_dep))
 
   def test_remote_runner_translation(self):
     remote_runner = DataflowRunner()
diff --git a/sdks/python/apache_beam/runners/portability/stager.py 
b/sdks/python/apache_beam/runners/portability/stager.py
index 9147410c246..aa03082f0d5 100644
--- a/sdks/python/apache_beam/runners/portability/stager.py
+++ b/sdks/python/apache_beam/runners/portability/stager.py
@@ -376,7 +376,6 @@ class Stager(object):
       pickled_session_file = os.path.join(
           temp_dir, names.PICKLED_MAIN_SESSION_FILE)
       pickler.dump_session(pickled_session_file)
-      # for pickle_library: cloudpickle, dump_session is no op
       if os.path.exists(pickled_session_file):
         resources.append(
             Stager._create_file_stage_to_artifact(
diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py 
b/sdks/python/apache_beam/runners/portability/stager_test.py
index 22a41e592c2..233e0c3dcea 100644
--- a/sdks/python/apache_beam/runners/portability/stager_test.py
+++ b/sdks/python/apache_beam/runners/portability/stager_test.py
@@ -200,7 +200,7 @@ class StagerTest(unittest.TestCase):
   # (https://github.com/apache/beam/issues/21457): Remove the decorator once
   # cloudpickle is default pickle library
   @pytest.mark.no_xdist
-  def test_main_session_not_staged_when_using_cloudpickle(self):
+  def test_main_session_staged_when_using_cloudpickle(self):
     staging_dir = self.make_temp_dir()
     options = PipelineOptions()
 
@@ -209,7 +209,10 @@ class StagerTest(unittest.TestCase):
     # session is saved when pickle_library==cloudpickle.
     options.view_as(SetupOptions).pickle_library = pickler.USE_CLOUDPICKLE
     self.update_options(options)
-    self.assertEqual([stager.SUBMISSION_ENV_DEPENDENCIES_FILE],
+    self.assertEqual([
+        names.PICKLED_MAIN_SESSION_FILE,
+        stager.SUBMISSION_ENV_DEPENDENCIES_FILE
+    ],
                      self.stager.create_and_stage_job_resources(
                          options, staging_location=staging_dir)[1])
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index cdb807e8dbc..e4dd6cc2121 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -113,26 +113,25 @@ def create_harness(environment, dry_run=False):
   _LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
   _worker_id = environment.get('WORKER_ID', None)
 
-  if pickle_library != pickler.USE_CLOUDPICKLE:
-    try:
-      _load_main_session(semi_persistent_directory)
-    except LoadMainSessionException:
-      exception_details = traceback.format_exc()
-      _LOGGER.error(
-          'Could not load main session: %s', exception_details, exc_info=True)
-      raise
-    except Exception:  # pylint: disable=broad-except
-      summary = (
-          "Could not load main session. Inspect which external dependencies "
-          "are used in the main module of your pipeline. Verify that "
-          "corresponding packages are installed in the pipeline runtime "
-          "environment and their installed versions match the versions used in 
"
-          "pipeline submission environment. For more information, see: 
https://";
-          "beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
-      _LOGGER.error(summary, exc_info=True)
-      exception_details = traceback.format_exc()
-      deferred_exception = LoadMainSessionException(
-          f"{summary} {exception_details}")
+  try:
+    _load_main_session(semi_persistent_directory)
+  except LoadMainSessionException:
+    exception_details = traceback.format_exc()
+    _LOGGER.error(
+        'Could not load main session: %s', exception_details, exc_info=True)
+    raise
+  except Exception:  # pylint: disable=broad-except
+    summary = (
+        "Could not load main session. Inspect which external dependencies "
+        "are used in the main module of your pipeline. Verify that "
+        "corresponding packages are installed in the pipeline runtime "
+        "environment and their installed versions match the versions used in "
+        "pipeline submission environment. For more information, see: https://";
+        "beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
+    _LOGGER.error(summary, exc_info=True)
+    exception_details = traceback.format_exc()
+    deferred_exception = LoadMainSessionException(
+        f"{summary} {exception_details}")
 
   _LOGGER.info(
       'Pipeline_options: %s',
@@ -356,6 +355,14 @@ class LoadMainSessionException(Exception):
 
 def _load_main_session(semi_persistent_directory):
   """Loads a pickled main session from the path specified."""
+  if pickler.is_currently_dill():
+    warn_msg = ' Functions defined in __main__ (interactive session) may fail.'
+    err_msg = ' Functions defined in __main__ (interactive session) will ' \
+      'almost certainly fail.'
+  elif pickler.is_currently_cloudpickle():
+    warn_msg = ' User registered objects (e.g. schema, logical type) through' \
+        'registeries may not be effective'
+    err_msg = ''
   if semi_persistent_directory:
     session_file = os.path.join(
         semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
@@ -365,21 +372,18 @@ def _load_main_session(semi_persistent_directory):
       # This can happen if the worker fails to download the main session.
       # Raise a fatal error and crash this worker, forcing a restart.
       if os.path.getsize(session_file) == 0:
-        # Potenitally transient error, unclear if still happening.
-        raise LoadMainSessionException(
-            'Session file found, but empty: %s. Functions defined in __main__ '
-            '(interactive session) will almost certainly fail.' %
-            (session_file, ))
-      pickler.load_session(session_file)
+        if pickler.is_currently_dill():
+          # Potenitally transient error, unclear if still happening.
+          raise LoadMainSessionException(
+              'Session file found, but empty: %s.%s' % (session_file, err_msg))
+        else:
+          _LOGGER.warning('Empty session file: %s.%s', warn_msg, session_file)
+      else:
+        pickler.load_session(session_file)
     else:
-      _LOGGER.warning(
-          'No session file found: %s. Functions defined in __main__ '
-          '(interactive session) may fail.',
-          session_file)
+      _LOGGER.warning('No session file found: %s.%s', warn_msg, session_file)
   else:
-    _LOGGER.warning(
-        'No semi_persistent_directory found: Functions defined in __main__ '
-        '(interactive session) may fail.')
+    _LOGGER.warning('No semi_persistent_directory found: %s', warn_msg)
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index c21dde426fc..e9674fa5bc2 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -684,12 +684,17 @@ class LogicalTypeRegistry(object):
     self.by_urn = {}
     self.by_logical_type = {}
     self.by_language_type = {}
+    self._custom_urns = set()
 
-  def add(self, urn, logical_type):
+  def _add_internal(self, urn, logical_type):
     self.by_urn[urn] = logical_type
     self.by_logical_type[logical_type] = urn
     self.by_language_type[logical_type.language_type()] = logical_type
 
+  def add(self, urn, logical_type):
+    self._add_internal(urn, logical_type)
+    self._custom_urns.add(urn)
+
   def get_logical_type_by_urn(self, urn):
     return self.by_urn.get(urn, None)
 
@@ -704,8 +709,25 @@ class LogicalTypeRegistry(object):
     copy.by_urn.update(self.by_urn)
     copy.by_logical_type.update(self.by_logical_type)
     copy.by_language_type.update(self.by_language_type)
+    copy._custom_urns.update(self._custom_urns)
     return copy
 
+  def copy_custom(self):
+    copy = LogicalTypeRegistry()
+    for urn in self._custom_urns:
+      logical_type = self.by_urn[urn]
+      copy.by_urn[urn] = logical_type
+      copy.by_logical_type[logical_type] = urn
+      copy.by_language_type[logical_type.language_type()] = logical_type
+      copy._custom_urns.add(urn)
+    return copy
+
+  def load(self, another):
+    self.by_urn.update(another.by_urn)
+    self.by_logical_type.update(another.by_logical_type)
+    self.by_language_type.update(another.by_language_type)
+    self._custom_urns.update(another._custom_urns)
+
 
 LanguageT = TypeVar('LanguageT')
 RepresentationT = TypeVar('RepresentationT')
@@ -768,6 +790,19 @@ class LogicalType(Generic[LanguageT, RepresentationT, 
ArgT]):
     """Convert an instance of RepresentationT to LanguageT."""
     raise NotImplementedError()
 
+  @classmethod
+  def _register_internal(cls, logical_type_cls):
+    """
+    Register an implementation of LogicalType.
+
+    The types registered using this decorator are not pickled on pipeline
+    submission, as it relies module import to be registered on worker
+    initialization. Should be used within schemas module and static context.
+    """
+    cls._known_logical_types._add_internal(
+        logical_type_cls.urn(), logical_type_cls)
+    return logical_type_cls
+
   @classmethod
   def register_logical_type(cls, logical_type_cls):
     """Register an implementation of LogicalType."""
@@ -884,7 +919,7 @@ MicrosInstantRepresentation = NamedTuple(
                                     ('micros', np.int64)])
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class MillisInstant(NoArgumentLogicalType[Timestamp, np.int64]):
   """Millisecond-precision instant logical type handles values consistent with
   that encoded by ``InstantCoder`` in the Java SDK.
@@ -928,7 +963,7 @@ class MillisInstant(NoArgumentLogicalType[Timestamp, 
np.int64]):
 # Make sure MicrosInstant is registered after MillisInstant so that it
 # overwrites the mapping of Timestamp language type representation choice and
 # thus does not lose microsecond precision inside python sdk.
[email protected]_logical_type
+@LogicalType._register_internal
 class MicrosInstant(NoArgumentLogicalType[Timestamp,
                                           MicrosInstantRepresentation]):
   """Microsecond-precision instant logical type that handles ``Timestamp``."""
@@ -955,7 +990,7 @@ class MicrosInstant(NoArgumentLogicalType[Timestamp,
     return Timestamp(seconds=int(value.seconds), micros=int(value.micros))
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]):
   """A logical type for PythonCallableSource objects."""
   @classmethod
@@ -1011,7 +1046,7 @@ class 
DecimalLogicalType(NoArgumentLogicalType[decimal.Decimal, bytes]):
     return decimal.Decimal(value.decode())
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class FixedPrecisionDecimalLogicalType(
     LogicalType[decimal.Decimal,
                 DecimalLogicalType,
@@ -1063,10 +1098,10 @@ class FixedPrecisionDecimalLogicalType(
 
 # TODO(yathu,BEAM-10722): Investigate and resolve conflicts in logical type
 # registration when more than one logical types sharing the same language type
-LogicalType.register_logical_type(DecimalLogicalType)
+LogicalType._register_internal(DecimalLogicalType)
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class FixedBytes(PassThroughLogicalType[bytes, np.int32]):
   """A logical type for fixed-length bytes."""
   @classmethod
@@ -1099,7 +1134,7 @@ class FixedBytes(PassThroughLogicalType[bytes, np.int32]):
     return self.length
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class VariableBytes(PassThroughLogicalType[bytes, np.int32]):
   """A logical type for variable-length bytes with specified maximum length."""
   @classmethod
@@ -1129,7 +1164,7 @@ class VariableBytes(PassThroughLogicalType[bytes, 
np.int32]):
     return self.max_length
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class FixedString(PassThroughLogicalType[str, np.int32]):
   """A logical type for fixed-length string."""
   @classmethod
@@ -1162,7 +1197,7 @@ class FixedString(PassThroughLogicalType[str, np.int32]):
     return self.length
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class VariableString(PassThroughLogicalType[str, np.int32]):
   """A logical type for variable-length string with specified maximum 
length."""
   @classmethod
@@ -1195,7 +1230,7 @@ class VariableString(PassThroughLogicalType[str, 
np.int32]):
 # TODO: A temporary fix for missing jdbc logical types.
 # See the discussion in https://github.com/apache/beam/issues/35738 for
 # more detail.
[email protected]_logical_type
+@LogicalType._register_internal
 class JdbcDateType(LogicalType[datetime.date, MillisInstant, str]):
   """
   For internal use only; no backwards-compatibility guarantees.
@@ -1238,7 +1273,7 @@ class JdbcDateType(LogicalType[datetime.date, 
MillisInstant, str]):
     return cls()
 
 
[email protected]_logical_type
+@LogicalType._register_internal
 class JdbcTimeType(LogicalType[datetime.time, MillisInstant, str]):
   """
   For internal use only; no backwards-compatibility guarantees.

Reply via email to