This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f22d2bc1f16 Save Logical Type and Coder Registry on cloudpickle save
main session (#36271)
f22d2bc1f16 is described below
commit f22d2bc1f1638a19639483ef18b56b4cf9fb8504
Author: Yi Hu <[email protected]>
AuthorDate: Fri Dec 5 16:16:42 2025 -0500
Save Logical Type and Coder Registry on cloudpickle save main session
(#36271)
* Save Logical Type Registry and Coder Registry on cloudpickle save main
session
fix naming
* Track custom_urn set in logical type registry
* Fix, add tests
* Set save_main_session default to true for cloudpickle and introduce
overwrite flag
* Fix test; trigger postcommits
* Fix test as Dataflow runner submission now staging a main session file
---
.github/trigger_files/beam_PostCommit_Python.json | 1 +
.../beam_PostCommit_Python_Xlang_Gcp_Dataflow.json | 2 +-
.../beam_PostCommit_Python_Xlang_IO_Dataflow.json | 2 +-
CHANGES.md | 1 +
sdks/python/apache_beam/coders/typecoders.py | 25 +++++---
.../apache_beam/internal/cloudpickle_pickler.py | 35 +++++++++--
.../internal/cloudpickle_pickler_test.py | 20 +++++++
sdks/python/apache_beam/internal/pickler.py | 11 +++-
.../python/apache_beam/options/pipeline_options.py | 29 ++++++++-
.../runners/dataflow/dataflow_runner_test.py | 63 ++++++++++++-------
.../apache_beam/runners/portability/stager.py | 1 -
.../apache_beam/runners/portability/stager_test.py | 7 ++-
.../apache_beam/runners/worker/sdk_worker_main.py | 70 ++++++++++++----------
sdks/python/apache_beam/typehints/schemas.py | 59 ++++++++++++++----
14 files changed, 239 insertions(+), 87 deletions(-)
diff --git a/.github/trigger_files/beam_PostCommit_Python.json
b/.github/trigger_files/beam_PostCommit_Python.json
index c6ec17f4841..d6818d275f1 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,5 +1,6 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run.",
+ "pr": "36271",
"modification": 35
}
diff --git
a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
index 99a8fc8ff6d..b60f5c4cc3c 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run",
- "modification": 14
+ "modification": 0
}
diff --git
a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
index f1ba03a243e..b60f5c4cc3c 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run",
- "modification": 5
+ "modification": 0
}
diff --git a/CHANGES.md b/CHANGES.md
index bab7182539e..dfad320a694 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -85,6 +85,7 @@
## Bugfixes
* Fixed FirestoreV1 Beam connectors allow configuring inconsistent
project/database IDs between RPC requests and routing headers #36895 (Java)
([#36895](https://github.com/apache/beam/issues/36895)).
+ Logical type and coder registry are saved for pipelines in the case of
default pickler. This fixes a side effect of switching to cloudpickle as
default pickler in Beam 2.65.0 (Python)
([#35738](https://github.com/apache/beam/issues/35738)).
## Known Issues
diff --git a/sdks/python/apache_beam/coders/typecoders.py
b/sdks/python/apache_beam/coders/typecoders.py
index 779c65dc772..ef75a21ce9e 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -114,6 +114,14 @@ class CoderRegistry(object):
typehint_coder_class: Type[coders.Coder]) -> None:
self._coders[typehint_type] = typehint_coder_class
+ @staticmethod
+ def _normalize_typehint_type(typehint_type):
+ if typehint_type.__module__ == '__main__':
+ # See https://github.com/apache/beam/issues/21541
+ # TODO(robertwb): Remove once all runners are portable.
+ return getattr(typehint_type, '__name__', str(typehint_type))
+ return typehint_type
+
def register_coder(
self, typehint_type: Any,
typehint_coder_class: Type[coders.Coder]) -> None:
@@ -123,11 +131,8 @@ class CoderRegistry(object):
'Received %r instead.' % typehint_coder_class)
if typehint_type not in self.custom_types:
self.custom_types.append(typehint_type)
- if typehint_type.__module__ == '__main__':
- # See https://github.com/apache/beam/issues/21541
- # TODO(robertwb): Remove once all runners are portable.
- typehint_type = getattr(typehint_type, '__name__', str(typehint_type))
- self._register_coder_internal(typehint_type, typehint_coder_class)
+ self._register_coder_internal(
+ self._normalize_typehint_type(typehint_type), typehint_coder_class)
def get_coder(self, typehint: Any) -> coders.Coder:
if typehint and typehint.__module__ == '__main__':
@@ -170,9 +175,15 @@ class CoderRegistry(object):
coder = self._fallback_coder
return coder.from_type_hint(typehint, self)
- def get_custom_type_coder_tuples(self, types):
+ def get_custom_type_coder_tuples(self, types=None):
"""Returns type/coder tuples for all custom types passed in."""
- return [(t, self._coders[t]) for t in types if t in self.custom_types]
+ return [(t, self._coders[self._normalize_typehint_type(t)])
+ for t in self.custom_types if (types is None or t in types)]
+
+ def load_custom_type_coder_tuples(self, type_coder):
+ """Load type/coder tuples into coder registry."""
+ for t, c in type_coder:
+ self.register_coder(t, c)
def verify_deterministic(self, key_coder, op_name, silent=True):
if not key_coder.is_deterministic():
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
index 199294f1731..acdcc46cd40 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
@@ -252,12 +252,35 @@ def _lock_reducer(obj):
def dump_session(file_path):
- # It is possible to dump session with cloudpickle. However, since references
- # are saved it should not be necessary. See
https://s.apache.org/beam-picklers
- pass
+ # Since References are saved (https://s.apache.org/beam-picklers), we only
+ # dump supported Beam Registries (currently only logical type registry)
+ from apache_beam.coders import typecoders
+ from apache_beam.typehints import schemas
+
+ with _pickle_lock, open(file_path, 'wb') as file:
+ coder_reg = typecoders.registry.get_custom_type_coder_tuples()
+ logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
+
+ pickler = cloudpickle.CloudPickler(file)
+ # TODO(https://github.com/apache/beam/issues/18500) add file system
registry
+ # once implemented
+ pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg})
def load_session(file_path):
- # It is possible to load_session with cloudpickle. However, since references
- # are saved it should not be necessary. See
https://s.apache.org/beam-picklers
- pass
+ from apache_beam.coders import typecoders
+ from apache_beam.typehints import schemas
+
+ with _pickle_lock, open(file_path, 'rb') as file:
+ registries = cloudpickle.load(file)
+ if type(registries) != dict:
+ raise ValueError(
+ "Faled loading session: expected dict, got {}", type(registries))
+ if "coder" in registries:
+ typecoders.registry.load_custom_type_coder_tuples(registries["coder"])
+ else:
+ _LOGGER.warning('No coder registry found in saved session')
+ if "logical_type" in registries:
+ schemas.LogicalType._known_logical_types.load(registries["logical_type"])
+ else:
+ _LOGGER.warning('No logical type registry found in saved session')
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
index 4a51c56c24b..99fbb03ac2e 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
@@ -20,6 +20,7 @@
# pytype: skip-file
import os
+import tempfile
import threading
import types
import unittest
@@ -31,6 +32,7 @@ from apache_beam.internal import code_object_pickler
from apache_beam.internal import module_test
from apache_beam.internal.cloudpickle_pickler import dumps
from apache_beam.internal.cloudpickle_pickler import loads
+from apache_beam.typehints.schemas import LogicalTypeRegistry
from apache_beam.utils import shared
GLOBAL_DICT_REF = module_test.GLOBAL_DICT
@@ -244,6 +246,24 @@ self.assertEqual(DataClass(datum='abc'),
loads(dumps(DataClass(datum='abc'))))
unpickled_filename = os.path.abspath(unpickled_code.co_filename)
self.assertEqual(unpickled_filename, original_filename)
+ @mock.patch(
+ "apache_beam.coders.typecoders.registry.load_custom_type_coder_tuples")
+ @mock.patch(
+ "apache_beam.typehints.schemas.LogicalType._known_logical_types.load")
+ def test_dump_load_session(self, logicaltype_mock, coder_mock):
+ session_file = 'pickled'
+
+ with tempfile.TemporaryDirectory() as tmp_dirname:
+ pickled_session_file = os.path.join(tmp_dirname, session_file)
+ beam_cloudpickle.dump_session(pickled_session_file)
+ beam_cloudpickle.load_session(pickled_session_file)
+ load_logical_types = logicaltype_mock.call_args.args
+ load_coders = coder_mock.call_args.args
+ self.assertEqual(len(load_logical_types), 1)
+ self.assertEqual(len(load_coders), 1)
+ self.assertTrue(isinstance(load_logical_types[0], LogicalTypeRegistry))
+ self.assertTrue(isinstance(load_coders[0], list))
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/internal/pickler.py
b/sdks/python/apache_beam/internal/pickler.py
index 0af3b16ec05..3626b599a5c 100644
--- a/sdks/python/apache_beam/internal/pickler.py
+++ b/sdks/python/apache_beam/internal/pickler.py
@@ -91,6 +91,14 @@ def load_session(file_path):
return desired_pickle_lib.load_session(file_path)
+def is_currently_dill():
+ return desired_pickle_lib == dill_pickler
+
+
+def is_currently_cloudpickle():
+ return desired_pickle_lib == cloudpickle_pickler
+
+
def set_library(selected_library=DEFAULT_PICKLE_LIB):
""" Sets pickle library that will be used. """
global desired_pickle_lib
@@ -108,12 +116,11 @@ def set_library(selected_library=DEFAULT_PICKLE_LIB):
"Pipeline option pickle_library=dill_unsafe is set, but dill is not "
"installed. Install dill in job submission and runtime environments.")
- is_currently_dill = (desired_pickle_lib == dill_pickler)
dill_is_requested = (
selected_library == USE_DILL or selected_library == USE_DILL_UNSAFE)
# If switching to or from dill, update the pickler hook overrides.
- if is_currently_dill != dill_is_requested:
+ if is_currently_dill() != dill_is_requested:
dill_pickler.override_pickler_hooks(selected_library == USE_DILL)
if dill_is_requested:
diff --git a/sdks/python/apache_beam/options/pipeline_options.py
b/sdks/python/apache_beam/options/pipeline_options.py
index f2addf6f9d5..170ade224c1 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -64,7 +64,10 @@ _LOGGER = logging.getLogger(__name__)
# Map defined with option names to flag names for boolean options
# that have a destination(dest) in parser.add_argument() different
# from the flag name and whose default value is `None`.
-_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
+_FLAG_THAT_SETS_FALSE_VALUE = {
+ 'use_public_ips': 'no_use_public_ips',
+ 'save_main_session': 'no_save_main_session'
+}
# Set of options which should not be overriden when applying options from a
# different language. This is relevant when using x-lang transforms where the
# expansion service is started up with some pipeline options, and will
@@ -1672,7 +1675,7 @@ class SetupOptions(PipelineOptions):
choices=['cloudpickle', 'default', 'dill', 'dill_unsafe'])
parser.add_argument(
'--save_main_session',
- default=False,
+ default=None,
action='store_true',
help=(
'Save the main session state so that pickled functions and classes
'
@@ -1680,6 +1683,15 @@ class SetupOptions(PipelineOptions):
'Some workflows do not need the session state if for instance all '
'their functions/classes are defined in proper modules '
'(not __main__) and the modules are importable in the worker. '))
+ parser.add_argument(
+ '--no_save_main_session',
+ default=None,
+ action='store_false',
+ dest='save_main_session',
+ help=(
+ 'Disable saving the main session state. It is enabled/disabled by'
+ 'default for cloudpickle/dill pickler. See "save_main_session".'))
+
parser.add_argument(
'--sdk_location',
default='default',
@@ -1780,10 +1792,23 @@ class SetupOptions(PipelineOptions):
'If not specified, the default Maven Central repository will be '
'used.'))
+ def _handle_load_main_session(self, validator):
+ save_main_session = getattr(self, 'save_main_session')
+ if save_main_session is None:
+ # save_main_session default to False for dill, while default to true
+ # for cloudpickle
+ pickle_library = getattr(self, 'pickle_library')
+ if pickle_library in ['default', 'cloudpickle']:
+ setattr(self, 'save_main_session', True)
+ else:
+ setattr(self, 'save_main_session', False)
+ return []
+
def validate(self, validator):
errors = []
errors.extend(validator.validate_container_prebuilding_options(self))
errors.extend(validator.validate_pickle_library(self))
+ errors.extend(self._handle_load_main_session(validator))
return errors
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index 178a75ec41d..d5d8ba662f0 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -42,6 +42,7 @@ from apache_beam.runners.dataflow.dataflow_runner import
DataflowRuntimeExceptio
from apache_beam.runners.dataflow.dataflow_runner import
_check_and_add_missing_options
from apache_beam.runners.dataflow.dataflow_runner import
_check_and_add_missing_streaming_options
from apache_beam.runners.dataflow.internal.clients import dataflow as
dataflow_api
+from apache_beam.runners.internal import names
from apache_beam.runners.runner import PipelineState
from apache_beam.testing.extra_assertions import ExtraAssertionsMixin
from apache_beam.testing.test_pipeline import TestPipeline
@@ -243,6 +244,18 @@ class DataflowRunnerTest(unittest.TestCase,
ExtraAssertionsMixin):
self.assertTrue(
isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner))
+ @staticmethod
+ def dependency_proto_from_main_session_file(serialized_path):
+ return [
+ beam_runner_api_pb2.ArtifactInformation(
+ type_urn=common_urns.artifact_types.FILE.urn,
+ type_payload=serialized_path,
+ role_urn=common_urns.artifact_roles.STAGING_TO.urn,
+ role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
+ staged_name=names.PICKLED_MAIN_SESSION_FILE).SerializeToString(
+ ))
+ ]
+
def test_environment_override_translation_legacy_worker_harness_image(self):
self.default_properties.append('--experiments=beam_fn_api')
self.default_properties.append('--worker_harness_container_image=LEGACY')
@@ -256,17 +269,22 @@ class DataflowRunnerTest(unittest.TestCase,
ExtraAssertionsMixin):
| 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
| ptransform.GroupByKey())
+ actual =
list(remote_runner.proto_pipeline.components.environments.values())
+ self.assertEqual(len(actual), 1)
+ actual = actual[0]
+ file_path = actual.dependencies[0].type_payload
+ # Dependency payload contains main_session from a transient temp directory
+ # Use actual for expected value.
+ main_session_dep = self.dependency_proto_from_main_session_file(file_path)
self.assertEqual(
- list(remote_runner.proto_pipeline.components.environments.values()),
- [
- beam_runner_api_pb2.Environment(
- urn=common_urns.environments.DOCKER.urn,
- payload=beam_runner_api_pb2.DockerPayload(
- container_image='LEGACY').SerializeToString(),
- capabilities=environments.python_sdk_docker_capabilities(),
- dependencies=environments.python_sdk_dependencies(
- options=options))
- ])
+ actual,
+ beam_runner_api_pb2.Environment(
+ urn=common_urns.environments.DOCKER.urn,
+ payload=beam_runner_api_pb2.DockerPayload(
+ container_image='LEGACY').SerializeToString(),
+ capabilities=environments.python_sdk_docker_capabilities(),
+ dependencies=environments.python_sdk_dependencies(options=options)
+
+ main_session_dep))
def test_environment_override_translation_sdk_container_image(self):
self.default_properties.append('--experiments=beam_fn_api')
@@ -281,17 +299,22 @@ class DataflowRunnerTest(unittest.TestCase,
ExtraAssertionsMixin):
| 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
| ptransform.GroupByKey())
+ actual =
list(remote_runner.proto_pipeline.components.environments.values())
+ self.assertEqual(len(actual), 1)
+ actual = actual[0]
+ file_path = actual.dependencies[0].type_payload
+ # Dependency payload contains main_session from a transient temp directory
+ # Use actual for expected value.
+ main_session_dep = self.dependency_proto_from_main_session_file(file_path)
self.assertEqual(
- list(remote_runner.proto_pipeline.components.environments.values()),
- [
- beam_runner_api_pb2.Environment(
- urn=common_urns.environments.DOCKER.urn,
- payload=beam_runner_api_pb2.DockerPayload(
- container_image='FOO').SerializeToString(),
- capabilities=environments.python_sdk_docker_capabilities(),
- dependencies=environments.python_sdk_dependencies(
- options=options))
- ])
+ actual,
+ beam_runner_api_pb2.Environment(
+ urn=common_urns.environments.DOCKER.urn,
+ payload=beam_runner_api_pb2.DockerPayload(
+ container_image='FOO').SerializeToString(),
+ capabilities=environments.python_sdk_docker_capabilities(),
+ dependencies=environments.python_sdk_dependencies(options=options)
+
+ main_session_dep))
def test_remote_runner_translation(self):
remote_runner = DataflowRunner()
diff --git a/sdks/python/apache_beam/runners/portability/stager.py
b/sdks/python/apache_beam/runners/portability/stager.py
index 9147410c246..aa03082f0d5 100644
--- a/sdks/python/apache_beam/runners/portability/stager.py
+++ b/sdks/python/apache_beam/runners/portability/stager.py
@@ -376,7 +376,6 @@ class Stager(object):
pickled_session_file = os.path.join(
temp_dir, names.PICKLED_MAIN_SESSION_FILE)
pickler.dump_session(pickled_session_file)
- # for pickle_library: cloudpickle, dump_session is no op
if os.path.exists(pickled_session_file):
resources.append(
Stager._create_file_stage_to_artifact(
diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py
b/sdks/python/apache_beam/runners/portability/stager_test.py
index 22a41e592c2..233e0c3dcea 100644
--- a/sdks/python/apache_beam/runners/portability/stager_test.py
+++ b/sdks/python/apache_beam/runners/portability/stager_test.py
@@ -200,7 +200,7 @@ class StagerTest(unittest.TestCase):
# (https://github.com/apache/beam/issues/21457): Remove the decorator once
# cloudpickle is default pickle library
@pytest.mark.no_xdist
- def test_main_session_not_staged_when_using_cloudpickle(self):
+ def test_main_session_staged_when_using_cloudpickle(self):
staging_dir = self.make_temp_dir()
options = PipelineOptions()
@@ -209,7 +209,10 @@ class StagerTest(unittest.TestCase):
# session is saved when pickle_library==cloudpickle.
options.view_as(SetupOptions).pickle_library = pickler.USE_CLOUDPICKLE
self.update_options(options)
- self.assertEqual([stager.SUBMISSION_ENV_DEPENDENCIES_FILE],
+ self.assertEqual([
+ names.PICKLED_MAIN_SESSION_FILE,
+ stager.SUBMISSION_ENV_DEPENDENCIES_FILE
+ ],
self.stager.create_and_stage_job_resources(
options, staging_location=staging_dir)[1])
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index cdb807e8dbc..e4dd6cc2121 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -113,26 +113,25 @@ def create_harness(environment, dry_run=False):
_LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
_worker_id = environment.get('WORKER_ID', None)
- if pickle_library != pickler.USE_CLOUDPICKLE:
- try:
- _load_main_session(semi_persistent_directory)
- except LoadMainSessionException:
- exception_details = traceback.format_exc()
- _LOGGER.error(
- 'Could not load main session: %s', exception_details, exc_info=True)
- raise
- except Exception: # pylint: disable=broad-except
- summary = (
- "Could not load main session. Inspect which external dependencies "
- "are used in the main module of your pipeline. Verify that "
- "corresponding packages are installed in the pipeline runtime "
- "environment and their installed versions match the versions used in
"
- "pipeline submission environment. For more information, see:
https://"
- "beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
- _LOGGER.error(summary, exc_info=True)
- exception_details = traceback.format_exc()
- deferred_exception = LoadMainSessionException(
- f"{summary} {exception_details}")
+ try:
+ _load_main_session(semi_persistent_directory)
+ except LoadMainSessionException:
+ exception_details = traceback.format_exc()
+ _LOGGER.error(
+ 'Could not load main session: %s', exception_details, exc_info=True)
+ raise
+ except Exception: # pylint: disable=broad-except
+ summary = (
+ "Could not load main session. Inspect which external dependencies "
+ "are used in the main module of your pipeline. Verify that "
+ "corresponding packages are installed in the pipeline runtime "
+ "environment and their installed versions match the versions used in "
+ "pipeline submission environment. For more information, see: https://"
+ "beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
+ _LOGGER.error(summary, exc_info=True)
+ exception_details = traceback.format_exc()
+ deferred_exception = LoadMainSessionException(
+ f"{summary} {exception_details}")
_LOGGER.info(
'Pipeline_options: %s',
@@ -356,6 +355,14 @@ class LoadMainSessionException(Exception):
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
+ if pickler.is_currently_dill():
+ warn_msg = ' Functions defined in __main__ (interactive session) may fail.'
+ err_msg = ' Functions defined in __main__ (interactive session) will ' \
+ 'almost certainly fail.'
+ elif pickler.is_currently_cloudpickle():
+ warn_msg = ' User registered objects (e.g. schema, logical type) through' \
+ 'registeries may not be effective'
+ err_msg = ''
if semi_persistent_directory:
session_file = os.path.join(
semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
@@ -365,21 +372,18 @@ def _load_main_session(semi_persistent_directory):
# This can happen if the worker fails to download the main session.
# Raise a fatal error and crash this worker, forcing a restart.
if os.path.getsize(session_file) == 0:
- # Potenitally transient error, unclear if still happening.
- raise LoadMainSessionException(
- 'Session file found, but empty: %s. Functions defined in __main__ '
- '(interactive session) will almost certainly fail.' %
- (session_file, ))
- pickler.load_session(session_file)
+ if pickler.is_currently_dill():
+ # Potenitally transient error, unclear if still happening.
+ raise LoadMainSessionException(
+ 'Session file found, but empty: %s.%s' % (session_file, err_msg))
+ else:
+ _LOGGER.warning('Empty session file: %s.%s', warn_msg, session_file)
+ else:
+ pickler.load_session(session_file)
else:
- _LOGGER.warning(
- 'No session file found: %s. Functions defined in __main__ '
- '(interactive session) may fail.',
- session_file)
+ _LOGGER.warning('No session file found: %s.%s', warn_msg, session_file)
else:
- _LOGGER.warning(
- 'No semi_persistent_directory found: Functions defined in __main__ '
- '(interactive session) may fail.')
+ _LOGGER.warning('No semi_persistent_directory found: %s', warn_msg)
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/typehints/schemas.py
b/sdks/python/apache_beam/typehints/schemas.py
index c21dde426fc..e9674fa5bc2 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -684,12 +684,17 @@ class LogicalTypeRegistry(object):
self.by_urn = {}
self.by_logical_type = {}
self.by_language_type = {}
+ self._custom_urns = set()
- def add(self, urn, logical_type):
+ def _add_internal(self, urn, logical_type):
self.by_urn[urn] = logical_type
self.by_logical_type[logical_type] = urn
self.by_language_type[logical_type.language_type()] = logical_type
+ def add(self, urn, logical_type):
+ self._add_internal(urn, logical_type)
+ self._custom_urns.add(urn)
+
def get_logical_type_by_urn(self, urn):
return self.by_urn.get(urn, None)
@@ -704,8 +709,25 @@ class LogicalTypeRegistry(object):
copy.by_urn.update(self.by_urn)
copy.by_logical_type.update(self.by_logical_type)
copy.by_language_type.update(self.by_language_type)
+ copy._custom_urns.update(self._custom_urns)
return copy
+ def copy_custom(self):
+ copy = LogicalTypeRegistry()
+ for urn in self._custom_urns:
+ logical_type = self.by_urn[urn]
+ copy.by_urn[urn] = logical_type
+ copy.by_logical_type[logical_type] = urn
+ copy.by_language_type[logical_type.language_type()] = logical_type
+ copy._custom_urns.add(urn)
+ return copy
+
+ def load(self, another):
+ self.by_urn.update(another.by_urn)
+ self.by_logical_type.update(another.by_logical_type)
+ self.by_language_type.update(another.by_language_type)
+ self._custom_urns.update(another._custom_urns)
+
LanguageT = TypeVar('LanguageT')
RepresentationT = TypeVar('RepresentationT')
@@ -768,6 +790,19 @@ class LogicalType(Generic[LanguageT, RepresentationT,
ArgT]):
"""Convert an instance of RepresentationT to LanguageT."""
raise NotImplementedError()
+ @classmethod
+ def _register_internal(cls, logical_type_cls):
+ """
+ Register an implementation of LogicalType.
+
+ The types registered using this decorator are not pickled on pipeline
+ submission, as it relies module import to be registered on worker
+ initialization. Should be used within schemas module and static context.
+ """
+ cls._known_logical_types._add_internal(
+ logical_type_cls.urn(), logical_type_cls)
+ return logical_type_cls
+
@classmethod
def register_logical_type(cls, logical_type_cls):
"""Register an implementation of LogicalType."""
@@ -884,7 +919,7 @@ MicrosInstantRepresentation = NamedTuple(
('micros', np.int64)])
[email protected]_logical_type
+@LogicalType._register_internal
class MillisInstant(NoArgumentLogicalType[Timestamp, np.int64]):
"""Millisecond-precision instant logical type handles values consistent with
that encoded by ``InstantCoder`` in the Java SDK.
@@ -928,7 +963,7 @@ class MillisInstant(NoArgumentLogicalType[Timestamp,
np.int64]):
# Make sure MicrosInstant is registered after MillisInstant so that it
# overwrites the mapping of Timestamp language type representation choice and
# thus does not lose microsecond precision inside python sdk.
[email protected]_logical_type
+@LogicalType._register_internal
class MicrosInstant(NoArgumentLogicalType[Timestamp,
MicrosInstantRepresentation]):
"""Microsecond-precision instant logical type that handles ``Timestamp``."""
@@ -955,7 +990,7 @@ class MicrosInstant(NoArgumentLogicalType[Timestamp,
return Timestamp(seconds=int(value.seconds), micros=int(value.micros))
[email protected]_logical_type
+@LogicalType._register_internal
class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]):
"""A logical type for PythonCallableSource objects."""
@classmethod
@@ -1011,7 +1046,7 @@ class
DecimalLogicalType(NoArgumentLogicalType[decimal.Decimal, bytes]):
return decimal.Decimal(value.decode())
[email protected]_logical_type
+@LogicalType._register_internal
class FixedPrecisionDecimalLogicalType(
LogicalType[decimal.Decimal,
DecimalLogicalType,
@@ -1063,10 +1098,10 @@ class FixedPrecisionDecimalLogicalType(
# TODO(yathu,BEAM-10722): Investigate and resolve conflicts in logical type
# registration when more than one logical types sharing the same language type
-LogicalType.register_logical_type(DecimalLogicalType)
+LogicalType._register_internal(DecimalLogicalType)
[email protected]_logical_type
+@LogicalType._register_internal
class FixedBytes(PassThroughLogicalType[bytes, np.int32]):
"""A logical type for fixed-length bytes."""
@classmethod
@@ -1099,7 +1134,7 @@ class FixedBytes(PassThroughLogicalType[bytes, np.int32]):
return self.length
[email protected]_logical_type
+@LogicalType._register_internal
class VariableBytes(PassThroughLogicalType[bytes, np.int32]):
"""A logical type for variable-length bytes with specified maximum length."""
@classmethod
@@ -1129,7 +1164,7 @@ class VariableBytes(PassThroughLogicalType[bytes,
np.int32]):
return self.max_length
[email protected]_logical_type
+@LogicalType._register_internal
class FixedString(PassThroughLogicalType[str, np.int32]):
"""A logical type for fixed-length string."""
@classmethod
@@ -1162,7 +1197,7 @@ class FixedString(PassThroughLogicalType[str, np.int32]):
return self.length
[email protected]_logical_type
+@LogicalType._register_internal
class VariableString(PassThroughLogicalType[str, np.int32]):
"""A logical type for variable-length string with specified maximum
length."""
@classmethod
@@ -1195,7 +1230,7 @@ class VariableString(PassThroughLogicalType[str,
np.int32]):
# TODO: A temporary fix for missing jdbc logical types.
# See the discussion in https://github.com/apache/beam/issues/35738 for
# more detail.
[email protected]_logical_type
+@LogicalType._register_internal
class JdbcDateType(LogicalType[datetime.date, MillisInstant, str]):
"""
For internal use only; no backwards-compatibility guarantees.
@@ -1238,7 +1273,7 @@ class JdbcDateType(LogicalType[datetime.date,
MillisInstant, str]):
return cls()
[email protected]_logical_type
+@LogicalType._register_internal
class JdbcTimeType(LogicalType[datetime.time, MillisInstant, str]):
"""
For internal use only; no backwards-compatibility guarantees.