This is an automated email from the ASF dual-hosted git repository.
BewareMyPower pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-python.git
The following commit(s) were added to refs/heads/main by this push:
new 7e03a4d feat: implement ProtobufNative schema (#299)
7e03a4d is described below
commit 7e03a4dd09ef1bc890d87d524aa11dff227edcbd
Author: Yunze Xu <[email protected]>
AuthorDate: Tue May 12 09:49:16 2026 +0800
feat: implement ProtobufNative schema (#299)
---
.github/workflows/ci-pr-validation.yaml | 2 +-
pulsar/schema/__init__.py | 1 +
pulsar/schema/schema_protobuf.py | 145 +++++++++++++++++++++++++++++++
setup.py | 11 ++-
src/enums.cc | 3 +-
tests/schema_test.py | 147 ++++++++++++++++++++++++++++++++
6 files changed, 305 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/ci-pr-validation.yaml
b/.github/workflows/ci-pr-validation.yaml
index df2417d..dfe6d67 100644
--- a/.github/workflows/ci-pr-validation.yaml
+++ b/.github/workflows/ci-pr-validation.yaml
@@ -78,7 +78,7 @@ jobs:
python3 -m pip install -U pip setuptools wheel requests
python3 setup.py bdist_wheel
WHEEL=$(find dist -name '*.whl')
- pip3 install ${WHEEL}[avro]
+ pip3 install ${WHEEL}[avro,protobuf]
- name: Run Oauth2 tests
run: |
diff --git a/pulsar/schema/__init__.py b/pulsar/schema/__init__.py
index efa6806..e3fa49e 100644
--- a/pulsar/schema/__init__.py
+++ b/pulsar/schema/__init__.py
@@ -22,3 +22,4 @@ from .definition import Record, Field, Null, Boolean,
Integer, Long, \
from .schema import Schema, BytesSchema, StringSchema, JsonSchema
from .schema_avro import AvroSchema
+from .schema_protobuf import ProtobufNativeSchema
diff --git a/pulsar/schema/schema_protobuf.py b/pulsar/schema/schema_protobuf.py
new file mode 100644
index 0000000..1852cd9
--- /dev/null
+++ b/pulsar/schema/schema_protobuf.py
@@ -0,0 +1,145 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import base64
+import _pulsar
+
+from .schema import Schema
+
+try:
+ from google.protobuf import descriptor_pb2
+ from google.protobuf.message import Message as ProtobufMessage
+ HAS_PROTOBUF = True
+except ImportError:
+ HAS_PROTOBUF = False
+
+
+def _collect_file_descriptors(file_descriptor, visited, file_descriptor_set):
+ """Recursively collect all FileDescriptorProto objects into
file_descriptor_set."""
+ if file_descriptor.name in visited:
+ return
+ for dep in file_descriptor.dependencies:
+ _collect_file_descriptors(dep, visited, file_descriptor_set)
+ visited.add(file_descriptor.name)
+ proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.CopyToProto(proto)
+ file_descriptor_set.file.append(proto)
+
+
+def _build_schema_definition(descriptor):
+ """
+ Build the schema definition dict used by Java's
``ProtobufNativeSchemaData``.
+
+ The returned mapping has these keys:
+
+ .. code-block:: text
+
+ fileDescriptorSet
+ rootMessageTypeName
+ rootFileDescriptorName
+
+ ``fileDescriptorSet`` contains base64-encoded ``FileDescriptorSet`` bytes.
+ This mirrors ``ProtobufNativeSchemaUtils.serialize()`` in the Java client.
+ """
+ file_descriptor_set = descriptor_pb2.FileDescriptorSet()
+ _collect_file_descriptors(descriptor.file, set(), file_descriptor_set)
+ file_descriptor_set_bytes = file_descriptor_set.SerializeToString()
+ return {
+ "fileDescriptorSet":
base64.b64encode(file_descriptor_set_bytes).decode('utf-8'),
+ "rootMessageTypeName": descriptor.full_name,
+ "rootFileDescriptorName": descriptor.file.name,
+ }
+
+
+if HAS_PROTOBUF:
+ class ProtobufNativeSchema(Schema):
+ """
+ Schema for protobuf messages using the native protobuf binary encoding.
+
+ The schema definition is stored as a JSON-encoded
ProtobufNativeSchemaData
+ (fileDescriptorSet, rootMessageTypeName, rootFileDescriptorName),
which is
+ compatible with the Java client's ProtobufNativeSchema.
+
+ Parameters
+ ----------
+ record_cls:
+ A generated protobuf message class (subclass of
google.protobuf.message.Message).
+
+ Example
+ -------
+ .. code-block:: python
+
+ import pulsar
+ from pulsar.schema import ProtobufNativeSchema
+ from my_proto_pb2 import MyMessage
+
+ client = pulsar.Client('pulsar://localhost:6650')
+ schema = ProtobufNativeSchema(MyMessage)
+ producer = client.create_producer('my-topic', schema=schema)
+ consumer = client.subscribe('my-topic', 'my-sub', schema=schema)
+
+ message = MyMessage()
+ message.field = 'value'
+ producer.send(message)
+
+ received = consumer.receive(timeout_millis=5000)
+ typed_value = received.value()
+ consumer.acknowledge(received)
+
+ assert isinstance(typed_value, MyMessage)
+ assert typed_value.field == 'value'
+
+ consumer.close()
+ producer.close()
+ client.close()
+ """
+
+ def __init__(self, record_cls):
+ if not (isinstance(record_cls, type) and issubclass(record_cls,
ProtobufMessage)):
+ raise TypeError(
+ f'record_cls must be a protobuf Message subclass, got
{record_cls!r}'
+ )
+ schema_definition = _build_schema_definition(record_cls.DESCRIPTOR)
+ super(ProtobufNativeSchema, self).__init__(
+ record_cls, _pulsar.SchemaType.PROTOBUF_NATIVE,
schema_definition, 'PROTOBUF_NATIVE'
+ )
+
+ def encode(self, obj):
+ self._validate_object_type(obj)
+ return obj.SerializeToString()
+
+ def decode(self, data):
+ return self._record_cls.FromString(data)
+
+ def __str__(self):
+ return f'ProtobufNativeSchema({self._record_cls.__name__})'
+
+else:
+ class ProtobufNativeSchema(Schema):
+ def __init__(self, _record_cls=None):
+ raise Exception(
+ "protobuf library support was not found. "
+ "Install it with: pip install protobuf"
+ )
+
+ def encode(self, obj):
+ pass
+
+ def decode(self, data):
+ pass
diff --git a/setup.py b/setup.py
index 76d929b..370205d 100755
--- a/setup.py
+++ b/setup.py
@@ -76,14 +76,21 @@ dependencies = [
extras_require = {}
+# protobuf schema dependencies
+extras_require["protobuf"] = sorted(
+ {
+ "protobuf>=6.33.6",
+ }
+)
+
# functions dependencies
extras_require["functions"] = sorted(
{
- "protobuf>=3.6.1",
"grpcio>=1.59.3",
"apache-bookkeeper-client>=4.16.1",
"prometheus_client",
- "ratelimit"
+ "ratelimit",
+ *extras_require["protobuf"],
}
)
diff --git a/src/enums.cc b/src/enums.cc
index 447d013..7ee28ea 100644
--- a/src/enums.cc
+++ b/src/enums.cc
@@ -115,7 +115,8 @@ void export_enums(py::module_& m) {
.value("AVRO", pulsar::AVRO)
.value("AUTO_CONSUME", pulsar::AUTO_CONSUME)
.value("AUTO_PUBLISH", pulsar::AUTO_PUBLISH)
- .value("KEY_VALUE", pulsar::KEY_VALUE);
+ .value("KEY_VALUE", pulsar::KEY_VALUE)
+ .value("PROTOBUF_NATIVE", pulsar::PROTOBUF_NATIVE);
enum_<InitialPosition>(m, "InitialPosition", "Supported initial position")
.value("Latest", InitialPositionLatest)
diff --git a/tests/schema_test.py b/tests/schema_test.py
index 9d031d1..42b6590 100755
--- a/tests/schema_test.py
+++ b/tests/schema_test.py
@@ -18,6 +18,7 @@
# under the License.
#
+import base64
import math
import requests
from typing import List
@@ -29,6 +30,67 @@ from pulsar.schema import *
from enum import Enum
import json
from fastavro.schema import load_schema
+from google.protobuf import descriptor_pb2, descriptor_pool, message_factory
+
+
+def _add_protobuf_field(message, name, number, field_type, type_name=None):
+ field = message.field.add()
+ field.name = name
+ field.number = number
+ field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ field.type = field_type
+ if type_name:
+ field.type_name = type_name
+
+
+def _get_message_classes(pool, message_names):
+ if hasattr(message_factory, 'GetMessageClass'):
+ return tuple(
+
message_factory.GetMessageClass(pool.FindMessageTypeByName(message_name))
+ for message_name in message_names
+ )
+ factory = message_factory.MessageFactory(pool)
+ return tuple(
+ factory.GetPrototype(pool.FindMessageTypeByName(message_name))
+ for message_name in message_names
+ )
+
+
+def _build_protobuf_test_messages():
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ file_proto.name = 'test_schema.proto'
+ file_proto.package = 'test'
+ file_proto.syntax = 'proto3'
+
+ test_message = file_proto.message_type.add()
+ test_message.name = 'TestMessage'
+ _add_protobuf_field(test_message, 'name', 1,
descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
+ _add_protobuf_field(test_message, 'value', 2,
descriptor_pb2.FieldDescriptorProto.TYPE_INT32)
+
+ nested_message = file_proto.message_type.add()
+ nested_message.name = 'TestMessageWithNested'
+ _add_protobuf_field(nested_message, 'str_field', 1,
descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
+ _add_protobuf_field(nested_message, 'int_field', 2,
descriptor_pb2.FieldDescriptorProto.TYPE_INT32)
+ _add_protobuf_field(nested_message, 'double_field', 3,
descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE)
+ _add_protobuf_field(
+ nested_message, 'nested', 4,
descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, '.test.TestInner'
+ )
+
+ inner_message = file_proto.message_type.add()
+ inner_message.name = 'TestInner'
+ _add_protobuf_field(inner_message, 'inner_str', 1,
descriptor_pb2.FieldDescriptorProto.TYPE_STRING)
+ _add_protobuf_field(inner_message, 'inner_int', 2,
descriptor_pb2.FieldDescriptorProto.TYPE_INT64)
+
+ pool = descriptor_pool.DescriptorPool()
+ pool.AddSerializedFile(file_proto.SerializeToString())
+ return _get_message_classes(
+ pool,
+ ('test.TestMessage', 'test.TestMessageWithNested', 'test.TestInner'),
+ )
+
+
+TestMessage, TestMessageWithNested, TestInner = _build_protobuf_test_messages()
+
class ExampleRecord(Record):
str_field = String()
@@ -1404,5 +1466,90 @@ class SchemaTest(TestCase):
client.close()
+class ProtobufNativeSchemaTest(TestCase):
+ """Unit tests for ProtobufNativeSchema (no Pulsar broker required)."""
+
+ def test_schema_type(self):
+ """Schema type must be PROTOBUF_NATIVE."""
+ import _pulsar
+ schema = ProtobufNativeSchema(TestMessage)
+ self.assertEqual(schema.schema_info().schema_type(),
_pulsar.SchemaType.PROTOBUF_NATIVE)
+
+ def test_schema_definition_keys(self):
+ """Schema definition JSON must contain the three required keys."""
+ schema = ProtobufNativeSchema(TestMessage)
+ schema_def = json.loads(schema.schema_info().schema())
+ self.assertIn('fileDescriptorSet', schema_def)
+ self.assertIn('rootMessageTypeName', schema_def)
+ self.assertIn('rootFileDescriptorName', schema_def)
+
+ def test_schema_definition_values(self):
+ """rootMessageTypeName and rootFileDescriptorName must match the
descriptor."""
+ schema = ProtobufNativeSchema(TestMessage)
+ schema_def = json.loads(schema.schema_info().schema())
+ self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage')
+ self.assertEqual(schema_def['rootFileDescriptorName'],
'test_schema.proto')
+
+ def test_file_descriptor_set_is_valid_base64_proto(self):
+ """fileDescriptorSet must be valid base64-encoded FileDescriptorSet
bytes."""
+ from google.protobuf import descriptor_pb2
+ schema = ProtobufNativeSchema(TestMessage)
+ schema_def = json.loads(schema.schema_info().schema())
+ raw = base64.b64decode(schema_def['fileDescriptorSet'])
+ fds = descriptor_pb2.FileDescriptorSet.FromString(raw)
+ file_names = [f.name for f in fds.file]
+ self.assertIn('test_schema.proto', file_names)
+
+ def test_encode_decode_roundtrip(self):
+ """encode then decode must reproduce the original message."""
+ schema = ProtobufNativeSchema(TestMessage)
+ original = TestMessage(name='hello', value=42)
+ encoded = schema.encode(original)
+ decoded = schema.decode(encoded)
+ self.assertEqual(decoded.name, 'hello')
+ self.assertEqual(decoded.value, 42)
+
+ def test_encode_produces_protobuf_binary(self):
+ """Encoded bytes must be valid protobuf binary (parseable by the class
directly)."""
+ schema = ProtobufNativeSchema(TestMessage)
+ msg = TestMessage(name='pulsar', value=100)
+ encoded = schema.encode(msg)
+ # Verify with protobuf's own parser
+ reparsed = TestMessage.FromString(encoded)
+ self.assertEqual(reparsed, msg)
+
+ def test_encode_decode_nested_message(self):
+ """encode/decode round-trip works for messages containing nested
message fields."""
+ schema = ProtobufNativeSchema(TestMessageWithNested)
+ original = TestMessageWithNested(
+ str_field='test',
+ int_field=7,
+ double_field=3.14,
+ nested=TestInner(inner_str='inner', inner_int=999),
+ )
+ decoded = schema.decode(schema.encode(original))
+ self.assertEqual(decoded.str_field, 'test')
+ self.assertEqual(decoded.int_field, 7)
+ self.assertAlmostEqual(decoded.double_field, 3.14)
+ self.assertEqual(decoded.nested.inner_str, 'inner')
+ self.assertEqual(decoded.nested.inner_int, 999)
+
+ def test_wrong_type_raises(self):
+ """Encoding an object of the wrong type must raise TypeError."""
+ schema = ProtobufNativeSchema(TestMessage)
+ with self.assertRaises(TypeError):
+ schema.encode("not a protobuf message")
+
+ def test_non_message_class_raises(self):
+ """Constructing with a non-Message class must raise TypeError."""
+ with self.assertRaises(TypeError):
+ ProtobufNativeSchema(str)
+
+ def test_schema_name(self):
+ """Schema name must be 'PROTOBUF_NATIVE'."""
+ schema = ProtobufNativeSchema(TestMessage)
+ self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE')
+
+
if __name__ == '__main__':
main()