This is an automated email from the ASF dual-hosted git repository.
chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f3e0f6d4962 Updates YAML SDK to replace Kafka read/write transforms
with equivalent managed transforms (#34755)
f3e0f6d4962 is described below
commit f3e0f6d496218a2025380cdef583f259685b82cf
Author: Chamikara Jayalath <[email protected]>
AuthorDate: Wed Apr 30 12:03:57 2025 -0700
Updates YAML SDK to replace Kafka read/write transforms with equivalent
managed transforms (#34755)
* Updates YAML SDK to replace Kafka read/write transforms with equivalent
managed transforms
* Addressing reviewer comments and adding unit tests
* Resolves conflict
* Fixes a test failure
* Fix lint
* Fixes a test
---
sdks/python/apache_beam/transforms/external.py | 79 +++++++++++++++++++++-
.../python/apache_beam/transforms/external_test.py | 66 +++++++++++++++++-
sdks/python/apache_beam/transforms/managed.py | 25 +++----
sdks/python/apache_beam/yaml/standard_io.yaml | 6 ++
sdks/python/apache_beam/yaml/yaml_provider.py | 37 ++++++++--
sdks/python/gen_managed_doc.py | 8 ++-
6 files changed, 194 insertions(+), 27 deletions(-)
diff --git a/sdks/python/apache_beam/transforms/external.py
b/sdks/python/apache_beam/transforms/external.py
index 3fc58f04a78..9b6b4060cb7 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -31,6 +31,7 @@ from collections import OrderedDict
from collections import namedtuple
import grpc
+import yaml
from apache_beam import pvalue
from apache_beam.coders import RowCoder
@@ -42,10 +43,12 @@ from apache_beam.portability.api import
beam_expansion_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import external_transforms_pb2
from apache_beam.portability.api import schema_pb2
+from apache_beam.portability.common_urns import ManagedTransforms
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import artifact_service
from apache_beam.transforms import environments
from apache_beam.transforms import ptransform
+from apache_beam.transforms.util import is_compat_version_prior_to
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import row_type
@@ -61,6 +64,25 @@ from apache_beam.utils import transform_service_launcher
DEFAULT_EXPANSION_SERVICE = 'localhost:8097'
+MANAGED_SCHEMA_TRANSFORM_IDENTIFIER = "beam:transform:managed:v1"
+
+_IO_EXPANSION_SERVICE_JAR_TARGET = "sdks:java:io:expansion-service:shadowJar"
+
+_GCP_EXPANSION_SERVICE_JAR_TARGET = (
+ "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")
+
+# A mapping from supported managed transforms URNs to expansion service jars
+# that include the corresponding transforms.
+MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING = {
+ ManagedTransforms.Urns.ICEBERG_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+ ManagedTransforms.Urns.ICEBERG_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+ ManagedTransforms.Urns.ICEBERG_CDC_READ.urn:
_IO_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long
+ ManagedTransforms.Urns.KAFKA_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+ ManagedTransforms.Urns.KAFKA_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+ ManagedTransforms.Urns.BIGQUERY_READ.urn:
_GCP_EXPANSION_SERVICE_JAR_TARGET,
+ ManagedTransforms.Urns.BIGQUERY_WRITE.urn:
_GCP_EXPANSION_SERVICE_JAR_TARGET
+}
+
def convert_to_typing_type(type_):
if isinstance(type_, row_type.RowTypeConstraint):
@@ -378,6 +400,10 @@ SchemaTransformsConfig = namedtuple(
'SchemaTransformsConfig',
['identifier', 'configuration_schema', 'inputs', 'outputs', 'description'])
+ManagedReplacement = namedtuple(
+ 'ManagedReplacement',
+ ['underlying_transform_identifier', 'update_compatibility_version'])
+
class SchemaAwareExternalTransform(ptransform.PTransform):
"""A proxy transform for SchemaTransforms implemented in external SDKs.
@@ -396,6 +422,12 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
the configuration.
:param classpath: (Optional) A list paths to additional jars to place on the
expansion service classpath.
+ :param managed_replacement: (Optional) a 'ManagedReplacement' namedtuple that
+ defines information needed to replace the transform with an equivalent
+ managed transform during the expansion. If an
+ 'updateCompatibilityBeamVersion' pipeline option is provided, we will
+ only replace if the managed transform is update compatible with the
+ provided version.
:kwargs: field name to value mapping for configuring the schema transform.
keys map to the field names of the schema of the SchemaTransform
(in-order).
@@ -406,10 +438,14 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
expansion_service,
rearrange_based_on_discovery=False,
classpath=None,
+ managed_replacement=None,
**kwargs):
self._expansion_service = expansion_service
self._kwargs = kwargs
self._classpath = classpath
+ if managed_replacement:
+ assert isinstance(managed_replacement, ManagedReplacement)
+ self._managed_replacement = managed_replacement
_kwargs = kwargs
if rearrange_based_on_discovery:
@@ -420,16 +456,55 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
named_tuple_to_schema(config.configuration_schema),
**_kwargs)
+ if self._managed_replacement:
+ # We have to do the replacement at the expansion instead of at
+ # construction
+ # since we don't have access to the PipelineOptions object at the
+ # construction.
+ underlying_transform_id = (
+ self._managed_replacement.underlying_transform_identifier)
+ if not (underlying_transform_id in
+ MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING):
+ raise ValueError(
+ 'Could not find an expansion service jar for the managed ' +
+ 'transform ' + underlying_transform_id)
+ managed_expansion_service_jar = (
+ MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
+ )[underlying_transform_id]
+ self._managed_expansion_service = BeamJarExpansionService(
+ managed_expansion_service_jar)
+ managed_config = SchemaAwareExternalTransform.discover_config(
+ self._managed_expansion_service,
+ MANAGED_SCHEMA_TRANSFORM_IDENTIFIER)
+
+ yaml_config = yaml.dump(kwargs)
+ self._managed_payload_builder = (
+ ExplicitSchemaTransformPayloadBuilder(
+ MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
+ named_tuple_to_schema(managed_config.configuration_schema),
+ transform_identifier=underlying_transform_id,
+ config=yaml_config))
else:
self._payload_builder = SchemaTransformPayloadBuilder(
identifier, **_kwargs)
def expand(self, pcolls):
# Expand the transform using the expansion service.
+ payload_builder = self._payload_builder
+ expansion_service = self._expansion_service
+
+ if self._managed_replacement:
+ compat_version_prior_to_current = is_compat_version_prior_to(
+ pcolls.pipeline._options,
+ self._managed_replacement.update_compatibility_version)
+ if not compat_version_prior_to_current:
+ payload_builder = self._managed_payload_builder
+ expansion_service = self._managed_expansion_service
+
return pcolls | self._payload_builder.identifier() >> ExternalTransform(
common_urns.schematransform_based_expand.urn,
- self._payload_builder,
- self._expansion_service)
+ payload_builder,
+ expansion_service)
@classmethod
@functools.lru_cache
diff --git a/sdks/python/apache_beam/transforms/external_test.py
b/sdks/python/apache_beam/transforms/external_test.py
index adf44d2286c..84a7025c0a5 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -29,18 +29,21 @@ import unittest
import mock
import apache_beam as beam
+from apache_beam import ManagedReplacement
from apache_beam import Pipeline
from apache_beam.coders import RowCoder
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_expansion_api_pb2
from apache_beam.portability.api import external_transforms_pb2
from apache_beam.portability.api import schema_pb2
+from apache_beam.portability.common_urns import ManagedTransforms
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import expansion_service
from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import external
+from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
from apache_beam.transforms.external import AnnotationBasedPayloadBuilder
from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
@@ -530,8 +533,28 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
id="test-id"),
input_pcollection_names=["input"],
output_pcollection_names=["output"])
+
+ test_managed_config = beam_expansion_api_pb2.SchemaTransformConfig(
+ config_schema=schema_pb2.Schema(
+ fields=[
+ schema_pb2.Field(
+ name="transform_identifier",
+ type=schema_pb2.FieldType(atomic_type="STRING")),
+ schema_pb2.Field(
+ name="config_url",
+ type=schema_pb2.FieldType(atomic_type="STRING")),
+ schema_pb2.Field(
+ name="config",
+ type=schema_pb2.FieldType(atomic_type="STRING"))
+ ],
+ id="test-id1"),
+ input_pcollection_names=["input"],
+ output_pcollection_names=["output"])
return beam_expansion_api_pb2.DiscoverSchemaTransformResponse(
- schema_transform_configs={"test_schematransform": test_config})
+ schema_transform_configs={
+ "test_schematransform": test_config,
+ MANAGED_SCHEMA_TRANSFORM_IDENTIFIER: test_managed_config
+ })
@mock.patch("apache_beam.transforms.external.ExternalTransform.service")
def test_discover_one_config(self, mock_service):
@@ -573,6 +596,47 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
self.assertNotEqual(tuple(kwargs.keys()), external_config_fields)
self.assertEqual(tuple(ordered_fields), external_config_fields)
+ @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
+ def test_managed_replacement_unknown_id(self, mock_service):
+ mock_service.return_value = self.MockDiscoveryService()
+
+ identifier = "test_schematransform"
+ kwargs = {"int_field": 0, "str_field": "str"}
+
+ managed_replacement = ManagedReplacement(
+ underlying_transform_identifier="unknown_id",
+ update_compatibility_version="2.50.0")
+
+ with self.assertRaises(ValueError):
+ beam.SchemaAwareExternalTransform(
+ identifier=identifier,
+ expansion_service=expansion_service,
+ rearrange_based_on_discovery=True,
+ managed_replacement=managed_replacement,
+ **kwargs)
+
+ @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
+ @mock.patch("apache_beam.transforms.external.BeamJarExpansionService")
+ def test_managed_replacement_known_id(
+ self, mock_service, mock_beam_jar_service):
+ mock_service.return_value = self.MockDiscoveryService()
+ mock_beam_jar_service.return_value = self.MockDiscoveryService()
+
+ identifier = "test_schematransform"
+ kwargs = {"int_field": 0, "str_field": "str"}
+
+ managed_replacement = ManagedReplacement(
+
underlying_transform_identifier=ManagedTransforms.Urns.ICEBERG_READ.urn,
+ update_compatibility_version="2.50.0")
+
+ external_transform = beam.SchemaAwareExternalTransform(
+ identifier=identifier,
+ expansion_service=expansion_service,
+ rearrange_based_on_discovery=True,
+ managed_replacement=managed_replacement,
+ **kwargs)
+ self.assertIsNotNone(external_transform._managed_payload_builder)
+
class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
def _verify_row(self, schema, row_payload, expected_values):
diff --git a/sdks/python/apache_beam/transforms/managed.py
b/sdks/python/apache_beam/transforms/managed.py
index 6113f953aed..609a27b3713 100644
--- a/sdks/python/apache_beam/transforms/managed.py
+++ b/sdks/python/apache_beam/transforms/managed.py
@@ -77,6 +77,8 @@ from typing import Optional
import yaml
from apache_beam.portability.common_urns import ManagedTransforms
+from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
+from apache_beam.transforms.external import
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external import SchemaAwareExternalTransform
from apache_beam.transforms.ptransform import PTransform
@@ -87,13 +89,6 @@ ICEBERG = "iceberg"
_ICEBERG_CDC = "iceberg_cdc"
KAFKA = "kafka"
BIGQUERY = "bigquery"
-_MANAGED_IDENTIFIER = "beam:transform:managed:v1"
-_EXPANSION_SERVICE_JAR_TARGETS = {
- "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG, _ICEBERG_CDC],
- "sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [
- BIGQUERY
- ]
-}
__all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"]
@@ -131,7 +126,7 @@ class Read(PTransform):
def expand(self, input):
return input | SchemaAwareExternalTransform(
- identifier=_MANAGED_IDENTIFIER,
+ identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
expansion_service=self._expansion_service,
rearrange_based_on_discovery=True,
transform_identifier=self._underlying_identifier,
@@ -175,7 +170,7 @@ class Write(PTransform):
def expand(self, input):
return input | SchemaAwareExternalTransform(
- identifier=_MANAGED_IDENTIFIER,
+ identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
expansion_service=self._expansion_service,
rearrange_based_on_discovery=True,
transform_identifier=self._underlying_identifier,
@@ -192,13 +187,11 @@ def _resolve_expansion_service(
if expansion_service:
return expansion_service
- default_target = None
- for gradle_target, transforms in _EXPANSION_SERVICE_JAR_TARGETS.items():
- if transform_name.lower() in transforms:
- default_target = gradle_target
- break
- if not default_target:
+ gradle_target = None
+ if identifier in MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING:
+ gradle_target = MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING.get(identifier)
+ if not gradle_target:
raise ValueError(
"No expansion service was specified and could not find a "
f"default expansion service for {transform_name}: '{identifier}'.")
- return BeamJarExpansionService(default_target)
+ return BeamJarExpansionService(gradle_target)
diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml
b/sdks/python/apache_beam/yaml/standard_io.yaml
index be652a300ef..c6713725025 100644
--- a/sdks/python/apache_beam/yaml/standard_io.yaml
+++ b/sdks/python/apache_beam/yaml/standard_io.yaml
@@ -84,6 +84,12 @@
'WriteToKafka': 'beam:schematransform:org.apache.beam:kafka_write:v1'
config:
gradle_target: 'sdks:java:io:expansion-service:shadowJar'
+ managed_replacement:
+ # Following transforms may be replaced with equivalent managed
transforms,
+ # if the pipelines 'updateCompatibilityBeamVersion' match the
provided
+ # version.
+ 'ReadFromKafka': '2.66.0'
+ 'WriteToKafka': '2.66.0'
# PubSub
- type: renaming
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 171f229746a..7537e80164c 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -48,6 +48,7 @@ import apache_beam as beam
import apache_beam.dataframe.io
import apache_beam.io
import apache_beam.transforms.util
+from apache_beam import ManagedReplacement
from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
from apache_beam.runners import pipeline_context
@@ -181,10 +182,20 @@ class ExternalProvider(Provider):
"""A Provider implemented via the cross language transform service."""
_provider_types: dict[str, Callable[..., Provider]] = {}
- def __init__(self, urns, service):
+ def __init__(self, urns, service, managed_replacement=None):
+ """Initializes the ExternalProvider.
+
+ Args:
+ urns: a set of URNs that uniquely identify the transforms supported.
+ service: the gradle target that identified the expansion service jar.
+ managed_replacement (Optional): a map that defines the transform for
+ which the SDK may replace the transform with an available managed
+ transform.
+ """
self._urns = urns
self._service = service
self._schema_transforms = None
+ self._managed_replacement = managed_replacement
def provided_transforms(self):
return self._urns.keys()
@@ -224,8 +235,18 @@ class ExternalProvider(Provider):
self._service = self._service()
urn = self._urns[type]
if urn in self.schema_transforms():
+ managed_replacement = None
+ if self._managed_replacement and type in self._managed_replacement:
+ managed_replacement = ManagedReplacement(
+ underlying_transform_identifier=urn,
+ update_compatibility_version=self._managed_replacement[type])
+
return external.SchemaAwareExternalTransform(
- urn, self._service, rearrange_based_on_discovery=True, **args)
+ urn,
+ self._service,
+ rearrange_based_on_discovery=True,
+ managed_replacement=managed_replacement,
+ **args)
else:
return type >> self.create_external_transform(urn, args)
@@ -318,14 +339,16 @@ def beam_jar(
urns,
*,
gradle_target,
+ managed_replacement=None,
appendix=None,
version=beam_version,
artifact_id=None):
return ExternalJavaProvider(
urns,
lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
- gradle_target=gradle_target, version=version,
artifact_id=artifact_id)
- )
+ gradle_target=gradle_target, version=version, artifact_id=artifact_id
+ ),
+ managed_replacement=managed_replacement)
@ExternalProvider.register_provider_type('docker')
@@ -357,11 +380,13 @@ class RemoteProvider(ExternalProvider):
class ExternalJavaProvider(ExternalProvider):
- def __init__(self, urns, jar_provider, classpath=None):
+ def __init__(
+ self, urns, jar_provider, managed_replacement=None, classpath=None):
super().__init__(
urns,
lambda: external.JavaJarExpansionService(
- jar_provider(), classpath=classpath))
+ jar_provider(), classpath=classpath),
+ managed_replacement)
self._jar_provider = jar_provider
self._classpath = classpath
diff --git a/sdks/python/gen_managed_doc.py b/sdks/python/gen_managed_doc.py
index d3d3f373d1e..85a7c73679a 100644
--- a/sdks/python/gen_managed_doc.py
+++ b/sdks/python/gen_managed_doc.py
@@ -89,6 +89,7 @@ _DOCUMENTATION_DESTINATION = os.path.join(
def generate_managed_doc(output_location):
+ from apache_beam.transforms.external import
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external_transform_provider import
ExternalTransform
from apache_beam.transforms.external_transform_provider import
ExternalTransformProvider
@@ -99,13 +100,16 @@ def generate_managed_doc(output_location):
with open(_MANAGED_CONFIG_ALIASES) as f:
all_config_aliases: dict = yaml.safe_load(f)
- services_and_names = managed._EXPANSION_SERVICE_JAR_TARGETS
+ # Creating a unique list of expansion service jars.
+ expansion_service_jar_targets = list(
+ dict.fromkeys(MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING.values()))
+
read_names_and_identifiers = managed.Read._READ_TRANSFORMS
write_names_and_identifiers = managed.Write._WRITE_TRANSFORMS
all_transforms = {}
- for gradle_target in services_and_names.keys():
+ for gradle_target in expansion_service_jar_targets:
provider =
ExternalTransformProvider(BeamJarExpansionService(gradle_target))
discovered: Dict[str, ExternalTransform] = provider.get_all()